arcis-compiler 0.9.3

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::FieldBounds,
        expressions::expr::EvalFailure,
        global_value::{global_expr_store::with_local_expr_store_as_global, value::FieldValue},
        ir_builder::IRBuilder,
    },
    utils::field::{BaseField, ScalarField},
};
use std::fmt::Debug;

#[allow(dead_code)]
pub trait GeneralCircuit: Debug {
    /// The operation that is being performed.
    fn eval(
        &self,
        scalars: Vec<ScalarField>,
        bases: Vec<BaseField>,
    ) -> Result<(Vec<ScalarField>, Vec<BaseField>), EvalFailure>;
    /// The bounds on the results.
    fn bounds(
        &self,
        scalar_bounds: Vec<FieldBounds<ScalarField>>,
        base_bounds: Vec<FieldBounds<BaseField>>,
    ) -> (Vec<FieldBounds<ScalarField>>, Vec<FieldBounds<BaseField>>);
    /// The operation, in MPC.
    fn run(
        &self,
        scalar_vals: Vec<FieldValue<ScalarField>>,
        base_vals: Vec<FieldValue<BaseField>>,
    ) -> (Vec<FieldValue<ScalarField>>, Vec<FieldValue<BaseField>>);
    fn run_usize(
        &self,
        scalar_vals: &[usize],
        base_vals: &[usize],
        expr_store: &mut IRBuilder,
    ) -> (Vec<usize>, Vec<usize>) {
        let (scalar_res, base_res) = with_local_expr_store_as_global(
            || {
                self.run(
                    scalar_vals
                        .iter()
                        .map(|id| FieldValue::from_id(*id))
                        .collect(),
                    base_vals
                        .iter()
                        .map(|id| FieldValue::from_id(*id))
                        .collect(),
                )
            },
            expr_store,
        );
        (
            scalar_res.iter().map(FieldValue::get_id).collect(),
            base_res.iter().map(FieldValue::get_id).collect(),
        )
    }
}

#[cfg(test)]
pub mod tests {
    use super::*;
    use crate::{
        core::{
            actually_used_field::ActuallyUsedField,
            bounds::IsBounds,
            expressions::{domain::Domain, expr::EvalValue, field_expr::FieldExpr, InputKind},
            ir_builder::{ExprStore, IRBuilder},
        },
        utils::used_field::UsedField,
    };
    use rand::Rng;
    use rustc_hash::FxHashMap;

    /// A trait to test GeneralCircuits.
    pub trait TestedGeneralCircuit: GeneralCircuit + Clone + 'static {
        /// A function to randomly generate a description of the sub-circuit.
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self;

        /// A function to randomly generate the number of scalars for the sub-circuit
        fn gen_n_scalars<R: Rng + ?Sized>(&self, rng: &mut R) -> usize;
        /// A function to randomly generate the number of base field elements for the sub-circuit
        fn gen_n_bases<R: Rng + ?Sized>(&self, rng: &mut R, n_scalars: usize) -> usize;

        /// A function that can be overwritten in impls to perform extra-checks.
        #[allow(unused_variables)]
        fn extra_checks(
            &self,
            scalar_inputs: Vec<ScalarField>,
            base_inputs: Vec<BaseField>,
            scalar_outputs: Vec<ScalarField>,
            base_outputs: Vec<BaseField>,
        ) {
        }

        /// This generates the bounds for each input. Should not be rewritten in impls.
        fn gen_input_bounds<F: UsedField, R: Rng + ?Sized>(rng: &mut R) -> FieldBounds<F> {
            if rng.gen_bool(0.125) {
                return FieldBounds::All;
            }

            let signed = rng.gen_bool(0.5);
            let size = (rng.next_u32() % F::NUM_BITS) as usize;
            let two_power_size = F::power_of_two(size);
            let bounds_bounds = if signed {
                FieldBounds::new(-two_power_size, two_power_size)
            } else {
                FieldBounds::new(F::ZERO, two_power_size)
            };
            FieldBounds::gen_bounds(rng, bounds_bounds)
        }

        /// The actual test.
        /// n_desc is the number of different descriptions that will be generated.
        /// n_runs is the number of runs that will be attempted per description.
        /// A run where eval fails is a successful try.
        fn test(n_desc: usize, n_runs_per_desc: usize) {
            fn gen_input_values_and_expr<F: ActuallyUsedField, R: Rng + ?Sized>(
                rng: &mut R,
                n_inputs: usize,
                start_input_id: usize,
                inputs: &mut FxHashMap<usize, EvalValue>,
                expr_store: &mut impl ExprStore<F>,
                mut gen_bounds: impl FnMut(&mut R) -> FieldBounds<F>,
            ) -> (Vec<usize>, Vec<F>, Vec<FieldBounds<F>>) {
                let mut bounds: Vec<_> = (0..n_inputs).map(|_| gen_bounds(rng)).collect();
                let input_vals: Vec<F> = bounds.iter().map(|bound| bound.sample(rng)).collect();
                let input_ids: Vec<usize> = bounds
                    .iter_mut()
                    .enumerate()
                    .map(|(i, bound)| {
                        if rng.gen_bool(0.125) {
                            let val = input_vals[i];
                            *bound = FieldBounds::new(val, val);
                            expr_store.push_field(FieldExpr::Val(val))
                        } else {
                            expr_store.push_field(FieldExpr::Input(
                                start_input_id + i,
                                bound.as_input_info(InputKind::Secret),
                            ))
                        }
                    })
                    .collect();
                input_vals.iter().enumerate().for_each(|(i, val)| {
                    inputs.insert(start_input_id + i, F::field_to_eval_value(*val));
                });
                (input_ids, input_vals, bounds)
            }

            fn check_bounds<F: UsedField>(vals: &[F], bounds: &[FieldBounds<F>]) {
                assert_eq!(vals.len(), bounds.len());
                vals.iter().zip(bounds).for_each(|(val, bound)| {
                    assert!(bound.contains(*val));
                })
            }

            let rng = &mut crate::utils::test_rng::get();
            for _ in 0..n_desc {
                let desc = Self::gen_desc(rng);
                for _ in 0..n_runs_per_desc {
                    // This circuit will be built by the compiler when circuits should be built.
                    let mut expr_store = IRBuilder::new(false);
                    let mut input_values: FxHashMap<usize, EvalValue> = FxHashMap::default();
                    let n_scalars = desc.gen_n_scalars(rng);
                    let n_bases = desc.gen_n_bases(rng, n_scalars);
                    let (scalar_ids, scalar_vals, scalar_bounds) =
                        gen_input_values_and_expr::<ScalarField, _>(
                            rng,
                            n_scalars,
                            0,
                            &mut input_values,
                            &mut expr_store,
                            Self::gen_input_bounds,
                        );
                    let (base_ids, base_vals, base_bounds) =
                        gen_input_values_and_expr::<BaseField, _>(
                            rng,
                            n_bases,
                            scalar_vals.len(),
                            &mut input_values,
                            &mut expr_store,
                            Self::gen_input_bounds,
                        );
                    let ctrl_eval_result = desc.eval(scalar_vals.clone(), base_vals.clone());
                    let Ok(ctrl_eval_result) = ctrl_eval_result else {
                        // If eval fails, then the inputs were wrong.
                        // We do not have to bother testing.
                        continue;
                    };
                    let (scalar_res, base_res) =
                        desc.run_usize(&scalar_ids, &base_ids, &mut expr_store);
                    let scalar_base_border = scalar_res.len();
                    let output_ids = [scalar_res, base_res].concat();
                    let n_outputs = output_ids.len();
                    let run_result = expr_store.into_ir(output_ids).eval(rng, &mut input_values);
                    let Ok(run_result) = run_result else {
                        panic!("run failed: {:?}", run_result);
                    };
                    let scalar_res = (0..scalar_base_border)
                        .map(|i| ScalarField::unwrap(run_result[i]))
                        .collect::<Vec<_>>();
                    let base_res = (scalar_base_border..n_outputs)
                        .map(|i| BaseField::unwrap(run_result[i]))
                        .collect::<Vec<_>>();

                    let (scalar_bounds_res, base_bounds_res) =
                        desc.bounds(scalar_bounds, base_bounds);

                    assert_eq!(ctrl_eval_result.0, scalar_res);
                    assert_eq!(ctrl_eval_result.1, base_res);
                    check_bounds(&scalar_res, &scalar_bounds_res);
                    check_bounds(&base_res, &base_bounds_res);
                    desc.extra_checks(scalar_vals, base_vals, scalar_res, base_res);
                }
            }
        }
    }
}