use crate::{
core::{
bounds::{BoolBounds, Bounds, CurveBounds, FieldBounds, IsBounds},
expressions::{domain::Domain, expr::EvalValue, expr_macros::ApplyExpr},
},
utils::{
curve_point::CurvePoint,
field::{BaseField, ScalarField},
},
};
#[cfg(test)]
use rand::Rng;
use std::marker::PhantomData;
pub struct EvalValueUnwrap;
impl ApplyExpr for EvalValueUnwrap {
type ScalarInput = EvalValue;
type BitInput = EvalValue;
type BaseInput = EvalValue;
type CurveInput = EvalValue;
type ScalarOutput = ScalarField;
type BitOutput = bool;
type BaseOutput = BaseField;
type CurveOutput = CurvePoint;
fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
ScalarField::unwrap(val)
}
fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
bool::unwrap(val)
}
fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
BaseField::unwrap(val)
}
fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
CurvePoint::unwrap(val)
}
}
pub struct BoundUnFold;
impl ApplyExpr for BoundUnFold {
type ScalarInput = Bounds;
type BitInput = Bounds;
type BaseInput = Bounds;
type CurveInput = Bounds;
type ScalarOutput = FieldBounds<ScalarField>;
type BitOutput = BoolBounds;
type BaseOutput = FieldBounds<BaseField>;
type CurveOutput = CurveBounds;
fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
if let Bounds::Scalar(b) = val {
b
} else {
panic!("These bounds must be scalar.")
}
}
fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
if let Bounds::Bit(b) = val {
b
} else {
panic!("These bounds must be boolean.")
}
}
fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
BaseField::unwrap(val)
}
fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
if let Bounds::Curve(b) = val {
b
} else {
panic!("These bounds must be curve.")
}
}
}
pub struct BoundWrap;
impl ApplyExpr for BoundWrap {
type ScalarInput = FieldBounds<ScalarField>;
type BitInput = BoolBounds;
type BaseInput = FieldBounds<BaseField>;
type CurveInput = CurveBounds;
type ScalarOutput = Bounds;
type BitOutput = Bounds;
type BaseOutput = Bounds;
type CurveOutput = Bounds;
fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
Bounds::Scalar(val)
}
fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
Bounds::Bit(val)
}
fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
BaseField::wrap(val)
}
fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
Bounds::Curve(val)
}
}
pub struct ConstantBoundSampler;
impl ApplyExpr for ConstantBoundSampler {
type ScalarInput = FieldBounds<ScalarField>;
type BitInput = BoolBounds;
type BaseInput = FieldBounds<BaseField>;
type CurveInput = CurveBounds;
type ScalarOutput = ScalarField;
type BitOutput = bool;
type BaseOutput = BaseField;
type CurveOutput = CurvePoint;
fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
val.as_constant().expect("Scalar bounds were not constant.")
}
fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
val.as_constant().expect("Bit bounds were not constant.")
}
fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
val.as_constant().expect("Base bounds were not constant.")
}
fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
val.as_constant().expect("Curve bounds were not constant.")
}
}
#[derive(Debug)]
pub struct DefaultFiller<T, U, V, W> {
phantom: PhantomData<(T, U, V, W)>,
}
impl<T: Clone, U: Clone, V: Clone, W: Clone> Default for DefaultFiller<T, U, V, W> {
fn default() -> Self {
Self {
phantom: PhantomData,
}
}
}
impl<T: Clone, U: Clone, V: Clone, W: Clone> ApplyExpr for DefaultFiller<T, U, V, W> {
type ScalarInput = U;
type BitInput = T;
type BaseInput = V;
type CurveInput = W;
type ScalarOutput = ScalarField;
type BitOutput = bool;
type BaseOutput = BaseField;
type CurveOutput = CurvePoint;
fn f_scalar(&mut self, _: Self::ScalarInput) -> Self::ScalarOutput {
Default::default()
}
fn f_bit(&mut self, _: Self::BitInput) -> Self::BitOutput {
Default::default()
}
fn f_base(&mut self, _: Self::BaseInput) -> Self::BaseOutput {
Default::default()
}
fn f_curve(&mut self, _: Self::CurveInput) -> Self::CurveOutput {
Default::default()
}
}
#[cfg(test)]
pub struct BoundSampler<'a, R: Rng + ?Sized>(pub &'a mut R);
#[cfg(test)]
impl<R: Rng + ?Sized> ApplyExpr for BoundSampler<'_, R> {
type ScalarInput = FieldBounds<ScalarField>;
type BitInput = BoolBounds;
type BaseInput = FieldBounds<BaseField>;
type CurveInput = CurveBounds;
type ScalarOutput = ScalarField;
type BitOutput = bool;
type BaseOutput = BaseField;
type CurveOutput = CurvePoint;
fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
val.sample(self.0)
}
fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
val.sample(self.0)
}
fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
val.sample(self.0)
}
fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
val.sample(self.0)
}
}