arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::core::{
    actually_used_field::ActuallyUsedField,
    bounds::IsBounds,
    compile_passes::optimizer::Optimizer,
    expressions::{
        circuit::ArithmeticCircuitId,
        expr::Expr,
        field_expr::{
            expr_lincomb,
            FieldExpr,
            FieldExpr::{LinComb, Val},
        },
    },
    ir_builder::ExprStore,
};
use std::collections::BTreeMap;

impl Optimizer {
    fn is_constant<F: ActuallyUsedField>(&self, e: usize) -> Option<F> {
        self.expr_store.bounds(e).as_constant()
    }
    /// `is_result_plaintext` should be `false` when you do not know.
    fn merge_lincomb<F: ActuallyUsedField>(
        &mut self,
        v: Vec<(usize, F)>,
        c: F,
        is_result_plaintext: bool,
    ) -> Expr<usize> {
        let mut coeffs: BTreeMap<usize, F> = BTreeMap::new();
        let mut new_c = c;
        fn add_expr<F: ActuallyUsedField>(coeffs: &mut BTreeMap<usize, F>, e: usize, factor: F) {
            match coeffs.get_mut(&e) {
                None => {
                    coeffs.insert(e, factor);
                }
                Some(mut_factor) => {
                    *mut_factor += factor;
                }
            }
        }
        for (e, factor) in &v {
            match F::expr_to_field_expr(self.expr_store.get_expr(*e).clone()) {
                // The 16 limit is for performance reasons.
                // Unlimited, this would lead to quadratic complexity for compiling sums
                // And potentially for the circuit size too.
                Some(LinComb(v2, c2)) if v2.len() < 16 => {
                    for (e2, factor2) in v2 {
                        add_expr(&mut coeffs, e2, *factor * factor2)
                    }
                    new_c += *factor * c2;
                }
                Some(Val(c2)) => new_c += *factor * c2,
                _ => add_expr(&mut coeffs, *e, *factor),
            }
        }
        let coeffs: Vec<_> = coeffs
            .iter()
            .filter(|(_, x)| **x != F::ZERO)
            .map(|(expr1, x)| (*expr1, *x))
            .collect();
        if coeffs.is_empty() {
            F::field_expr_to_expr(Val(new_c))
        } else if new_c == F::ZERO && coeffs.len() == 1 && coeffs[0].1 == F::ONE {
            if is_result_plaintext {
                self.expr_store.reveal(coeffs[0].0)
            }
            self.expr_store.get_expr(coeffs[0].0).clone()
        } else {
            if is_result_plaintext {
                let n_non_plaintext = coeffs
                    .iter()
                    .filter(|(e, _)| !self.expr_store.get_is_plaintext(*e))
                    .count();
                // if `a` is plaintext, then `reveal(a + b)` is equal to `a + reveal(b)`
                if n_non_plaintext == 1 {
                    for (e, _) in &coeffs {
                        self.expr_store.reveal(*e)
                    }
                }
            }
            F::field_expr_to_expr(LinComb(coeffs, new_c))
        }
    }

    pub fn optimize_field_expr<F: ActuallyUsedField>(
        &mut self,
        expr: FieldExpr<F, usize>,
        is_plaintext: bool,
    ) -> Expr<usize> {
        use FieldExpr::*;
        match expr {
            LinComb(v, c) => self.merge_lincomb(v, c, is_plaintext),
            Mul(e1, e2) => {
                if is_plaintext
                    && self.expr_store.get_is_plaintext(e1)
                    && !F::bounds_to_field_bounds(*self.expr_store.get_bounds(e1)).contains(F::ZERO)
                {
                    self.expr_store.reveal(e2);
                }
                if is_plaintext
                    && self.expr_store.get_is_plaintext(e2)
                    && !F::bounds_to_field_bounds(*self.expr_store.get_bounds(e2)).contains(F::ZERO)
                {
                    self.expr_store.reveal(e1);
                }
                if let Some(c) = self.is_constant(e1) {
                    self.merge_lincomb(vec![(e2, c)], F::ZERO, is_plaintext)
                } else if let Some(c) = self.is_constant(e2) {
                    self.merge_lincomb(vec![(e1, c)], F::ZERO, is_plaintext)
                } else if e1 > e2 {
                    F::field_expr_to_expr(Mul(e2, e1))
                } else {
                    F::field_expr_to_expr(expr)
                }
            }
            Equal(e1, e2) => {
                let (e1, e2) = if self.is_constant(e2) == Some(F::ZERO) {
                    (e1, e2)
                } else {
                    (e1.min(e2), e1.max(e2))
                };
                let sub = self.merge_lincomb(expr_lincomb!((e1, 1), (e2, -1)), F::ZERO, false);
                let sub_field_expr = F::expr_to_field_expr(sub.clone());
                if let Some(Val(val)) = sub_field_expr.as_ref() {
                    F::field_expr_to_expr(Val((*val == F::ZERO).into()))
                } else {
                    // Optimizing a*b == 0
                    let e_sub = if let Some(Mul(f1, f2)) = sub_field_expr.as_ref() {
                        let f1_nonzero =
                            !F::bounds_to_field_bounds(*self.expr_store.get_bounds(*f1))
                                .contains(F::ZERO);
                        let f2_nonzero =
                            !F::bounds_to_field_bounds(*self.expr_store.get_bounds(*f2))
                                .contains(F::ZERO);
                        if f1_nonzero && f2_nonzero {
                            return F::field_expr_to_expr(Val(F::ZERO));
                        } else if f1_nonzero {
                            Some(*f2)
                        } else if f2_nonzero {
                            Some(*f1)
                        } else {
                            None
                        }
                    } else if let Some(LinComb(v, c)) = sub_field_expr.as_ref() {
                        if *c == F::ZERO && v.len() == 1 {
                            Some(v[0].0)
                        } else {
                            None
                        }
                    } else {
                        None
                    };
                    if is_plaintext || e_sub.is_some() {
                        let e_sub = e_sub.unwrap_or_else(|| self.expr_store.new_expr(sub));
                        let e_zero = self
                            .expr_store
                            .new_expr(F::field_expr_to_expr(Val(F::ZERO)));
                        F::field_expr_to_expr(Equal(e_sub, e_zero))
                    } else {
                        F::field_expr_to_expr(expr)
                    }
                }
            }
            Gt(e1, e2, _) => {
                if e1 == e2 {
                    F::field_expr_to_expr(Val(F::ZERO))
                } else {
                    F::field_expr_to_expr(expr)
                }
            }
            Ge(e1, e2, _) => {
                if e1 == e2 {
                    F::field_expr_to_expr(Val(1.into()))
                } else {
                    F::field_expr_to_expr(expr)
                }
            }
            Add(e1, e2) => {
                self.merge_lincomb(expr_lincomb!((e1, 1), (e2, 1)), F::ZERO, is_plaintext)
            }
            Sub(e1, e2) => {
                self.merge_lincomb(expr_lincomb!((e1, 1), (e2, -1)), F::ZERO, is_plaintext)
            }
            Neg(e) => self.merge_lincomb(expr_lincomb!((e, -1)), F::ZERO, is_plaintext),
            Reveal(e) => {
                let expr = self.expr_store.get_expr(e);
                expr.clone()
            }
            Bounds(e, _) => {
                let expr = self.expr_store.get_expr(e);
                expr.clone()
            }
            Where(e1, e2, e3) => {
                let sub = self.merge_lincomb(expr_lincomb!((e2, 1), (e3, -1)), F::ZERO, false);
                let e_sub = self.expr_store.new_expr(sub);
                let e_prod = self.expr_store.push_field(Mul::<F, _>(e1, e_sub));
                self.merge_lincomb(expr_lincomb!((e_prod, 1), (e3, 1)), F::ZERO, is_plaintext)
            }
            KeepLsBits(e, c, signed_output) => {
                let bounds = F::bounds_to_field_bounds(*self.expr_store.get_bounds(e));
                if bounds == KeepLsBits(bounds, c, signed_output).bounds() {
                    self.expr_store.get_expr(e).clone()
                } else {
                    F::field_expr_to_expr(expr)
                }
            }
            SubCircuit(v, ArithmeticCircuitId::Div, 0) => {
                assert_eq!(v.len(), 2);
                let e_a = v[0];
                let e_b = v[1];
                let expr_b = self.expr_store.get_expr(e_b).clone();
                let res = if let Some(SubCircuit(v2, ArithmeticCircuitId::Sqrt, 0)) =
                    F::expr_to_field_expr(expr_b)
                {
                    SubCircuit(vec![e_a, v2[0]], ArithmeticCircuitId::DivSqrt, 0)
                } else {
                    SubCircuit(v, ArithmeticCircuitId::Div, 0)
                };
                F::field_expr_to_expr(res)
            }
            y => F::field_expr_to_expr(y),
        }
    }
}