arcis-compiler 0.9.2

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::core::{
    actually_used_field::ActuallyUsedField,
    compile_passes::compilation_pass::LocalCompilationPass,
    expressions::{
        bit_expr::BitExpr,
        conversion_expr::{ConversionExpr, EdaBitId},
        expr::Expr,
        field_expr::FieldExpr,
    },
    ir_builder::{ExprStore, IRBuilder},
};
use rustc_hash::FxHashMap;
use std::marker::PhantomData;

/// Expands all linear combinations into gates (additions and multiplications).
/// Also expands all Boolean BinOp Expr into gate-able Expr.
#[derive(Default)]
pub struct ComplexExprExpander {
    expr_store: IRBuilder,
}

impl ComplexExprExpander {
    fn expand_lincombs<F: ActuallyUsedField>(&mut self, expr: FieldExpr<F, usize>) -> Expr<usize> {
        expand_lincombs(expr, &mut self.expr_store)
    }
    fn expand_binop(&mut self, expr: BitExpr<usize>) -> BitExpr<usize> {
        use BitExpr::*;
        use Expr::Bit;
        let expr_store = &mut self.expr_store;
        if let Binop(e1, e2, truth_table) = expr {
            // There could not be NOT(A) or NOT(B) within e1 or e2
            debug_assert!(!is_bit_expr_not(e1, expr_store), "a NOT has escaped");
            debug_assert!(!is_bit_expr_not(e2, expr_store), "a NOT has escaped");
            debug_assert!(
                !truth_table[0],
                "a truth table is true for (false, false). This is unacceptable"
            );
            // number of true values in the truth table
            let n_true_values = truth_table.iter().filter(|x| **x).count();
            // n_true_values equals 1, 2, or 3
            // expr takes different values from e1 or e2
            if n_true_values == 0 || n_true_values == 4 {
                panic!("unexpected constant binary operation");
            } else if n_true_values == 2 {
                // there are 6 theoretical functions:
                // e1, !e1, e2, !e2, e1 ^ e2, !(e1 ^ e2)
                // Only e1 ^ e2 is possible, because the first 4 have been optimized out and
                // truth_table[0] is false.
                if truth_table[0] {
                    panic!("No Nots");
                } else {
                    bit_xor(e1, e2)
                }
            } else if n_true_values == 3 {
                // !(!a & !b), that we write as a^b^(a&b)
                let and = expr_store.new_expr(Bit(bit_and(e1, e2)));
                let sum = expr_store.new_expr(Bit(bit_xor(e1, and)));
                bit_xor(e2, sum)
            } else if truth_table[1] {
                // !a & b, that we write as b^(a&b)
                let and = expr_store.new_expr(Bit(bit_and(e1, e2)));
                bit_xor(e2, and)
            } else if truth_table[2] {
                // a & !b, that we write as a^(a&b)
                let and = expr_store.new_expr(Bit(bit_and(e1, e2)));
                bit_xor(e1, and)
            } else {
                debug_assert!(truth_table[3], "this should be and");
                // a & b
                bit_and(e1, e2)
            }
        } else {
            expr
        }
    }
}

fn expand_lincombs<F: ActuallyUsedField>(
    expr: FieldExpr<F, usize>,
    expr_store: &mut impl ExprStore<F>,
) -> Expr<usize> {
    use FieldExpr::*;
    let expr = if let LinComb(v, c) = expr {
        match v.len() {
            0 => Val(c),
            1 => {
                let (e, f) = &v[0];
                if c == F::ZERO {
                    if *f == F::ONE {
                        return expr_store.get_expr(*e).clone();
                    } else if *f == -F::ONE {
                        Neg(*e)
                    } else {
                        let f = expr_store.push_field(Val(*f));
                        Mul(*e, f)
                    }
                } else {
                    let c = expr_store.push_field(Val(c));
                    if *f == F::ONE {
                        Add(*e, c)
                    } else if *f == -F::ONE {
                        let new_e = expr_store.push_field(Neg(*e));
                        Add(new_e, c)
                    } else {
                        let f = expr_store.push_field(Val(*f));
                        let new_e = expr_store.push_field(Mul(*e, f));
                        Add(new_e, c)
                    }
                }
            }
            _ => {
                let new_v: Vec<usize> = v
                    .iter()
                    .map(|(e, f)| {
                        if *f == F::ONE {
                            *e
                        } else if *f == -F::ONE {
                            expr_store.push_field(Neg(*e))
                        } else {
                            let f = expr_store.push_field(Val(*f));
                            expr_store.push_field(Mul(*e, f))
                        }
                    })
                    .collect();
                let mut expr = Add(new_v[0], new_v[1]);
                for item in new_v.iter().skip(2) {
                    let e = expr_store.push_field(expr);
                    expr = Add(e, *item);
                }
                if c != F::ZERO {
                    let e = expr_store.push_field(expr);
                    let c = expr_store.push_field(Val(c));
                    expr = Add(e, c);
                }
                expr
            }
        }
    } else {
        expr
    };
    F::field_expr_to_expr(expr)
}

fn bit_xor(e1: usize, e2: usize) -> BitExpr<usize> {
    if e2 < e1 {
        BitExpr::Xor(e2, e1)
    } else {
        BitExpr::Xor(e1, e2)
    }
}
fn bit_and(e1: usize, e2: usize) -> BitExpr<usize> {
    if e2 < e1 {
        BitExpr::And(e2, e1)
    } else {
        BitExpr::And(e1, e2)
    }
}
fn is_bit_expr_not(e: usize, expr_store: &mut IRBuilder) -> bool {
    matches!(expr_store.get_expr(e), Expr::Bit(BitExpr::Not(_)))
}

impl LocalCompilationPass for ComplexExprExpander {
    fn expr_store(&mut self) -> &mut IRBuilder {
        &mut self.expr_store
    }

    fn transform(&mut self, expr: Expr<usize>, _is_plaintext: bool) -> Expr<usize> {
        match expr {
            Expr::Scalar(e) => self.expand_lincombs(e),
            Expr::Bit(e) => Expr::Bit(self.expand_binop(e)),
            Expr::Base(e) => self.expand_lincombs(e),
            _ => expr,
        }
    }
}

#[derive(Default)]
pub struct ComplexExprExpanderTestnet {
    inner: ComplexExprExpander,
    eda_bit_id_to_da_bit_expr_ids_map: FxHashMap<EdaBitId, Vec<usize>>,
}

impl ComplexExprExpanderTestnet {
    fn eda_bit_expr_id_to_da_bit_expr_ids<F: ActuallyUsedField>(
        &self,
        eda_bit_expr_id: usize,
    ) -> &Vec<usize> {
        let Some(ConversionExpr::EdaBit(eda_bit_id, _, _)) =
            F::expr_to_conversion_expr(self.inner.expr_store.get_expr(eda_bit_expr_id).clone())
        else {
            panic!("cannot expand eda_bits_to_da_bits");
        };
        let Some(v) = self.eda_bit_id_to_da_bit_expr_ids_map.get(&eda_bit_id) else {
            panic!("cannot expand eda_bits_to_da_bits");
        };
        v
    }
    fn expand_eda_bits_to_da_bits<F: ActuallyUsedField>(
        &mut self,
        expr: ConversionExpr<F, usize>,
    ) -> Expr<usize> {
        match expr {
            ConversionExpr::EdaBit(eda_bit_id, width, _) => {
                let v = (0..width)
                    .map(|_| {
                        self.inner.expr_store.new_expr(F::conversion_expr_to_expr(
                            ConversionExpr::EdaBit(EdaBitId::new(), 1, PhantomData),
                        ))
                    })
                    .collect::<Vec<_>>();
                self.eda_bit_id_to_da_bit_expr_ids_map.insert(eda_bit_id, v);
                F::conversion_expr_to_expr(ConversionExpr::EdaBit(eda_bit_id, width, PhantomData))
            }
            ConversionExpr::BitFromEdaBit(eda_bit_expr_id, bit_idx) => {
                let v = self.eda_bit_expr_id_to_da_bit_expr_ids::<F>(eda_bit_expr_id);
                assert!(bit_idx < v.len());
                F::conversion_expr_to_expr(ConversionExpr::BitFromEdaBit(v[bit_idx], 0))
            }
            ConversionExpr::ScalarFromEdaBit(eda_bit_expr_id) => {
                let v = self
                    .eda_bit_expr_id_to_da_bit_expr_ids::<F>(eda_bit_expr_id)
                    .clone();
                let v = v
                    .into_iter()
                    .enumerate()
                    .map(|(i, x)| {
                        let expr_id = self.inner.expr_store.new_expr(F::conversion_expr_to_expr(
                            ConversionExpr::ScalarFromEdaBit(x),
                        ));
                        (expr_id, F::power_of_two(i))
                    })
                    .collect::<Vec<_>>();
                self.inner.expand_lincombs(FieldExpr::LinComb(v, F::ZERO))
            }
            _ => F::conversion_expr_to_expr(expr),
        }
    }
}

impl LocalCompilationPass for ComplexExprExpanderTestnet {
    fn expr_store(&mut self) -> &mut IRBuilder {
        &mut self.inner.expr_store
    }

    fn transform(&mut self, expr: Expr<usize>, _is_plaintext: bool) -> Expr<usize> {
        match expr {
            Expr::Scalar(e) => self.inner.expand_lincombs(e),
            Expr::Bit(e) => Expr::Bit(self.inner.expand_binop(e)),
            Expr::Base(e) => self.inner.expand_lincombs(e),
            Expr::BaseConversion(e) => self.expand_eda_bits_to_da_bits(e),
            Expr::ScalarConversion(e) => self.expand_eda_bits_to_da_bits(e),
            _ => expr,
        }
    }
}