cubek_convolution/components/global/layout/
spatial.rs1use cubecl::std::tensor::{
2 layout::{
3 Coordinates, Coords1d, Layout, LayoutExpand,
4 as_dyn::{IntoDyn, IntoDynExpand},
5 },
6 r#virtual::VirtualTensor,
7};
8use cubecl::{
9 prelude::*,
10 std::tensor::launch::{BufferArg, ViewLayoutLaunchArg},
11};
12use enumset::{EnumSet, EnumSetType};
13
14use crate::components::Dimensionality;
15
16#[derive(CubeType, CubeLaunch, Clone)]
17pub struct NhwcCoords {
18 pub batch: u32,
19 pub spatial: Sequence<i32>,
20 pub channel: u32,
21}
22
23#[cube]
24impl IntoDyn for NhwcCoords {
25 fn into_dyn(self) -> Sequence<i32> {
26 let mut seq = Sequence::new();
27 seq.push(self.batch as i32);
28 for x in self.spatial {
29 seq.push(x);
30 }
31 seq.push(self.channel as i32);
32 seq
33 }
34}
35
36type NhwcTuple = (u32, Sequence<i32>, u32);
37
38#[cube]
39impl NhwcCoords {
40 pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
41 NhwcCoords {
42 batch,
43 spatial,
44 channel,
45 }
46 }
47
48 fn into_tuple(self) -> NhwcTuple {
49 (self.batch, self.spatial, self.channel)
50 }
51
52 fn from_tuple(tuple: NhwcTuple) -> Self {
53 NhwcCoords::new(tuple.0, tuple.1, tuple.2)
54 }
55}
56
57#[cube]
58impl Coordinates for NhwcCoords {
59 fn add(this: Self, other: Self) -> Self {
60 let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
61 NhwcCoords::from_tuple(tuple)
62 }
63
64 fn sub(this: Self, other: Self) -> Self {
65 let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
66 NhwcCoords::from_tuple(tuple)
67 }
68
69 fn min(this: Self, other: Self) -> Self {
70 let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
71 NhwcCoords::from_tuple(tuple)
72 }
73
74 fn max(this: Self, other: Self) -> Self {
75 let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
76 NhwcCoords::from_tuple(tuple)
77 }
78
79 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
80 NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
81 }
82
83 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
84 let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
85 NhwcCoords::from_tuple(tuple)
86 }
87}
88
89#[derive(EnumSetType, Debug, Hash)]
90pub enum NhwcCheck {
91 Batch,
92 Spatial,
93 Channel,
94}
95
96#[derive(CubeType, Clone)]
99pub struct NhwcLayout {
100 pub stride_batch: usize,
102 pub strides_spatial: Sequence<usize>,
104 pub stride_channel: usize,
106
107 pub shape_batch: u32,
109 pub shapes_spatial: Sequence<u32>,
111 pub shape_channel: u32,
113
114 #[cube(comptime)]
115 pub vector_size: VectorSize,
116 #[cube(comptime)]
117 pub checks: EnumSet<NhwcCheck>,
118}
119
120#[cube]
121impl NhwcLayout {
122 pub fn new<E: Numeric, N: Size, IO: Clone>(
123 tensor: VirtualTensor<E, N, IO>,
124 #[comptime] dim: Dimensionality,
125 #[comptime] checks: EnumSet<NhwcCheck>,
126 ) -> Self {
127 let spatial_dims = dim.num_dims();
128 let mut strides_spatial = Sequence::new();
129 let mut shapes_spatial = Sequence::new();
130
131 #[unroll]
132 for i in 0..spatial_dims {
133 strides_spatial.push(tensor.stride(i + 1));
134 shapes_spatial.push(tensor.shape(i + 1) as u32);
135 }
136
137 let stride_batch = tensor.stride(0);
138 let stride_channel = tensor.stride(spatial_dims + 1);
139
140 let shape_batch = tensor.shape(0) as u32;
141 let shape_channel = tensor.shape(spatial_dims + 1) as u32;
142
143 NhwcLayout {
144 stride_batch,
145 strides_spatial,
146 stride_channel,
147 shape_batch,
148 shapes_spatial,
149 shape_channel,
150 vector_size: tensor.vector_size(),
151 checks,
152 }
153 }
154}
155
156#[cube]
157impl Layout for NhwcLayout {
158 type Coordinates = NhwcCoords;
159 type SourceCoordinates = Coords1d;
160
161 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
162 let NhwcCoords {
163 batch,
164 spatial,
165 channel,
166 } = pos;
167
168 let spatial_dims = self.shapes_spatial.len();
169 let mut read_pos =
170 batch as usize * self.stride_batch + channel as usize * self.stride_channel;
171
172 #[unroll]
173 for i in 0..spatial_dims {
174 read_pos += spatial[i] as usize * self.strides_spatial[i];
175 }
176
177 read_pos / self.vector_size
178 }
179
180 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
181 (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
182 }
183
184 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
185 let mut in_bounds = true.runtime();
186 if self.checks.comptime().contains(NhwcCheck::Batch) {
187 in_bounds &= pos.batch < self.shape_batch;
188 }
189 if self.checks.comptime().contains(NhwcCheck::Spatial) {
190 let spatial_dims = self.shapes_spatial.len();
191
192 #[unroll]
193 for i in 0..spatial_dims {
194 let pos = pos.spatial[i];
195 in_bounds &= pos >= 0 && (pos as u32) < self.shapes_spatial[i];
196 }
197 }
198 if self.checks.comptime().contains(NhwcCheck::Channel) {
199 in_bounds &= pos.channel < self.shape_channel;
200 }
201
202 in_bounds
203 }
204
205 fn shape(&self) -> Self::Coordinates {
206 NhwcCoords {
207 batch: self.shape_batch,
208 spatial: cast_seq(self.shapes_spatial.clone()),
209 channel: self.shape_channel,
210 }
211 }
212}
213
214#[cube]
215pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
216 seq: Sequence<From>,
217) -> Sequence<To> {
218 let num_elems = seq.len();
219 let mut out_seq = Sequence::new();
220 #[unroll]
221 for i in 0..num_elems {
222 let elem = To::cast_from(seq[i]);
223 out_seq.push(elem);
224 }
225 out_seq
226}
227
228pub struct NhwcLayoutLaunch {
229 checks: EnumSet<NhwcCheck>,
230}
231
232impl NhwcLayoutLaunch {
233 pub fn checked(checks: EnumSet<NhwcCheck>) -> Self {
234 Self { checks }
235 }
236
237 pub fn unchecked() -> Self {
238 Self {
239 checks: EnumSet::empty(),
240 }
241 }
242}
243
244#[derive_cube_comptime]
245pub struct NhwcLayoutCompilationArg {
246 pub spatial_rank: usize,
247 pub checks: EnumSet<NhwcCheck>,
248}
249
250impl ViewLayoutLaunchArg for NhwcLayout {
251 type RuntimeArg<R: Runtime> = NhwcLayoutLaunch;
252 type CompilationArg = NhwcLayoutCompilationArg;
253
254 fn register<R: Runtime, B: BufferArg>(
255 arg: Self::RuntimeArg<R>,
256 buffer: &B,
257 _: Type,
258 launcher: &mut KernelLauncher<R>,
259 ) -> Self::CompilationArg {
260 let shape = buffer.shape();
261 let strides = buffer.strides();
262
263 let rank = shape.len();
264 let dim_c = rank - 1;
265
266 let stride_batch = strides[0];
267 let strides_spatial = strides[1..dim_c].iter().copied().collect();
268 let stride_channel = strides[dim_c];
269
270 let shape_batch = shape[0] as u32;
271 let shapes_spatial = shape[1..dim_c].iter().map(|s| *s as u32).collect();
272 let shape_channel = shape[dim_c] as u32;
273
274 <usize as LaunchArg>::register(stride_batch, launcher);
275 <Sequence<usize> as LaunchArg>::register(strides_spatial, launcher);
276 <usize as LaunchArg>::register(stride_channel, launcher);
277 <u32 as LaunchArg>::register(shape_batch, launcher);
278 <Sequence<u32> as LaunchArg>::register(shapes_spatial, launcher);
279 <u32 as LaunchArg>::register(shape_channel, launcher);
280
281 NhwcLayoutCompilationArg {
282 spatial_rank: buffer.shape().len() - 2,
283 checks: arg.checks,
284 }
285 }
286
287 fn expand(
288 arg: &Self::CompilationArg,
289 ty: Type,
290 builder: &mut KernelBuilder,
291 ) -> <Self as CubeType>::ExpandType {
292 let strides_comp_arg = (0..arg.spatial_rank).map(|_| ()).collect();
293 let shape_comp_arg = (0..arg.spatial_rank).map(|_| ()).collect();
294 NhwcLayoutExpand {
295 stride_batch: <usize as LaunchArg>::expand(&(), builder),
296 strides_spatial: <Sequence<usize> as LaunchArg>::expand(&strides_comp_arg, builder),
297 stride_channel: <usize as LaunchArg>::expand(&(), builder),
298 shape_batch: <u32 as LaunchArg>::expand(&(), builder),
299 shapes_spatial: <Sequence<u32> as LaunchArg>::expand(&shape_comp_arg, builder),
300 shape_channel: <u32 as LaunchArg>::expand(&(), builder),
301 vector_size: ty.vector_size(),
302 checks: arg.checks,
303 }
304 }
305}