rstsr_common/layout/
dim.rs

1use crate::prelude_dev::*;
2use core::ops::IndexMut;
3
4/* #region basic definitions */
5
6/// Fixed size dimension.
7pub type Ix<const N: usize> = [usize; N];
8pub type Ix0 = Ix<0>;
9pub type Ix1 = Ix<1>;
10pub type Ix2 = Ix<2>;
11pub type Ix3 = Ix<3>;
12pub type Ix4 = Ix<4>;
13pub type Ix5 = Ix<5>;
14pub type Ix6 = Ix<6>;
15pub type Ix7 = Ix<7>;
16pub type Ix8 = Ix<8>;
17pub type Ix9 = Ix<9>;
18pub type IxD = Vec<usize>;
19
20/// Dynamic size dimension.
21pub type IxDyn = IxD;
22
23pub trait DimBaseAPI:
24    AsMut<[usize]>
25    + AsRef<[usize]>
26    + IndexMut<usize, Output = usize>
27    + Debug
28    + PartialEq
29    + Clone
30    + TryFrom<Vec<usize>>
31    + Into<Vec<usize>>
32    + Send
33    + Sync
34    + PartialOrd
35    + PartialEq
36{
37    type Stride: AsMut<[isize]>
38        + AsRef<[isize]>
39        + IndexMut<usize, Output = isize>
40        + Debug
41        + PartialEq
42        + Clone
43        + TryFrom<Vec<isize>>
44        + Into<Vec<isize>>;
45
46    /// Number of dimension
47    fn ndim(&self) -> usize;
48
49    /// Dynamic or static dimension
50    fn const_ndim() -> Option<usize>;
51
52    /// New shape
53    fn new_shape(&self) -> Self;
54
55    /// New stride
56    fn new_stride(&self) -> Self::Stride;
57}
58
59impl<const N: usize> DimBaseAPI for Ix<N> {
60    type Stride = [isize; N];
61
62    #[inline]
63    fn ndim(&self) -> usize {
64        N
65    }
66
67    #[inline]
68    fn const_ndim() -> Option<usize> {
69        Some(N)
70    }
71
72    #[inline]
73    fn new_shape(&self) -> Self {
74        [0; N]
75    }
76
77    #[inline]
78    fn new_stride(&self) -> Self::Stride {
79        [0; N]
80    }
81}
82
83impl DimBaseAPI for IxD {
84    type Stride = Vec<isize>;
85
86    #[inline]
87    fn ndim(&self) -> usize {
88        self.len()
89    }
90
91    #[inline]
92    fn const_ndim() -> Option<usize> {
93        None
94    }
95
96    #[inline]
97    fn new_shape(&self) -> Self {
98        vec![0; self.len()]
99    }
100
101    #[inline]
102    fn new_stride(&self) -> Self::Stride {
103        vec![0; self.len()]
104    }
105}
106
107/* #endregion */
108
109/* #region dimension relative eq */
110
111// Trait for defining smaller dimension by one.
112#[doc(hidden)]
113pub trait DimSmallerOneAPI: DimBaseAPI {
114    type SmallerOne: DimBaseAPI;
115}
116
117// Trait for defining larger dimension by one.
118#[doc(hidden)]
119pub trait DimLargerOneAPI: DimBaseAPI {
120    type LargerOne: DimBaseAPI;
121}
122
123impl DimSmallerOneAPI for IxD {
124    type SmallerOne = IxD;
125}
126
127impl DimLargerOneAPI for IxD {
128    type LargerOne = IxD;
129}
130
131macro_rules! impl_dim_smaller_one {
132    ($(($N:literal, $N1:literal)),*) => {
133        $(
134            impl DimSmallerOneAPI for Ix<$N> {
135                type SmallerOne = Ix<$N1>;
136            }
137        )*
138    };
139}
140
141impl_dim_smaller_one!((1, 0), (2, 1), (3, 2), (4, 3), (5, 4), (6, 5), (7, 6), (8, 7), (9, 8));
142
143macro_rules! impl_dim_larger_one {
144    ($(($N:literal, $N1:literal)),*) => {
145        $(
146            impl DimLargerOneAPI for Ix<$N> {
147                type LargerOne = Ix<$N1>;
148            }
149        )*
150    };
151}
152
153impl_dim_larger_one!((0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9));
154
155/// Trait for comparing two dimensions.
156///
157/// This trait is used to broadcast two tensors.
158#[doc(hidden)]
159pub trait DimMaxAPI<D2>
160where
161    D2: DimBaseAPI,
162{
163    // This type will be used in many cases outside layout module.
164    // So we use `DimAPI` instead of `DimBaseAPI`, being convenient for
165    // functions outside this module.
166    type Max: DimAPI;
167}
168
169impl DimMaxAPI<IxD> for IxD {
170    type Max = IxD;
171}
172
173macro_rules! impl_dim_max_dyn {
174    ($($N:literal),*) => {
175        $(
176            impl DimMaxAPI<IxD> for Ix<$N> {
177                type Max = IxD;
178            }
179
180            impl DimMaxAPI<Ix<$N>> for IxD {
181                type Max = IxD;
182            }
183        )*
184    };
185}
186
187impl_dim_max_dyn!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
188
189macro_rules! impl_dim_max {
190    ($(($N1:literal, $N2:literal, $N:literal)),*) => {
191        $(
192            impl DimMaxAPI<Ix<$N1>> for Ix<$N2> {
193                type Max = Ix<$N>;
194            }
195        )*
196    };
197}
198
199impl_dim_max!(
200    (0, 0, 0),
201    (0, 1, 1),
202    (0, 2, 2),
203    (0, 3, 3),
204    (0, 4, 4),
205    (0, 5, 5),
206    (0, 6, 6),
207    (0, 7, 7),
208    (0, 8, 8),
209    (0, 9, 9)
210);
211impl_dim_max!(
212    (1, 0, 1),
213    (1, 1, 1),
214    (1, 2, 2),
215    (1, 3, 3),
216    (1, 4, 4),
217    (1, 5, 5),
218    (1, 6, 6),
219    (1, 7, 7),
220    (1, 8, 8),
221    (1, 9, 9)
222);
223impl_dim_max!(
224    (2, 0, 2),
225    (2, 1, 2),
226    (2, 2, 2),
227    (2, 3, 3),
228    (2, 4, 4),
229    (2, 5, 5),
230    (2, 6, 6),
231    (2, 7, 7),
232    (2, 8, 8),
233    (2, 9, 9)
234);
235impl_dim_max!(
236    (3, 0, 3),
237    (3, 1, 3),
238    (3, 2, 3),
239    (3, 3, 3),
240    (3, 4, 4),
241    (3, 5, 5),
242    (3, 6, 6),
243    (3, 7, 7),
244    (3, 8, 8),
245    (3, 9, 9)
246);
247impl_dim_max!(
248    (4, 0, 4),
249    (4, 1, 4),
250    (4, 2, 4),
251    (4, 3, 4),
252    (4, 4, 4),
253    (4, 5, 5),
254    (4, 6, 6),
255    (4, 7, 7),
256    (4, 8, 8),
257    (4, 9, 9)
258);
259impl_dim_max!(
260    (5, 0, 5),
261    (5, 1, 5),
262    (5, 2, 5),
263    (5, 3, 5),
264    (5, 4, 5),
265    (5, 5, 5),
266    (5, 6, 6),
267    (5, 7, 7),
268    (5, 8, 8),
269    (5, 9, 9)
270);
271impl_dim_max!(
272    (6, 0, 6),
273    (6, 1, 6),
274    (6, 2, 6),
275    (6, 3, 6),
276    (6, 4, 6),
277    (6, 5, 6),
278    (6, 6, 6),
279    (6, 7, 7),
280    (6, 8, 8),
281    (6, 9, 9)
282);
283impl_dim_max!(
284    (7, 0, 7),
285    (7, 1, 7),
286    (7, 2, 7),
287    (7, 3, 7),
288    (7, 4, 7),
289    (7, 5, 7),
290    (7, 6, 7),
291    (7, 7, 7),
292    (7, 8, 8),
293    (7, 9, 9)
294);
295impl_dim_max!(
296    (8, 0, 8),
297    (8, 1, 8),
298    (8, 2, 8),
299    (8, 3, 8),
300    (8, 4, 8),
301    (8, 5, 8),
302    (8, 6, 8),
303    (8, 7, 8),
304    (8, 8, 8),
305    (8, 9, 9)
306);
307impl_dim_max!(
308    (9, 0, 9),
309    (9, 1, 9),
310    (9, 2, 9),
311    (9, 3, 9),
312    (9, 4, 9),
313    (9, 5, 9),
314    (9, 6, 9),
315    (9, 7, 9),
316    (9, 8, 9),
317    (9, 9, 9)
318);
319
320/* #endregion */