1use core::fmt::Debug;
2use core::hash::Hash;
3
4use crate::dim::{Const, Dim, Dyn};
5use crate::layout::Layout;
6use crate::mapping::{DenseMapping, Mapping};
7use crate::shape::{DynRank, Shape};
8
9pub trait Axis: Copy + Debug + Default + Hash + Ord + Send + Sync {
11 type Dim<S: Shape>: Dim;
13
14 type Init<S: Shape>: Shape;
16
17 type Rest<S: Shape>: Shape;
19
20 type Remove<S: Shape>: Shape;
22
23 type Insert<D: Dim, S: Shape>: Shape;
25
26 fn index(self, rank: usize) -> usize;
28
29 #[doc(hidden)]
30 fn get<M: Mapping>(
31 self,
32 mapping: &M,
33 ) -> <Keep<Self, M::Shape, M::Layout> as Layout>::Mapping<(Self::Dim<M::Shape>,)> {
34 let index = self.index(mapping.rank());
35
36 Mapping::prepend_dim(&DenseMapping::new(()), mapping.dim(index), mapping.stride(index))
37 }
38
39 #[doc(hidden)]
40 fn remove<M: Mapping>(
41 self,
42 mapping: &M,
43 ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Self::Remove<M::Shape>> {
44 Mapping::remove_dim::<M>(mapping, self.index(mapping.rank()))
45 }
46
47 #[doc(hidden)]
48 fn resize<M: Mapping>(
49 self,
50 mapping: &M,
51 new_size: usize,
52 ) -> <Split<Self, M::Shape, M::Layout> as Layout>::Mapping<Resize<Self, M::Shape>> {
53 Mapping::resize_dim::<M>(mapping, self.index(mapping.rank()), new_size)
54 }
55}
56
57#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
59pub struct Cols;
60
61#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
63pub struct Rows;
64
65#[doc(hidden)]
71pub type Resize<A, S> = <A as Axis>::Insert<Dyn, <A as Axis>::Remove<S>>;
72
73#[doc(hidden)]
74pub type Keep<A, S, L> = <<A as Axis>::Rest<S> as Shape>::Layout<L>;
75
76#[doc(hidden)]
77pub type Split<A, S, L> = <<A as Axis>::Init<S> as Shape>::Layout<L>;
78
79impl Axis for Const<0> {
104 type Dim<S: Shape> = S::Head;
105
106 type Init<S: Shape> = ();
107 type Rest<S: Shape> = S::Tail;
108
109 type Remove<S: Shape> = S::Tail;
110 type Insert<D: Dim, S: Shape> = S::Prepend<D>;
111
112 fn index(self, rank: usize) -> usize {
113 assert!(rank > 0, "invalid dimension");
114
115 0
116 }
117}
118
119macro_rules! impl_axis {
120 (($($n:tt),*), ($($k:tt),*)) => {
121 $(
122 impl Axis for Const<$n> {
123 type Dim<S: Shape> = <Const<$k> as Axis>::Dim<S::Tail>;
124
125 type Init<S: Shape> =
126 <<Const<$k> as Axis>::Init<S::Tail> as Shape>::Prepend<S::Head>;
127 type Rest<S: Shape> = <Const<$k> as Axis>::Rest<S::Tail>;
128
129 type Remove<S: Shape> =
130 <<Const<$k> as Axis>::Remove<S::Tail> as Shape>::Prepend<S::Head>;
131 type Insert<D: Dim, S: Shape> =
132 <<Const<$k> as Axis>::Insert<D, S::Tail> as Shape>::Prepend<S::Head>;
133
134 fn index(self, rank: usize) -> usize {
135 assert!(rank > $n, "invalid dimension");
136
137 $n
138 }
139 }
140 )*
141 };
142}
143
144impl_axis!((1, 2, 3, 4, 5), (0, 1, 2, 3, 4));
145
146macro_rules! impl_cols_rows {
147 ($name:tt, $n:tt) => {
148 impl Axis for $name {
149 type Dim<S: Shape> = <Const<$n> as Axis>::Dim<S::Reverse>;
150
151 type Init<S: Shape> = <<Const<$n> as Axis>::Rest<S::Reverse> as Shape>::Reverse;
152 type Rest<S: Shape> = <<Const<$n> as Axis>::Init<S::Reverse> as Shape>::Reverse;
153
154 type Remove<S: Shape> = <<Const<$n> as Axis>::Remove<S::Reverse> as Shape>::Reverse;
155 type Insert<D: Dim, S: Shape> =
156 <<Const<$n> as Axis>::Insert<D, S::Reverse> as Shape>::Reverse;
157
158 fn index(self, rank: usize) -> usize {
159 assert!(rank > $n, "invalid dimension");
160
161 rank - $n - 1
162 }
163 }
164 };
165}
166
167impl_cols_rows!(Cols, 1);
168impl_cols_rows!(Rows, 0);
169
170impl Axis for Dyn {
171 type Dim<S: Shape> = Dyn;
172
173 type Init<S: Shape> = DynRank;
174 type Rest<S: Shape> = DynRank;
175
176 type Remove<S: Shape> = <S::Tail as Shape>::Dyn;
177 type Insert<D: Dim, S: Shape> = <S::Dyn as Shape>::Prepend<Dyn>;
178
179 fn index(self, rank: usize) -> usize {
180 assert!(self < rank, "invalid dimension");
181
182 self
183 }
184}