use crate::{
core::expressions::{
bit_expr::ApplyBitExpr,
conversion_expr::ApplyConversionExpr,
curve_expr::ApplyCurveExpr,
domain::{Domain, DomainElement},
expr::Expr,
field_expr::ApplyFieldExpr,
other_expr::ApplyOtherExpr,
},
utils::used_field::UsedField,
};
use std::marker::PhantomData;
struct ClosureWrapper<I, O, FN: FnMut(I) -> O> {
func: FN,
input_marker: PhantomData<I>,
output_marker: PhantomData<O>,
}
impl<I, O, FN: FnMut(I) -> O> ClosureWrapper<I, O, FN> {
fn new(func: FN) -> Self {
ClosureWrapper {
func,
input_marker: PhantomData,
output_marker: PhantomData,
}
}
}
pub trait ApplyExpr {
type ScalarInput: Clone;
type BitInput: Clone;
type BaseInput: Clone;
type CurveInput: Clone;
type ScalarOutput: Clone;
type BitOutput: Clone;
type BaseOutput: Clone;
type CurveOutput: Clone;
fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput;
fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput;
fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput;
fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput;
#[inline(always)]
fn wrapped(
&mut self,
val: DomainElement<Self::BitInput, Self::ScalarInput, Self::BaseInput, Self::CurveInput>,
) -> DomainElement<Self::BitOutput, Self::ScalarOutput, Self::BaseOutput, Self::CurveOutput>
{
match val {
DomainElement::Bit(x) => DomainElement::Bit(self.f_bit(x)),
DomainElement::Scalar(x) => DomainElement::Scalar(self.f_scalar(x)),
DomainElement::Base(x) => DomainElement::Base(self.f_base(x)),
DomainElement::Curve(x) => DomainElement::Curve(self.f_curve(x)),
}
}
}
impl<T: Clone, NewT: Clone, FN: FnMut(T) -> NewT> ApplyExpr for ClosureWrapper<T, NewT, FN> {
type ScalarInput = T;
type BitInput = T;
type BaseInput = T;
type CurveInput = T;
type ScalarOutput = NewT;
type BitOutput = NewT;
type BaseOutput = NewT;
type CurveOutput = NewT;
fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
(self.func)(val)
}
fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
(self.func)(val)
}
fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
(self.func)(val)
}
fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
(self.func)(val)
}
}
impl<
F: UsedField,
A: Clone,
NewA: Clone,
Scalar: Clone,
NewScalar: Clone,
Bit: Clone,
NewBit: Clone,
Base: Clone,
NewBase: Clone,
Curve: Clone,
NewCurve: Clone,
U: ApplyExpr<
ScalarInput = Scalar,
ScalarOutput = NewScalar,
BitInput = Bit,
BitOutput = NewBit,
BaseInput = Base,
BaseOutput = NewBase,
CurveInput = Curve,
CurveOutput = NewCurve,
>,
> ApplyFieldExpr<F, A, NewA> for U
where
F: Domain<Bit, Scalar, Base, Curve, DomainType = A>,
F: Domain<NewBit, NewScalar, NewBase, NewCurve, DomainType = NewA>,
{
fn t(&mut self, val: A) -> NewA {
F::unwrap(self.wrapped(F::wrap(val)))
}
fn p(&mut self, val: A) -> NewA {
F::unwrap(self.wrapped(F::wrap(val)))
}
fn c(&mut self, val: A) -> NewA {
F::unwrap(self.wrapped(F::wrap(val)))
}
}
impl<B: Clone, NewB: Clone, U: ApplyExpr<BitInput = B, BitOutput = NewB>> ApplyBitExpr<B, NewB>
for U
{
fn b(&mut self, val: B) -> NewB {
self.f_bit(val)
}
}
impl<
F: UsedField,
A: Clone,
NewA: Clone,
Scalar: Clone,
NewScalar: Clone,
Bit: Clone,
NewBit: Clone,
Base: Clone,
NewBase: Clone,
Curve: Clone,
NewCurve: Clone,
U: ApplyExpr<
ScalarInput = Scalar,
ScalarOutput = NewScalar,
BitInput = Bit,
BitOutput = NewBit,
BaseInput = Base,
BaseOutput = NewBase,
CurveInput = Curve,
CurveOutput = NewCurve,
>,
> ApplyConversionExpr<F, A, NewA, Bit, NewBit> for U
where
F: Domain<Bit, Scalar, Base, Curve, DomainType = A>,
F: Domain<NewBit, NewScalar, NewBase, NewCurve, DomainType = NewA>,
{
fn t(&mut self, val: A) -> NewA {
F::unwrap(self.wrapped(F::wrap(val)))
}
fn e(&mut self, val: A) -> NewA {
F::unwrap(self.wrapped(F::wrap(val)))
}
fn b(&mut self, val: Bit) -> NewBit {
self.f_bit(val)
}
}
impl<
C: Clone,
NewC: Clone,
S: Clone,
NewS: Clone,
U: ApplyExpr<CurveInput = C, CurveOutput = NewC, ScalarInput = S, ScalarOutput = NewS>,
> ApplyCurveExpr<C, NewC, S, NewS> for U
{
fn c(&mut self, val: C) -> NewC {
self.f_curve(val)
}
fn s(&mut self, val: S) -> NewS {
self.f_scalar(val)
}
}
impl<
S: Clone,
NewS: Clone,
B: Clone,
NewB: Clone,
C: Clone,
NewC: Clone,
U: ApplyExpr<
ScalarInput = S,
ScalarOutput = NewS,
BaseInput = B,
BaseOutput = NewB,
CurveInput = C,
CurveOutput = NewC,
>,
> ApplyOtherExpr<S, NewS, B, NewB, C, NewC> for U
{
fn t(&mut self, val: S) -> NewS {
self.f_scalar(val)
}
fn b(&mut self, val: B) -> NewB {
self.f_base(val)
}
fn c(&mut self, val: C) -> NewC {
self.f_curve(val)
}
}
impl<Scalar: Clone, Bit: Clone, Base: Clone, Curve: Clone> Expr<Scalar, Bit, Base, Curve> {
pub fn apply_2<NewScalar: Clone, NewBit: Clone, NewBase: Clone, NewCurve: Clone>(
self,
func: &mut impl ApplyExpr<
ScalarInput = Scalar,
ScalarOutput = NewScalar,
BitInput = Bit,
BitOutput = NewBit,
BaseInput = Base,
BaseOutput = NewBase,
CurveInput = Curve,
CurveOutput = NewCurve,
>,
) -> Expr<NewScalar, NewBit, NewBase, NewCurve> {
match self {
Expr::Scalar(e) => Expr::Scalar(ApplyFieldExpr::apply(func, e)),
Expr::Bit(e) => Expr::Bit(ApplyBitExpr::apply(func, e)),
Expr::ScalarConversion(e) => {
Expr::ScalarConversion(ApplyConversionExpr::apply(func, e))
}
Expr::Base(e) => Expr::Base(ApplyFieldExpr::apply(func, e)),
Expr::BaseConversion(e) => Expr::BaseConversion(ApplyConversionExpr::apply(func, e)),
Expr::Curve(e) => Expr::Curve(ApplyCurveExpr::apply(func, e)),
Expr::Other(e) => Expr::Other(ApplyOtherExpr::apply(func, e)),
}
}
}
impl<T: Clone> Expr<T> {
pub fn apply<NewT: Clone>(self, func: impl FnMut(T) -> NewT) -> Expr<NewT, NewT> {
self.apply_2(&mut ClosureWrapper::new(func))
}
pub fn proc(&self, func: impl FnMut(&T)) {
match self {
Expr::Scalar(e) => e.proc(func),
Expr::Bit(e) => e.proc(func),
Expr::ScalarConversion(e) => e.proc(func),
Expr::Base(e) => e.proc(func),
Expr::BaseConversion(e) => e.proc(func),
Expr::Curve(e) => e.proc(func),
Expr::Other(e) => e.proc(func),
}
}
pub fn get_deps(&self) -> Vec<T> {
let mut v: Vec<T> = Vec::new();
self.proc(|x| v.push(x.clone()));
v
}
pub fn n_deps(&self) -> usize {
let mut res = 0usize;
self.proc(|_| res += 1usize);
res
}
}