arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        expressions::expr::EvalFailure,
        global_value::{global_expr_store::with_local_expr_store_as_global, value::FieldValue},
        ir_builder::IRBuilder,
    },
    utils::used_field::UsedField,
};
use std::fmt::Debug;

/// A trait to define a scalar sub-circuit.
pub trait ArithmeticCircuit<F: UsedField>: Debug {
    /// The operation that is being performed.
    fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure>;
    /// The maximum gap between the result of eval and the result of running
    #[allow(unused_variables)]
    fn eval_gap(&self, x: &[F]) -> F {
        F::zero()
    }
    /// The bounds on the result.
    fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>>;
    /// The operation, in MPC.
    fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField;

    fn run_usize(&self, vals: &[usize], expr_store: &mut IRBuilder) -> Vec<usize>
    where
        F: ActuallyUsedField,
    {
        with_local_expr_store_as_global(
            || {
                self.run(
                    vals.iter()
                        .map(|id| FieldValue::<F>::from_id(*id))
                        .collect(),
                )
            },
            expr_store,
        )
        .iter()
        .map(FieldValue::get_id)
        .collect()
    }
}

#[cfg(test)]
pub mod tests {
    use super::*;
    use crate::{
        core::{
            bounds::IsBounds,
            circuits::traits::SAVE_CIRC_TEST_FOLDER_ENV_VAR,
            expressions::{field_expr::FieldExpr, InputKind},
            ir_builder::ExprStore,
        },
        ArcisFloatValue,
    };
    use rand::Rng;
    use std::{marker::PhantomData, path::Path};

    fn check_bounds<F: UsedField>(vals: &[F], bounds: &[FieldBounds<F>]) {
        assert_eq!(vals.len(), bounds.len());
        vals.iter().zip(bounds).for_each(|(val, bound)| {
            assert!(
                bound.contains(*val),
                "val ({:?}) is not in bounds ({:?})",
                val,
                bound
            );
        })
    }

    fn desc_file_path() -> String {
        let folder = std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).unwrap();
        let binding = std::thread::current();
        let test_name = binding.name().unwrap();
        format!("{folder}/{}.desc", test_name)
    }
    fn run_file_path() -> String {
        let folder = std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).unwrap();
        let binding = std::thread::current();
        let test_name = binding.name().unwrap();
        format!("{folder}/{}.run", test_name)
    }

    fn test<R: Rng + ?Sized, F: ActuallyUsedField, C: TestedArithmeticCircuit<F>>(
        rng: &mut R,
        desc: &C,
    ) {
        let n_inputs = desc.gen_n_inputs(rng);
        let mut bounds: Vec<FieldBounds<F>> =
            (0..n_inputs).map(|_| C::gen_input_bounds(rng)).collect();
        let input_vals: Vec<F> = bounds.iter().map(|bound| bound.sample(rng)).collect();
        // This circuit will be built by the compiler when circuits should be built.
        let mut expr_store = IRBuilder::new(false);
        let input_ids: Vec<usize> = bounds
            .iter_mut()
            .enumerate()
            .map(|(i, bound)| {
                if rng.gen_bool(0.125) {
                    let val = input_vals[i];
                    *bound = FieldBounds::new(val, val);
                    expr_store.push_field(FieldExpr::Val(val))
                } else {
                    expr_store
                        .push_field(FieldExpr::Input(i, bound.as_input_info(InputKind::Secret)))
                }
            })
            .collect();
        let mut input_values = input_vals
            .iter()
            .map(|x| F::field_to_eval_value(*x))
            .enumerate()
            .collect();
        let eval_result = desc.eval(input_vals.clone());
        let outputs = desc.run_usize(&input_ids, &mut expr_store);
        let test_ir = expr_store.into_ir(outputs);
        let test_result: Result<Vec<_>, _> = test_ir
            .eval(rng, &mut input_values)
            .map(|x| x.into_iter().map(F::eval_value_to_field).collect());
        let result_bounds = desc.bounds(bounds.clone());
        if let Ok(test_result) = &test_result {
            check_bounds(test_result, &result_bounds);
        }
        if eval_result.is_err() {
            return;
        }
        let test_result = test_result.unwrap();
        let eval_result = eval_result.unwrap();
        let eval_gap = desc.eval_gap(&input_vals);
        if eval_gap != F::ZERO {
            assert_eq!(eval_result.len(), test_result.len());
            let input_precisions = desc.input_precisions();
            let input_vals_float = if input_precisions.is_empty() {
                input_vals
                    .iter()
                    .map(|val| ArcisFloatValue::number_to_f64(val.to_signed_number()))
                    .collect::<Vec<_>>()
            } else {
                input_vals
                    .iter()
                    .zip(input_precisions)
                    .map(|(val, precision)| {
                        ArcisFloatValue::number_with_precision_to_f64(
                            val.to_signed_number(),
                            precision,
                        )
                    })
                    .collect::<Vec<_>>()
            };
            test_result
                .iter()
                .zip(eval_result.iter())
                .for_each(|(x, y)| {
                    assert!(
                        (*x - *y).abs() <= eval_gap,
                        "\nGap mismatch:\n{} (likely representing float {:?})\nis far from expected\n{} (likely representing float {:?})\n\nThe gap\n{}\nis larger than the tolerated\n{}.\n\nInputs were\n{:?} (likely representing floats {:?}).\n",
                        x.to_signed_number(),
                        ArcisFloatValue::number_to_f64(x.to_signed_number()),
                        y.to_signed_number(),
                        ArcisFloatValue::number_to_f64(y.to_signed_number()),
                        (*x - *y).abs().to_unsigned_number(),
                        eval_gap.to_unsigned_number(),
                        input_vals.iter().map(|val| {val.to_signed_number()}).collect::<Vec<_>>(),
                        input_vals_float,
                    )
                });
            check_bounds(&test_result, &result_bounds);
        } else {
            assert_eq!(
                test_result, eval_result,
                "different results in circuit {:?} for inputs {:?}",
                desc, input_vals
            );
        }
        check_bounds(&eval_result, &result_bounds);
        desc.extra_checks(input_vals, eval_result)
    }

    /// A trait to test ArithmeticCircuits.
    pub trait TestedArithmeticCircuit<F: ActuallyUsedField>:
        ArithmeticCircuit<F> + Clone + 'static
    {
        /// A function to randomly generate a description of the sub-circuit.
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self;

        /// A function to randomly generate the number of inputs for the sub-circuit
        fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize;

        /// A function that can be overwritten in impls to perform extra-checks.
        #[allow(unused_variables)]
        fn extra_checks(&self, inputs: Vec<F>, outputs: Vec<F>) {}

        /// This generates the bounds for each input. Should not be rewritten in impls.
        fn gen_input_bounds<R: Rng + ?Sized>(rng: &mut R) -> FieldBounds<F> {
            if rng.gen_bool(0.125) {
                return FieldBounds::All;
            }

            let signed = rng.gen_bool(0.5);
            let size = (rng.next_u32() % F::NUM_BITS) as usize;
            let two_power_size = F::power_of_two(size);
            let bounds_bounds = if signed {
                FieldBounds::new(-two_power_size, two_power_size)
            } else {
                FieldBounds::new(F::ZERO, two_power_size)
            };
            FieldBounds::gen_bounds(rng, bounds_bounds)
        }

        /// The actual test.
        /// n_desc is the number of different descriptions that will be generated.
        /// n_runs is the number of runs that will be attempted per description.
        /// A run where eval fails is a successful try.
        fn test(n_desc: usize, n_runs_per_desc: usize) {
            let rng = &mut crate::utils::test_rng::get();
            let (save_desc, save_run) = if std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).is_ok() {
                let desc_path = desc_file_path();
                println!("saving the circuit description at {}", desc_path);
                let run_path = run_file_path();
                println!("saving the circuit run at {}", run_path);
                (Some(desc_path), Some(run_path))
            } else {
                (None, None)
            };
            for _ in 0..n_desc {
                if let Some(file_path) = &save_desc {
                    crate::utils::test_rng::save_to_file(rng, file_path);
                }
                let desc = Self::gen_desc(rng);
                for _ in 0..n_runs_per_desc {
                    if let Some(file_path) = &save_run {
                        crate::utils::test_rng::save_to_file(rng, file_path);
                    }
                    test(rng, &desc);
                }
            }
        }
        fn test_with_marker(n_desc: usize, n_runs_per_desc: usize, _marker: PhantomData<F>) {
            Self::test(n_desc, n_runs_per_desc);
        }
        #[allow(dead_code)]
        fn test_once_with_rng_paths(desc_path: impl AsRef<Path>, run_path: impl AsRef<Path>) {
            let desc = Self::gen_desc(&mut crate::utils::test_rng::load_from_file(desc_path));
            let mut rng = crate::utils::test_rng::load_from_file(run_path);
            test(&mut rng, &desc);
        }
        #[allow(dead_code)]
        fn test_once_with_rng_paths_and_marker(
            desc_path: impl AsRef<Path>,
            run_path: impl AsRef<Path>,
            _marker: PhantomData<F>,
        ) {
            let desc = Self::gen_desc(&mut crate::utils::test_rng::load_from_file(desc_path));
            let mut rng = crate::utils::test_rng::load_from_file(run_path);
            test(&mut rng, &desc);
        }

        /// A function that can be overwritten in impls to return the precision of floating point
        /// inputs.
        #[allow(unused_variables)]
        fn input_precisions(&self) -> Vec<usize> {
            Vec::new()
        }
    }
}