arcis-compiler 0.9.1

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::core::{
    circuits::{
        boolean::{boolean_value::BooleanValue, byte::Byte},
        f64::utils::F64,
    },
    expressions::expr::EvalFailure,
    global_value::global_expr_store::with_local_expr_store_as_global,
    ir_builder::IRBuilder,
};
use std::fmt::Debug;

/// A trait to define a f64 sub-circuit.
pub trait F64Circuit: Debug {
    #[allow(dead_code)]
    /// The operation that is being performed.
    fn eval(&self, x: Vec<f64>) -> Result<Vec<f64>, EvalFailure>;

    /// The relative tolerance between the result of eval and the result of running.
    #[allow(dead_code)]
    fn rtol(&self) -> f64 {
        0f64
    }

    #[allow(dead_code)]
    /// The operation in MPC.
    fn run(&self, vals: Vec<F64>) -> Vec<F64>;

    #[allow(dead_code)]
    fn run_usize(&self, vals: &[usize], expr_store: &mut IRBuilder) -> Vec<usize> {
        with_local_expr_store_as_global(
            || {
                self.run(
                    vals.iter()
                        .map(|val| BooleanValue::new(*val))
                        .collect::<Vec<BooleanValue>>()
                        .chunks(64)
                        .map(|chunk| {
                            F64::from_le_bytes(
                                chunk
                                    .to_vec()
                                    .chunks(8)
                                    .map(|bits| {
                                        Byte::new(bits.to_vec().try_into().unwrap_or_else(
                                            |v: Vec<BooleanValue>| {
                                                panic!(
                                                    "Expected a Vec of length 8 (found {})",
                                                    v.len()
                                                )
                                            },
                                        ))
                                    })
                                    .collect::<Vec<Byte<BooleanValue>>>()
                                    .try_into()
                                    .unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
                                        panic!("Expected a Vec of length 8 (found {})", v.len())
                                    }),
                            )
                        })
                        .collect::<Vec<F64>>(),
                )
                .into_iter()
                .flat_map(|x| {
                    x.to_le_bytes()
                        .into_iter()
                        .flat_map(|byte| {
                            byte.get_bits()
                                .into_iter()
                                .map(|bit| bit.get_id())
                                .collect::<Vec<usize>>()
                        })
                        .collect::<Vec<usize>>()
                })
                .collect::<Vec<usize>>()
            },
            expr_store,
        )
    }
}

#[cfg(test)]
pub mod tests {
    use super::*;
    use crate::{
        core::{
            circuits::traits::SAVE_CIRC_TEST_FOLDER_ENV_VAR,
            expressions::{
                bit_expr::{BitExpr, BitInputInfo},
                domain::Domain,
                expr::EvalValue,
                InputKind,
            },
            ir_builder::ExprStore,
        },
        utils::field::BaseField,
    };
    use rand::Rng;
    use std::rc::Rc;

    fn bits_to_f64s(bits: Vec<bool>) -> Vec<f64> {
        bits.chunks(64)
            .map(|chunk| {
                f64::from_le_bytes(
                    chunk
                        .to_vec()
                        .chunks(8)
                        .map(|bits| {
                            u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
                                |v: Vec<bool>| {
                                    panic!("Expected a Vec of length 8 (found {})", v.len())
                                },
                            )))
                        })
                        .collect::<Vec<u8>>()
                        .try_into()
                        .unwrap_or_else(|v: Vec<u8>| {
                            panic!("Expected a Vec of length 8 (found {})", v.len())
                        }),
                )
            })
            .collect::<Vec<f64>>()
    }

    fn desc_file_path() -> String {
        let folder = std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).unwrap();
        let binding = std::thread::current();
        let test_name = binding.name().unwrap();
        format!("{folder}/{}.desc", test_name)
    }
    fn run_file_path() -> String {
        let folder = std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).unwrap();
        let binding = std::thread::current();
        let test_name = binding.name().unwrap();
        format!("{folder}/{}.run", test_name)
    }

    fn test<R: Rng + ?Sized, C: TestedF64Circuit>(rng: &mut R, desc: &C) {
        let n_inputs = 64 * desc.gen_n_inputs(rng);
        let input_vals_bool = (0..n_inputs)
            .map(|_| rng.gen_bool(0.5))
            .collect::<Vec<bool>>();
        let input_vals_f64 = bits_to_f64s(input_vals_bool.clone());
        let eval_result = desc.eval(input_vals_f64.clone());
        // This circuit will be built by the compiler when circuits should be built.
        let mut expr_store = IRBuilder::new(false);
        let input_ids = (0..n_inputs)
            .map(|i| {
                <IRBuilder as ExprStore<BaseField>>::push_bit(
                    &mut expr_store,
                    BitExpr::Input(
                        i,
                        Rc::new(BitInputInfo {
                            kind: InputKind::Secret,
                            ..BitInputInfo::default()
                        }),
                    ),
                )
            })
            .collect::<Vec<usize>>();
        let outputs = desc.run_usize(&input_ids, &mut expr_store);
        let test_ir = expr_store.into_ir(outputs);
        let mut input_vals_map = input_vals_bool
            .into_iter()
            .map(EvalValue::Bit)
            .enumerate()
            .collect();
        let test_result = bits_to_f64s(
            test_ir
                .eval(rng, &mut input_vals_map)
                .map(|x| x.into_iter().map(bool::unwrap).collect::<Vec<bool>>())
                .unwrap(),
        );
        if eval_result.is_err() {
            return;
        }
        let eval_result = eval_result.unwrap();
        eval_result
            .iter()
            .zip(test_result)
            .for_each(|(eval_res, test_res)| {
                assert!((*eval_res - test_res).abs() <= desc.rtol() * (*eval_res).abs(), "\nRelative difference between eval_res: {:?} and test_res: {:?} exceeds rtol: {:?}. Inputs were: {:?}.\n", *eval_res, test_res, desc.rtol(), input_vals_f64)
            });
        desc.extra_checks(input_vals_f64, eval_result)
    }

    /// A trait to test F64Circuits.
    pub trait TestedF64Circuit: F64Circuit + Clone + 'static {
        /// A function to randomly generate a description of the sub-circuit.
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self;

        /// A function to randomly generate the number of inputs for the sub-circuit
        fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize;

        /// A function that can be overwritten in impls to perform extra-checks.
        #[allow(unused_variables)]
        fn extra_checks(&self, inputs: Vec<f64>, outputs: Vec<f64>) {}

        /// The actual test.
        /// n_desc is the number of different descriptions that will be generated.
        /// n_runs is the number of runs that will be attempted per description.
        /// A run where eval fails is a successful try.
        fn test(n_desc: usize, n_runs_per_desc: usize) {
            let rng = &mut crate::utils::test_rng::get();
            let (save_desc, save_run) = if std::env::var(SAVE_CIRC_TEST_FOLDER_ENV_VAR).is_ok() {
                let desc_path = desc_file_path();
                println!("saving the circuit description at {}", desc_path);
                let run_path = run_file_path();
                println!("saving the circuit run at {}", run_path);
                (Some(desc_path), Some(run_path))
            } else {
                (None, None)
            };
            for _ in 0..n_desc {
                if let Some(file_path) = &save_desc {
                    crate::utils::test_rng::save_to_file(rng, file_path);
                }
                let desc = Self::gen_desc(rng);
                for _ in 0..n_runs_per_desc {
                    if let Some(file_path) = &save_run {
                        crate::utils::test_rng::save_to_file(rng, file_path);
                    }
                    test(rng, &desc);
                }
            }
        }
    }
}