pub mod shapes;
use shapes::{Ix, Shape, IncompatibleShapes, Broadcast, BShape};
pub trait Class<S: Shape> {
type Buffer<F: Scalar>: Buffer<Class = Self, Shape = S, Field = F>;
fn build<F: Scalar>(shape: S, f: impl Fn(S::Index) -> F) -> Self::Buffer<F>;
fn build_subset<F: Scalar>(
shape: S,
base: F,
subset: impl Iterator<Item = shapes::IndexOf<S>>,
active: impl Fn(S::Index) -> F,
) -> Self::Buffer<F>;
fn full<F: Scalar>(shape: S, value: F) -> Self::Buffer<F> { Self::build(shape, |_| value) }
fn zeroes<F: Scalar>(shape: S) -> Self::Buffer<F> { Self::full(shape, F::zero()) }
fn ones<F: Scalar>(shape: S) -> Self::Buffer<F> { Self::full(shape, F::one()) }
fn diagonal<F: Scalar>(shape: S, value: F) -> Self::Buffer<F> {
Self::build(shape, |ijk: S::Index| {
if ijk.is_diagonal() {
value
} else {
num_traits::zero()
}
})
}
fn identity<F: Scalar>(shape: S) -> Self::Buffer<F> {
Self::diagonal(shape, num_traits::one())
}
}
pub type BufferOf<C, S, F> = <C as Class<S>>::Buffer<F>;
pub trait Buffer: Clone + shapes::Shaped + IntoSpec<Buffer = Self> {
type Class: Class<Self::Shape, Buffer<Self::Field> = Self>;
type Field: Scalar;
fn class() -> Self::Class;
fn get(&self, ix: shapes::IndexOf<Self::Shape>) -> Option<Self::Field> {
if self.shape().contains(ix) {
Some(self.get_unchecked(ix))
} else {
None
}
}
fn get_unchecked(&self, ix: shapes::IndexOf<Self::Shape>) -> Self::Field;
fn map<F: Scalar, M: Fn(Self::Field) -> F>(
self,
f: M,
) -> <Self::Class as Class<Self::Shape>>::Buffer<F>;
fn map_ref<F: Scalar, M: Fn(Self::Field) -> F>(
&self,
f: M,
) -> <Self::Class as Class<Self::Shape>>::Buffer<F> {
<Self::Class as Class<Self::Shape>>::build(self.shape(), |ix| f(self.get_unchecked(ix)))
}
fn mutate<M: Fn(Self::Field) -> Self::Field>(&mut self, f: M);
fn fold<F, M: Fn(F, Self::Field) -> F>(&self, init: F, f: M) -> F {
self.shape()
.indices()
.fold(init, |acc, ix| f(acc, self.get_unchecked(ix)))
}
fn sum(&self) -> Self::Field { self.fold(num_traits::zero(), |init, el| init + el) }
fn into_constant(self) -> crate::meta::Constant<Self> { crate::meta::Constant(self) }
}
pub type FieldOf<B> = <B as Buffer>::Field;
pub type ClassOf<B> = <B as Buffer>::Class;
pub trait ZipFold<RHS: Buffer = Self>: Buffer {
fn zip_fold<F: Scalar, M: Fn(F, (Self::Field, RHS::Field)) -> F>(
&self,
rhs: &RHS,
init: F,
f: M,
) -> Result<F, IncompatibleShapes<Self::Shape, RHS::Shape>>;
}
pub trait ZipMap<RHS: Buffer = Self>: Buffer
where
Self::Shape: Broadcast<RHS::Shape>,
{
type Output<F: Scalar>: Buffer<Field = F, Shape = BShape<Self::Shape, RHS::Shape>>;
fn zip_map<F: Scalar, M: Fn(Self::Field, RHS::Field) -> F>(
self,
rhs: &RHS,
f: M,
) -> Result<Self::Output<F>, IncompatibleShapes<Self::Shape, RHS::Shape>>;
#[inline]
fn zip_map_id<M: Fn(Self::Field, RHS::Field) -> Self::Field>(
self,
rhs: &RHS,
f: M,
) -> Result<Self::Output<Self::Field>, IncompatibleShapes<Self::Shape, RHS::Shape>> {
self.zip_map(rhs, f)
}
fn zip_shape(self, rshape: RHS::Shape) -> Result<Self::Output<Self::Field>, IncompatibleShapes<Self::Shape, RHS::Shape>>;
}
pub trait ZipMut<RHS: Buffer = Self>: Buffer {
fn zip_mut<M: Fn(Self::Field, RHS::Field) -> Self::Field>(
&mut self,
rhs: &RHS,
f: M,
) -> Result<(), IncompatibleShapes<Self::Shape, RHS::Shape>>;
}
pub trait ZipOps<RHS: Buffer>:
ZipFold<RHS>
+ ZipMap<RHS>
+ ZipMut<RHS>
where
Self::Shape: Broadcast<RHS::Shape>,
{}
impl<A, B> ZipOps<B> for A
where
A: ZipFold<B> + ZipMap<B> + ZipMut<B>,
B: Buffer,
A::Shape: Broadcast<B::Shape>,
{}
pub trait Contract<RHS: Buffer<Field = Self::Field>, const AXES: usize = 1>: Buffer {
type Output: Buffer<Field = Self::Field>;
fn contract(
self,
rhs: RHS,
) -> Result<Self::Output, IncompatibleShapes<Self::Shape, RHS::Shape>>;
fn contract_spec(
lhs: Spec<Self>,
rhs: Spec<RHS>,
) -> Result<Spec<Self::Output>, IncompatibleShapes<Self::Shape, RHS::Shape>>;
fn contract_shape(
lhs: shapes::ShapeOf<Self>,
rhs: shapes::ShapeOf<RHS>,
) -> Result<shapes::ShapeOf<Self::Output>, IncompatibleShapes<Self::Shape, RHS::Shape>>;
}
pub fn contract<const AXES: usize, X, Y>(x: X, y: Y) -> Result<X::Output, IncompatibleShapes<X::Shape, Y::Shape>>
where
X: Contract<Y, AXES>,
Y: Buffer<Field = X::Field>,
{
x.contract(y)
}
pub fn contract_spec<const AXES: usize, X, Y>(x: Spec<X>, y: Spec<Y>) -> Result<Spec<X::Output>, IncompatibleShapes<X::Shape, Y::Shape>>
where
X: Contract<Y, AXES>,
Y: Buffer<Field = X::Field>,
{
<X as Contract<Y, AXES>>::contract_spec(x, y)
}
pub fn contract_shape<const AXES: usize, X, Y>(x: X::Shape, y: Y::Shape) -> Result<shapes::ShapeOf<X::Output>, IncompatibleShapes<X::Shape, Y::Shape>>
where
X: Contract<Y, AXES>,
Y: Buffer<Field = X::Field>,
{
<X as Contract<Y, AXES>>::contract_shape(x, y)
}
mod scalars;
pub use scalars::*;
mod tuples;
pub use tuples::*;
mod vecs;
pub use vecs::*;
mod arrays;
pub use arrays::*;
mod spec;
pub use spec::{IntoSpec, Spec};