cubecl_convolution/components/global/layout/
spatial.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::tensor::{
4    layout::{Coordinates, Coords1d, Layout, LayoutExpand},
5    r#virtual::VirtualTensor,
6};
7
8use crate::components::Dimensionality;
9
10#[derive(CubeType, Clone)]
11pub struct NhwcCoords {
12    pub batch: u32,
13    pub spatial: Sequence<i32>,
14    pub channel: u32,
15}
16
17type NhwcTuple = (u32, Sequence<i32>, u32);
18
19#[cube]
20impl NhwcCoords {
21    pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
22        NhwcCoords {
23            batch,
24            spatial,
25            channel,
26        }
27    }
28
29    fn into_tuple(self) -> NhwcTuple {
30        (self.batch, self.spatial, self.channel)
31    }
32
33    fn from_tuple(tuple: NhwcTuple) -> Self {
34        NhwcCoords::new(tuple.0, tuple.1, tuple.2)
35    }
36}
37
38impl NhwcCoordsExpand {
39    pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
40        NhwcCoordsExpand {
41            batch: self.batch.clone(),
42            spatial: self.spatial.clone(),
43            channel: self.channel.clone(),
44        }
45    }
46}
47
48#[cube]
49impl Coordinates for NhwcCoords {
50    fn add(this: Self, other: Self) -> Self {
51        let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
52        NhwcCoords::from_tuple(tuple)
53    }
54
55    fn sub(this: Self, other: Self) -> Self {
56        let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
57        NhwcCoords::from_tuple(tuple)
58    }
59
60    fn min(this: Self, other: Self) -> Self {
61        let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
62        NhwcCoords::from_tuple(tuple)
63    }
64
65    fn max(this: Self, other: Self) -> Self {
66        let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
67        NhwcCoords::from_tuple(tuple)
68    }
69
70    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
71        NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
72    }
73
74    fn from_int(this: &Self, #[comptime] value: i64) -> Self {
75        let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
76        NhwcCoords::from_tuple(tuple)
77    }
78}
79
80/// Layout for a spatial (i.e. NHWC) tensor. Bounds check only applies to spatial dimensions, not
81/// channel or batch (because these are implicitly checked in the layouts used with spatial tensors).
82#[derive(CubeType, CubeLaunch, Clone)]
83pub struct NhwcLayout {
84    /// Stride for N
85    pub stride_batch: u32,
86    /// Strides for DHW
87    pub strides_spatial: Sequence<u32>,
88    /// Stride for C
89    pub stride_channel: u32,
90
91    /// Shape of N
92    pub shape_batch: u32,
93    /// Shape of DHW
94    pub shapes_spatial: Sequence<u32>,
95    /// Shape of C
96    pub shape_channel: u32,
97
98    #[cube(comptime)]
99    pub line_size: u32,
100    #[cube(comptime)]
101    pub check_spatial: bool,
102}
103
104#[cube]
105impl NhwcLayout {
106    pub fn new<E: Numeric, IO: Clone>(
107        tensor: VirtualTensor<E, IO>,
108        #[comptime] dim: Dimensionality,
109        #[comptime] check_spatial: bool,
110    ) -> Self {
111        let spatial_dims = comptime![dim.num_dims()];
112        let mut strides_spatial = Sequence::new();
113        let mut shapes_spatial = Sequence::new();
114
115        #[unroll]
116        for i in 0..spatial_dims {
117            strides_spatial.push(tensor.stride(i + 1));
118            shapes_spatial.push(tensor.shape(i + 1));
119        }
120
121        let stride_batch = tensor.stride(0);
122        let stride_channel = tensor.stride(spatial_dims + 1);
123
124        let shape_batch = tensor.shape(0);
125        let shape_channel = tensor.shape(spatial_dims + 1);
126
127        NhwcLayout {
128            stride_batch,
129            strides_spatial,
130            stride_channel,
131            shape_batch,
132            shapes_spatial,
133            shape_channel,
134            line_size: tensor.line_size(),
135            check_spatial,
136        }
137    }
138}
139
140#[cube]
141impl Layout for NhwcLayout {
142    type Coordinates = NhwcCoords;
143    type SourceCoordinates = Coords1d;
144
145    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
146        let NhwcCoords {
147            batch,
148            spatial,
149            channel,
150        } = pos;
151
152        let spatial_dims = self.shapes_spatial.len();
153        let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
154
155        #[unroll]
156        for i in 0..spatial_dims {
157            read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
158        }
159
160        read_pos / self.line_size
161    }
162
163    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
164        (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
165    }
166
167    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
168        if comptime![self.check_spatial] {
169            let spatial_dims = self.shapes_spatial.len();
170            let mut spatial_in_bounds = true;
171
172            #[unroll]
173            for i in 0..spatial_dims {
174                let pos = *pos.spatial.index(i);
175                spatial_in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
176            }
177
178            spatial_in_bounds
179        } else {
180            true.runtime()
181        }
182    }
183
184    fn shape(&self) -> Self::Coordinates {
185        NhwcCoords {
186            batch: self.shape_batch,
187            spatial: cast_seq(self.shapes_spatial.clone()),
188            channel: self.shape_channel,
189        }
190    }
191}
192
193#[cube]
194pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
195    seq: Sequence<From>,
196) -> Sequence<To> {
197    let num_elems = seq.len();
198    let mut out_seq = Sequence::new();
199    #[unroll]
200    for i in 0..num_elems {
201        let elem = To::cast_from(*seq.index(i));
202        out_seq.push(elem);
203    }
204    out_seq
205}
206
207impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
208    pub fn from_handle(
209        handle: &TensorHandleRef<'a, R>,
210        line_size: u32,
211        check_spatial: bool,
212    ) -> Self {
213        let rank = handle.shape.len();
214        let dim_c = rank - 1;
215
216        let stride_batch = ScalarArg::new(handle.strides[0] as u32);
217        let strides_spatial = handle.strides[1..dim_c]
218            .iter()
219            .map(|s| ScalarArg::new(*s as u32))
220            .collect();
221        let stride_channel = ScalarArg::new(handle.strides[dim_c] as u32);
222
223        let shape_batch = ScalarArg::new(handle.shape[0] as u32);
224        let shapes_spatial = handle.shape[1..dim_c]
225            .iter()
226            .map(|s| ScalarArg::new(*s as u32))
227            .collect();
228        let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
229
230        Self::new(
231            stride_batch,
232            strides_spatial,
233            stride_channel,
234            shape_batch,
235            shapes_spatial,
236            shape_channel,
237            line_size,
238            check_spatial,
239        )
240    }
241}