cubecl_convolution/components/global/layout/
spatial.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::tensor::{
4 layout::{
5 Coordinates, Coords1d, Layout, LayoutExpand,
6 as_dyn::{IntoDyn, IntoDynExpand},
7 },
8 r#virtual::VirtualTensor,
9};
10
11use crate::components::Dimensionality;
12
13#[derive(CubeType, CubeLaunch, Clone)]
14pub struct NhwcCoords {
15 pub batch: u32,
16 pub spatial: Sequence<i32>,
17 pub channel: u32,
18}
19
20#[cube]
21impl IntoDyn for NhwcCoords {
22 fn into_dyn(self) -> Sequence<i32> {
23 let mut seq = Sequence::new();
24 seq.push(self.batch as i32);
25 for x in self.spatial {
26 seq.push(x);
27 }
28 seq.push(self.channel as i32);
29 seq
30 }
31}
32
33type NhwcTuple = (u32, Sequence<i32>, u32);
34
35#[cube]
36impl NhwcCoords {
37 pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
38 NhwcCoords {
39 batch,
40 spatial,
41 channel,
42 }
43 }
44
45 fn into_tuple(self) -> NhwcTuple {
46 (self.batch, self.spatial, self.channel)
47 }
48
49 fn from_tuple(tuple: NhwcTuple) -> Self {
50 NhwcCoords::new(tuple.0, tuple.1, tuple.2)
51 }
52}
53
54#[cube]
55impl Coordinates for NhwcCoords {
56 fn add(this: Self, other: Self) -> Self {
57 let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
58 NhwcCoords::from_tuple(tuple)
59 }
60
61 fn sub(this: Self, other: Self) -> Self {
62 let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
63 NhwcCoords::from_tuple(tuple)
64 }
65
66 fn min(this: Self, other: Self) -> Self {
67 let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
68 NhwcCoords::from_tuple(tuple)
69 }
70
71 fn max(this: Self, other: Self) -> Self {
72 let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
73 NhwcCoords::from_tuple(tuple)
74 }
75
76 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
77 NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
78 }
79
80 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
81 let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
82 NhwcCoords::from_tuple(tuple)
83 }
84}
85
86#[derive(CubeType, CubeLaunch, Clone)]
89pub struct NhwcLayout {
90 pub stride_batch: u32,
92 pub strides_spatial: Sequence<u32>,
94 pub stride_channel: u32,
96
97 pub shape_batch: u32,
99 pub shapes_spatial: Sequence<u32>,
101 pub shape_channel: u32,
103
104 #[cube(comptime)]
105 pub line_size: u32,
106 #[cube(comptime)]
107 pub check_spatial: bool,
108 #[cube(comptime)]
109 pub check_channel: bool,
110}
111
112#[cube]
113impl NhwcLayout {
114 pub fn new<E: Numeric, IO: Clone>(
115 tensor: VirtualTensor<E, IO>,
116 #[comptime] dim: Dimensionality,
117 #[comptime] check_spatial: bool,
118 #[comptime] check_channel: bool,
119 ) -> Self {
120 let spatial_dims = comptime![dim.num_dims()];
121 let mut strides_spatial = Sequence::new();
122 let mut shapes_spatial = Sequence::new();
123
124 #[unroll]
125 for i in 0..spatial_dims {
126 strides_spatial.push(tensor.stride(i + 1));
127 shapes_spatial.push(tensor.shape(i + 1));
128 }
129
130 let stride_batch = tensor.stride(0);
131 let stride_channel = tensor.stride(spatial_dims + 1);
132
133 let shape_batch = tensor.shape(0);
134 let shape_channel = tensor.shape(spatial_dims + 1);
135
136 NhwcLayout {
137 stride_batch,
138 strides_spatial,
139 stride_channel,
140 shape_batch,
141 shapes_spatial,
142 shape_channel,
143 line_size: tensor.line_size(),
144 check_spatial,
145 check_channel,
146 }
147 }
148}
149
150#[cube]
151impl Layout for NhwcLayout {
152 type Coordinates = NhwcCoords;
153 type SourceCoordinates = Coords1d;
154
155 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
156 let NhwcCoords {
157 batch,
158 spatial,
159 channel,
160 } = pos;
161
162 let spatial_dims = self.shapes_spatial.len();
163 let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
164
165 #[unroll]
166 for i in 0..spatial_dims {
167 read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
168 }
169
170 read_pos / self.line_size
171 }
172
173 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
174 (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
175 }
176
177 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
178 let mut in_bounds = true.runtime();
179 if comptime![self.check_spatial] {
180 let spatial_dims = self.shapes_spatial.len();
181
182 #[unroll]
183 for i in 0..spatial_dims {
184 let pos = *pos.spatial.index(i);
185 in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
186 }
187 }
188 if comptime![self.check_channel] {
189 in_bounds &= pos.channel < self.shape_channel;
190 }
191
192 in_bounds
193 }
194
195 fn shape(&self) -> Self::Coordinates {
196 NhwcCoords {
197 batch: self.shape_batch,
198 spatial: cast_seq(self.shapes_spatial.clone()),
199 channel: self.shape_channel,
200 }
201 }
202}
203
204#[cube]
205pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
206 seq: Sequence<From>,
207) -> Sequence<To> {
208 let num_elems = seq.len();
209 let mut out_seq = Sequence::new();
210 #[unroll]
211 for i in 0..num_elems {
212 let elem = To::cast_from(*seq.index(i));
213 out_seq.push(elem);
214 }
215 out_seq
216}
217
218impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
219 pub fn from_handle(
220 handle: &TensorHandleRef<'a, R>,
221 line_size: u32,
222 check_spatial: bool,
223 check_channel: bool,
224 ) -> Self {
225 let rank = handle.shape.len();
226 let dim_c = rank - 1;
227
228 let stride_batch = ScalarArg::new(handle.strides[0] as u32);
229 let strides_spatial = handle.strides[1..dim_c]
230 .iter()
231 .map(|s| ScalarArg::new(*s as u32))
232 .collect();
233 let stride_channel = ScalarArg::new(handle.strides[dim_c] as u32);
234
235 let shape_batch = ScalarArg::new(handle.shape[0] as u32);
236 let shapes_spatial = handle.shape[1..dim_c]
237 .iter()
238 .map(|s| ScalarArg::new(*s as u32))
239 .collect();
240 let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
241
242 Self::new(
243 stride_batch,
244 strides_spatial,
245 stride_channel,
246 shape_batch,
247 shapes_spatial,
248 shape_channel,
249 line_size,
250 check_spatial,
251 check_channel,
252 )
253 }
254}