Skip to main content

mdarray/
shape.rs

1#[cfg(feature = "nightly")]
2use alloc::alloc::Allocator;
3#[cfg(not(feature = "std"))]
4use alloc::boxed::Box;
5#[cfg(not(feature = "std"))]
6use alloc::vec::Vec;
7
8use core::cmp::Ordering;
9use core::fmt::Debug;
10use core::hash::{Hash, Hasher};
11use core::slice;
12
13#[cfg(not(feature = "nightly"))]
14use crate::allocator::Allocator;
15use crate::buffer::{DynBuffer, Owned, StaticBuffer};
16use crate::dim::{Const, Dim, Dims, Dyn};
17use crate::layout::{Layout, Strided};
18
19/// Array shape trait.
20pub trait Shape: Clone + Debug + Default + Hash + Ord + Send + Sync {
21    /// First dimension.
22    type Head: Dim;
23
24    /// Shape excluding the first dimension.
25    type Tail: Shape;
26
27    /// Shape with the reverse ordering of dimensions.
28    type Reverse: Shape;
29
30    /// Prepend the dimension to the shape.
31    type Prepend<D: Dim>: Shape;
32
33    /// Corresponding shape with dynamically-sized dimensions.
34    type Dyn: Shape;
35
36    /// Merge each dimension pair, where constant size is preferred over dynamic.
37    /// The result has dynamic rank if at least one of the inputs has dynamic rank.
38    type Merge<S: Shape>: Shape;
39
40    /// Array buffer type.
41    type Buffer<T, A: Allocator>: Owned<Item = T, Shape = Self, Alloc = A>;
42
43    /// Select layout `L` for rank 0, or `Strided` for rank >0 or dynamic.
44    type Layout<L: Layout>: Layout;
45
46    #[doc(hidden)]
47    type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync>: Dims<T>;
48
49    /// Array rank if known statically, or `None` if dynamic.
50    const RANK: Option<usize>;
51
52    /// Returns the number of elements in the specified dimension.
53    ///
54    /// # Panics
55    ///
56    /// Panics if the dimension is out of bounds.
57    #[inline]
58    fn dim(&self, index: usize) -> usize {
59        assert!(index < self.rank(), "invalid dimension");
60
61        self.with_dims(|dims| dims[index])
62    }
63
64    /// Creates an array shape with the given dimensions.
65    ///
66    /// # Panics
67    ///
68    /// Panics if the dimensions are not matching static rank or constant-sized dimensions.
69    #[inline]
70    fn from_dims(dims: &[usize]) -> Self {
71        let mut shape = Self::new(dims.len());
72
73        shape.with_mut_dims(|dst| dst.copy_from_slice(dims));
74        shape
75    }
76
77    /// Returns `true` if the array contains no elements.
78    #[inline]
79    fn is_empty(&self) -> bool {
80        self.len() == 0
81    }
82
83    /// Returns the number of elements in the array.
84    #[inline]
85    fn len(&self) -> usize {
86        self.with_dims(|dims| dims.iter().product())
87    }
88
89    /// Returns the array rank, i.e. the number of dimensions.
90    #[inline]
91    fn rank(&self) -> usize {
92        self.with_dims(|dims| dims.len())
93    }
94
95    #[doc(hidden)]
96    fn new(rank: usize) -> Self;
97
98    #[doc(hidden)]
99    fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T;
100
101    #[doc(hidden)]
102    fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T;
103
104    #[doc(hidden)]
105    #[inline]
106    fn checked_len(&self) -> Option<usize> {
107        self.with_dims(|dims| dims.iter().try_fold(1usize, |acc, &x| acc.checked_mul(x)))
108    }
109
110    #[doc(hidden)]
111    #[inline]
112    fn prepend_dim<S: Shape>(&self, size: usize) -> S {
113        let mut shape = S::new(self.rank() + 1);
114
115        shape.with_mut_dims(|dims| {
116            dims[0] = size;
117            self.with_dims(|src| dims[1..].copy_from_slice(src));
118        });
119
120        shape
121    }
122
123    #[doc(hidden)]
124    #[inline]
125    fn remove_dim<S: Shape>(&self, index: usize) -> S {
126        assert!(index < self.rank(), "invalid dimension");
127
128        let mut shape = S::new(self.rank() - 1);
129
130        shape.with_mut_dims(|dims| {
131            self.with_dims(|src| {
132                dims[..index].copy_from_slice(&src[..index]);
133                dims[index..].copy_from_slice(&src[index + 1..]);
134            });
135        });
136
137        shape
138    }
139
140    #[doc(hidden)]
141    #[inline]
142    fn reshape<S: Shape>(&self, mut new_shape: S) -> S {
143        let mut inferred = None;
144
145        new_shape.with_mut_dims(|dims| {
146            for i in 0..dims.len() {
147                if dims[i] == usize::MAX {
148                    assert!(inferred.is_none(), "at most one dimension can be inferred");
149
150                    dims[i] = 1;
151                    inferred = Some(i);
152                }
153            }
154        });
155
156        let old_len = self.len();
157        let new_len = new_shape.checked_len().expect("invalid length");
158
159        if let Some(i) = inferred {
160            assert!(old_len.is_multiple_of(new_len), "length not divisible by the new dimensions");
161
162            new_shape.with_mut_dims(|dims| dims[i] = old_len / new_len);
163        } else {
164            assert!(new_len == old_len, "length must not change");
165        }
166
167        new_shape
168    }
169
170    #[doc(hidden)]
171    #[inline]
172    fn resize_dim<S: Shape>(&self, index: usize, new_size: usize) -> S {
173        assert!(index < self.rank(), "invalid dimension");
174
175        let mut shape = S::new(self.rank());
176
177        shape.with_mut_dims(|dims| {
178            self.with_dims(|src| dims[..].copy_from_slice(src));
179            dims[index] = new_size;
180        });
181
182        shape
183    }
184
185    #[doc(hidden)]
186    #[inline]
187    fn reverse(&self) -> Self::Reverse {
188        let mut shape = Self::Reverse::new(self.rank());
189
190        shape.with_mut_dims(|dims| {
191            self.with_dims(|src| dims.copy_from_slice(src));
192            dims.reverse();
193        });
194
195        shape
196    }
197}
198
199/// Trait for array shapes where all dimensions are constant-sized.
200pub trait ConstShape: Copy + Shape {
201    #[doc(hidden)]
202    type Inner<T>;
203
204    #[doc(hidden)]
205    type WithConst<T, const N: usize, A: Allocator>: Owned<Item = T, Shape = Self::Prepend<Const<N>>, Alloc = A>;
206}
207
208/// Conversion trait into an array shape.
209pub trait IntoShape {
210    /// Which kind of array shape are we turning this into?
211    type IntoShape: Shape;
212
213    /// Creates an array shape from a value.
214    fn into_shape(self) -> Self::IntoShape;
215
216    #[doc(hidden)]
217    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T;
218}
219
220/// Array shape type with dynamic rank.
221///
222/// If the rank is 0 or 1, no heap allocation is necessary. The default value
223/// will have rank 1 and contain no elements.
224pub enum DynRank {
225    /// Shape variant with dynamic rank.
226    Dyn(Box<[usize]>),
227    /// Shape variant with rank 1.
228    One(usize),
229}
230
231/// Array shape type with dynamically-sized dimensions.
232pub type Rank<const N: usize> = <[usize; N] as IntoShape>::IntoShape;
233
234impl DynRank {
235    /// Returns the number of elements in each dimension.
236    #[inline]
237    pub fn dims(&self) -> &[usize] {
238        match self {
239            Self::Dyn(dims) => dims,
240            Self::One(size) => slice::from_ref(size),
241        }
242    }
243}
244
245impl Clone for DynRank {
246    #[inline]
247    fn clone(&self) -> Self {
248        match self {
249            Self::Dyn(dims) => {
250                if dims.len() == 1 {
251                    Self::One(dims[0])
252                } else {
253                    Self::Dyn(dims.clone())
254                }
255            }
256            Self::One(dim) => Self::One(*dim),
257        }
258    }
259
260    #[inline]
261    fn clone_from(&mut self, source: &Self) {
262        if let Self::Dyn(dims) = self
263            && let Self::Dyn(src) = source
264            && dims.len() == src.len()
265        {
266            dims.clone_from_slice(src);
267
268            return;
269        }
270
271        *self = source.clone();
272    }
273}
274
275impl Debug for DynRank {
276    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
277        self.with_dims(|dims| f.debug_tuple("DynRank").field(&dims).finish())
278    }
279}
280
281impl Default for DynRank {
282    #[inline]
283    fn default() -> Self {
284        Self::One(0)
285    }
286}
287
288impl Eq for DynRank {}
289
290impl Hash for DynRank {
291    #[inline]
292    fn hash<H: Hasher>(&self, state: &mut H) {
293        self.with_dims(|dims| dims.hash(state))
294    }
295}
296
297impl Ord for DynRank {
298    #[inline]
299    fn cmp(&self, other: &Self) -> Ordering {
300        self.with_dims(|dims| other.with_dims(|other| dims.cmp(other)))
301    }
302}
303
304impl PartialEq for DynRank {
305    #[inline]
306    fn eq(&self, other: &Self) -> bool {
307        self.with_dims(|dims| other.with_dims(|other| dims.eq(other)))
308    }
309}
310
311impl PartialOrd for DynRank {
312    #[inline]
313    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
314        Some(self.cmp(other))
315    }
316}
317
318impl Shape for DynRank {
319    type Head = Dyn;
320    type Tail = Self;
321
322    type Reverse = Self;
323    type Prepend<D: Dim> = Self;
324
325    type Dyn = Self;
326    type Merge<S: Shape> = Self;
327
328    type Buffer<T, A: Allocator> = DynBuffer<T, Self, A>;
329    type Layout<L: Layout> = Strided;
330
331    type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = Box<[T]>;
332
333    const RANK: Option<usize> = None;
334
335    #[inline]
336    fn new(rank: usize) -> Self {
337        if rank == 1 { Self::One(0) } else { Self::Dyn(Dims::new(rank)) }
338    }
339
340    #[inline]
341    fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
342        let dims = match self {
343            Self::Dyn(dims) => dims,
344            Self::One(size) => slice::from_ref(size),
345        };
346
347        f(dims)
348    }
349
350    #[inline]
351    fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
352        let dims = match self {
353            Self::Dyn(dims) => dims,
354            Self::One(size) => slice::from_mut(size),
355        };
356
357        f(dims)
358    }
359}
360
361impl Shape for () {
362    type Head = Dyn;
363    type Tail = Self;
364
365    type Reverse = Self;
366    type Prepend<D: Dim> = (D,);
367
368    type Dyn = Self;
369    type Merge<S: Shape> = S;
370
371    type Buffer<T, A: Allocator> = StaticBuffer<T, Self, A>;
372    type Layout<L: Layout> = L;
373
374    type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = [T; 0];
375
376    const RANK: Option<usize> = Some(0);
377
378    #[inline]
379    fn new(rank: usize) {
380        assert!(rank == 0, "invalid rank");
381    }
382
383    #[inline]
384    fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
385        f(&[])
386    }
387
388    #[inline]
389    fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
390        f(&mut [])
391    }
392}
393
394impl<X: Dim> Shape for (X,) {
395    type Head = X;
396    type Tail = ();
397
398    type Reverse = Self;
399    type Prepend<D: Dim> = (D, X);
400
401    type Dyn = (Dyn,);
402    type Merge<S: Shape> = <S::Tail as Shape>::Prepend<X::Merge<S::Head>>;
403
404    type Buffer<T, A: Allocator> = X::Buffer<T, (), A>;
405    type Layout<L: Layout> = Strided;
406
407    type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = [T; 1];
408
409    const RANK: Option<usize> = Some(1);
410
411    #[inline]
412    fn new(rank: usize) -> Self {
413        assert!(rank == 1, "invalid rank");
414
415        Self::default()
416    }
417
418    #[inline]
419    fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
420        f(&[self.0.size()])
421    }
422
423    #[inline]
424    fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
425        let mut dims = [self.0.size()];
426        let value = f(&mut dims);
427
428        *self = (X::from_size(dims[0]),);
429
430        value
431    }
432}
433
434#[cfg(not(feature = "nightly"))]
435macro_rules! dyn_shape {
436    ($($yz:tt),+) => {
437        <<Self::Tail as Shape>::Dyn as Shape>::Prepend<Dyn>
438    };
439}
440
441#[cfg(feature = "nightly")]
442macro_rules! dyn_shape {
443    ($($yz:tt),+) => {
444        (Dyn $(,${ignore($yz)} Dyn)+)
445    };
446}
447
448macro_rules! impl_shape {
449    ($n:tt, ($($jk:tt),+), ($($yz:tt),+), $reverse:tt, $prepend:tt) => {
450        impl<X: Dim $(,$yz: Dim)+> Shape for (X $(,$yz)+) {
451            type Head = X;
452            type Tail = ($($yz,)+);
453
454            type Reverse = $reverse;
455            type Prepend<D: Dim> = $prepend;
456
457            type Dyn = dyn_shape!($($yz),+);
458            type Merge<S: Shape> =
459                <<Self::Tail as Shape>::Merge<S::Tail> as Shape>::Prepend<X::Merge<S::Head>>;
460
461            type Buffer<T, A: Allocator> = X::Buffer<T, Self::Tail, A>;
462            type Layout<L: Layout> = Strided;
463
464            type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = [T; $n];
465
466            const RANK: Option<usize> = Some($n);
467
468            #[inline]
469            fn new(rank: usize) -> Self {
470                assert!(rank == $n, "invalid rank");
471
472                Self::default()
473            }
474
475            #[inline]
476            fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
477                f(&[self.0.size() $(,self.$jk.size())+])
478            }
479
480            #[inline]
481            fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
482                let mut dims = [self.0.size() $(,self.$jk.size())+];
483                let value = f(&mut dims);
484
485                *self = (X::from_size(dims[0]) $(,$yz::from_size(dims[$jk]))+);
486
487                value
488            }
489        }
490    };
491}
492
493impl_shape!(2, (1), (Y), (Y, X), (D, X, Y));
494impl_shape!(3, (1, 2), (Y, Z), (Z, Y, X), (D, X, Y, Z));
495impl_shape!(4, (1, 2, 3), (Y, Z, W), (W, Z, Y, X), (D, X, Y, Z, W));
496impl_shape!(5, (1, 2, 3, 4), (Y, Z, W, U), (U, W, Z, Y, X), (D, X, Y, Z, W, U));
497impl_shape!(6, (1, 2, 3, 4, 5), (Y, Z, W, U, V), (V, U, W, Z, Y, X), DynRank);
498
499macro_rules! impl_const_shape {
500    (($($xyz:tt),*), $inner:ty, $with_const:tt) => {
501        impl<$(const $xyz: usize),*> ConstShape for ($(Const<$xyz>,)*) {
502            type Inner<T> = $inner;
503            type WithConst<T, const N: usize, A: Allocator> =
504                $with_const<T, Self::Prepend<Const<N>>, A>;
505        }
506    };
507}
508
509impl_const_shape!((), T, StaticBuffer);
510impl_const_shape!((X), [T; X], StaticBuffer);
511impl_const_shape!((X, Y), [[T; Y]; X], StaticBuffer);
512impl_const_shape!((X, Y, Z), [[[T; Z]; Y]; X], StaticBuffer);
513impl_const_shape!((X, Y, Z, W), [[[[T; W]; Z]; Y]; X], StaticBuffer);
514impl_const_shape!((X, Y, Z, W, U), [[[[[T; U]; W]; Z]; Y]; X], StaticBuffer);
515impl_const_shape!((X, Y, Z, W, U, V), [[[[[[T; V]; U]; W]; Z]; Y]; X], DynBuffer);
516
517impl<S: Shape> IntoShape for S {
518    type IntoShape = S;
519
520    #[inline]
521    fn into_shape(self) -> S {
522        self
523    }
524
525    #[inline]
526    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
527        self.with_dims(f)
528    }
529}
530
531impl<const N: usize> IntoShape for &[usize; N] {
532    type IntoShape = DynRank;
533
534    #[inline]
535    fn into_shape(self) -> DynRank {
536        Shape::from_dims(self)
537    }
538
539    #[inline]
540    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
541        f(self)
542    }
543}
544
545impl IntoShape for &[usize] {
546    type IntoShape = DynRank;
547
548    #[inline]
549    fn into_shape(self) -> DynRank {
550        Shape::from_dims(self)
551    }
552
553    #[inline]
554    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
555        f(self)
556    }
557}
558
559impl IntoShape for Box<[usize]> {
560    type IntoShape = DynRank;
561
562    #[inline]
563    fn into_shape(self) -> DynRank {
564        DynRank::Dyn(self)
565    }
566
567    #[inline]
568    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
569        f(&self)
570    }
571}
572
573impl<const N: usize> IntoShape for Const<N> {
574    type IntoShape = (Self,);
575
576    #[inline]
577    fn into_shape(self) -> Self::IntoShape {
578        (self,)
579    }
580
581    #[inline]
582    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
583        f(&[N])
584    }
585}
586
587impl IntoShape for Dyn {
588    type IntoShape = (Self,);
589
590    #[inline]
591    fn into_shape(self) -> Self::IntoShape {
592        (self,)
593    }
594
595    #[inline]
596    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
597        f(&[self])
598    }
599}
600
601impl IntoShape for Vec<usize> {
602    type IntoShape = DynRank;
603
604    #[inline]
605    fn into_shape(self) -> DynRank {
606        DynRank::Dyn(self.into())
607    }
608
609    #[inline]
610    fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
611        f(&self)
612    }
613}
614
615macro_rules! impl_into_shape {
616    ($n:tt, $shape:ty) => {
617        impl IntoShape for [usize; $n] {
618            type IntoShape = $shape;
619
620            #[inline]
621            fn into_shape(self) -> Self::IntoShape {
622                Shape::from_dims(&self)
623            }
624
625            #[inline]
626            fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
627                f(&self)
628            }
629        }
630    };
631}
632
633impl_into_shape!(0, ());
634impl_into_shape!(1, (Dyn,));
635impl_into_shape!(2, (Dyn, Dyn));
636impl_into_shape!(3, (Dyn, Dyn, Dyn));
637impl_into_shape!(4, (Dyn, Dyn, Dyn, Dyn));
638impl_into_shape!(5, (Dyn, Dyn, Dyn, Dyn, Dyn));
639impl_into_shape!(6, (Dyn, Dyn, Dyn, Dyn, Dyn, Dyn));