use std::fmt::Debug;
use std::ops::{
Bound, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
use crate::layout::{Dense, Layout};
pub trait Dim {
type Higher: Dim;
type Lower: Dim;
type Max<D: Dim>: Dim;
type Layout<L: Layout>: Layout;
type Shape: Shape<Dim = Self>;
type Strides: Strides<Dim = Self>;
const RANK: usize;
#[doc(hidden)]
fn add_dim(shape: Self::Shape, size: usize) -> <Self::Higher as Dim>::Shape {
assert!(<Self::Higher as Dim>::RANK > Self::RANK, "invalid rank");
let mut new_shape = <Self::Higher as Dim>::Shape::default();
new_shape[..Self::RANK].copy_from_slice(&shape[..]);
new_shape[Self::RANK] = size;
new_shape
}
#[doc(hidden)]
fn checked_len(shape: Self::Shape) -> usize {
shape[..].iter().fold(1, |acc, &x| acc.checked_mul(x).expect("length too large"))
}
#[doc(hidden)]
fn remove_dim(shape: Self::Shape, dim: usize) -> <Self::Lower as Dim>::Shape {
assert!(dim < Self::RANK, "invalid dimension");
let mut new_shape = <Self::Lower as Dim>::Shape::default();
new_shape[..dim].copy_from_slice(&shape[..dim]);
new_shape[dim..].copy_from_slice(&shape[dim + 1..]);
new_shape
}
#[doc(hidden)]
fn resize_dim(mut shape: Self::Shape, dim: usize, new_size: usize) -> Self::Shape {
assert!(dim < Self::RANK, "invalid dimension");
shape[dim] = new_size;
shape
}
}
pub trait Shape:
Copy
+ Debug
+ Default
+ IndexMut<(Bound<usize>, Bound<usize>), Output = [usize]>
+ IndexMut<usize, Output = usize>
+ IndexMut<Range<usize>, Output = [usize]>
+ IndexMut<RangeFrom<usize>, Output = [usize]>
+ IndexMut<RangeFull, Output = [usize]>
+ IndexMut<RangeInclusive<usize>, Output = [usize]>
+ IndexMut<RangeTo<usize>, Output = [usize]>
+ IndexMut<RangeToInclusive<usize>, Output = [usize]>
{
type Dim: Dim<Shape = Self>;
}
pub trait Strides:
Copy
+ Debug
+ Default
+ IndexMut<(Bound<usize>, Bound<usize>), Output = [isize]>
+ IndexMut<usize, Output = isize>
+ IndexMut<Range<usize>, Output = [isize]>
+ IndexMut<RangeFrom<usize>, Output = [isize]>
+ IndexMut<RangeFull, Output = [isize]>
+ IndexMut<RangeInclusive<usize>, Output = [isize]>
+ IndexMut<RangeTo<usize>, Output = [isize]>
+ IndexMut<RangeToInclusive<usize>, Output = [isize]>
{
type Dim: Dim<Strides = Self>;
}
pub struct Const<const N: usize>;
impl Dim for Const<0> {
type Higher = Const<1>;
type Lower = Self;
type Max<D: Dim> = D;
type Layout<L: Layout> = Dense;
type Shape = [usize; 0];
type Strides = [isize; 0];
const RANK: usize = 0;
}
macro_rules! impl_dim {
(($($n:tt),*), ($($m:tt),*), ($($k:tt),*), ($($layout:ty),*)) => {
$(
impl Dim for Const<$n> {
type Higher = Const<$k>;
type Lower = Const<$m>;
type Max<D: Dim> = <<Const<$m> as Dim>::Max<D::Lower> as Dim>::Higher;
type Layout<L: Layout> = $layout;
type Shape = [usize; $n];
type Strides = [isize; $n];
const RANK: usize = $n;
}
)*
}
}
impl_dim!((1, 2, 3, 4, 5, 6), (0, 1, 2, 3, 4, 5), (2, 3, 4, 5, 6, 6), (L::Uniform, L, L, L, L, L));
macro_rules! impl_shape_strides {
($($n:tt),*) => {
$(
impl Shape for [usize; $n] {
type Dim = Const<$n>;
}
impl Strides for [isize; $n] {
type Dim = Const<$n>;
}
)*
}
}
impl_shape_strides!(0, 1, 2, 3, 4, 5, 6);