Skip to main content

cubek_convolution/components/global/layout/
spatial.rs

1use cubecl::std::tensor::{
2    layout::{
3        Coordinates, Coords1d, Layout, LayoutExpand,
4        as_dyn::{IntoDyn, IntoDynExpand},
5    },
6    r#virtual::VirtualTensor,
7};
8use cubecl::{
9    prelude::*,
10    std::tensor::launch::{BufferArg, ViewLayoutLaunchArg},
11};
12use enumset::{EnumSet, EnumSetType};
13
14use crate::components::Dimensionality;
15
16#[derive(CubeType, CubeLaunch, Clone)]
17pub struct NhwcCoords {
18    pub batch: u32,
19    pub spatial: Sequence<i32>,
20    pub channel: u32,
21}
22
23#[cube]
24impl IntoDyn for NhwcCoords {
25    fn into_dyn(self) -> Sequence<i32> {
26        let mut seq = Sequence::new();
27        seq.push(self.batch as i32);
28        for x in self.spatial {
29            seq.push(x);
30        }
31        seq.push(self.channel as i32);
32        seq
33    }
34}
35
36type NhwcTuple = (u32, Sequence<i32>, u32);
37
38#[cube]
39impl NhwcCoords {
40    pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
41        NhwcCoords {
42            batch,
43            spatial,
44            channel,
45        }
46    }
47
48    fn into_tuple(self) -> NhwcTuple {
49        (self.batch, self.spatial, self.channel)
50    }
51
52    fn from_tuple(tuple: NhwcTuple) -> Self {
53        NhwcCoords::new(tuple.0, tuple.1, tuple.2)
54    }
55}
56
57#[cube]
58impl Coordinates for NhwcCoords {
59    fn add(this: Self, other: Self) -> Self {
60        let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
61        NhwcCoords::from_tuple(tuple)
62    }
63
64    fn sub(this: Self, other: Self) -> Self {
65        let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
66        NhwcCoords::from_tuple(tuple)
67    }
68
69    fn min(this: Self, other: Self) -> Self {
70        let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
71        NhwcCoords::from_tuple(tuple)
72    }
73
74    fn max(this: Self, other: Self) -> Self {
75        let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
76        NhwcCoords::from_tuple(tuple)
77    }
78
79    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
80        NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
81    }
82
83    fn from_int(this: &Self, #[comptime] value: i64) -> Self {
84        let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
85        NhwcCoords::from_tuple(tuple)
86    }
87}
88
89#[derive(EnumSetType, Debug, Hash)]
90pub enum NhwcCheck {
91    Batch,
92    Spatial,
93    Channel,
94}
95
96/// Layout for a spatial (i.e. NHWC) tensor. Bounds check only applies to spatial dimensions, not
97/// channel or batch (because these are implicitly checked in the layouts used with spatial tensors).
98#[derive(CubeType, Clone)]
99pub struct NhwcLayout {
100    /// Stride for N
101    pub stride_batch: usize,
102    /// Strides for DHW
103    pub strides_spatial: Sequence<usize>,
104    /// Stride for C
105    pub stride_channel: usize,
106
107    /// Shape of N
108    pub shape_batch: u32,
109    /// Shape of DHW
110    pub shapes_spatial: Sequence<u32>,
111    /// Shape of C
112    pub shape_channel: u32,
113
114    #[cube(comptime)]
115    pub vector_size: VectorSize,
116    #[cube(comptime)]
117    pub checks: EnumSet<NhwcCheck>,
118}
119
120#[cube]
121impl NhwcLayout {
122    pub fn new<E: Numeric, N: Size, IO: Clone>(
123        tensor: VirtualTensor<E, N, IO>,
124        #[comptime] dim: Dimensionality,
125        #[comptime] checks: EnumSet<NhwcCheck>,
126    ) -> Self {
127        let spatial_dims = dim.num_dims();
128        let mut strides_spatial = Sequence::new();
129        let mut shapes_spatial = Sequence::new();
130
131        #[unroll]
132        for i in 0..spatial_dims {
133            strides_spatial.push(tensor.stride(i + 1));
134            shapes_spatial.push(tensor.shape(i + 1) as u32);
135        }
136
137        let stride_batch = tensor.stride(0);
138        let stride_channel = tensor.stride(spatial_dims + 1);
139
140        let shape_batch = tensor.shape(0) as u32;
141        let shape_channel = tensor.shape(spatial_dims + 1) as u32;
142
143        NhwcLayout {
144            stride_batch,
145            strides_spatial,
146            stride_channel,
147            shape_batch,
148            shapes_spatial,
149            shape_channel,
150            vector_size: tensor.vector_size(),
151            checks,
152        }
153    }
154}
155
156#[cube]
157impl Layout for NhwcLayout {
158    type Coordinates = NhwcCoords;
159    type SourceCoordinates = Coords1d;
160
161    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
162        let NhwcCoords {
163            batch,
164            spatial,
165            channel,
166        } = pos;
167
168        let spatial_dims = self.shapes_spatial.len();
169        let mut read_pos =
170            batch as usize * self.stride_batch + channel as usize * self.stride_channel;
171
172        #[unroll]
173        for i in 0..spatial_dims {
174            read_pos += spatial[i] as usize * self.strides_spatial[i];
175        }
176
177        read_pos / self.vector_size
178    }
179
180    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
181        (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
182    }
183
184    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
185        let mut in_bounds = true.runtime();
186        if self.checks.comptime().contains(NhwcCheck::Batch) {
187            in_bounds &= pos.batch < self.shape_batch;
188        }
189        if self.checks.comptime().contains(NhwcCheck::Spatial) {
190            let spatial_dims = self.shapes_spatial.len();
191
192            #[unroll]
193            for i in 0..spatial_dims {
194                let pos = pos.spatial[i];
195                in_bounds &= pos >= 0 && (pos as u32) < self.shapes_spatial[i];
196            }
197        }
198        if self.checks.comptime().contains(NhwcCheck::Channel) {
199            in_bounds &= pos.channel < self.shape_channel;
200        }
201
202        in_bounds
203    }
204
205    fn shape(&self) -> Self::Coordinates {
206        NhwcCoords {
207            batch: self.shape_batch,
208            spatial: cast_seq(self.shapes_spatial.clone()),
209            channel: self.shape_channel,
210        }
211    }
212}
213
214#[cube]
215pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
216    seq: Sequence<From>,
217) -> Sequence<To> {
218    let num_elems = seq.len();
219    let mut out_seq = Sequence::new();
220    #[unroll]
221    for i in 0..num_elems {
222        let elem = To::cast_from(seq[i]);
223        out_seq.push(elem);
224    }
225    out_seq
226}
227
228pub struct NhwcLayoutLaunch {
229    checks: EnumSet<NhwcCheck>,
230}
231
232impl NhwcLayoutLaunch {
233    pub fn checked(checks: EnumSet<NhwcCheck>) -> Self {
234        Self { checks }
235    }
236
237    pub fn unchecked() -> Self {
238        Self {
239            checks: EnumSet::empty(),
240        }
241    }
242}
243
244#[derive_cube_comptime]
245pub struct NhwcLayoutCompilationArg {
246    pub spatial_rank: usize,
247    pub checks: EnumSet<NhwcCheck>,
248}
249
250impl ViewLayoutLaunchArg for NhwcLayout {
251    type RuntimeArg<R: Runtime> = NhwcLayoutLaunch;
252    type CompilationArg = NhwcLayoutCompilationArg;
253
254    fn register<R: Runtime, B: BufferArg>(
255        arg: Self::RuntimeArg<R>,
256        buffer: &B,
257        _: Type,
258        launcher: &mut KernelLauncher<R>,
259    ) -> Self::CompilationArg {
260        let shape = buffer.shape();
261        let strides = buffer.strides();
262
263        let rank = shape.len();
264        let dim_c = rank - 1;
265
266        let stride_batch = strides[0];
267        let strides_spatial = strides[1..dim_c].iter().copied().collect();
268        let stride_channel = strides[dim_c];
269
270        let shape_batch = shape[0] as u32;
271        let shapes_spatial = shape[1..dim_c].iter().map(|s| *s as u32).collect();
272        let shape_channel = shape[dim_c] as u32;
273
274        <usize as LaunchArg>::register(stride_batch, launcher);
275        <Sequence<usize> as LaunchArg>::register(strides_spatial, launcher);
276        <usize as LaunchArg>::register(stride_channel, launcher);
277        <u32 as LaunchArg>::register(shape_batch, launcher);
278        <Sequence<u32> as LaunchArg>::register(shapes_spatial, launcher);
279        <u32 as LaunchArg>::register(shape_channel, launcher);
280
281        NhwcLayoutCompilationArg {
282            spatial_rank: buffer.shape().len() - 2,
283            checks: arg.checks,
284        }
285    }
286
287    fn expand(
288        arg: &Self::CompilationArg,
289        ty: Type,
290        builder: &mut KernelBuilder,
291    ) -> <Self as CubeType>::ExpandType {
292        let strides_comp_arg = (0..arg.spatial_rank).map(|_| ()).collect();
293        let shape_comp_arg = (0..arg.spatial_rank).map(|_| ()).collect();
294        NhwcLayoutExpand {
295            stride_batch: <usize as LaunchArg>::expand(&(), builder),
296            strides_spatial: <Sequence<usize> as LaunchArg>::expand(&strides_comp_arg, builder),
297            stride_channel: <usize as LaunchArg>::expand(&(), builder),
298            shape_batch: <u32 as LaunchArg>::expand(&(), builder),
299            shapes_spatial: <Sequence<u32> as LaunchArg>::expand(&shape_comp_arg, builder),
300            shape_channel: <u32 as LaunchArg>::expand(&(), builder),
301            vector_size: ty.vector_size(),
302            checks: arg.checks,
303        }
304    }
305}