1use crate::Shape;
2
3#[derive(Clone)]
4pub struct RuntimeShape<C, const N: usize> {
5 array: [C; N],
6 strides: [C; N],
7 size: C,
8}
9
10macro_rules! impl_shape2 {
11 ($scalar:ident) => {
12 impl RuntimeShape<$scalar, 2> {
13 pub fn new([x, y]: [$scalar; 2]) -> Self {
14 Self {
15 array: [x, y],
16 strides: [1, x],
17 size: x * y,
18 }
19 }
20 }
21
22 impl Shape<2> for RuntimeShape<$scalar, 2> {
23 type Coord = $scalar;
24
25 #[inline]
26 fn as_array(&self) -> [$scalar; 2] {
27 self.array
28 }
29
30 #[inline]
31 fn size(&self) -> $scalar {
32 self.size
33 }
34
35 #[inline]
36 fn usize(&self) -> usize {
37 self.size as usize
38 }
39
40 #[inline]
41 fn linearize(&self, p: [$scalar; 2]) -> $scalar {
42 p[0] + self.strides[1].wrapping_mul(p[1])
43 }
44
45 #[inline]
46 fn delinearize(&self, i: $scalar) -> [$scalar; 2] {
47 let y = i / self.strides[1];
48 let x = i % self.strides[1];
49 [x, y]
50 }
51 }
52 };
53}
54
55impl_shape2!(u8);
56impl_shape2!(u16);
57impl_shape2!(u32);
58impl_shape2!(u64);
59impl_shape2!(usize);
60
61impl_shape2!(i8);
62impl_shape2!(i16);
63impl_shape2!(i32);
64impl_shape2!(i64);
65
66macro_rules! impl_shape3 {
67 ($scalar:ident) => {
68 impl RuntimeShape<$scalar, 3> {
69 pub fn new([x, y, z]: [$scalar; 3]) -> Self {
70 Self {
71 array: [x, y, z],
72 strides: [1, x, x * y],
73 size: x * y * z,
74 }
75 }
76 }
77
78 impl Shape<3> for RuntimeShape<$scalar, 3> {
79 type Coord = $scalar;
80
81 #[inline]
82 fn as_array(&self) -> [$scalar; 3] {
83 self.array
84 }
85
86 #[inline]
87 fn size(&self) -> $scalar {
88 self.size
89 }
90
91 #[inline]
92 fn usize(&self) -> usize {
93 self.size as usize
94 }
95
96 #[inline]
97 fn linearize(&self, p: [$scalar; 3]) -> $scalar {
98 p[0] + self.strides[1].wrapping_mul(p[1]) + self.strides[2].wrapping_mul(p[2])
99 }
100
101 #[inline]
102 fn delinearize(&self, mut i: $scalar) -> [$scalar; 3] {
103 let z = i / self.strides[2];
104 i -= z * self.strides[2];
105 let y = i / self.strides[1];
106 let x = i % self.strides[1];
107 [x, y, z]
108 }
109 }
110 };
111}
112
113impl_shape3!(u8);
114impl_shape3!(u16);
115impl_shape3!(u32);
116impl_shape3!(u64);
117impl_shape3!(usize);
118
119impl_shape3!(i8);
120impl_shape3!(i16);
121impl_shape3!(i32);
122impl_shape3!(i64);
123
124macro_rules! impl_shape4 {
125 ($scalar:ident) => {
126 impl RuntimeShape<$scalar, 4> {
127 pub fn new([x, y, z, w]: [$scalar; 4]) -> Self {
128 Self {
129 array: [x, y, z, w],
130 strides: [1, x, x * y, x * y * z],
131 size: x * y * z * w,
132 }
133 }
134 }
135
136 impl Shape<4> for RuntimeShape<$scalar, 4> {
137 type Coord = $scalar;
138
139 #[inline]
140 fn as_array(&self) -> [$scalar; 4] {
141 self.array
142 }
143
144 #[inline]
145 fn size(&self) -> $scalar {
146 self.size
147 }
148
149 #[inline]
150 fn usize(&self) -> usize {
151 self.size as usize
152 }
153
154 #[inline]
155 fn linearize(&self, p: [$scalar; 4]) -> $scalar {
156 p[0] + self.strides[1].wrapping_mul(p[1])
157 + self.strides[2].wrapping_mul(p[2])
158 + self.strides[3].wrapping_mul(p[3])
159 }
160
161 #[inline]
162 fn delinearize(&self, mut i: $scalar) -> [$scalar; 4] {
163 let w = i / self.strides[3];
164 i -= w * self.strides[3];
165 let z = i / self.strides[2];
166 i -= z * self.strides[2];
167 let y = i / self.strides[1];
168 let x = i % self.strides[1];
169 [x, y, z, w]
170 }
171 }
172 };
173}
174
175impl_shape4!(u8);
176impl_shape4!(u16);
177impl_shape4!(u32);
178impl_shape4!(u64);
179impl_shape4!(usize);
180
181impl_shape4!(i8);
182impl_shape4!(i16);
183impl_shape4!(i32);
184impl_shape4!(i64);
185
186#[derive(Clone)]
187pub struct RuntimePow2Shape<C, const N: usize> {
188 array: [C; N],
189 shifts: [C; N],
190 masks: [C; N],
191 size: C,
192}
193
194macro_rules! impl_pow2_shape2 {
195 ($scalar:ty) => {
196 impl RuntimePow2Shape<$scalar, 2> {
197 pub fn new([x, y]: [$scalar; 2]) -> Self {
198 let y_shift = x;
199 Self {
200 array: [1 << x, 1 << y],
201 shifts: [0, y_shift],
202 size: 1 << x + y,
203 masks: [!(!0 << x), !(!0 << y) << y_shift],
204 }
205 }
206 }
207
208 impl Shape<2> for RuntimePow2Shape<$scalar, 2> {
209 type Coord = $scalar;
210
211 #[inline]
212 fn as_array(&self) -> [$scalar; 2] {
213 self.array
214 }
215
216 #[inline]
217 fn size(&self) -> $scalar {
218 self.size
219 }
220
221 #[inline]
222 fn usize(&self) -> usize {
223 self.size as usize
224 }
225
226 #[inline]
227 fn linearize(&self, p: [$scalar; 2]) -> $scalar {
228 (p[1] << self.shifts[1]) | p[0]
229 }
230
231 #[inline]
232 fn delinearize(&self, i: $scalar) -> [$scalar; 2] {
233 [i & self.masks[0], (i & self.masks[1]) >> self.shifts[1]]
234 }
235 }
236 };
237}
238
239impl_pow2_shape2!(u8);
240impl_pow2_shape2!(u16);
241impl_pow2_shape2!(u32);
242impl_pow2_shape2!(u64);
243impl_pow2_shape2!(usize);
244
245impl_pow2_shape2!(i8);
246impl_pow2_shape2!(i16);
247impl_pow2_shape2!(i32);
248impl_pow2_shape2!(i64);
249
250macro_rules! impl_pow2_shape3 {
251 ($scalar:ty) => {
252 impl RuntimePow2Shape<$scalar, 3> {
253 pub fn new([x, y, z]: [$scalar; 3]) -> Self {
254 let y_shift = x;
255 let z_shift = x + y;
256 Self {
257 array: [1 << x, 1 << y, 1 << z],
258 shifts: [0, y_shift, z_shift],
259 masks: [!(!0 << x), !(!0 << y) << y_shift, !(!0 << z) << z_shift],
260 size: 1 << x + y + z,
261 }
262 }
263 }
264
265 impl Shape<3> for RuntimePow2Shape<$scalar, 3> {
266 type Coord = $scalar;
267
268 #[inline]
269 fn as_array(&self) -> [$scalar; 3] {
270 self.array
271 }
272
273 #[inline]
274 fn size(&self) -> $scalar {
275 self.size
276 }
277
278 #[inline]
279 fn usize(&self) -> usize {
280 self.size as usize
281 }
282
283 #[inline]
284 fn linearize(&self, p: [$scalar; 3]) -> $scalar {
285 (p[2] << self.shifts[2]) | (p[1] << self.shifts[1]) | p[0]
286 }
287
288 #[inline]
289 fn delinearize(&self, i: $scalar) -> [$scalar; 3] {
290 [
291 i & self.masks[0],
292 (i & self.masks[1]) >> self.shifts[1],
293 (i & self.masks[2]) >> self.shifts[2],
294 ]
295 }
296 }
297 };
298}
299
300impl_pow2_shape3!(u8);
301impl_pow2_shape3!(u16);
302impl_pow2_shape3!(u32);
303impl_pow2_shape3!(u64);
304impl_pow2_shape3!(usize);
305
306impl_pow2_shape3!(i8);
307impl_pow2_shape3!(i16);
308impl_pow2_shape3!(i32);
309impl_pow2_shape3!(i64);
310
311macro_rules! impl_pow2_shape4 {
312 ($scalar:ty) => {
313 impl RuntimePow2Shape<$scalar, 4> {
314 pub fn new([x, y, z, w]: [$scalar; 4]) -> Self {
315 let y_shift = x;
316 let z_shift = x + y;
317 let w_shift = x + y + z;
318 Self {
319 array: [1 << x, 1 << y, 1 << z, 1 << w],
320 size: 1 << (x + y + z + w),
321 shifts: [0, y_shift, z_shift, w_shift],
322 masks: [
323 !(!0 << x),
324 !(!0 << y) << y_shift,
325 !(!0 << z) << z_shift,
326 !(!0 << w) << w_shift,
327 ],
328 }
329 }
330 }
331
332 impl Shape<4> for RuntimePow2Shape<$scalar, 4> {
333 type Coord = $scalar;
334
335 #[inline]
336 fn as_array(&self) -> [$scalar; 4] {
337 self.array
338 }
339
340 #[inline]
341 fn size(&self) -> $scalar {
342 self.size
343 }
344
345 #[inline]
346 fn usize(&self) -> usize {
347 self.size as usize
348 }
349
350 #[inline]
351 fn linearize(&self, p: [$scalar; 4]) -> $scalar {
352 (p[2] << self.shifts[2]) | (p[1] << self.shifts[1]) | p[0]
353 }
354
355 #[inline]
356 fn delinearize(&self, i: $scalar) -> [$scalar; 4] {
357 [
358 i & self.masks[0],
359 (i & self.masks[1]) >> self.shifts[1],
360 (i & self.masks[2]) >> self.shifts[2],
361 (i & self.masks[3]) >> self.shifts[3],
362 ]
363 }
364 }
365 };
366}
367
368impl_pow2_shape4!(u8);
369impl_pow2_shape4!(u16);
370impl_pow2_shape4!(u32);
371impl_pow2_shape4!(u64);
372impl_pow2_shape4!(usize);
373
374impl_pow2_shape4!(i8);
375impl_pow2_shape4!(i16);
376impl_pow2_shape4!(i32);
377impl_pow2_shape4!(i64);