use super::*;
pub trait IntoSpec {
type Buffer: Buffer;
fn into_spec(self) -> Spec<Self::Buffer>;
}
impl<B: Buffer> IntoSpec for B {
type Buffer = B;
fn into_spec(self) -> Spec<B> { Spec::Raw(self) }
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum Spec<B: Buffer> {
Raw(B),
Full(B::Shape, B::Field),
Diagonal(B::Shape, B::Field),
}
impl<B: Buffer> IntoSpec for Spec<B> {
type Buffer = B;
fn into_spec(self) -> Spec<B> { self }
}
impl<B: Buffer> shapes::Shaped for Spec<B> {
type Shape = B::Shape;
fn shape(&self) -> B::Shape {
match self {
&Spec::Raw(ref b) => b.shape(),
&Spec::Full(s, _) | &Spec::Diagonal(s, _) => s,
}
}
}
impl<F, B> Spec<B>
where
F: Scalar,
B: Buffer<Field = F>,
{
#[inline]
pub fn unwrap(self) -> B {
match self {
Spec::Raw(b) => b,
Spec::Full(s, x) => <B::Class as Class<B::Shape>>::full(s, x),
Spec::Diagonal(s, x) => <B::Class as Class<B::Shape>>::diagonal(s, x),
}
}
pub fn map<A: Scalar, M: Fn(F) -> A>(
self,
f: M,
) -> Spec<<B::Class as Class<B::Shape>>::Buffer<A>> {
match self {
Spec::Raw(buffer) => Spec::Raw(buffer.map(f)),
Spec::Full(shape, value) => Spec::Full(shape, f(value)),
Spec::Diagonal(shape, value) => Spec::Raw({
let zero = F::zero();
<B::Class as Class<B::Shape>>::build(shape, |ix| {
if ix.is_diagonal() {
f(value)
} else {
f(zero)
}
})
}),
}
}
pub fn zip_map<R, T, M: Fn(B::Field, R::Field) -> T>(
self,
other: &Spec<R>,
f: M,
) -> Result<Spec<B::Output<T>>, IncompatibleShapes<B::Shape, R::Shape>>
where
B: ZipMap<R>,
B::Shape: Broadcast<R::Shape>,
R: Buffer,
T: Scalar,
{
use Spec::*;
match (self, other) {
(Full(sx, fx), &Full(sy, fy)) => sx.broadcast(sy).map(|sz| Full(sz, f(fx, fy))),
(lhs, rhs) => lhs.unwrap().zip_map(&rhs.clone().unwrap(), f).map(Spec::Raw),
}
}
#[inline]
pub fn zeroes(shape: B::Shape) -> Self { Spec::Full(shape, B::Field::zero()) }
#[inline]
pub fn ones(shape: B::Shape) -> Self { Spec::Full(shape, B::Field::one()) }
#[inline]
pub fn eye(shape: B::Shape) -> Self { Spec::Diagonal(shape, B::Field::one()) }
}
impl<B: Buffer> From<B> for Spec<B> {
fn from(buf: B) -> Self { Spec::Raw(buf) }
}