use concat_arrays::concat_arrays;
use std::fmt::{Debug, Display};
#[derive(Copy, Clone, Debug)]
pub struct IncompatibleShapes<L: Shape, R: Shape = L> {
pub left: L,
pub right: R,
}
impl<L: Shape, R: Shape> IncompatibleShapes<L, R> {
pub fn reverse(self) -> IncompatibleShapes<R, L> {
IncompatibleShapes {
left: self.right,
right: self.left,
}
}
}
impl<L: Shape, R: Shape> std::fmt::Display for IncompatibleShapes<L, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Buffer shapes are incompatible: {} vs {}.",
self.left, self.right
)
}
}
impl<L: Shape, R: Shape> std::error::Error for IncompatibleShapes<L, R> {}
pub trait Ix: Eq + Copy + Debug {
fn is_diagonal(&self) -> bool;
}
impl Ix for () {
fn is_diagonal(&self) -> bool { true }
}
impl Ix for usize {
fn is_diagonal(&self) -> bool { true }
}
impl<const DIM: usize> Ix for [usize; DIM] {
fn is_diagonal(&self) -> bool {
let mut it = self.iter();
let first = it.next();
match first {
None => true,
Some(ix) => it.all(|jx| ix == jx),
}
}
}
pub trait Shaped {
type Shape: Shape;
fn shape(&self) -> Self::Shape;
}
pub type ShapeOf<B> = <B as Shaped>::Shape;
pub trait Shape: Copy + Debug + Display {
const DIM: usize;
type Index: Ix;
type IndexIter: Iterator<Item = Self::Index>;
fn contains(&self, ix: Self::Index) -> bool;
fn cardinality(&self) -> usize;
fn indices(&self) -> Self::IndexIter;
fn is_scalar(&self) -> bool { Self::DIM == 0 }
fn is_vector(&self) -> bool { Self::DIM == 1 }
fn is_matrix(&self) -> bool { Self::DIM == 2 }
fn is_equivalent<S: Shape>(&self, other: &S) -> bool {
self.cardinality() == other.cardinality()
}
}
pub type IndexOf<S> = <S as Shape>::Index;
pub trait Split: Shape + Sized {
type Left: Concat<Self::Right, Shape = Self>;
type Right: Shape;
fn split(self) -> (Self::Left, Self::Right);
fn split_index(index: Self::Index) -> (IndexOf<Self::Left>, IndexOf<Self::Right>);
}
pub trait Concat<RHS: Shape = Self>: Shape {
type Shape: Shape;
fn concat(self, rhs: RHS) -> Self::Shape;
fn concat_indices(left: Self::Index, rhs: RHS::Index) -> IndexOf<Self::Shape>;
}
pub type CShape<X, Y> = <X as Concat<Y>>::Shape;
pub trait Broadcast<RHS: Shape = Self>: Shape {
type Shape: Shape;
fn broadcast(self, rhs: RHS) -> Result<Self::Shape, IncompatibleShapes<Self, RHS>>;
}
pub type BShape<X, Y> = <X as Broadcast<Y>>::Shape;
mod multi_product;
mod runtime;
pub use self::runtime::*;
mod compiled;
pub use self::compiled::*;