use super::{axes::*, ReduceShape, ReduceShapeTo};
pub trait Dim: 'static + Copy + Clone + std::fmt::Debug + Send + Sync + Eq + PartialEq {
fn size(&self) -> usize;
fn from_size(size: usize) -> Option<Self>;
}
pub trait ConstDim: Default + Dim {
const SIZE: usize;
}
impl Dim for usize {
#[inline(always)]
fn size(&self) -> usize {
*self
}
#[inline(always)]
fn from_size(size: usize) -> Option<Self> {
Some(size)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Const<const M: usize>;
impl<const M: usize> Dim for Const<M> {
#[inline(always)]
fn size(&self) -> usize {
M
}
#[inline(always)]
fn from_size(size: usize) -> Option<Self> {
if size == M {
Some(Const)
} else {
None
}
}
}
impl<const M: usize> ConstDim for Const<M> {
const SIZE: usize = M;
}
impl<const N: usize> core::ops::Add<Const<N>> for usize {
type Output = usize;
fn add(self, _: Const<N>) -> Self::Output {
self.size() + N
}
}
impl<const N: usize> core::ops::Add<usize> for Const<N> {
type Output = usize;
fn add(self, rhs: usize) -> Self::Output {
N + rhs.size()
}
}
#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Add<Const<N>> for Const<M>
where
Const<{ M + N }>: Sized,
{
type Output = Const<{ M + N }>;
fn add(self, _: Const<N>) -> Self::Output {
Const
}
}
impl<const N: usize> core::ops::Mul<Const<N>> for usize {
type Output = usize;
fn mul(self, _: Const<N>) -> Self::Output {
self.size() * N
}
}
impl<const N: usize> core::ops::Mul<usize> for Const<N> {
type Output = usize;
fn mul(self, rhs: usize) -> Self::Output {
N * rhs.size()
}
}
#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Mul<Const<N>> for Const<M>
where
Const<{ M * N }>: Sized,
{
type Output = Const<{ M * N }>;
fn mul(self, _: Const<N>) -> Self::Output {
Const
}
}
impl<const N: usize> core::ops::Div<Const<N>> for usize {
type Output = usize;
fn div(self, _: Const<N>) -> Self::Output {
self.size() / N
}
}
impl<const N: usize> core::ops::Div<usize> for Const<N> {
type Output = usize;
fn div(self, rhs: usize) -> Self::Output {
N / rhs.size()
}
}
#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Div<Const<N>> for Const<M>
where
Const<{ M / N }>: Sized,
{
type Output = Const<{ M / N }>;
fn div(self, _: Const<N>) -> Self::Output {
Const
}
}
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dim;
fn dim(&self) -> Self::Dim;
}
impl<T, const N: usize> Array<T> for [T; N] {
type Dim = Const<N>;
fn dim(&self) -> Self::Dim {
Const
}
}
impl<T> Array<T> for std::vec::Vec<T> {
type Dim = usize;
fn dim(&self) -> Self::Dim {
self.len()
}
}
pub trait Shape:
'static
+ std::fmt::Debug
+ Clone
+ Copy
+ Send
+ Sync
+ Eq
+ PartialEq
+ HasAxes<Self::AllAxes>
+ HasAxes<Self::LastAxis>
+ ReduceShapeTo<(), Self::AllAxes>
+ ReduceShape<Self::LastAxis>
{
const NUM_DIMS: usize;
type Concrete: std::fmt::Debug
+ Clone
+ Copy
+ Default
+ Eq
+ PartialEq
+ std::ops::Index<usize, Output = usize>
+ std::ops::IndexMut<usize>
+ Send
+ Sync
+ IntoIterator<Item = usize>
+ Into<std::vec::Vec<usize>>
+ AsRef<[usize]>;
type AllAxes: Axes;
type LastAxis: Axes;
fn concrete(&self) -> Self::Concrete;
fn from_concrete(concrete: &Self::Concrete) -> Option<Self>;
#[inline(always)]
fn num_elements(&self) -> usize {
self.concrete().into_iter().product()
}
#[inline(always)]
fn strides(&self) -> Self::Concrete {
let sizes = self.concrete();
let mut strides: Self::Concrete = Default::default();
strides[Self::NUM_DIMS - 1] = 1;
for i in (0..(Self::NUM_DIMS - 1)).rev() {
strides[i] = strides[i + 1] * sizes[i + 1];
}
strides
}
}
pub trait ConstShape: Default + Shape {
const NUMEL: usize;
}
pub trait HasShape {
type WithShape<New: Shape>: HasShape<Shape = New>;
type Shape: Shape;
fn shape(&self) -> &Self::Shape;
}
impl<S: Shape> HasShape for S {
type WithShape<New: Shape> = New;
type Shape = Self;
fn shape(&self) -> &Self::Shape {
self
}
}
pub type Rank0 = ();
pub type Rank1<const M: usize> = (Const<M>,);
pub type Rank2<const M: usize, const N: usize> = (Const<M>, Const<N>);
pub type Rank3<const M: usize, const N: usize, const O: usize> = (Const<M>, Const<N>, Const<O>);
pub type Rank4<const M: usize, const N: usize, const O: usize, const P: usize> =
(Const<M>, Const<N>, Const<O>, Const<P>);
pub type Rank5<const M: usize, const N: usize, const O: usize, const P: usize, const Q: usize> =
(Const<M>, Const<N>, Const<O>, Const<P>, Const<Q>);
#[rustfmt::skip]
pub type Rank6<const M: usize, const N: usize, const O: usize, const P: usize, const Q: usize, const R: usize> =
(Const<M>, Const<N>, Const<O>, Const<P>, Const<Q>, Const<R>);
macro_rules! shape {
(($($D:tt $Idx:tt),*), rank=$Num:expr, all=$All:tt) => {
impl<$($D: Dim, )*> Shape for ($($D, )*) {
const NUM_DIMS: usize = $Num;
type Concrete = [usize; $Num];
type AllAxes = $All<$($Idx,)*>;
type LastAxis = Axis<{$Num - 1}>;
#[inline(always)]
fn concrete(&self) -> Self::Concrete {
[$(self.$Idx.size(), )*]
}
#[inline(always)]
fn from_concrete(concrete: &Self::Concrete) -> Option<Self> {
Some(($(Dim::from_size(concrete[$Idx])?, )*))
}
}
impl<$($D: ConstDim, )*> ConstShape for ($($D, )*) {
const NUMEL: usize = $($D::SIZE * )* 1;
}
impl Shape for [usize; $Num] {
const NUM_DIMS: usize = $Num;
type Concrete = Self;
type AllAxes = $All<$($Idx,)*>;
type LastAxis = Axis<{$Num - 1}>;
fn concrete(&self) -> Self::Concrete {
*self
}
fn from_concrete(concrete: &Self::Concrete) -> Option<Self> {
Some(*concrete)
}
}
};
}
impl Shape for () {
const NUM_DIMS: usize = 0;
type Concrete = [usize; 0];
type AllAxes = Axis<0>;
type LastAxis = Axis<0>;
#[inline(always)]
fn concrete(&self) -> Self::Concrete {
[]
}
#[inline(always)]
fn strides(&self) -> Self::Concrete {
[]
}
#[inline(always)]
fn from_concrete(_: &Self::Concrete) -> Option<Self> {
Some(())
}
}
impl ConstShape for () {
const NUMEL: usize = 1;
}
shape!((D1 0), rank=1, all=Axis);
shape!((D1 0, D2 1), rank=2, all=Axes2);
shape!((D1 0, D2 1, D3 2), rank=3, all=Axes3);
shape!((D1 0, D2 1, D3 2, D4 3), rank=4, all=Axes4);
shape!((D1 0, D2 1, D3 2, D4 3, D5 4), rank=5, all=Axes5);
shape!((D1 0, D2 1, D3 2, D4 3, D5 4, D6 5), rank=6, all=Axes6);