cubek_convolution/components/global/layout/
spatial.rs1use cubecl::prelude::*;
2use cubecl::std::tensor::{
3 layout::{
4 Coordinates, Coords1d, Layout, LayoutExpand,
5 as_dyn::{IntoDyn, IntoDynExpand},
6 },
7 r#virtual::VirtualTensor,
8};
9
10use crate::components::Dimensionality;
11
12#[derive(CubeType, CubeLaunch, Clone)]
13pub struct NhwcCoords {
14 pub batch: u32,
15 pub spatial: Sequence<i32>,
16 pub channel: u32,
17}
18
19#[cube]
20impl IntoDyn for NhwcCoords {
21 fn into_dyn(self) -> Sequence<i32> {
22 let mut seq = Sequence::new();
23 seq.push(self.batch as i32);
24 for x in self.spatial {
25 seq.push(x);
26 }
27 seq.push(self.channel as i32);
28 seq
29 }
30}
31
32type NhwcTuple = (u32, Sequence<i32>, u32);
33
34#[cube]
35impl NhwcCoords {
36 pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
37 NhwcCoords {
38 batch,
39 spatial,
40 channel,
41 }
42 }
43
44 fn into_tuple(self) -> NhwcTuple {
45 (self.batch, self.spatial, self.channel)
46 }
47
48 fn from_tuple(tuple: NhwcTuple) -> Self {
49 NhwcCoords::new(tuple.0, tuple.1, tuple.2)
50 }
51}
52
53#[cube]
54impl Coordinates for NhwcCoords {
55 fn add(this: Self, other: Self) -> Self {
56 let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
57 NhwcCoords::from_tuple(tuple)
58 }
59
60 fn sub(this: Self, other: Self) -> Self {
61 let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
62 NhwcCoords::from_tuple(tuple)
63 }
64
65 fn min(this: Self, other: Self) -> Self {
66 let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
67 NhwcCoords::from_tuple(tuple)
68 }
69
70 fn max(this: Self, other: Self) -> Self {
71 let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
72 NhwcCoords::from_tuple(tuple)
73 }
74
75 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
76 NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
77 }
78
79 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
80 let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
81 NhwcCoords::from_tuple(tuple)
82 }
83}
84
85#[derive(CubeType, CubeLaunch, Clone)]
88pub struct NhwcLayout {
89 pub stride_batch: u32,
91 pub strides_spatial: Sequence<u32>,
93 pub stride_channel: u32,
95
96 pub shape_batch: u32,
98 pub shapes_spatial: Sequence<u32>,
100 pub shape_channel: u32,
102
103 #[cube(comptime)]
104 pub line_size: u32,
105 #[cube(comptime)]
106 pub check_spatial: bool,
107 #[cube(comptime)]
108 pub check_channel: bool,
109}
110
111#[cube]
112impl NhwcLayout {
113 pub fn new<E: Numeric, IO: Clone>(
114 tensor: VirtualTensor<E, IO>,
115 #[comptime] dim: Dimensionality,
116 #[comptime] check_spatial: bool,
117 #[comptime] check_channel: bool,
118 ) -> Self {
119 let spatial_dims = comptime![dim.num_dims()];
120 let mut strides_spatial = Sequence::new();
121 let mut shapes_spatial = Sequence::new();
122
123 #[unroll]
124 for i in 0..spatial_dims {
125 strides_spatial.push(tensor.stride(i + 1));
126 shapes_spatial.push(tensor.shape(i + 1));
127 }
128
129 let stride_batch = tensor.stride(0);
130 let stride_channel = tensor.stride(spatial_dims + 1);
131
132 let shape_batch = tensor.shape(0);
133 let shape_channel = tensor.shape(spatial_dims + 1);
134
135 NhwcLayout {
136 stride_batch,
137 strides_spatial,
138 stride_channel,
139 shape_batch,
140 shapes_spatial,
141 shape_channel,
142 line_size: tensor.line_size(),
143 check_spatial,
144 check_channel,
145 }
146 }
147}
148
149#[cube]
150impl Layout for NhwcLayout {
151 type Coordinates = NhwcCoords;
152 type SourceCoordinates = Coords1d;
153
154 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
155 let NhwcCoords {
156 batch,
157 spatial,
158 channel,
159 } = pos;
160
161 let spatial_dims = self.shapes_spatial.len();
162 let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
163
164 #[unroll]
165 for i in 0..spatial_dims {
166 read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
167 }
168
169 read_pos / self.line_size
170 }
171
172 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
173 (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
174 }
175
176 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
177 let mut in_bounds = true.runtime();
178 if comptime![self.check_spatial] {
179 let spatial_dims = self.shapes_spatial.len();
180
181 #[unroll]
182 for i in 0..spatial_dims {
183 let pos = *pos.spatial.index(i);
184 in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
185 }
186 }
187 if comptime![self.check_channel] {
188 in_bounds &= pos.channel < self.shape_channel;
189 }
190
191 in_bounds
192 }
193
194 fn shape(&self) -> Self::Coordinates {
195 NhwcCoords {
196 batch: self.shape_batch,
197 spatial: cast_seq(self.shapes_spatial.clone()),
198 channel: self.shape_channel,
199 }
200 }
201}
202
203#[cube]
204pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
205 seq: Sequence<From>,
206) -> Sequence<To> {
207 let num_elems = seq.len();
208 let mut out_seq = Sequence::new();
209 #[unroll]
210 for i in 0..num_elems {
211 let elem = To::cast_from(*seq.index(i));
212 out_seq.push(elem);
213 }
214 out_seq
215}
216
217impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
218 pub fn from_handle(
219 handle: &TensorHandleRef<'a, R>,
220 line_size: u32,
221 check_spatial: bool,
222 check_channel: bool,
223 ) -> Self {
224 let rank = handle.shape.len();
225 let dim_c = rank - 1;
226
227 let stride_batch = ScalarArg::new(handle.strides[0] as u32);
228 let strides_spatial = handle.strides[1..dim_c]
229 .iter()
230 .map(|s| ScalarArg::new(*s as u32))
231 .collect();
232 let stride_channel = ScalarArg::new(handle.strides[dim_c] as u32);
233
234 let shape_batch = ScalarArg::new(handle.shape[0] as u32);
235 let shapes_spatial = handle.shape[1..dim_c]
236 .iter()
237 .map(|s| ScalarArg::new(*s as u32))
238 .collect();
239 let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
240
241 Self::new(
242 stride_batch,
243 strides_spatial,
244 stride_channel,
245 shape_batch,
246 shapes_spatial,
247 shape_channel,
248 line_size,
249 check_spatial,
250 check_channel,
251 )
252 }
253}