arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        expressions::expr::EvalFailure,
        global_value::{
            curve_value::CurveValue,
            global_expr_store::with_local_expr_store_as_global,
            value::FieldValue,
        },
        ir_builder::IRBuilder,
    },
    utils::{curve_point::CurvePoint, field::ScalarField},
};
use std::fmt::Debug;

pub trait CurveCircuit: Debug {
    /// The operation that is being performed.
    #[allow(dead_code)]
    fn eval(
        &self,
        curve_points: Vec<CurvePoint>,
        scalars: Vec<ScalarField>,
    ) -> Result<Vec<CurvePoint>, EvalFailure>;

    /// The operation, in MPC.
    #[allow(dead_code)]
    fn run(
        &self,
        curve_vals: Vec<CurveValue>,
        scalar_vals: Vec<FieldValue<ScalarField>>,
    ) -> Vec<CurveValue>;

    #[allow(dead_code)]
    fn run_usize(
        &self,
        curve_vals: &[usize],
        scalar_vals: &[usize],
        expr_store: &mut IRBuilder,
    ) -> Vec<usize> {
        let res = with_local_expr_store_as_global(
            || {
                self.run(
                    curve_vals.iter().map(|id| CurveValue::new(*id)).collect(),
                    scalar_vals
                        .iter()
                        .map(|id| FieldValue::from_id(*id))
                        .collect(),
                )
            },
            expr_store,
        );

        res.iter().map(CurveValue::get_id).collect::<Vec<usize>>()
    }
}

#[cfg(test)]
pub mod tests {
    use super::*;
    use crate::{
        core::{
            bounds::FieldBounds,
            expressions::{
                curve_expr::{self, CurveExpr},
                domain::Domain,
                expr::EvalValue,
                field_expr::FieldExpr,
                InputKind,
            },
            ir_builder::{ExprStore, IRBuilder},
        },
        utils::used_field::UsedField,
    };
    use ff::PrimeField;
    use rand::Rng;
    use rustc_hash::FxHashMap;
    use std::rc::Rc;

    /// A trait to test CurveCircuits.
    pub trait TestedCurveCircuit: CurveCircuit + Clone + 'static {
        /// A function to randomly generate a description of the sub-circuit.
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self;

        /// A function to randomly generate the number of curve points for the sub-circuit
        fn gen_n_points<R: Rng + ?Sized>(&self, rng: &mut R) -> usize;

        /// A function to randomly generate the number of scalar field elements for the sub-circuit
        fn gen_n_scalars<R: Rng + ?Sized>(&self, rng: &mut R, n_points: usize) -> usize;

        /// A function that can be overwritten in impls to perform extra-checks.
        #[allow(unused_variables)]
        fn extra_checks(
            &self,
            curve_inputs: Vec<CurvePoint>,
            scalar_inputs: Vec<ScalarField>,
            curve_outputs: Vec<CurvePoint>,
        ) {
        }

        /// This generates the bounds for each input. Should not be rewritten in impls.
        fn gen_input_bounds<R: Rng + ?Sized>(rng: &mut R) -> FieldBounds<ScalarField> {
            if rng.gen_bool(0.125) {
                return FieldBounds::All;
            }

            let signed = rng.gen_bool(0.5);
            let size = (rng.next_u32() % ScalarField::NUM_BITS) as usize;
            let two_power_size = ScalarField::power_of_two(size);
            let bounds_bounds = if signed {
                FieldBounds::new(-two_power_size, two_power_size)
            } else {
                FieldBounds::new(ScalarField::from(0), two_power_size)
            };
            FieldBounds::gen_bounds(rng, bounds_bounds)
        }

        /// The actual test.
        /// n_desc is the number of different descriptions that will be generated.
        /// n_runs is the number of runs that will be attempted per description.
        /// A run where eval fails is a successful try.
        fn test(n_desc: usize, n_runs_per_desc: usize) {
            fn gen_input_points_and_expr<R: Rng + ?Sized>(
                rng: &mut R,
                n_inputs: usize,
                start_input_id: usize,
                inputs: &mut FxHashMap<usize, EvalValue>,
                expr_store: &mut impl ExprStore<ScalarField>,
            ) -> (Vec<usize>, Vec<CurvePoint>) {
                let input_points = (0..n_inputs)
                    .map(|_| R::gen(rng))
                    .collect::<Vec<CurvePoint>>();
                let input_ids = input_points
                    .iter()
                    .enumerate()
                    .map(|(i, point)| {
                        if rng.gen_bool(0.125) {
                            expr_store.push_curve(CurveExpr::Val(*point))
                        } else {
                            expr_store.push_curve(CurveExpr::Input(
                                start_input_id + i,
                                Rc::new(curve_expr::InputInfo::from(InputKind::Secret)),
                            ))
                        }
                    })
                    .collect::<Vec<usize>>();
                input_points.iter().enumerate().for_each(|(i, point)| {
                    inputs.insert(start_input_id + i, EvalValue::Curve(*point));
                });
                (input_ids, input_points)
            }

            fn gen_input_scalars_and_expr<R: Rng + ?Sized>(
                rng: &mut R,
                n_inputs: usize,
                start_input_id: usize,
                inputs: &mut FxHashMap<usize, EvalValue>,
                expr_store: &mut impl ExprStore<ScalarField>,
                mut gen_bounds: impl FnMut(&mut R) -> FieldBounds<ScalarField>,
            ) -> (Vec<usize>, Vec<ScalarField>) {
                let mut bounds = (0..n_inputs)
                    .map(|_| gen_bounds(rng))
                    .collect::<Vec<FieldBounds<ScalarField>>>();
                let input_scalars = bounds
                    .iter()
                    .map(|bound| bound.sample(rng))
                    .collect::<Vec<ScalarField>>();
                let input_ids = bounds
                    .iter_mut()
                    .enumerate()
                    .map(|(i, bound)| {
                        if rng.gen_bool(0.125) {
                            let val = input_scalars[i];
                            *bound = FieldBounds::new(val, val);
                            expr_store.push_field(FieldExpr::Val(val))
                        } else {
                            expr_store.push_field(FieldExpr::Input(
                                start_input_id + i,
                                bound.as_input_info(InputKind::Secret),
                            ))
                        }
                    })
                    .collect::<Vec<usize>>();
                input_scalars.iter().enumerate().for_each(|(i, val)| {
                    inputs.insert(start_input_id + i, EvalValue::Scalar(*val));
                });
                (input_ids, input_scalars)
            }

            let rng = &mut crate::utils::test_rng::get();
            for _ in 0..n_desc {
                let desc = Self::gen_desc(rng);
                for _ in 0..n_runs_per_desc {
                    // This circuit will be built by the compiler when circuits should be built.
                    let mut expr_store = IRBuilder::new(false);
                    let mut input_values: FxHashMap<usize, _> = FxHashMap::default();
                    let n_points = desc.gen_n_points(rng);
                    let n_scalars = desc.gen_n_scalars(rng, n_points);
                    let (input_point_ids, input_points) = gen_input_points_and_expr(
                        rng,
                        n_points,
                        0,
                        &mut input_values,
                        &mut expr_store,
                    );
                    let (input_scalar_ids, input_scalars) = gen_input_scalars_and_expr(
                        rng,
                        n_scalars,
                        input_points.len(),
                        &mut input_values,
                        &mut expr_store,
                        Self::gen_input_bounds,
                    );
                    let ctrl_eval_result = desc.eval(input_points.clone(), input_scalars.clone());
                    let Ok(ctrl_eval_result) = ctrl_eval_result else {
                        // If eval fails, then the inputs were wrong.
                        // We do not have to bother testing.
                        continue;
                    };
                    let output_ids =
                        desc.run_usize(&input_point_ids, &input_scalar_ids, &mut expr_store);
                    let n_outputs = output_ids.len();
                    let run_result = expr_store.into_ir(output_ids).eval(rng, &mut input_values);
                    let Ok(run_result) = run_result else {
                        panic!("run failed: {:?}", run_result);
                    };
                    let res = (0..n_outputs)
                        .map(|i| CurvePoint::unwrap(run_result[i]))
                        .collect::<Vec<CurvePoint>>();

                    assert_eq!(ctrl_eval_result, res);
                    desc.extra_checks(input_points, input_scalars, res);
                }
            }
        }
    }
}