Skip to main content

mdarray/index/
axis.rs

1use core::fmt::Debug;
2use core::hash::Hash;
3
4use crate::dim::{Const, Dim, Dyn};
5use crate::layout::Layout;
6use crate::mapping::{DenseMapping, Mapping};
7use crate::shape::{DynRank, Shape};
8
9/// Array axis trait, for subarray shapes.
10pub trait Axis: Copy + Debug + Default + Hash + Ord + Send + Sync {
11    /// Corresponding dimension.
12    type Dim<S: Shape>: Dim;
13
14    /// Shape for the previous dimensions excluding the current dimension.
15    type Init<S: Shape>: Shape;
16
17    /// Shape for the next dimensions excluding the current dimension.
18    type Rest<S: Shape>: Shape;
19
20    /// Remove the dimension from the shape.
21    type Remove<S: Shape>: Shape;
22
23    /// Resize the dimension in the shape.
24    type Resize<D: Dim, S: Shape>: Shape;
25
26    /// Returns the dimension index.
27    fn index(self, rank: usize) -> usize;
28
29    #[doc(hidden)]
30    #[inline]
31    fn get<M: Mapping>(
32        self,
33        mapping: &M,
34    ) -> <Keep<Self, M::Shape, M::Layout> as Layout>::Mapping<(Self::Dim<M::Shape>,)> {
35        let index = self.index(mapping.rank());
36
37        Mapping::prepend_dim(&DenseMapping::new(()), mapping.dim(index), mapping.stride(index))
38    }
39
40    #[doc(hidden)]
41    #[inline]
42    fn remove<M: Mapping>(
43        self,
44        mapping: &M,
45    ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Self::Remove<M::Shape>> {
46        Mapping::remove_dim::<M>(mapping, self.index(mapping.rank()))
47    }
48
49    #[doc(hidden)]
50    #[inline]
51    fn resize<M: Mapping>(
52        self,
53        mapping: &M,
54        new_size: usize,
55    ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Self::Resize<Dyn, M::Shape>> {
56        Mapping::resize_dim::<M>(mapping, self.index(mapping.rank()), new_size)
57    }
58}
59
60/// Column axis type, for the second last dimension.
61#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
62pub struct Cols;
63
64/// Row axis type, for the last dimension.
65#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
66pub struct Rows;
67
68//
69// These types are public to improve documentation, but hidden since
70// they are not considered part of the API.
71//
72
73#[doc(hidden)]
74pub type Keep<A, S, L> = <<A as Axis>::Rest<S> as Shape>::Layout<L>;
75
76#[doc(hidden)]
77pub type Split<A, S, L> = <<A as Axis>::Init<S> as Shape>::Layout<L>;
78
79//
80// The tables below give the resulting layout depending on the rank and axis.
81//
82// Keep<A, S, L>:
83//
84// Rank \ Axis  0           1           2           ...         Dyn
85// -------------------------------------------------------------------------
86// 1            L           -           -           -           Strided
87// 2            Strided     L           -           -           Strided
88// 3            Strided     Strided     L           -           Strided
89// ...
90// DynRank      Strided     Strided     Strided     ...         Strided
91//
92// Split<A, S, L>:
93//
94// Rank \ Axis  0           1           2           ...         Dyn
95// -------------------------------------------------------------------------
96// 1            L           -           -           -           Strided
97// 2            L           Strided     -           -           Strided
98// 3            L           Strided     Strided     -           Strided
99// ...
100// DynRank      L           Strided     Strided     ...         Strided
101//
102
103impl Axis for Const<0> {
104    type Dim<S: Shape> = S::Head;
105
106    type Init<S: Shape> = ();
107    type Rest<S: Shape> = S::Tail;
108
109    type Remove<S: Shape> = S::Tail;
110    type Resize<D: Dim, S: Shape> = <S::Tail as Shape>::Prepend<D>;
111
112    #[inline]
113    fn index(self, rank: usize) -> usize {
114        assert!(rank > 0, "invalid dimension");
115
116        0
117    }
118}
119
120macro_rules! impl_axis {
121    (($($n:tt),*), ($($k:tt),*)) => {
122        $(
123            impl Axis for Const<$n> {
124                type Dim<S: Shape> = <Const<$k> as Axis>::Dim<S::Tail>;
125
126                type Init<S: Shape> =
127                    <<Const<$k> as Axis>::Init<S::Tail> as Shape>::Prepend<S::Head>;
128                type Rest<S: Shape> = <Const<$k> as Axis>::Rest<S::Tail>;
129
130                type Remove<S: Shape> =
131                    <<Const<$k> as Axis>::Remove<S::Tail> as Shape>::Prepend<S::Head>;
132                type Resize<D: Dim, S: Shape> =
133                    <<Const<$k> as Axis>::Resize<D, S::Tail> as Shape>::Prepend<S::Head>;
134
135                #[inline]
136                fn index(self, rank: usize) -> usize {
137                    assert!(rank > $n, "invalid dimension");
138
139                    $n
140                }
141            }
142        )*
143    };
144}
145
146impl_axis!((1, 2, 3, 4, 5), (0, 1, 2, 3, 4));
147
148macro_rules! impl_cols_rows {
149    ($name:tt, $n:tt) => {
150        impl Axis for $name {
151            type Dim<S: Shape> = <Const<$n> as Axis>::Dim<S::Reverse>;
152
153            type Init<S: Shape> = <<Const<$n> as Axis>::Rest<S::Reverse> as Shape>::Reverse;
154            type Rest<S: Shape> = <<Const<$n> as Axis>::Init<S::Reverse> as Shape>::Reverse;
155
156            type Remove<S: Shape> = <<Const<$n> as Axis>::Remove<S::Reverse> as Shape>::Reverse;
157            type Resize<D: Dim, S: Shape> =
158                <<Const<$n> as Axis>::Resize<D, S::Reverse> as Shape>::Reverse;
159
160            #[inline]
161            fn index(self, rank: usize) -> usize {
162                assert!(rank > $n, "invalid dimension");
163
164                rank - $n - 1
165            }
166        }
167    };
168}
169
170impl_cols_rows!(Cols, 1);
171impl_cols_rows!(Rows, 0);
172
173impl Axis for Dyn {
174    type Dim<S: Shape> = Dyn;
175
176    type Init<S: Shape> = DynRank;
177    type Rest<S: Shape> = DynRank;
178
179    type Remove<S: Shape> = <S::Tail as Shape>::Dyn;
180    type Resize<D: Dim, S: Shape> = S::Dyn;
181
182    #[inline]
183    fn index(self, rank: usize) -> usize {
184        assert!(self < rank, "invalid dimension");
185
186        self
187    }
188}