ndshape/
runtime_shape.rs

1use crate::Shape;
2
3#[derive(Clone)]
4pub struct RuntimeShape<C, const N: usize> {
5    array: [C; N],
6    strides: [C; N],
7    size: C,
8}
9
10macro_rules! impl_shape2 {
11    ($scalar:ident) => {
12        impl RuntimeShape<$scalar, 2> {
13            pub fn new([x, y]: [$scalar; 2]) -> Self {
14                Self {
15                    array: [x, y],
16                    strides: [1, x],
17                    size: x * y,
18                }
19            }
20        }
21
22        impl Shape<2> for RuntimeShape<$scalar, 2> {
23            type Coord = $scalar;
24
25            #[inline]
26            fn as_array(&self) -> [$scalar; 2] {
27                self.array
28            }
29
30            #[inline]
31            fn size(&self) -> $scalar {
32                self.size
33            }
34
35            #[inline]
36            fn usize(&self) -> usize {
37                self.size as usize
38            }
39
40            #[inline]
41            fn linearize(&self, p: [$scalar; 2]) -> $scalar {
42                p[0] + self.strides[1].wrapping_mul(p[1])
43            }
44
45            #[inline]
46            fn delinearize(&self, i: $scalar) -> [$scalar; 2] {
47                let y = i / self.strides[1];
48                let x = i % self.strides[1];
49                [x, y]
50            }
51        }
52    };
53}
54
55impl_shape2!(u8);
56impl_shape2!(u16);
57impl_shape2!(u32);
58impl_shape2!(u64);
59impl_shape2!(usize);
60
61impl_shape2!(i8);
62impl_shape2!(i16);
63impl_shape2!(i32);
64impl_shape2!(i64);
65
66macro_rules! impl_shape3 {
67    ($scalar:ident) => {
68        impl RuntimeShape<$scalar, 3> {
69            pub fn new([x, y, z]: [$scalar; 3]) -> Self {
70                Self {
71                    array: [x, y, z],
72                    strides: [1, x, x * y],
73                    size: x * y * z,
74                }
75            }
76        }
77
78        impl Shape<3> for RuntimeShape<$scalar, 3> {
79            type Coord = $scalar;
80
81            #[inline]
82            fn as_array(&self) -> [$scalar; 3] {
83                self.array
84            }
85
86            #[inline]
87            fn size(&self) -> $scalar {
88                self.size
89            }
90
91            #[inline]
92            fn usize(&self) -> usize {
93                self.size as usize
94            }
95
96            #[inline]
97            fn linearize(&self, p: [$scalar; 3]) -> $scalar {
98                p[0] + self.strides[1].wrapping_mul(p[1]) + self.strides[2].wrapping_mul(p[2])
99            }
100
101            #[inline]
102            fn delinearize(&self, mut i: $scalar) -> [$scalar; 3] {
103                let z = i / self.strides[2];
104                i -= z * self.strides[2];
105                let y = i / self.strides[1];
106                let x = i % self.strides[1];
107                [x, y, z]
108            }
109        }
110    };
111}
112
113impl_shape3!(u8);
114impl_shape3!(u16);
115impl_shape3!(u32);
116impl_shape3!(u64);
117impl_shape3!(usize);
118
119impl_shape3!(i8);
120impl_shape3!(i16);
121impl_shape3!(i32);
122impl_shape3!(i64);
123
124macro_rules! impl_shape4 {
125    ($scalar:ident) => {
126        impl RuntimeShape<$scalar, 4> {
127            pub fn new([x, y, z, w]: [$scalar; 4]) -> Self {
128                Self {
129                    array: [x, y, z, w],
130                    strides: [1, x, x * y, x * y * z],
131                    size: x * y * z * w,
132                }
133            }
134        }
135
136        impl Shape<4> for RuntimeShape<$scalar, 4> {
137            type Coord = $scalar;
138
139            #[inline]
140            fn as_array(&self) -> [$scalar; 4] {
141                self.array
142            }
143
144            #[inline]
145            fn size(&self) -> $scalar {
146                self.size
147            }
148
149            #[inline]
150            fn usize(&self) -> usize {
151                self.size as usize
152            }
153
154            #[inline]
155            fn linearize(&self, p: [$scalar; 4]) -> $scalar {
156                p[0] + self.strides[1].wrapping_mul(p[1])
157                    + self.strides[2].wrapping_mul(p[2])
158                    + self.strides[3].wrapping_mul(p[3])
159            }
160
161            #[inline]
162            fn delinearize(&self, mut i: $scalar) -> [$scalar; 4] {
163                let w = i / self.strides[3];
164                i -= w * self.strides[3];
165                let z = i / self.strides[2];
166                i -= z * self.strides[2];
167                let y = i / self.strides[1];
168                let x = i % self.strides[1];
169                [x, y, z, w]
170            }
171        }
172    };
173}
174
175impl_shape4!(u8);
176impl_shape4!(u16);
177impl_shape4!(u32);
178impl_shape4!(u64);
179impl_shape4!(usize);
180
181impl_shape4!(i8);
182impl_shape4!(i16);
183impl_shape4!(i32);
184impl_shape4!(i64);
185
186#[derive(Clone)]
187pub struct RuntimePow2Shape<C, const N: usize> {
188    array: [C; N],
189    shifts: [C; N],
190    masks: [C; N],
191    size: C,
192}
193
194macro_rules! impl_pow2_shape2 {
195    ($scalar:ty) => {
196        impl RuntimePow2Shape<$scalar, 2> {
197            pub fn new([x, y]: [$scalar; 2]) -> Self {
198                let y_shift = x;
199                Self {
200                    array: [1 << x, 1 << y],
201                    shifts: [0, y_shift],
202                    size: 1 << x + y,
203                    masks: [!(!0 << x), !(!0 << y) << y_shift],
204                }
205            }
206        }
207
208        impl Shape<2> for RuntimePow2Shape<$scalar, 2> {
209            type Coord = $scalar;
210
211            #[inline]
212            fn as_array(&self) -> [$scalar; 2] {
213                self.array
214            }
215
216            #[inline]
217            fn size(&self) -> $scalar {
218                self.size
219            }
220
221            #[inline]
222            fn usize(&self) -> usize {
223                self.size as usize
224            }
225
226            #[inline]
227            fn linearize(&self, p: [$scalar; 2]) -> $scalar {
228                (p[1] << self.shifts[1]) | p[0]
229            }
230
231            #[inline]
232            fn delinearize(&self, i: $scalar) -> [$scalar; 2] {
233                [i & self.masks[0], (i & self.masks[1]) >> self.shifts[1]]
234            }
235        }
236    };
237}
238
239impl_pow2_shape2!(u8);
240impl_pow2_shape2!(u16);
241impl_pow2_shape2!(u32);
242impl_pow2_shape2!(u64);
243impl_pow2_shape2!(usize);
244
245impl_pow2_shape2!(i8);
246impl_pow2_shape2!(i16);
247impl_pow2_shape2!(i32);
248impl_pow2_shape2!(i64);
249
250macro_rules! impl_pow2_shape3 {
251    ($scalar:ty) => {
252        impl RuntimePow2Shape<$scalar, 3> {
253            pub fn new([x, y, z]: [$scalar; 3]) -> Self {
254                let y_shift = x;
255                let z_shift = x + y;
256                Self {
257                    array: [1 << x, 1 << y, 1 << z],
258                    shifts: [0, y_shift, z_shift],
259                    masks: [!(!0 << x), !(!0 << y) << y_shift, !(!0 << z) << z_shift],
260                    size: 1 << x + y + z,
261                }
262            }
263        }
264
265        impl Shape<3> for RuntimePow2Shape<$scalar, 3> {
266            type Coord = $scalar;
267
268            #[inline]
269            fn as_array(&self) -> [$scalar; 3] {
270                self.array
271            }
272
273            #[inline]
274            fn size(&self) -> $scalar {
275                self.size
276            }
277
278            #[inline]
279            fn usize(&self) -> usize {
280                self.size as usize
281            }
282
283            #[inline]
284            fn linearize(&self, p: [$scalar; 3]) -> $scalar {
285                (p[2] << self.shifts[2]) | (p[1] << self.shifts[1]) | p[0]
286            }
287
288            #[inline]
289            fn delinearize(&self, i: $scalar) -> [$scalar; 3] {
290                [
291                    i & self.masks[0],
292                    (i & self.masks[1]) >> self.shifts[1],
293                    (i & self.masks[2]) >> self.shifts[2],
294                ]
295            }
296        }
297    };
298}
299
300impl_pow2_shape3!(u8);
301impl_pow2_shape3!(u16);
302impl_pow2_shape3!(u32);
303impl_pow2_shape3!(u64);
304impl_pow2_shape3!(usize);
305
306impl_pow2_shape3!(i8);
307impl_pow2_shape3!(i16);
308impl_pow2_shape3!(i32);
309impl_pow2_shape3!(i64);
310
311macro_rules! impl_pow2_shape4 {
312    ($scalar:ty) => {
313        impl RuntimePow2Shape<$scalar, 4> {
314            pub fn new([x, y, z, w]: [$scalar; 4]) -> Self {
315                let y_shift = x;
316                let z_shift = x + y;
317                let w_shift = x + y + z;
318                Self {
319                    array: [1 << x, 1 << y, 1 << z, 1 << w],
320                    size: 1 << (x + y + z + w),
321                    shifts: [0, y_shift, z_shift, w_shift],
322                    masks: [
323                        !(!0 << x),
324                        !(!0 << y) << y_shift,
325                        !(!0 << z) << z_shift,
326                        !(!0 << w) << w_shift,
327                    ],
328                }
329            }
330        }
331
332        impl Shape<4> for RuntimePow2Shape<$scalar, 4> {
333            type Coord = $scalar;
334
335            #[inline]
336            fn as_array(&self) -> [$scalar; 4] {
337                self.array
338            }
339
340            #[inline]
341            fn size(&self) -> $scalar {
342                self.size
343            }
344
345            #[inline]
346            fn usize(&self) -> usize {
347                self.size as usize
348            }
349
350            #[inline]
351            fn linearize(&self, p: [$scalar; 4]) -> $scalar {
352                (p[2] << self.shifts[2]) | (p[1] << self.shifts[1]) | p[0]
353            }
354
355            #[inline]
356            fn delinearize(&self, i: $scalar) -> [$scalar; 4] {
357                [
358                    i & self.masks[0],
359                    (i & self.masks[1]) >> self.shifts[1],
360                    (i & self.masks[2]) >> self.shifts[2],
361                    (i & self.masks[3]) >> self.shifts[3],
362                ]
363            }
364        }
365    };
366}
367
368impl_pow2_shape4!(u8);
369impl_pow2_shape4!(u16);
370impl_pow2_shape4!(u32);
371impl_pow2_shape4!(u64);
372impl_pow2_shape4!(usize);
373
374impl_pow2_shape4!(i8);
375impl_pow2_shape4!(i16);
376impl_pow2_shape4!(i32);
377impl_pow2_shape4!(i64);