mdarray/
dim.rs

1#[cfg(not(feature = "std"))]
2use alloc::boxed::Box;
3#[cfg(not(feature = "std"))]
4use alloc::vec;
5
6use core::fmt::{self, Debug, Formatter};
7use core::hash::Hash;
8
9use crate::shape::Shape;
10use crate::tensor::Tensor;
11use crate::traits::Owned;
12
13/// Array dimension trait.
14pub trait Dim: Copy + Debug + Default + Hash + Ord + Send + Sync {
15    /// Merge dimensions, where constant size is preferred over dynamic.
16    type Merge<D: Dim>: Dim;
17
18    #[doc(hidden)]
19    type Owned<T, S: Shape>: Owned<T, S::Prepend<Self>>;
20
21    /// Dimension size if known statically, or `None` if dynamic.
22    const SIZE: Option<usize>;
23
24    /// Creates an array dimension with the given size.
25    ///
26    /// # Panics
27    ///
28    /// Panics if the size is not matching a constant-sized dimension.
29    fn from_size(size: usize) -> Self;
30
31    /// Returns the number of elements in the dimension.
32    fn size(self) -> usize;
33}
34
35#[allow(unreachable_pub)]
36pub trait Dims<T: Copy + Debug + Default + Eq + Hash + Send + Sync>:
37    AsMut<[T]>
38    + AsRef<[T]>
39    + Clone
40    + Debug
41    + Default
42    + Eq
43    + Hash
44    + Send
45    + Sync
46    + for<'a> TryFrom<&'a [T], Error: Debug>
47{
48    fn new(len: usize) -> Self;
49}
50
51/// Type-level constant.
52#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
53pub struct Const<const N: usize>;
54
55/// Dynamically-sized dimension type.
56pub type Dyn = usize;
57
58impl<const N: usize> Debug for Const<N> {
59    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
60        f.debug_tuple("Const").field(&N).finish()
61    }
62}
63
64impl<const N: usize> Dim for Const<N> {
65    type Merge<D: Dim> = Self;
66    type Owned<T, S: Shape> = <S::Owned<T> as Owned<T, S>>::WithConst<N>;
67
68    const SIZE: Option<usize> = Some(N);
69
70    #[inline]
71    fn from_size(size: usize) -> Self {
72        assert!(size == N, "invalid size");
73
74        Self
75    }
76
77    #[inline]
78    fn size(self) -> usize {
79        N
80    }
81}
82
83impl Dim for Dyn {
84    type Merge<D: Dim> = D;
85    type Owned<T, S: Shape> = Tensor<T, S::Prepend<Self>>;
86
87    const SIZE: Option<usize> = None;
88
89    #[inline]
90    fn from_size(size: usize) -> Self {
91        size
92    }
93
94    #[inline]
95    fn size(self) -> usize {
96        self
97    }
98}
99
100macro_rules! impl_dims {
101    ($($n:tt),+) => {
102        $(
103            impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for [T; $n] {
104                #[inline]
105                fn new(len: usize) -> Self {
106                    assert!(len == $n, "invalid length");
107
108                    Self::default()
109                }
110            }
111        )+
112    };
113}
114
115impl_dims!(0, 1, 2, 3, 4, 5, 6);
116
117impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for Box<[T]> {
118    #[inline]
119    fn new(len: usize) -> Self {
120        vec![T::default(); len].into()
121    }
122}
123
124impl<const N: usize> From<Const<N>> for Dyn {
125    #[inline]
126    fn from(_: Const<N>) -> Self {
127        N
128    }
129}
130
131impl<const N: usize> TryFrom<Dyn> for Const<N> {
132    type Error = Dyn;
133
134    #[inline]
135    fn try_from(value: Dyn) -> Result<Self, Self::Error> {
136        if value.size() == N { Ok(Self) } else { Err(value) }
137    }
138}