arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::expressions::{
        bit_expr::ApplyBitExpr,
        conversion_expr::ApplyConversionExpr,
        curve_expr::ApplyCurveExpr,
        domain::{Domain, DomainElement},
        expr::Expr,
        field_expr::ApplyFieldExpr,
        other_expr::ApplyOtherExpr,
    },
    utils::used_field::UsedField,
};
use std::marker::PhantomData;

struct ClosureWrapper<I, O, FN: FnMut(I) -> O> {
    func: FN,
    input_marker: PhantomData<I>,
    output_marker: PhantomData<O>,
}

impl<I, O, FN: FnMut(I) -> O> ClosureWrapper<I, O, FN> {
    fn new(func: FN) -> Self {
        ClosureWrapper {
            func,
            input_marker: PhantomData,
            output_marker: PhantomData,
        }
    }
}

pub trait ApplyExpr {
    type ScalarInput: Clone;
    type BitInput: Clone;
    type BaseInput: Clone;
    type CurveInput: Clone;
    type ScalarOutput: Clone;
    type BitOutput: Clone;
    type BaseOutput: Clone;
    type CurveOutput: Clone;
    fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput;
    fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput;
    fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput;
    fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput;
    #[inline(always)]
    fn wrapped(
        &mut self,
        val: DomainElement<Self::BitInput, Self::ScalarInput, Self::BaseInput, Self::CurveInput>,
    ) -> DomainElement<Self::BitOutput, Self::ScalarOutput, Self::BaseOutput, Self::CurveOutput>
    {
        match val {
            DomainElement::Bit(x) => DomainElement::Bit(self.f_bit(x)),
            DomainElement::Scalar(x) => DomainElement::Scalar(self.f_scalar(x)),
            DomainElement::Base(x) => DomainElement::Base(self.f_base(x)),
            DomainElement::Curve(x) => DomainElement::Curve(self.f_curve(x)),
        }
    }
}

impl<T: Clone, NewT: Clone, FN: FnMut(T) -> NewT> ApplyExpr for ClosureWrapper<T, NewT, FN> {
    type ScalarInput = T;
    type BitInput = T;
    type BaseInput = T;
    type CurveInput = T;
    type ScalarOutput = NewT;
    type BitOutput = NewT;
    type BaseOutput = NewT;
    type CurveOutput = NewT;

    fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
        (self.func)(val)
    }

    fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
        (self.func)(val)
    }

    fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
        (self.func)(val)
    }

    fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
        (self.func)(val)
    }
}
impl<
        F: UsedField,
        A: Clone,
        NewA: Clone,
        Scalar: Clone,
        NewScalar: Clone,
        Bit: Clone,
        NewBit: Clone,
        Base: Clone,
        NewBase: Clone,
        Curve: Clone,
        NewCurve: Clone,
        U: ApplyExpr<
            ScalarInput = Scalar,
            ScalarOutput = NewScalar,
            BitInput = Bit,
            BitOutput = NewBit,
            BaseInput = Base,
            BaseOutput = NewBase,
            CurveInput = Curve,
            CurveOutput = NewCurve,
        >,
    > ApplyFieldExpr<F, A, NewA> for U
where
    F: Domain<Bit, Scalar, Base, Curve, DomainType = A>,
    F: Domain<NewBit, NewScalar, NewBase, NewCurve, DomainType = NewA>,
{
    fn t(&mut self, val: A) -> NewA {
        F::unwrap(self.wrapped(F::wrap(val)))
    }

    fn p(&mut self, val: A) -> NewA {
        F::unwrap(self.wrapped(F::wrap(val)))
    }

    fn c(&mut self, val: A) -> NewA {
        F::unwrap(self.wrapped(F::wrap(val)))
    }
}

impl<B: Clone, NewB: Clone, U: ApplyExpr<BitInput = B, BitOutput = NewB>> ApplyBitExpr<B, NewB>
    for U
{
    fn b(&mut self, val: B) -> NewB {
        self.f_bit(val)
    }
}
impl<
        F: UsedField,
        A: Clone,
        NewA: Clone,
        Scalar: Clone,
        NewScalar: Clone,
        Bit: Clone,
        NewBit: Clone,
        Base: Clone,
        NewBase: Clone,
        Curve: Clone,
        NewCurve: Clone,
        U: ApplyExpr<
            ScalarInput = Scalar,
            ScalarOutput = NewScalar,
            BitInput = Bit,
            BitOutput = NewBit,
            BaseInput = Base,
            BaseOutput = NewBase,
            CurveInput = Curve,
            CurveOutput = NewCurve,
        >,
    > ApplyConversionExpr<F, A, NewA, Bit, NewBit> for U
where
    F: Domain<Bit, Scalar, Base, Curve, DomainType = A>,
    F: Domain<NewBit, NewScalar, NewBase, NewCurve, DomainType = NewA>,
{
    fn t(&mut self, val: A) -> NewA {
        F::unwrap(self.wrapped(F::wrap(val)))
    }

    fn e(&mut self, val: A) -> NewA {
        F::unwrap(self.wrapped(F::wrap(val)))
    }

    fn b(&mut self, val: Bit) -> NewBit {
        self.f_bit(val)
    }
}

impl<
        C: Clone,
        NewC: Clone,
        S: Clone,
        NewS: Clone,
        U: ApplyExpr<CurveInput = C, CurveOutput = NewC, ScalarInput = S, ScalarOutput = NewS>,
    > ApplyCurveExpr<C, NewC, S, NewS> for U
{
    fn c(&mut self, val: C) -> NewC {
        self.f_curve(val)
    }

    fn s(&mut self, val: S) -> NewS {
        self.f_scalar(val)
    }
}

impl<
        S: Clone,
        NewS: Clone,
        B: Clone,
        NewB: Clone,
        C: Clone,
        NewC: Clone,
        U: ApplyExpr<
            ScalarInput = S,
            ScalarOutput = NewS,
            BaseInput = B,
            BaseOutput = NewB,
            CurveInput = C,
            CurveOutput = NewC,
        >,
    > ApplyOtherExpr<S, NewS, B, NewB, C, NewC> for U
{
    fn t(&mut self, val: S) -> NewS {
        self.f_scalar(val)
    }

    fn b(&mut self, val: B) -> NewB {
        self.f_base(val)
    }

    fn c(&mut self, val: C) -> NewC {
        self.f_curve(val)
    }
}

impl<Scalar: Clone, Bit: Clone, Base: Clone, Curve: Clone> Expr<Scalar, Bit, Base, Curve> {
    pub fn apply_2<NewScalar: Clone, NewBit: Clone, NewBase: Clone, NewCurve: Clone>(
        self,
        func: &mut impl ApplyExpr<
            ScalarInput = Scalar,
            ScalarOutput = NewScalar,
            BitInput = Bit,
            BitOutput = NewBit,
            BaseInput = Base,
            BaseOutput = NewBase,
            CurveInput = Curve,
            CurveOutput = NewCurve,
        >,
    ) -> Expr<NewScalar, NewBit, NewBase, NewCurve> {
        match self {
            Expr::Scalar(e) => Expr::Scalar(ApplyFieldExpr::apply(func, e)),
            Expr::Bit(e) => Expr::Bit(ApplyBitExpr::apply(func, e)),
            Expr::ScalarConversion(e) => {
                Expr::ScalarConversion(ApplyConversionExpr::apply(func, e))
            }
            Expr::Base(e) => Expr::Base(ApplyFieldExpr::apply(func, e)),
            Expr::BaseConversion(e) => Expr::BaseConversion(ApplyConversionExpr::apply(func, e)),
            Expr::Curve(e) => Expr::Curve(ApplyCurveExpr::apply(func, e)),
            Expr::Other(e) => Expr::Other(ApplyOtherExpr::apply(func, e)),
        }
    }
}

impl<T: Clone> Expr<T> {
    pub fn apply<NewT: Clone>(self, func: impl FnMut(T) -> NewT) -> Expr<NewT, NewT> {
        self.apply_2(&mut ClosureWrapper::new(func))
    }
    pub fn proc(&self, func: impl FnMut(&T)) {
        match self {
            Expr::Scalar(e) => e.proc(func),
            Expr::Bit(e) => e.proc(func),
            Expr::ScalarConversion(e) => e.proc(func),
            Expr::Base(e) => e.proc(func),
            Expr::BaseConversion(e) => e.proc(func),
            Expr::Curve(e) => e.proc(func),
            Expr::Other(e) => e.proc(func),
        }
    }
    pub fn get_deps(&self) -> Vec<T> {
        let mut v: Vec<T> = Vec::new();
        self.proc(|x| v.push(x.clone()));
        v
    }
    pub fn n_deps(&self) -> usize {
        let mut res = 0usize;
        self.proc(|_| res += 1usize);
        res
    }
}