cubek_convolution/components/global/layout/
spatial.rs

1use cubecl::prelude::*;
2use cubecl::std::tensor::{
3    layout::{
4        Coordinates, Coords1d, Layout, LayoutExpand,
5        as_dyn::{IntoDyn, IntoDynExpand},
6    },
7    r#virtual::VirtualTensor,
8};
9use enumset::{EnumSet, EnumSetType};
10
11use crate::components::Dimensionality;
12
13#[derive(CubeType, CubeLaunch, Clone)]
14pub struct NhwcCoords {
15    pub batch: u32,
16    pub spatial: Sequence<i32>,
17    pub channel: u32,
18}
19
20#[cube]
21impl IntoDyn for NhwcCoords {
22    fn into_dyn(self) -> Sequence<i32> {
23        let mut seq = Sequence::new();
24        seq.push(self.batch as i32);
25        for x in self.spatial {
26            seq.push(x);
27        }
28        seq.push(self.channel as i32);
29        seq
30    }
31}
32
33type NhwcTuple = (u32, Sequence<i32>, u32);
34
35#[cube]
36impl NhwcCoords {
37    pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
38        NhwcCoords {
39            batch,
40            spatial,
41            channel,
42        }
43    }
44
45    fn into_tuple(self) -> NhwcTuple {
46        (self.batch, self.spatial, self.channel)
47    }
48
49    fn from_tuple(tuple: NhwcTuple) -> Self {
50        NhwcCoords::new(tuple.0, tuple.1, tuple.2)
51    }
52}
53
54#[cube]
55impl Coordinates for NhwcCoords {
56    fn add(this: Self, other: Self) -> Self {
57        let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
58        NhwcCoords::from_tuple(tuple)
59    }
60
61    fn sub(this: Self, other: Self) -> Self {
62        let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
63        NhwcCoords::from_tuple(tuple)
64    }
65
66    fn min(this: Self, other: Self) -> Self {
67        let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
68        NhwcCoords::from_tuple(tuple)
69    }
70
71    fn max(this: Self, other: Self) -> Self {
72        let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
73        NhwcCoords::from_tuple(tuple)
74    }
75
76    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
77        NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
78    }
79
80    fn from_int(this: &Self, #[comptime] value: i64) -> Self {
81        let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
82        NhwcCoords::from_tuple(tuple)
83    }
84}
85
86#[derive(EnumSetType, Debug, Hash)]
87pub enum NhwcCheck {
88    Batch,
89    Spatial,
90    Channel,
91}
92
93/// Layout for a spatial (i.e. NHWC) tensor. Bounds check only applies to spatial dimensions, not
94/// channel or batch (because these are implicitly checked in the layouts used with spatial tensors).
95#[derive(CubeType, CubeLaunch, Clone)]
96pub struct NhwcLayout {
97    /// Stride for N
98    pub stride_batch: usize,
99    /// Strides for DHW
100    pub strides_spatial: Sequence<usize>,
101    /// Stride for C
102    pub stride_channel: usize,
103
104    /// Shape of N
105    pub shape_batch: u32,
106    /// Shape of DHW
107    pub shapes_spatial: Sequence<u32>,
108    /// Shape of C
109    pub shape_channel: u32,
110
111    #[cube(comptime)]
112    pub line_size: LineSize,
113    #[cube(comptime)]
114    pub checks: EnumSet<NhwcCheck>,
115}
116
117#[cube]
118impl NhwcLayout {
119    pub fn new<E: Numeric, IO: Clone>(
120        tensor: VirtualTensor<E, IO>,
121        #[comptime] dim: Dimensionality,
122        #[comptime] checks: EnumSet<NhwcCheck>,
123    ) -> Self {
124        let spatial_dims = dim.num_dims();
125        let mut strides_spatial = Sequence::new();
126        let mut shapes_spatial = Sequence::new();
127
128        #[unroll]
129        for i in 0..spatial_dims {
130            strides_spatial.push(tensor.stride(i + 1));
131            shapes_spatial.push(tensor.shape(i + 1) as u32);
132        }
133
134        let stride_batch = tensor.stride(0);
135        let stride_channel = tensor.stride(spatial_dims + 1);
136
137        let shape_batch = tensor.shape(0) as u32;
138        let shape_channel = tensor.shape(spatial_dims + 1) as u32;
139
140        NhwcLayout {
141            stride_batch,
142            strides_spatial,
143            stride_channel,
144            shape_batch,
145            shapes_spatial,
146            shape_channel,
147            line_size: tensor.line_size(),
148            checks,
149        }
150    }
151}
152
153#[cube]
154impl Layout for NhwcLayout {
155    type Coordinates = NhwcCoords;
156    type SourceCoordinates = Coords1d;
157
158    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
159        let NhwcCoords {
160            batch,
161            spatial,
162            channel,
163        } = pos;
164
165        let spatial_dims = self.shapes_spatial.len();
166        let mut read_pos =
167            batch as usize * self.stride_batch + channel as usize * self.stride_channel;
168
169        #[unroll]
170        for i in 0..spatial_dims {
171            read_pos += spatial[i] as usize * self.strides_spatial[i];
172        }
173
174        read_pos / self.line_size
175    }
176
177    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
178        (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
179    }
180
181    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
182        let mut in_bounds = true.runtime();
183        if self.checks.comptime().contains(NhwcCheck::Batch) {
184            in_bounds &= pos.batch < self.shape_batch;
185        }
186        if self.checks.comptime().contains(NhwcCheck::Spatial) {
187            let spatial_dims = self.shapes_spatial.len();
188
189            #[unroll]
190            for i in 0..spatial_dims {
191                let pos = pos.spatial[i];
192                in_bounds &= pos >= 0 && (pos as u32) < self.shapes_spatial[i];
193            }
194        }
195        if self.checks.comptime().contains(NhwcCheck::Channel) {
196            in_bounds &= pos.channel < self.shape_channel;
197        }
198
199        in_bounds
200    }
201
202    fn shape(&self) -> Self::Coordinates {
203        NhwcCoords {
204            batch: self.shape_batch,
205            spatial: cast_seq(self.shapes_spatial.clone()),
206            channel: self.shape_channel,
207        }
208    }
209}
210
211#[cube]
212pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
213    seq: Sequence<From>,
214) -> Sequence<To> {
215    let num_elems = seq.len();
216    let mut out_seq = Sequence::new();
217    #[unroll]
218    for i in 0..num_elems {
219        let elem = To::cast_from(seq[i]);
220        out_seq.push(elem);
221    }
222    out_seq
223}
224
225impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
226    pub fn from_handle(
227        handle: &TensorHandleRef<'a, R>,
228        line_size: LineSize,
229        checks: EnumSet<NhwcCheck>,
230    ) -> Self {
231        let rank = handle.shape.len();
232        let dim_c = rank - 1;
233
234        let stride_batch = ScalarArg::new(handle.strides[0]);
235        let strides_spatial = handle.strides[1..dim_c]
236            .iter()
237            .map(|s| ScalarArg::new(*s))
238            .collect();
239        let stride_channel = ScalarArg::new(handle.strides[dim_c]);
240
241        let shape_batch = ScalarArg::new(handle.shape[0] as u32);
242        let shapes_spatial = handle.shape[1..dim_c]
243            .iter()
244            .map(|s| ScalarArg::new(*s as u32))
245            .collect();
246        let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
247
248        Self::new(
249            stride_batch,
250            strides_spatial,
251            stride_channel,
252            shape_batch,
253            shapes_spatial,
254            shape_channel,
255            line_size,
256            checks,
257        )
258    }
259}