1#[cfg(not(feature = "std"))]
2use alloc::boxed::Box;
3#[cfg(not(feature = "std"))]
4use alloc::vec;
5
6use core::fmt::{self, Debug, Formatter};
7use core::hash::Hash;
8
9use crate::shape::Shape;
10use crate::tensor::Tensor;
11use crate::traits::Owned;
12
13pub trait Dim: Copy + Debug + Default + Hash + Ord + Send + Sync {
15 type Merge<D: Dim>: Dim;
17
18 #[doc(hidden)]
19 type Owned<T, S: Shape>: Owned<T, S::Prepend<Self>>;
20
21 const SIZE: Option<usize>;
23
24 fn from_size(size: usize) -> Self;
30
31 fn size(self) -> usize;
33}
34
35#[allow(unreachable_pub)]
36pub trait Dims<T: Copy + Debug + Default + Eq + Hash + Send + Sync>:
37 AsMut<[T]>
38 + AsRef<[T]>
39 + Clone
40 + Debug
41 + Default
42 + Eq
43 + Hash
44 + Send
45 + Sync
46 + for<'a> TryFrom<&'a [T], Error: Debug>
47{
48 fn new(len: usize) -> Self;
49}
50
51#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
53pub struct Const<const N: usize>;
54
55pub type Dyn = usize;
57
58impl<const N: usize> Debug for Const<N> {
59 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
60 f.debug_tuple("Const").field(&N).finish()
61 }
62}
63
64impl<const N: usize> Dim for Const<N> {
65 type Merge<D: Dim> = Self;
66 type Owned<T, S: Shape> = <S::Owned<T> as Owned<T, S>>::WithConst<N>;
67
68 const SIZE: Option<usize> = Some(N);
69
70 #[inline]
71 fn from_size(size: usize) -> Self {
72 assert!(size == N, "invalid size");
73
74 Self
75 }
76
77 #[inline]
78 fn size(self) -> usize {
79 N
80 }
81}
82
83impl Dim for Dyn {
84 type Merge<D: Dim> = D;
85 type Owned<T, S: Shape> = Tensor<T, S::Prepend<Self>>;
86
87 const SIZE: Option<usize> = None;
88
89 #[inline]
90 fn from_size(size: usize) -> Self {
91 size
92 }
93
94 #[inline]
95 fn size(self) -> usize {
96 self
97 }
98}
99
100macro_rules! impl_dims {
101 ($($n:tt),+) => {
102 $(
103 impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for [T; $n] {
104 #[inline]
105 fn new(len: usize) -> Self {
106 assert!(len == $n, "invalid length");
107
108 Self::default()
109 }
110 }
111 )+
112 };
113}
114
115impl_dims!(0, 1, 2, 3, 4, 5, 6);
116
117impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for Box<[T]> {
118 #[inline]
119 fn new(len: usize) -> Self {
120 vec![T::default(); len].into()
121 }
122}
123
124impl<const N: usize> From<Const<N>> for Dyn {
125 #[inline]
126 fn from(_: Const<N>) -> Self {
127 N
128 }
129}
130
131impl<const N: usize> TryFrom<Dyn> for Const<N> {
132 type Error = Dyn;
133
134 #[inline]
135 fn try_from(value: Dyn) -> Result<Self, Self::Error> {
136 if value.size() == N { Ok(Self) } else { Err(value) }
137 }
138}