cubek_convolution/components/global/layout/
spatial.rs

1use cubecl::prelude::*;
2use cubecl::std::tensor::{
3    layout::{
4        Coordinates, Coords1d, Layout, LayoutExpand,
5        as_dyn::{IntoDyn, IntoDynExpand},
6    },
7    r#virtual::VirtualTensor,
8};
9
10use crate::components::Dimensionality;
11
12#[derive(CubeType, CubeLaunch, Clone)]
13pub struct NhwcCoords {
14    pub batch: u32,
15    pub spatial: Sequence<i32>,
16    pub channel: u32,
17}
18
19#[cube]
20impl IntoDyn for NhwcCoords {
21    fn into_dyn(self) -> Sequence<i32> {
22        let mut seq = Sequence::new();
23        seq.push(self.batch as i32);
24        for x in self.spatial {
25            seq.push(x);
26        }
27        seq.push(self.channel as i32);
28        seq
29    }
30}
31
32type NhwcTuple = (u32, Sequence<i32>, u32);
33
34#[cube]
35impl NhwcCoords {
36    pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
37        NhwcCoords {
38            batch,
39            spatial,
40            channel,
41        }
42    }
43
44    fn into_tuple(self) -> NhwcTuple {
45        (self.batch, self.spatial, self.channel)
46    }
47
48    fn from_tuple(tuple: NhwcTuple) -> Self {
49        NhwcCoords::new(tuple.0, tuple.1, tuple.2)
50    }
51}
52
53#[cube]
54impl Coordinates for NhwcCoords {
55    fn add(this: Self, other: Self) -> Self {
56        let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
57        NhwcCoords::from_tuple(tuple)
58    }
59
60    fn sub(this: Self, other: Self) -> Self {
61        let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
62        NhwcCoords::from_tuple(tuple)
63    }
64
65    fn min(this: Self, other: Self) -> Self {
66        let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
67        NhwcCoords::from_tuple(tuple)
68    }
69
70    fn max(this: Self, other: Self) -> Self {
71        let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
72        NhwcCoords::from_tuple(tuple)
73    }
74
75    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
76        NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
77    }
78
79    fn from_int(this: &Self, #[comptime] value: i64) -> Self {
80        let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
81        NhwcCoords::from_tuple(tuple)
82    }
83}
84
85/// Layout for a spatial (i.e. NHWC) tensor. Bounds check only applies to spatial dimensions, not
86/// channel or batch (because these are implicitly checked in the layouts used with spatial tensors).
87#[derive(CubeType, CubeLaunch, Clone)]
88pub struct NhwcLayout {
89    /// Stride for N
90    pub stride_batch: u32,
91    /// Strides for DHW
92    pub strides_spatial: Sequence<u32>,
93    /// Stride for C
94    pub stride_channel: u32,
95
96    /// Shape of N
97    pub shape_batch: u32,
98    /// Shape of DHW
99    pub shapes_spatial: Sequence<u32>,
100    /// Shape of C
101    pub shape_channel: u32,
102
103    #[cube(comptime)]
104    pub line_size: u32,
105    #[cube(comptime)]
106    pub check_spatial: bool,
107    #[cube(comptime)]
108    pub check_channel: bool,
109}
110
111#[cube]
112impl NhwcLayout {
113    pub fn new<E: Numeric, IO: Clone>(
114        tensor: VirtualTensor<E, IO>,
115        #[comptime] dim: Dimensionality,
116        #[comptime] check_spatial: bool,
117        #[comptime] check_channel: bool,
118    ) -> Self {
119        let spatial_dims = comptime![dim.num_dims()];
120        let mut strides_spatial = Sequence::new();
121        let mut shapes_spatial = Sequence::new();
122
123        #[unroll]
124        for i in 0..spatial_dims {
125            strides_spatial.push(tensor.stride(i + 1));
126            shapes_spatial.push(tensor.shape(i + 1));
127        }
128
129        let stride_batch = tensor.stride(0);
130        let stride_channel = tensor.stride(spatial_dims + 1);
131
132        let shape_batch = tensor.shape(0);
133        let shape_channel = tensor.shape(spatial_dims + 1);
134
135        NhwcLayout {
136            stride_batch,
137            strides_spatial,
138            stride_channel,
139            shape_batch,
140            shapes_spatial,
141            shape_channel,
142            line_size: tensor.line_size(),
143            check_spatial,
144            check_channel,
145        }
146    }
147}
148
149#[cube]
150impl Layout for NhwcLayout {
151    type Coordinates = NhwcCoords;
152    type SourceCoordinates = Coords1d;
153
154    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
155        let NhwcCoords {
156            batch,
157            spatial,
158            channel,
159        } = pos;
160
161        let spatial_dims = self.shapes_spatial.len();
162        let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
163
164        #[unroll]
165        for i in 0..spatial_dims {
166            read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
167        }
168
169        read_pos / self.line_size
170    }
171
172    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
173        (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
174    }
175
176    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
177        let mut in_bounds = true.runtime();
178        if comptime![self.check_spatial] {
179            let spatial_dims = self.shapes_spatial.len();
180
181            #[unroll]
182            for i in 0..spatial_dims {
183                let pos = *pos.spatial.index(i);
184                in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
185            }
186        }
187        if comptime![self.check_channel] {
188            in_bounds &= pos.channel < self.shape_channel;
189        }
190
191        in_bounds
192    }
193
194    fn shape(&self) -> Self::Coordinates {
195        NhwcCoords {
196            batch: self.shape_batch,
197            spatial: cast_seq(self.shapes_spatial.clone()),
198            channel: self.shape_channel,
199        }
200    }
201}
202
203#[cube]
204pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
205    seq: Sequence<From>,
206) -> Sequence<To> {
207    let num_elems = seq.len();
208    let mut out_seq = Sequence::new();
209    #[unroll]
210    for i in 0..num_elems {
211        let elem = To::cast_from(*seq.index(i));
212        out_seq.push(elem);
213    }
214    out_seq
215}
216
217impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
218    pub fn from_handle(
219        handle: &TensorHandleRef<'a, R>,
220        line_size: u32,
221        check_spatial: bool,
222        check_channel: bool,
223    ) -> Self {
224        let rank = handle.shape.len();
225        let dim_c = rank - 1;
226
227        let stride_batch = ScalarArg::new(handle.strides[0] as u32);
228        let strides_spatial = handle.strides[1..dim_c]
229            .iter()
230            .map(|s| ScalarArg::new(*s as u32))
231            .collect();
232        let stride_channel = ScalarArg::new(handle.strides[dim_c] as u32);
233
234        let shape_batch = ScalarArg::new(handle.shape[0] as u32);
235        let shapes_spatial = handle.shape[1..dim_c]
236            .iter()
237            .map(|s| ScalarArg::new(*s as u32))
238            .collect();
239        let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
240
241        Self::new(
242            stride_batch,
243            strides_spatial,
244            stride_channel,
245            shape_batch,
246            shapes_spatial,
247            shape_channel,
248            line_size,
249            check_spatial,
250            check_channel,
251        )
252    }
253}