cubecl_convolution/components/global/layout/
spatial.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::tensor::{
4    layout::{
5        Coordinates, Coords1d, Layout, LayoutExpand,
6        as_dyn::{IntoDyn, IntoDynExpand},
7    },
8    r#virtual::VirtualTensor,
9};
10
11use crate::components::Dimensionality;
12
13#[derive(CubeType, CubeLaunch, Clone)]
14pub struct NhwcCoords {
15    pub batch: u32,
16    pub spatial: Sequence<i32>,
17    pub channel: u32,
18}
19
20#[cube]
21impl IntoDyn for NhwcCoords {
22    fn into_dyn(self) -> Sequence<i32> {
23        let mut seq = Sequence::new();
24        seq.push(self.batch as i32);
25        for x in self.spatial {
26            seq.push(x);
27        }
28        seq.push(self.channel as i32);
29        seq
30    }
31}
32
33type NhwcTuple = (u32, Sequence<i32>, u32);
34
35#[cube]
36impl NhwcCoords {
37    pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
38        NhwcCoords {
39            batch,
40            spatial,
41            channel,
42        }
43    }
44
45    fn into_tuple(self) -> NhwcTuple {
46        (self.batch, self.spatial, self.channel)
47    }
48
49    fn from_tuple(tuple: NhwcTuple) -> Self {
50        NhwcCoords::new(tuple.0, tuple.1, tuple.2)
51    }
52}
53
54#[cube]
55impl Coordinates for NhwcCoords {
56    fn add(this: Self, other: Self) -> Self {
57        let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
58        NhwcCoords::from_tuple(tuple)
59    }
60
61    fn sub(this: Self, other: Self) -> Self {
62        let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
63        NhwcCoords::from_tuple(tuple)
64    }
65
66    fn min(this: Self, other: Self) -> Self {
67        let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
68        NhwcCoords::from_tuple(tuple)
69    }
70
71    fn max(this: Self, other: Self) -> Self {
72        let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
73        NhwcCoords::from_tuple(tuple)
74    }
75
76    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
77        NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
78    }
79
80    fn from_int(this: &Self, #[comptime] value: i64) -> Self {
81        let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
82        NhwcCoords::from_tuple(tuple)
83    }
84}
85
86/// Layout for a spatial (i.e. NHWC) tensor. Bounds check only applies to spatial dimensions, not
87/// channel or batch (because these are implicitly checked in the layouts used with spatial tensors).
88#[derive(CubeType, CubeLaunch, Clone)]
89pub struct NhwcLayout {
90    /// Stride for N
91    pub stride_batch: u32,
92    /// Strides for DHW
93    pub strides_spatial: Sequence<u32>,
94    /// Stride for C
95    pub stride_channel: u32,
96
97    /// Shape of N
98    pub shape_batch: u32,
99    /// Shape of DHW
100    pub shapes_spatial: Sequence<u32>,
101    /// Shape of C
102    pub shape_channel: u32,
103
104    #[cube(comptime)]
105    pub line_size: u32,
106    #[cube(comptime)]
107    pub check_spatial: bool,
108    #[cube(comptime)]
109    pub check_channel: bool,
110}
111
112#[cube]
113impl NhwcLayout {
114    pub fn new<E: Numeric, IO: Clone>(
115        tensor: VirtualTensor<E, IO>,
116        #[comptime] dim: Dimensionality,
117        #[comptime] check_spatial: bool,
118        #[comptime] check_channel: bool,
119    ) -> Self {
120        let spatial_dims = comptime![dim.num_dims()];
121        let mut strides_spatial = Sequence::new();
122        let mut shapes_spatial = Sequence::new();
123
124        #[unroll]
125        for i in 0..spatial_dims {
126            strides_spatial.push(tensor.stride(i + 1));
127            shapes_spatial.push(tensor.shape(i + 1));
128        }
129
130        let stride_batch = tensor.stride(0);
131        let stride_channel = tensor.stride(spatial_dims + 1);
132
133        let shape_batch = tensor.shape(0);
134        let shape_channel = tensor.shape(spatial_dims + 1);
135
136        NhwcLayout {
137            stride_batch,
138            strides_spatial,
139            stride_channel,
140            shape_batch,
141            shapes_spatial,
142            shape_channel,
143            line_size: tensor.line_size(),
144            check_spatial,
145            check_channel,
146        }
147    }
148}
149
150#[cube]
151impl Layout for NhwcLayout {
152    type Coordinates = NhwcCoords;
153    type SourceCoordinates = Coords1d;
154
155    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
156        let NhwcCoords {
157            batch,
158            spatial,
159            channel,
160        } = pos;
161
162        let spatial_dims = self.shapes_spatial.len();
163        let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
164
165        #[unroll]
166        for i in 0..spatial_dims {
167            read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
168        }
169
170        read_pos / self.line_size
171    }
172
173    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
174        (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
175    }
176
177    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
178        let mut in_bounds = true.runtime();
179        if comptime![self.check_spatial] {
180            let spatial_dims = self.shapes_spatial.len();
181
182            #[unroll]
183            for i in 0..spatial_dims {
184                let pos = *pos.spatial.index(i);
185                in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
186            }
187        }
188        if comptime![self.check_channel] {
189            in_bounds &= pos.channel < self.shape_channel;
190        }
191
192        in_bounds
193    }
194
195    fn shape(&self) -> Self::Coordinates {
196        NhwcCoords {
197            batch: self.shape_batch,
198            spatial: cast_seq(self.shapes_spatial.clone()),
199            channel: self.shape_channel,
200        }
201    }
202}
203
204#[cube]
205pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
206    seq: Sequence<From>,
207) -> Sequence<To> {
208    let num_elems = seq.len();
209    let mut out_seq = Sequence::new();
210    #[unroll]
211    for i in 0..num_elems {
212        let elem = To::cast_from(*seq.index(i));
213        out_seq.push(elem);
214    }
215    out_seq
216}
217
218impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
219    pub fn from_handle(
220        handle: &TensorHandleRef<'a, R>,
221        line_size: u32,
222        check_spatial: bool,
223        check_channel: bool,
224    ) -> Self {
225        let rank = handle.shape.len();
226        let dim_c = rank - 1;
227
228        let stride_batch = ScalarArg::new(handle.strides[0] as u32);
229        let strides_spatial = handle.strides[1..dim_c]
230            .iter()
231            .map(|s| ScalarArg::new(*s as u32))
232            .collect();
233        let stride_channel = ScalarArg::new(handle.strides[dim_c] as u32);
234
235        let shape_batch = ScalarArg::new(handle.shape[0] as u32);
236        let shapes_spatial = handle.shape[1..dim_c]
237            .iter()
238            .map(|s| ScalarArg::new(*s as u32))
239            .collect();
240        let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
241
242        Self::new(
243            stride_batch,
244            strides_spatial,
245            stride_channel,
246            shape_batch,
247            shapes_spatial,
248            shape_channel,
249            line_size,
250            check_spatial,
251            check_channel,
252        )
253    }
254}