arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        circuits::{
            arithmetic::{PowCircuit, SqrtCircuit},
            traits::arithmetic_circuit::ArithmeticCircuit,
        },
        compile_passes::compilation_pass::LocalCompilationPass,
        expressions::{
            circuit::{ArithmeticCircuitId, BaseCircuitId, GeneralCircuitId},
            expr::Expr,
            field_expr::FieldExpr,
            other_expr::OtherExpr,
        },
        ir_builder::IRBuilder,
    },
    utils::field::{BaseField, ScalarField},
};
use ff::Field;
use rustc_hash::FxHashMap;

type ArithmeticCircuitRunCache = FxHashMap<(Vec<usize>, ArithmeticCircuitId), Vec<usize>>;
type BaseCircuitRunCache = FxHashMap<(Vec<usize>, BaseCircuitId), Vec<usize>>;
type GeneralCircuitRunCache =
    FxHashMap<(Vec<usize>, Vec<usize>, GeneralCircuitId), (Vec<usize>, Vec<usize>)>;

/// Builds unoptimized arithmetic circuits.
#[derive(Default)]
pub struct ArithmeticCircuitBuilder {
    expr_store: IRBuilder,
    arithmetic_circuit_cache: ArithmeticCircuitRunCache,
    base_circuit_cache: BaseCircuitRunCache,
    general_circuit_cache: GeneralCircuitRunCache,
}

impl ArithmeticCircuitBuilder {
    fn build_arithmetic_circuits<F: ActuallyUsedField>(
        &mut self,
        expr: FieldExpr<F, usize>,
    ) -> Expr<usize> {
        match expr {
            FieldExpr::SubCircuit(v, c, idx) => {
                let res = match self.arithmetic_circuit_cache.get(&(v.clone(), c)) {
                    None => {
                        let circuit_outputs =
                            c.to_circuit::<F>().run_usize(&v, &mut self.expr_store);
                        let res = circuit_outputs.get(idx).cloned();
                        self.arithmetic_circuit_cache
                            .insert((v.clone(), c), circuit_outputs);
                        res
                    }
                    Some(circuit_outputs) => circuit_outputs.get(idx).cloned(),
                };
                res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
                    .unwrap_or(F::field_expr_to_expr(FieldExpr::Val(F::ZERO)))
            }
            FieldExpr::Sqrt(v) if !self.expr_store.get_is_plaintext(v) => {
                let circuit_outputs = <SqrtCircuit as ArithmeticCircuit<F>>::run_usize(
                    &SqrtCircuit,
                    &[v],
                    &mut self.expr_store,
                );
                let res = circuit_outputs[0];
                self.expr_store.get_expr(res).clone()
            }
            FieldExpr::Pow(v, e, is_expected_non_zero) if !self.expr_store.get_is_plaintext(v) => {
                let circuit_outputs = <PowCircuit as ArithmeticCircuit<F>>::run_usize(
                    &PowCircuit {
                        exponent: e,
                        is_expected_non_zero,
                    },
                    &[v],
                    &mut self.expr_store,
                );
                let res = circuit_outputs[0];
                self.expr_store.get_expr(res).clone()
            }
            _ => F::field_expr_to_expr(expr),
        }
    }

    fn build_general_circuits(&mut self, expr: OtherExpr<usize>) -> Expr<usize> {
        match expr {
            OtherExpr::ScalarGeneralCircuit(s, b, c, idx) => {
                let res = match self.general_circuit_cache.get(&(s.clone(), b.clone(), c)) {
                    None => {
                        let result = c.to_circuit().run_usize(&s, &b, &mut self.expr_store);
                        self.general_circuit_cache
                            .insert((s.clone(), b.clone(), c), result.clone());
                        result.0.get(idx).cloned()
                    }
                    Some((scalar_res, _)) => scalar_res.get(idx).cloned(),
                };
                res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
                    .unwrap_or(Expr::Scalar(FieldExpr::Val(ScalarField::ZERO)))
            }
            OtherExpr::BaseGeneralCircuit(s, b, c, idx) => {
                let res = match self.general_circuit_cache.get(&(s.clone(), b.clone(), c)) {
                    None => {
                        let result = c.to_circuit().run_usize(&s, &b, &mut self.expr_store);
                        self.general_circuit_cache
                            .insert((s.clone(), b.clone(), c), result.clone());
                        result.1.get(idx).cloned()
                    }
                    Some((_, base_res)) => base_res.get(idx).cloned(),
                };
                res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
                    .unwrap_or(Expr::Base(FieldExpr::Val(BaseField::ZERO)))
            }
            OtherExpr::BaseArithmeticCircuit(v, c, idx) => {
                let res = match self.base_circuit_cache.get(&(v.clone(), c)) {
                    None => {
                        let circuit_outputs = c.to_circuit().run_usize(&v, &mut self.expr_store);
                        let res = circuit_outputs.get(idx).cloned();
                        self.base_circuit_cache
                            .insert((v.clone(), c), circuit_outputs);
                        res
                    }
                    Some(circuit_outputs) => circuit_outputs.get(idx).cloned(),
                };
                res.map(|expr_id| self.expr_store.get_expr(expr_id).clone())
                    .unwrap_or(Expr::Base(FieldExpr::Val(BaseField::ZERO)))
            }
            _ => Expr::Other(expr),
        }
    }
}

impl LocalCompilationPass for ArithmeticCircuitBuilder {
    fn expr_store(&mut self) -> &mut IRBuilder {
        &mut self.expr_store
    }

    fn transform(&mut self, expr: Expr<usize>, _is_plaintext: bool) -> Expr<usize> {
        match expr {
            Expr::Scalar(e) => self.build_arithmetic_circuits(e),
            Expr::Base(e) => self.build_arithmetic_circuits(e),
            Expr::Other(e) => self.build_general_circuits(e),
            _ => expr,
        }
    }
}