cubecl_convolution/components/global/layout/
spatial.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::tensor::{
4 layout::{Coordinates, Coords1d, Layout, LayoutExpand},
5 r#virtual::VirtualTensor,
6};
7
8use crate::components::Dimensionality;
9
10#[derive(CubeType, Clone)]
11pub struct NhwcCoords {
12 pub batch: u32,
13 pub spatial: Sequence<i32>,
14 pub channel: u32,
15}
16
17type NhwcTuple = (u32, Sequence<i32>, u32);
18
19#[cube]
20impl NhwcCoords {
21 pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
22 NhwcCoords {
23 batch,
24 spatial,
25 channel,
26 }
27 }
28
29 fn into_tuple(self) -> NhwcTuple {
30 (self.batch, self.spatial, self.channel)
31 }
32
33 fn from_tuple(tuple: NhwcTuple) -> Self {
34 NhwcCoords::new(tuple.0, tuple.1, tuple.2)
35 }
36}
37
38impl NhwcCoordsExpand {
39 pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
40 NhwcCoordsExpand {
41 batch: self.batch.clone(),
42 spatial: self.spatial.clone(),
43 channel: self.channel.clone(),
44 }
45 }
46}
47
48#[cube]
49impl Coordinates for NhwcCoords {
50 fn add(this: Self, other: Self) -> Self {
51 let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
52 NhwcCoords::from_tuple(tuple)
53 }
54
55 fn sub(this: Self, other: Self) -> Self {
56 let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
57 NhwcCoords::from_tuple(tuple)
58 }
59
60 fn min(this: Self, other: Self) -> Self {
61 let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
62 NhwcCoords::from_tuple(tuple)
63 }
64
65 fn max(this: Self, other: Self) -> Self {
66 let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
67 NhwcCoords::from_tuple(tuple)
68 }
69
70 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
71 NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
72 }
73
74 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
75 let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
76 NhwcCoords::from_tuple(tuple)
77 }
78}
79
80#[derive(CubeType, CubeLaunch, Clone)]
83pub struct NhwcLayout {
84 pub stride_batch: u32,
86 pub strides_spatial: Sequence<u32>,
88 pub stride_channel: u32,
90
91 pub shape_batch: u32,
93 pub shapes_spatial: Sequence<u32>,
95 pub shape_channel: u32,
97
98 #[cube(comptime)]
99 pub line_size: u32,
100 #[cube(comptime)]
101 pub check_spatial: bool,
102}
103
104#[cube]
105impl NhwcLayout {
106 pub fn new<E: Numeric, IO: Clone>(
107 tensor: VirtualTensor<E, IO>,
108 #[comptime] dim: Dimensionality,
109 #[comptime] check_spatial: bool,
110 ) -> Self {
111 let spatial_dims = comptime![dim.num_dims()];
112 let mut strides_spatial = Sequence::new();
113 let mut shapes_spatial = Sequence::new();
114
115 #[unroll]
116 for i in 0..spatial_dims {
117 strides_spatial.push(tensor.stride(i + 1));
118 shapes_spatial.push(tensor.shape(i + 1));
119 }
120
121 let stride_batch = tensor.stride(0);
122 let stride_channel = tensor.stride(spatial_dims + 1);
123
124 let shape_batch = tensor.shape(0);
125 let shape_channel = tensor.shape(spatial_dims + 1);
126
127 NhwcLayout {
128 stride_batch,
129 strides_spatial,
130 stride_channel,
131 shape_batch,
132 shapes_spatial,
133 shape_channel,
134 line_size: tensor.line_size(),
135 check_spatial,
136 }
137 }
138}
139
140#[cube]
141impl Layout for NhwcLayout {
142 type Coordinates = NhwcCoords;
143 type SourceCoordinates = Coords1d;
144
145 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
146 let NhwcCoords {
147 batch,
148 spatial,
149 channel,
150 } = pos;
151
152 let spatial_dims = self.shapes_spatial.len();
153 let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
154
155 #[unroll]
156 for i in 0..spatial_dims {
157 read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
158 }
159
160 read_pos / self.line_size
161 }
162
163 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
164 (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
165 }
166
167 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
168 if comptime![self.check_spatial] {
169 let spatial_dims = self.shapes_spatial.len();
170 let mut spatial_in_bounds = true;
171
172 #[unroll]
173 for i in 0..spatial_dims {
174 let pos = *pos.spatial.index(i);
175 spatial_in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
176 }
177
178 spatial_in_bounds
179 } else {
180 true.runtime()
181 }
182 }
183
184 fn shape(&self) -> Self::Coordinates {
185 NhwcCoords {
186 batch: self.shape_batch,
187 spatial: cast_seq(self.shapes_spatial.clone()),
188 channel: self.shape_channel,
189 }
190 }
191}
192
193#[cube]
194pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
195 seq: Sequence<From>,
196) -> Sequence<To> {
197 let num_elems = seq.len();
198 let mut out_seq = Sequence::new();
199 #[unroll]
200 for i in 0..num_elems {
201 let elem = To::cast_from(*seq.index(i));
202 out_seq.push(elem);
203 }
204 out_seq
205}
206
207impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
208 pub fn from_handle(
209 handle: &TensorHandleRef<'a, R>,
210 line_size: u32,
211 check_spatial: bool,
212 ) -> Self {
213 let rank = handle.shape.len();
214 let dim_c = rank - 1;
215
216 let stride_batch = ScalarArg::new(handle.strides[0] as u32);
217 let strides_spatial = handle.strides[1..dim_c]
218 .iter()
219 .map(|s| ScalarArg::new(*s as u32))
220 .collect();
221 let stride_channel = ScalarArg::new(handle.strides[dim_c] as u32);
222
223 let shape_batch = ScalarArg::new(handle.shape[0] as u32);
224 let shapes_spatial = handle.shape[1..dim_c]
225 .iter()
226 .map(|s| ScalarArg::new(*s as u32))
227 .collect();
228 let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
229
230 Self::new(
231 stride_batch,
232 strides_spatial,
233 stride_channel,
234 shape_batch,
235 shapes_spatial,
236 shape_channel,
237 line_size,
238 check_spatial,
239 )
240 }
241}