arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::{BoolBounds, Bounds, CurveBounds, FieldBounds, IsBounds},
        expressions::{domain::Domain, expr::EvalValue, expr_macros::ApplyExpr},
    },
    utils::{
        curve_point::CurvePoint,
        field::{BaseField, ScalarField},
    },
};
#[cfg(test)]
use rand::Rng;
use std::marker::PhantomData;

/// A struct to convert `Expr<EvalValue>` into `Expr<Fp, bool>` by unwrapping every EvalValue.
/// For instance, it will turn
/// `Expr::Scalar(FieldExpr::Add(EvalValue::Scalar(Fp::ONE), EvalValue::Scalar(Fp::ZERO))`
/// into
/// `Expr::Scalar(FieldExpr::Add(Fp::ONE, Fp::ZERO)`
pub struct EvalValueUnwrap;

impl ApplyExpr for EvalValueUnwrap {
    type ScalarInput = EvalValue;
    type BitInput = EvalValue;
    type BaseInput = EvalValue;
    type CurveInput = EvalValue;
    type ScalarOutput = ScalarField;
    type BitOutput = bool;
    type BaseOutput = BaseField;
    type CurveOutput = CurvePoint;

    fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
        ScalarField::unwrap(val)
    }

    fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
        bool::unwrap(val)
    }

    fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
        BaseField::unwrap(val)
    }

    fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
        CurvePoint::unwrap(val)
    }
}

/// A struct to transform `Expr<Bounds>` into `Expr<FieldBounds<Fp>, BoolBounds>`.
pub struct BoundUnFold;

impl ApplyExpr for BoundUnFold {
    type ScalarInput = Bounds;
    type BitInput = Bounds;
    type BaseInput = Bounds;
    type CurveInput = Bounds;
    type ScalarOutput = FieldBounds<ScalarField>;
    type BitOutput = BoolBounds;
    type BaseOutput = FieldBounds<BaseField>;
    type CurveOutput = CurveBounds;

    fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
        if let Bounds::Scalar(b) = val {
            b
        } else {
            panic!("These bounds must be scalar.")
        }
    }

    fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
        if let Bounds::Bit(b) = val {
            b
        } else {
            panic!("These bounds must be boolean.")
        }
    }

    fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
        BaseField::unwrap(val)
    }

    fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
        if let Bounds::Curve(b) = val {
            b
        } else {
            panic!("These bounds must be curve.")
        }
    }
}

/// A struct to transform `Expr<FieldBounds<Fp>, BoolBounds>` into `Expr<Bounds>`.
pub struct BoundWrap;

impl ApplyExpr for BoundWrap {
    type ScalarInput = FieldBounds<ScalarField>;
    type BitInput = BoolBounds;
    type BaseInput = FieldBounds<BaseField>;
    type CurveInput = CurveBounds;
    type ScalarOutput = Bounds;
    type BitOutput = Bounds;
    type BaseOutput = Bounds;
    type CurveOutput = Bounds;

    fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
        Bounds::Scalar(val)
    }

    fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
        Bounds::Bit(val)
    }

    fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
        BaseField::wrap(val)
    }

    fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
        Bounds::Curve(val)
    }
}

/// Applies .as_constant().expect("...") on all bounds
pub struct ConstantBoundSampler;

impl ApplyExpr for ConstantBoundSampler {
    type ScalarInput = FieldBounds<ScalarField>;
    type BitInput = BoolBounds;
    type BaseInput = FieldBounds<BaseField>;
    type CurveInput = CurveBounds;
    type ScalarOutput = ScalarField;
    type BitOutput = bool;
    type BaseOutput = BaseField;
    type CurveOutput = CurvePoint;

    fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
        val.as_constant().expect("Scalar bounds were not constant.")
    }

    fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
        val.as_constant().expect("Bit bounds were not constant.")
    }

    fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
        val.as_constant().expect("Base bounds were not constant.")
    }

    fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
        val.as_constant().expect("Curve bounds were not constant.")
    }
}

#[derive(Debug)]
pub struct DefaultFiller<T, U, V, W> {
    phantom: PhantomData<(T, U, V, W)>,
}

impl<T: Clone, U: Clone, V: Clone, W: Clone> Default for DefaultFiller<T, U, V, W> {
    fn default() -> Self {
        Self {
            phantom: PhantomData,
        }
    }
}

impl<T: Clone, U: Clone, V: Clone, W: Clone> ApplyExpr for DefaultFiller<T, U, V, W> {
    type ScalarInput = U;
    type BitInput = T;
    type BaseInput = V;
    type CurveInput = W;
    type ScalarOutput = ScalarField;
    type BitOutput = bool;
    type BaseOutput = BaseField;
    type CurveOutput = CurvePoint;

    fn f_scalar(&mut self, _: Self::ScalarInput) -> Self::ScalarOutput {
        Default::default()
    }

    fn f_bit(&mut self, _: Self::BitInput) -> Self::BitOutput {
        Default::default()
    }

    fn f_base(&mut self, _: Self::BaseInput) -> Self::BaseOutput {
        Default::default()
    }

    fn f_curve(&mut self, _: Self::CurveInput) -> Self::CurveOutput {
        Default::default()
    }
}

/// A struct to transform `Expr<FieldBounds<Fp>, BoolBounds>` into `Expr<Fp, bool>`
/// by random sampling.
#[cfg(test)]
pub struct BoundSampler<'a, R: Rng + ?Sized>(pub &'a mut R);

#[cfg(test)]
impl<R: Rng + ?Sized> ApplyExpr for BoundSampler<'_, R> {
    type ScalarInput = FieldBounds<ScalarField>;
    type BitInput = BoolBounds;
    type BaseInput = FieldBounds<BaseField>;
    type CurveInput = CurveBounds;
    type ScalarOutput = ScalarField;
    type BitOutput = bool;
    type BaseOutput = BaseField;
    type CurveOutput = CurvePoint;

    fn f_scalar(&mut self, val: Self::ScalarInput) -> Self::ScalarOutput {
        val.sample(self.0)
    }

    fn f_bit(&mut self, val: Self::BitInput) -> Self::BitOutput {
        val.sample(self.0)
    }

    fn f_base(&mut self, val: Self::BaseInput) -> Self::BaseOutput {
        val.sample(self.0)
    }

    fn f_curve(&mut self, val: Self::CurveInput) -> Self::CurveOutput {
        val.sample(self.0)
    }
}