use crate::dim::{Dim, Dyn};
use crate::layout::{Flat, Layout};
use crate::mapping::{DenseMapping, Mapping};
use crate::shape::Shape;
pub trait Axis {
type Dim<S: Shape>: Dim;
type Other<S: Shape>: Shape;
type Insert<D: Dim, S: Shape>: Shape;
type Replace<D: Dim, S: Shape>: Shape;
type Keep<S: Shape, L: Layout>: Layout;
type Remove<S: Shape, L: Layout>: Layout;
type Resize<S: Shape, L: Layout>: Layout;
#[doc(hidden)]
fn index(rank: usize) -> usize;
#[doc(hidden)]
fn keep<M: Mapping>(
mapping: M,
) -> <Self::Keep<M::Shape, M::Layout> as Layout>::Mapping<Self::Dim<M::Shape>> {
let index = Self::index(M::Shape::RANK);
Mapping::add_dim(DenseMapping::new(()), mapping.dim(index), mapping.stride(index))
}
#[doc(hidden)]
fn remove<M: Mapping>(
mapping: M,
) -> <Self::Remove<M::Shape, M::Layout> as Layout>::Mapping<Self::Other<M::Shape>> {
Mapping::remove_dim::<M>(mapping, Self::index(M::Shape::RANK))
}
#[doc(hidden)]
fn resize<M: Mapping>(
mapping: M,
new_size: usize,
) -> <Self::Resize<M::Shape, M::Layout> as Layout>::Mapping<Self::Replace<Dyn, M::Shape>> {
Mapping::resize_dim::<M>(mapping, Self::index(M::Shape::RANK), new_size)
}
}
pub struct Inner<const N: usize>;
pub struct Outer;
impl Axis for Inner<0> {
type Dim<S: Shape> = S::Head;
type Other<S: Shape> = S::Tail;
type Insert<D: Dim, S: Shape> = S::Prepend<D>;
type Replace<D: Dim, S: Shape> = <S::Tail as Shape>::Prepend<D>;
type Keep<S: Shape, L: Layout> = L::Uniform;
type Remove<S: Shape, L: Layout> = <S::Tail as Shape>::Layout<Flat, L::NonUnitStrided>;
type Resize<S: Shape, L: Layout> = S::Layout<L, L::NonUniform>;
fn index(rank: usize) -> usize {
assert!(rank > 0, "invalid dimension");
0
}
}
impl Axis for Inner<1> {
type Dim<S: Shape> = <S::Tail as Shape>::Head;
type Other<S: Shape> = <<S::Tail as Shape>::Tail as Shape>::Prepend<S::Head>;
type Insert<D: Dim, S: Shape> = <<S::Tail as Shape>::Prepend<D> as Shape>::Prepend<S::Head>;
type Replace<D: Dim, S: Shape> =
<<<S::Tail as Shape>::Tail as Shape>::Prepend<D> as Shape>::Prepend<S::Head>;
type Keep<S: Shape, L: Layout> = Flat;
type Remove<S: Shape, L: Layout> = <S::Tail as Shape>::Layout<L::Uniform, L::NonUniform>;
type Resize<S: Shape, L: Layout> = <S::Tail as Shape>::Layout<L, L::NonUniform>;
fn index(rank: usize) -> usize {
assert!(rank > 1, "invalid dimension");
1
}
}
macro_rules! impl_axis {
(($($n:tt),*), ($($k:tt),*)) => {
$(
impl Axis for Inner<$n> {
type Dim<S: Shape> = <Inner<$k> as Axis>::Dim<S::Tail>;
type Other<S: Shape> =
<<Inner<$k> as Axis>::Other<S::Tail> as Shape>::Prepend<S::Head>;
type Insert<D: Dim, S: Shape> =
<<Inner<$k> as Axis>::Insert<D, S::Tail> as Shape>::Prepend<S::Head>;
type Replace<D: Dim, S: Shape> =
<<Inner<$k> as Axis>::Replace<D, S::Tail> as Shape>::Prepend<S::Head>;
type Keep<S: Shape, L: Layout> = Flat;
type Remove<S: Shape, L: Layout> = <Inner<$k> as Axis>::Remove<S::Tail, L>;
type Resize<S: Shape, L: Layout> = <Inner<$k> as Axis>::Resize<S::Tail, L>;
fn index(rank: usize) -> usize {
assert!(rank > $n, "invalid dimension");
$n
}
}
)*
};
}
impl_axis!((2, 3, 4, 5), (1, 2, 3, 4));
impl Axis for Outer {
type Dim<S: Shape> = <S::Reverse as Shape>::Head;
type Other<S: Shape> = <<S::Reverse as Shape>::Tail as Shape>::Reverse;
type Insert<D: Dim, S: Shape> = <<Inner<0> as Axis>::Insert<D, S::Reverse> as Shape>::Reverse;
type Replace<D: Dim, S: Shape> = <<Inner<0> as Axis>::Replace<D, S::Reverse> as Shape>::Reverse;
type Keep<S: Shape, L: Layout> = S::Layout<L::Uniform, Flat>;
type Remove<S: Shape, L: Layout> = <S::Tail as Shape>::Layout<L::Uniform, L>;
type Resize<S: Shape, L: Layout> = L;
fn index(rank: usize) -> usize {
assert!(rank > 0, "invalid dimension");
rank - 1
}
}