arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::{Bounds, IsBounds},
        compile::Compiler,
        expressions::{
            bit_expr::BitExpr,
            conversion_expr::ConversionExpr,
            curve_expr::CurveExpr,
            domain::DomainElement,
            expr::{EvalFailure, EvalValue, Expr, UndefinedBehavior},
            field_expr::FieldExpr,
            macro_uses::{BoundUnFold, EvalValueUnwrap},
        },
        instruction::ArcisInstruction,
        tracking::Tracking,
    },
    utils::{
        field::{BaseField, ScalarField},
        number::Number,
        used_field::UsedField,
    },
};
use ff::Field;
use indexmap::IndexSet;
use rand::Rng;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use std::{env, fmt, fmt::Formatter, io::Write};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateRepresentation {
    exprs: Vec<Expr<usize>>, // the usize are indices in exprs
    outputs: Vec<usize>,     // the items are indices in exprs
    bounds: Vec<Bounds>,     // same size as exprs
    is_plaintext: Vec<bool>, // same size as exprs
    tracking: Tracking,
}

impl IntermediateRepresentation {
    /// Constructor
    pub fn new(
        exprs: Vec<Expr<usize>>,
        outputs: Vec<usize>,
        bounds: Vec<Bounds>,
        is_plaintext: Vec<bool>,
    ) -> IntermediateRepresentation {
        let tracking = Default::default();
        IntermediateRepresentation {
            exprs,
            outputs,
            bounds,
            is_plaintext,
            tracking,
        }
    }
    pub fn new_with_tracking(
        exprs: Vec<Expr<usize>>,
        outputs: Vec<usize>,
        bounds: Vec<Bounds>,
        is_plaintext: Vec<bool>,
        tracking: Tracking,
    ) -> IntermediateRepresentation {
        IntermediateRepresentation {
            exprs,
            outputs,
            bounds,
            is_plaintext,
            tracking,
        }
    }
    #[allow(clippy::type_complexity)]
    pub fn destructure(
        self,
    ) -> (
        Vec<Expr<usize>>, // the usize are indices in exprs
        Vec<usize>,       // the items are indices in exprs
        Vec<Bounds>,      // same size as exprs
        Vec<bool>,        // same size as exprs
        Tracking,
    ) {
        (
            self.exprs,
            self.outputs,
            self.bounds,
            self.is_plaintext,
            self.tracking,
        )
    }

    pub fn get_exprs(&self) -> &[Expr<usize>] {
        &self.exprs
    }

    pub fn get_outputs(&self) -> &[usize] {
        &self.outputs
    }

    pub fn get_bounds(&self) -> &[Bounds] {
        &self.bounds
    }

    pub fn get_is_plaintext(&self) -> &[bool] {
        &self.is_plaintext
    }
    pub fn get_output_domains(&self) -> Vec<DomainElement<(), (), (), ()>> {
        self.outputs
            .iter()
            .map(|x| self.bounds[*x].to_domain())
            .collect()
    }

    pub fn get_expr(&self, id: usize) -> &Expr<usize> {
        &self.exprs[id]
    }
    pub fn get_tracking(&self) -> &Tracking {
        &self.tracking
    }

    pub fn check_for_integrity(&self) -> Result<(), ProgramError> {
        let mut input_ids: Vec<usize> = self.exprs.iter().filter_map(Expr::get_input).collect();
        input_ids.sort();
        let has_duplicates = input_ids.windows(2).any(|x| x[0] == x[1]);
        if has_duplicates {
            return Err(ProgramError::DuplicateInputId);
        }
        Ok(())
    }
    pub fn optimize_into_circuitable(self) -> IntermediateRepresentation {
        Compiler::optimize_into_circuitable(self)
    }
    pub fn to_async_mpc_circuit(&self) -> ArcisInstruction {
        let (unimproved_circuit, metadata, _) = Compiler::ir_to_async_mpc_circuit(self);
        let circuit = Compiler::improve_async_mpc_circuit(unimproved_circuit);
        ArcisInstruction { circuit, metadata }
    }
    /// Finds the generated booleans that have been optimized out.
    /// `result[idx]` is
    /// `Some(n)` if the `idx`-th generated bool in `opt_ir` is the `n`-th generated bool in `self`.
    /// `None` if the `idx`-th generated bool in `opt_ir` was not in `self`.
    pub fn generated_bools_filter(
        &self,
        opt_ir: &IntermediateRepresentation,
    ) -> Vec<Option<usize>> {
        let mut old_ids = IndexSet::new();
        for expr in self.get_exprs() {
            if let Expr::Bit(BitExpr::Random(id)) = expr {
                old_ids.insert(id.clone());
            }
        }
        let mut result = Vec::new();
        for expr in opt_ir.get_exprs() {
            if let Expr::Bit(BitExpr::Random(id)) = expr {
                let n = old_ids.get_index_of(id);
                result.push(n);
            }
        }
        result
    }

    pub fn to_bytes(&self) -> Vec<u8> {
        bincode::serialize(self).unwrap()
    }
    pub fn from_bytes(bytes: &[u8]) -> Result<IntermediateRepresentation, std::io::Error> {
        bincode::deserialize(bytes).map_err(|e| {
            std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!("Error while reading IntermediateRepresentation: {}.\nThis is probably caused by incompatible arcis versions used by the interpreter (in encrypted-ixs) and the CLI (the `arcium` binary in `arcium build`).", e),
            )
        })
    }
    const DEFAULT_CIRCUIT_OUT_DIR: &'static str = "build";

    pub fn write_bytes(
        &self,
        circuit_name: &str,
        out_dir: Option<String>,
    ) -> Result<String, std::io::Error> {
        let current_dir = env::current_dir()?;

        let circuits_dir =
            current_dir.join(out_dir.unwrap_or(Self::DEFAULT_CIRCUIT_OUT_DIR.to_string()));
        let file_path = circuits_dir.join(format!("{circuit_name}.arcis.ir"));
        let res = String::from(file_path.to_str().unwrap());

        std::fs::create_dir_all(&circuits_dir)?;

        let mut file = std::fs::File::create(file_path)?;
        file.write_all(&self.to_bytes())?;
        Ok(res)
    }
    pub fn eval_with_log<R: Rng + ?Sized>(
        &self,
        rng: &mut R,
        input_vals: &mut FxHashMap<usize, EvalValue>,
        always_recover_from_ub: bool,
        do_log: bool,
        skip_bound_check: bool,
        mut generated_bits: impl Iterator<Item = bool>,
    ) -> Result<Vec<EvalValue>, UndefinedBehavior> {
        let mut vals = Vec::<EvalValue>::with_capacity(self.get_exprs().len());
        for (i, expr) in self.get_exprs().iter().enumerate() {
            let val: EvalValue = match expr {
                Expr::Scalar(FieldExpr::Input(input_id, info)) => match input_vals.get(input_id) {
                    None => {
                        let v = ScalarField::gen_inclusive_range(rng, info.min, info.max);
                        let res = EvalValue::Scalar(v);
                        input_vals.insert(*input_id, res);
                        res
                    }
                    Some(v) => *v,
                },
                Expr::Base(FieldExpr::Input(input_id, info)) => match input_vals.get(input_id) {
                    None => {
                        let v = BaseField::gen_inclusive_range(rng, info.min, info.max);
                        let res = EvalValue::Base(v);
                        input_vals.insert(*input_id, res);
                        res
                    }
                    Some(v) => *v,
                },
                Expr::Bit(BitExpr::Input(input_id, _)) => match input_vals.get(input_id) {
                    None => {
                        let v = R::gen(rng);
                        let res = EvalValue::Bit(v);
                        input_vals.insert(*input_id, res);
                        res
                    }
                    Some(v) => *v,
                },
                Expr::Curve(CurveExpr::Input(input_id, _)) => match input_vals.get(input_id) {
                    None => {
                        let v = R::gen(rng);
                        let res = EvalValue::Curve(v);
                        input_vals.insert(*input_id, res);
                        res
                    }
                    Some(v) => *v,
                },
                Expr::Scalar(FieldExpr::RandomVal(_)) => {
                    EvalValue::Scalar(ScalarField::random(&mut *rng))
                }
                Expr::Base(FieldExpr::RandomVal(_)) => {
                    EvalValue::Base(BaseField::random(&mut *rng))
                }
                Expr::Bit(BitExpr::Random(_)) => {
                    EvalValue::Bit(generated_bits.next().unwrap_or(rng.gen()))
                }
                Expr::ScalarConversion(ConversionExpr::EdaBit(_, k, _)) => {
                    let v = Number::gen_range(rng, &0.into(), &Number::power_of_two(*k));
                    EvalValue::Scalar(v.into())
                }
                Expr::BaseConversion(ConversionExpr::EdaBit(_, k, _)) => {
                    let v = Number::gen_range(rng, &0.into(), &Number::power_of_two(*k));
                    EvalValue::Base(v.into())
                }
                _ => {
                    let res = expr
                        .clone()
                        .apply(|x| vals[x])
                        .apply_2(&mut EvalValueUnwrap)
                        .eval();
                    match res {
                        Ok(n) => n,
                        Err(e) => match e {
                            EvalFailure::UndefinedBehavior(ub) => {
                                if always_recover_from_ub {
                                    self.get_bounds()[i].get_sample_val()
                                } else {
                                    if do_log {
                                        println!("vals: ");
                                        for (i, val) in vals.iter().enumerate() {
                                            println!("{i}: {val:?}");
                                        }
                                    }
                                    return Err(ub);
                                }
                            }
                            _ => {
                                panic!("Error at expr {expr:?}: {e:?}")
                            }
                        },
                    }
                }
            };
            if !skip_bound_check && !self.get_bounds()[i].contains(val) {
                let expr_bound = expr
                    .clone()
                    .apply(|x| self.get_bounds()[x])
                    .apply_2(&mut BoundUnFold)
                    .bounds();
                panic!(
                    "{val:?} is not in {:?} for {expr:?}, or {expr_bound:?}: whose bounds are
            {:?}",
                    self.get_bounds()[i],
                    expr_bound
                );
            }
            vals.push(val);
        }
        if do_log {
            println!("vals: ");
            for (i, val) in vals.iter().enumerate() {
                println!("{i}: {val:?}");
            }
        }
        Ok(self.get_outputs().iter().map(|i| vals[*i]).collect())
    }
    pub fn eval<R: Rng + ?Sized>(
        &self,
        rng: &mut R,
        input_vals: &mut FxHashMap<usize, EvalValue>,
    ) -> Result<Vec<EvalValue>, UndefinedBehavior> {
        self.eval_with_log(rng, input_vals, false, false, false, std::iter::empty())
    }
    pub fn eval_vec<R: Rng + ?Sized>(
        &self,
        inputs: Vec<EvalValue>,
        rng: &mut R,
        generated_bits: &[bool],
    ) -> Result<Vec<EvalValue>, UndefinedBehavior> {
        let mut inputs = inputs.into_iter().enumerate().collect();
        self.eval_with_log(
            rng,
            &mut inputs,
            true,
            false,
            false,
            generated_bits.iter().cloned(),
        )
    }
}

impl fmt::Display for IntermediateRepresentation {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        let mut i = 0;
        for gate in &self.exprs {
            writeln!(f, "{i}: {gate:?};")?;
            i += 1;
        }
        for output in &self.outputs {
            writeln!(f, "{i}: Output({output});")?;
            i += 1;
        }
        write!(f, "")
    }
}

#[derive(Debug)]
pub enum ProgramError {
    DuplicateInputId,
}

#[cfg(test)]
mod tests {
    use crate::core::{expressions::expr::EvalValue, ir::IntermediateRepresentation};
    use rand::Rng;
    use rustc_hash::FxHashMap;

    impl IntermediateRepresentation {
        pub fn test_eq_with_vals<R: Rng + ?Sized>(
            rng: &mut R,
            ctrl_ir: &IntermediateRepresentation,
            test_ir: &IntermediateRepresentation,
            input_vals: &mut FxHashMap<usize, EvalValue>,
        ) {
            let ctrl_res = ctrl_ir.eval(rng, input_vals);
            if ctrl_res.is_err() {
                return;
            }
            let test_res = test_ir.eval(rng, input_vals);
            if ctrl_res != test_res {
                if test_ir.get_exprs().len() < 65536 {
                    println!("ctrl: {}", ctrl_ir);
                    println!("test: {}", test_ir);
                    println!("input_vals: {input_vals:?}");
                    println!("ctrl_res: {ctrl_res:?}");
                    println!("test_res: {test_res:?}");
                    let _ = test_ir.eval_with_log(
                        rng,
                        input_vals,
                        false,
                        true,
                        false,
                        std::iter::empty(),
                    );
                }
                assert_eq!(ctrl_res, test_res);
            }
        }
        pub fn test_eq<R: Rng + ?Sized>(
            rng: &mut R,
            ctrl_ir: &IntermediateRepresentation,
            test_ir: &IntermediateRepresentation,
            n_tests: usize,
        ) {
            for _ in 0..n_tests {
                let mut input_vals = FxHashMap::<usize, _>::default();
                Self::test_eq_with_vals(rng, ctrl_ir, test_ir, &mut input_vals)
            }
        }
    }
}