mdarray/index/
axis.rs

1use core::fmt::Debug;
2use core::hash::Hash;
3
4use crate::dim::{Const, Dim, Dyn};
5use crate::layout::Layout;
6use crate::mapping::{DenseMapping, Mapping};
7use crate::shape::{DynRank, Shape};
8
9/// Array axis trait, for subarray shapes.
10pub trait Axis: Copy + Debug + Default + Hash + Ord + Send + Sync {
11    /// Corresponding dimension.
12    type Dim<S: Shape>: Dim;
13
14    /// Shape for the previous dimensions excluding the current dimension.
15    type Init<S: Shape>: Shape;
16
17    /// Shape for the next dimensions excluding the current dimension.
18    type Rest<S: Shape>: Shape;
19
20    /// Remove the dimension from the shape.
21    type Remove<S: Shape>: Shape;
22
23    /// Insert the dimension into the shape.
24    type Insert<D: Dim, S: Shape>: Shape;
25
26    /// Returns the dimension index.
27    fn index(self, rank: usize) -> usize;
28
29    #[doc(hidden)]
30    fn get<M: Mapping>(
31        self,
32        mapping: &M,
33    ) -> <Keep<Self, M::Shape, M::Layout> as Layout>::Mapping<(Self::Dim<M::Shape>,)> {
34        let index = self.index(mapping.rank());
35
36        Mapping::prepend_dim(&DenseMapping::new(()), mapping.dim(index), mapping.stride(index))
37    }
38
39    #[doc(hidden)]
40    fn remove<M: Mapping>(
41        self,
42        mapping: &M,
43    ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Self::Remove<M::Shape>> {
44        Mapping::remove_dim::<M>(mapping, self.index(mapping.rank()))
45    }
46
47    #[doc(hidden)]
48    fn resize<M: Mapping>(
49        self,
50        mapping: &M,
51        new_size: usize,
52    ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Resize<Self, M::Shape>> {
53        Mapping::resize_dim::<M>(mapping, self.index(mapping.rank()), new_size)
54    }
55}
56
57/// Column axis type, for the second last dimension.
58#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
59pub struct Cols;
60
61/// Row axis type, for the last dimension.
62#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
63pub struct Rows;
64
65//
66// These types are public to improve documentation, but hidden since
67// they are not considered part of the API.
68//
69
70#[doc(hidden)]
71pub type Resize<A, S> = <A as Axis>::Insert<Dyn, <A as Axis>::Remove<S>>;
72
73#[doc(hidden)]
74pub type Keep<A, S, L> = <<A as Axis>::Rest<S> as Shape>::Layout<L>;
75
76#[doc(hidden)]
77pub type Split<A, S, L> = <<A as Axis>::Init<S> as Shape>::Layout<L>;
78
79//
80// The tables below give the resulting layout depending on the rank and axis.
81//
82// Keep<A, S, L>:
83//
84// Rank \ Axis  0           1           2           ...         Dyn
85// -------------------------------------------------------------------------
86// 1            L           -           -           -           Strided
87// 2            Strided     L           -           -           Strided
88// 3            Strided     Strided     L           -           Strided
89// ...
90// DynRank      Strided     Strided     Strided     ...         Strided
91//
92// Split<A, S, L>:
93//
94// Rank \ Axis  0           1           2           ...         Dyn
95// -------------------------------------------------------------------------
96// 1            L           -           -           -           Strided
97// 2            L           Strided     -           -           Strided
98// 3            L           Strided     Strided     -           Strided
99// ...
100// DynRank      L           Strided     Strided     ...         Strided
101//
102
103impl Axis for Const<0> {
104    type Dim<S: Shape> = S::Head;
105
106    type Init<S: Shape> = ();
107    type Rest<S: Shape> = S::Tail;
108
109    type Remove<S: Shape> = S::Tail;
110    type Insert<D: Dim, S: Shape> = S::Prepend<D>;
111
112    fn index(self, rank: usize) -> usize {
113        assert!(rank > 0, "invalid dimension");
114
115        0
116    }
117}
118
119macro_rules! impl_axis {
120    (($($n:tt),*), ($($k:tt),*)) => {
121        $(
122            impl Axis for Const<$n> {
123                type Dim<S: Shape> = <Const<$k> as Axis>::Dim<S::Tail>;
124
125                type Init<S: Shape> =
126                    <<Const<$k> as Axis>::Init<S::Tail> as Shape>::Prepend<S::Head>;
127                type Rest<S: Shape> = <Const<$k> as Axis>::Rest<S::Tail>;
128
129                type Remove<S: Shape> =
130                    <<Const<$k> as Axis>::Remove<S::Tail> as Shape>::Prepend<S::Head>;
131                type Insert<D: Dim, S: Shape> =
132                    <<Const<$k> as Axis>::Insert<D, S::Tail> as Shape>::Prepend<S::Head>;
133
134                fn index(self, rank: usize) -> usize {
135                    assert!(rank > $n, "invalid dimension");
136
137                    $n
138                }
139            }
140        )*
141    };
142}
143
144impl_axis!((1, 2, 3, 4, 5), (0, 1, 2, 3, 4));
145
146macro_rules! impl_cols_rows {
147    ($name:tt, $n:tt) => {
148        impl Axis for $name {
149            type Dim<S: Shape> = <Const<$n> as Axis>::Dim<S::Reverse>;
150
151            type Init<S: Shape> = <<Const<$n> as Axis>::Rest<S::Reverse> as Shape>::Reverse;
152            type Rest<S: Shape> = <<Const<$n> as Axis>::Init<S::Reverse> as Shape>::Reverse;
153
154            type Remove<S: Shape> = <<Const<$n> as Axis>::Remove<S::Reverse> as Shape>::Reverse;
155            type Insert<D: Dim, S: Shape> =
156                <<Const<$n> as Axis>::Insert<D, S::Reverse> as Shape>::Reverse;
157
158            fn index(self, rank: usize) -> usize {
159                assert!(rank > $n, "invalid dimension");
160
161                rank - $n - 1
162            }
163        }
164    };
165}
166
167impl_cols_rows!(Cols, 1);
168impl_cols_rows!(Rows, 0);
169
170impl Axis for Dyn {
171    type Dim<S: Shape> = Dyn;
172
173    type Init<S: Shape> = DynRank;
174    type Rest<S: Shape> = DynRank;
175
176    type Remove<S: Shape> = <S::Tail as Shape>::Dyn;
177    type Insert<D: Dim, S: Shape> = <S::Dyn as Shape>::Prepend<Dyn>;
178
179    fn index(self, rank: usize) -> usize {
180        assert!(self < rank, "invalid dimension");
181
182        self
183    }
184}