cubek_convolution/components/global/layout/
spatial.rs1use cubecl::prelude::*;
2use cubecl::std::tensor::{
3 layout::{
4 Coordinates, Coords1d, Layout, LayoutExpand,
5 as_dyn::{IntoDyn, IntoDynExpand},
6 },
7 r#virtual::VirtualTensor,
8};
9use enumset::{EnumSet, EnumSetType};
10
11use crate::components::Dimensionality;
12
13#[derive(CubeType, CubeLaunch, Clone)]
14pub struct NhwcCoords {
15 pub batch: u32,
16 pub spatial: Sequence<i32>,
17 pub channel: u32,
18}
19
20#[cube]
21impl IntoDyn for NhwcCoords {
22 fn into_dyn(self) -> Sequence<i32> {
23 let mut seq = Sequence::new();
24 seq.push(self.batch as i32);
25 for x in self.spatial {
26 seq.push(x);
27 }
28 seq.push(self.channel as i32);
29 seq
30 }
31}
32
33type NhwcTuple = (u32, Sequence<i32>, u32);
34
35#[cube]
36impl NhwcCoords {
37 pub fn new(batch: u32, spatial: Sequence<i32>, channel: u32) -> Self {
38 NhwcCoords {
39 batch,
40 spatial,
41 channel,
42 }
43 }
44
45 fn into_tuple(self) -> NhwcTuple {
46 (self.batch, self.spatial, self.channel)
47 }
48
49 fn from_tuple(tuple: NhwcTuple) -> Self {
50 NhwcCoords::new(tuple.0, tuple.1, tuple.2)
51 }
52}
53
54#[cube]
55impl Coordinates for NhwcCoords {
56 fn add(this: Self, other: Self) -> Self {
57 let tuple = NhwcTuple::add(this.into_tuple(), other.into_tuple());
58 NhwcCoords::from_tuple(tuple)
59 }
60
61 fn sub(this: Self, other: Self) -> Self {
62 let tuple = NhwcTuple::sub(this.into_tuple(), other.into_tuple());
63 NhwcCoords::from_tuple(tuple)
64 }
65
66 fn min(this: Self, other: Self) -> Self {
67 let tuple = <NhwcTuple as Coordinates>::min(this.into_tuple(), other.into_tuple());
68 NhwcCoords::from_tuple(tuple)
69 }
70
71 fn max(this: Self, other: Self) -> Self {
72 let tuple = <NhwcTuple as Coordinates>::max(this.into_tuple(), other.into_tuple());
73 NhwcCoords::from_tuple(tuple)
74 }
75
76 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
77 NhwcTuple::is_in_bounds(&pos.clone().into_tuple(), &bounds.clone().into_tuple())
78 }
79
80 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
81 let tuple = NhwcTuple::from_int(&this.clone().into_tuple(), value);
82 NhwcCoords::from_tuple(tuple)
83 }
84}
85
86#[derive(EnumSetType, Debug, Hash)]
87pub enum NhwcCheck {
88 Batch,
89 Spatial,
90 Channel,
91}
92
93#[derive(CubeType, CubeLaunch, Clone)]
96pub struct NhwcLayout {
97 pub stride_batch: usize,
99 pub strides_spatial: Sequence<usize>,
101 pub stride_channel: usize,
103
104 pub shape_batch: u32,
106 pub shapes_spatial: Sequence<u32>,
108 pub shape_channel: u32,
110
111 #[cube(comptime)]
112 pub line_size: LineSize,
113 #[cube(comptime)]
114 pub checks: EnumSet<NhwcCheck>,
115}
116
117#[cube]
118impl NhwcLayout {
119 pub fn new<E: Numeric, IO: Clone>(
120 tensor: VirtualTensor<E, IO>,
121 #[comptime] dim: Dimensionality,
122 #[comptime] checks: EnumSet<NhwcCheck>,
123 ) -> Self {
124 let spatial_dims = dim.num_dims();
125 let mut strides_spatial = Sequence::new();
126 let mut shapes_spatial = Sequence::new();
127
128 #[unroll]
129 for i in 0..spatial_dims {
130 strides_spatial.push(tensor.stride(i + 1));
131 shapes_spatial.push(tensor.shape(i + 1) as u32);
132 }
133
134 let stride_batch = tensor.stride(0);
135 let stride_channel = tensor.stride(spatial_dims + 1);
136
137 let shape_batch = tensor.shape(0) as u32;
138 let shape_channel = tensor.shape(spatial_dims + 1) as u32;
139
140 NhwcLayout {
141 stride_batch,
142 strides_spatial,
143 stride_channel,
144 shape_batch,
145 shapes_spatial,
146 shape_channel,
147 line_size: tensor.line_size(),
148 checks,
149 }
150 }
151}
152
153#[cube]
154impl Layout for NhwcLayout {
155 type Coordinates = NhwcCoords;
156 type SourceCoordinates = Coords1d;
157
158 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
159 let NhwcCoords {
160 batch,
161 spatial,
162 channel,
163 } = pos;
164
165 let spatial_dims = self.shapes_spatial.len();
166 let mut read_pos =
167 batch as usize * self.stride_batch + channel as usize * self.stride_channel;
168
169 #[unroll]
170 for i in 0..spatial_dims {
171 read_pos += spatial[i] as usize * self.strides_spatial[i];
172 }
173
174 read_pos / self.line_size
175 }
176
177 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
178 (self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
179 }
180
181 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
182 let mut in_bounds = true.runtime();
183 if self.checks.comptime().contains(NhwcCheck::Batch) {
184 in_bounds &= pos.batch < self.shape_batch;
185 }
186 if self.checks.comptime().contains(NhwcCheck::Spatial) {
187 let spatial_dims = self.shapes_spatial.len();
188
189 #[unroll]
190 for i in 0..spatial_dims {
191 let pos = pos.spatial[i];
192 in_bounds &= pos >= 0 && (pos as u32) < self.shapes_spatial[i];
193 }
194 }
195 if self.checks.comptime().contains(NhwcCheck::Channel) {
196 in_bounds &= pos.channel < self.shape_channel;
197 }
198
199 in_bounds
200 }
201
202 fn shape(&self) -> Self::Coordinates {
203 NhwcCoords {
204 batch: self.shape_batch,
205 spatial: cast_seq(self.shapes_spatial.clone()),
206 channel: self.shape_channel,
207 }
208 }
209}
210
211#[cube]
212pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
213 seq: Sequence<From>,
214) -> Sequence<To> {
215 let num_elems = seq.len();
216 let mut out_seq = Sequence::new();
217 #[unroll]
218 for i in 0..num_elems {
219 let elem = To::cast_from(seq[i]);
220 out_seq.push(elem);
221 }
222 out_seq
223}
224
225impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> {
226 pub fn from_handle(
227 handle: &TensorHandleRef<'a, R>,
228 line_size: LineSize,
229 checks: EnumSet<NhwcCheck>,
230 ) -> Self {
231 let rank = handle.shape.len();
232 let dim_c = rank - 1;
233
234 let stride_batch = ScalarArg::new(handle.strides[0]);
235 let strides_spatial = handle.strides[1..dim_c]
236 .iter()
237 .map(|s| ScalarArg::new(*s))
238 .collect();
239 let stride_channel = ScalarArg::new(handle.strides[dim_c]);
240
241 let shape_batch = ScalarArg::new(handle.shape[0] as u32);
242 let shapes_spatial = handle.shape[1..dim_c]
243 .iter()
244 .map(|s| ScalarArg::new(*s as u32))
245 .collect();
246 let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32);
247
248 Self::new(
249 stride_batch,
250 strides_spatial,
251 stride_channel,
252 shape_batch,
253 shapes_spatial,
254 shape_channel,
255 line_size,
256 checks,
257 )
258 }
259}