mod axes;
mod broadcast;
mod permute;
mod realize;
mod slice;
pub mod symbolic;
pub mod tracker;
pub use realize::*;
pub use slice::*;
pub use axes::*;
pub use broadcast::*;
pub use permute::*;
pub use tracker::*;
use self::symbolic::Expression;
pub trait Dimension:
'static + Copy + Clone + std::fmt::Debug + Send + Sync + Eq + PartialEq
{
fn const_size() -> Expression;
}
pub trait ConstDim: Default + Dimension {
const SIZE: usize;
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Dyn<const C: char>;
impl<const C: char> Dimension for Dyn<C> {
fn const_size() -> Expression {
C.into()
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Const<const M: usize>;
impl<const M: usize> Dimension for Const<M> {
fn const_size() -> Expression {
M.into()
}
}
impl<const M: usize> ConstDim for Const<M> {
const SIZE: usize = M;
}
impl<const N: usize, const C: char> core::ops::Add<Const<N>> for Dyn<C> {
type Output = Dyn<C>;
fn add(self, _: Const<N>) -> Self::Output {
todo!();
}
}
impl<const N: usize, const C: char> core::ops::Add<Dyn<C>> for Const<N> {
type Output = Dyn<C>;
fn add(self, _: Dyn<C>) -> Self::Output {
todo!();
}
}
impl<const N: usize, const C: char> core::ops::Mul<Const<N>> for Dyn<C> {
type Output = Dyn<C>;
fn mul(self, _: Const<N>) -> Self::Output {
todo!();
}
}
impl<const N: usize, const C: char> core::ops::Mul<Dyn<C>> for Const<N> {
type Output = Dyn<C>;
fn mul(self, _: Dyn<C>) -> Self::Output {
todo!();
}
}
impl<const N: usize, const C: char> core::ops::Div<Const<N>> for Dyn<C> {
type Output = Dyn<C>;
fn div(self, _: Const<N>) -> Self::Output {
todo!();
}
}
impl<const N: usize, const C: char> core::ops::Div<Dyn<C>> for Const<N> {
type Output = Dyn<C>;
fn div(self, _: Dyn<C>) -> Self::Output {
todo!();
}
}
impl<const A: char, const C: char> core::ops::Add<Dyn<A>> for Dyn<C> {
type Output = Dyn<'-'>;
fn add(self, _: Dyn<A>) -> Self::Output {
todo!();
}
}
impl<const A: char, const C: char> core::ops::Mul<Dyn<A>> for Dyn<C> {
type Output = Dyn<'-'>;
fn mul(self, _: Dyn<A>) -> Self::Output {
todo!();
}
}
impl<const A: char, const C: char> core::ops::Div<Dyn<A>> for Dyn<C> {
type Output = Dyn<'-'>;
fn div(self, _: Dyn<A>) -> Self::Output {
todo!();
}
}
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dimension;
fn dim(&self) -> Self::Dim;
}
impl<T, const N: usize> Array<T> for [T; N] {
type Dim = Const<N>;
fn dim(&self) -> Self::Dim {
Const
}
}
impl<T> Array<T> for std::vec::Vec<T> {
type Dim = Dyn<'-'>;
fn dim(&self) -> Self::Dim {
Dyn::<'-'>
}
}
pub trait Shape:
'static
+ std::fmt::Debug
+ Clone
+ Copy
+ Send
+ Sync
+ Eq
+ PartialEq
+ HasAxes<Self::AllAxes>
+ HasAxes<Self::LastAxis>
+ ReduceShapeTo<(), Self::AllAxes>
+ ReduceShape<Self::LastAxis>
{
const NUM_DIMS: usize;
type Concrete: std::fmt::Debug
+ Clone
+ Copy
+ Default
+ Eq
+ PartialEq
+ std::ops::Index<usize, Output = usize>
+ std::ops::IndexMut<usize>
+ Send
+ Sync
+ IntoIterator<Item = usize>
+ Into<std::vec::Vec<usize>>
+ AsRef<[usize]>;
type AllAxes: Axes;
type LastAxis: Axes;
fn realized_shape() -> Vec<Expression>;
fn to_tracker() -> crate::core::shape::tracker::ShapeTracker;
}
pub trait ConstShape: Default + Shape {
const NUMEL: usize;
fn realized_shape() -> Vec<usize>;
}
pub trait HasShape {
type WithShape<New: Shape>: HasShape<Shape = New>;
type Shape: Shape;
fn shape(&self) -> &Self::Shape;
}
impl<S: Shape> HasShape for S {
type WithShape<New: Shape> = New;
type Shape = Self;
fn shape(&self) -> &Self::Shape {
self
}
}
pub type R0 = ();
pub type R1<const M: usize> = (Const<M>,);
pub type R2<const M: usize, const N: usize> = (Const<M>, Const<N>);
pub type R3<const M: usize, const N: usize, const O: usize> = (Const<M>, Const<N>, Const<O>);
pub type R4<const M: usize, const N: usize, const O: usize, const P: usize> =
(Const<M>, Const<N>, Const<O>, Const<P>);
pub type R5<const M: usize, const N: usize, const O: usize, const P: usize, const Q: usize> =
(Const<M>, Const<N>, Const<O>, Const<P>, Const<Q>);
#[rustfmt::skip]
pub type R6<const M: usize, const N: usize, const O: usize, const P: usize, const Q: usize, const R: usize> =
(Const<M>, Const<N>, Const<O>, Const<P>, Const<Q>, Const<R>);
macro_rules! shape {
(($($D:tt $Idx:tt),*), rank=$Num:expr, all=$All:tt) => {
impl<$($D: Dimension, )*> Shape for ($($D, )*) {
const NUM_DIMS: usize = $Num;
type Concrete = [usize; $Num];
type AllAxes = $All<$($Idx,)*>;
type LastAxis = Axis<{$Num - 1}>;
fn realized_shape() -> Vec<crate::prelude::symbolic::Expression> {
vec![$($D::const_size(), )*]
}
fn to_tracker() -> ShapeTracker {
ShapeTracker::new(&Self::realized_shape())
}
}
impl<$($D: ConstDim, )*> ConstShape for ($($D, )*) {
const NUMEL: usize = $($D::SIZE * )* 1;
fn realized_shape() -> Vec<usize> {
vec![$($D::SIZE , )*]
}
}
impl Shape for [usize; $Num] {
const NUM_DIMS: usize = $Num;
type Concrete = Self;
type AllAxes = $All<$($Idx,)*>;
type LastAxis = Axis<{$Num - 1}>;
fn realized_shape() -> Vec<crate::prelude::symbolic::Expression> {
vec!['-'.into(); $Num]
}
fn to_tracker() -> ShapeTracker {
let st = ShapeTracker::new(&Self::realized_shape());
st
}
}
};
}
impl Shape for () {
const NUM_DIMS: usize = 0;
type Concrete = [usize; 0];
type AllAxes = Axis<0>;
type LastAxis = Axis<0>;
fn realized_shape() -> Vec<Expression> {
vec![]
}
fn to_tracker() -> ShapeTracker {
ShapeTracker::new(&[])
}
}
impl ConstShape for () {
const NUMEL: usize = 1;
fn realized_shape() -> Vec<usize> {
vec![]
}
}
shape!((D1 0), rank=1, all=Axis);
shape!((D1 0, D2 1), rank=2, all=Axes2);
shape!((D1 0, D2 1, D3 2), rank=3, all=Axes3);
shape!((D1 0, D2 1, D3 2, D4 3), rank=4, all=Axes4);
shape!((D1 0, D2 1, D3 2, D4 3, D5 4), rank=5, all=Axes5);
shape!((D1 0, D2 1, D3 2, D4 3, D5 4, D6 5), rank=6, all=Axes6);
pub trait AssertSameNumel<Dst: ConstShape>: ConstShape {
const TYPE_CHECK: ();
fn assert_same_numel() {
#[allow(clippy::let_unit_value)]
let _ = <Self as AssertSameNumel<Dst>>::TYPE_CHECK;
}
}
impl<Src: ConstShape, Dst: ConstShape> AssertSameNumel<Dst> for Src {
const TYPE_CHECK: () = assert!(Src::NUMEL == Dst::NUMEL);
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReshapeDim {
Const(usize),
PrevDim(usize),
}