1use core::fmt::Debug;
2use core::hash::Hash;
3
4use crate::dim::{Const, Dim, Dyn};
5use crate::layout::Layout;
6use crate::mapping::{DenseMapping, Mapping};
7use crate::shape::{DynRank, Shape};
8
9pub trait Axis: Copy + Debug + Default + Hash + Ord + Send + Sync {
11 type Dim<S: Shape>: Dim;
13
14 type Init<S: Shape>: Shape;
16
17 type Rest<S: Shape>: Shape;
19
20 type Remove<S: Shape>: Shape;
22
23 type Resize<D: Dim, S: Shape>: Shape;
25
26 fn index(self, rank: usize) -> usize;
28
29 #[doc(hidden)]
30 #[inline]
31 fn get<M: Mapping>(
32 self,
33 mapping: &M,
34 ) -> <Keep<Self, M::Shape, M::Layout> as Layout>::Mapping<(Self::Dim<M::Shape>,)> {
35 let index = self.index(mapping.rank());
36
37 Mapping::prepend_dim(&DenseMapping::new(()), mapping.dim(index), mapping.stride(index))
38 }
39
40 #[doc(hidden)]
41 #[inline]
42 fn remove<M: Mapping>(
43 self,
44 mapping: &M,
45 ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Self::Remove<M::Shape>> {
46 Mapping::remove_dim::<M>(mapping, self.index(mapping.rank()))
47 }
48
49 #[doc(hidden)]
50 #[inline]
51 fn resize<M: Mapping>(
52 self,
53 mapping: &M,
54 new_size: usize,
55 ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Self::Resize<Dyn, M::Shape>> {
56 Mapping::resize_dim::<M>(mapping, self.index(mapping.rank()), new_size)
57 }
58}
59
60#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
62pub struct Cols;
63
64#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
66pub struct Rows;
67
68#[doc(hidden)]
74pub type Keep<A, S, L> = <<A as Axis>::Rest<S> as Shape>::Layout<L>;
75
76#[doc(hidden)]
77pub type Split<A, S, L> = <<A as Axis>::Init<S> as Shape>::Layout<L>;
78
79impl Axis for Const<0> {
104 type Dim<S: Shape> = S::Head;
105
106 type Init<S: Shape> = ();
107 type Rest<S: Shape> = S::Tail;
108
109 type Remove<S: Shape> = S::Tail;
110 type Resize<D: Dim, S: Shape> = <S::Tail as Shape>::Prepend<D>;
111
112 #[inline]
113 fn index(self, rank: usize) -> usize {
114 assert!(rank > 0, "invalid dimension");
115
116 0
117 }
118}
119
120macro_rules! impl_axis {
121 (($($n:tt),*), ($($k:tt),*)) => {
122 $(
123 impl Axis for Const<$n> {
124 type Dim<S: Shape> = <Const<$k> as Axis>::Dim<S::Tail>;
125
126 type Init<S: Shape> =
127 <<Const<$k> as Axis>::Init<S::Tail> as Shape>::Prepend<S::Head>;
128 type Rest<S: Shape> = <Const<$k> as Axis>::Rest<S::Tail>;
129
130 type Remove<S: Shape> =
131 <<Const<$k> as Axis>::Remove<S::Tail> as Shape>::Prepend<S::Head>;
132 type Resize<D: Dim, S: Shape> =
133 <<Const<$k> as Axis>::Resize<D, S::Tail> as Shape>::Prepend<S::Head>;
134
135 #[inline]
136 fn index(self, rank: usize) -> usize {
137 assert!(rank > $n, "invalid dimension");
138
139 $n
140 }
141 }
142 )*
143 };
144}
145
146impl_axis!((1, 2, 3, 4, 5), (0, 1, 2, 3, 4));
147
148macro_rules! impl_cols_rows {
149 ($name:tt, $n:tt) => {
150 impl Axis for $name {
151 type Dim<S: Shape> = <Const<$n> as Axis>::Dim<S::Reverse>;
152
153 type Init<S: Shape> = <<Const<$n> as Axis>::Rest<S::Reverse> as Shape>::Reverse;
154 type Rest<S: Shape> = <<Const<$n> as Axis>::Init<S::Reverse> as Shape>::Reverse;
155
156 type Remove<S: Shape> = <<Const<$n> as Axis>::Remove<S::Reverse> as Shape>::Reverse;
157 type Resize<D: Dim, S: Shape> =
158 <<Const<$n> as Axis>::Resize<D, S::Reverse> as Shape>::Reverse;
159
160 #[inline]
161 fn index(self, rank: usize) -> usize {
162 assert!(rank > $n, "invalid dimension");
163
164 rank - $n - 1
165 }
166 }
167 };
168}
169
170impl_cols_rows!(Cols, 1);
171impl_cols_rows!(Rows, 0);
172
173impl Axis for Dyn {
174 type Dim<S: Shape> = Dyn;
175
176 type Init<S: Shape> = DynRank;
177 type Rest<S: Shape> = DynRank;
178
179 type Remove<S: Shape> = <S::Tail as Shape>::Dyn;
180 type Resize<D: Dim, S: Shape> = S::Dyn;
181
182 #[inline]
183 fn index(self, rank: usize) -> usize {
184 assert!(self < rank, "invalid dimension");
185
186 self
187 }
188}