arcis-compiler 0.9.1

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::core::{expressions::expr::Expr, ir::IntermediateRepresentation, ir_builder::IRBuilder};

pub trait CompilationPass: Default {
    fn run(&mut self, old_ir: IntermediateRepresentation) -> IntermediateRepresentation;
    fn make_and_run(old_ir: IntermediateRepresentation) -> IntermediateRepresentation {
        let mut pass = Self::default();
        pass.run(old_ir)
    }
}

pub trait LocalCompilationPass: Default {
    fn expr_store(&mut self) -> &mut IRBuilder;
    #[allow(unused_variables)]
    fn setup(&mut self, old_ir: &IntermediateRepresentation) {}
    fn transform(&mut self, expr: Expr<usize>, is_plaintext: bool) -> Expr<usize>;
}

impl<T: LocalCompilationPass> CompilationPass for T {
    fn run(&mut self, old_ir: IntermediateRepresentation) -> IntermediateRepresentation {
        self.setup(&old_ir);
        let (old_ir_exprs, old_ir_outputs, old_ir_bounds, old_ir_is_plaintext, old_ir_tracking) =
            old_ir.destructure();
        let len = old_ir_exprs.len();
        let mut old_id_to_new_id = Vec::with_capacity(len);

        for old_expr_id in 0..len {
            let old_is_plaintext = old_ir_is_plaintext[old_expr_id];
            let new_expr = if let Some(c) = old_ir_bounds[old_expr_id].as_constant_expr() {
                self.transform(c, true)
            } else {
                let old_expr = &old_ir_exprs[old_expr_id];
                self.transform(
                    old_expr.clone().apply(|x| old_id_to_new_id[x]),
                    old_is_plaintext,
                )
            };
            let new_id = self.expr_store().new_expr_with_info(
                new_expr,
                Some(old_ir_bounds[old_expr_id]),
                old_is_plaintext,
            );
            old_id_to_new_id.push(new_id);
        }
        let new_outputs: Vec<usize> = old_ir_outputs
            .iter()
            .map(|old_output| old_id_to_new_id[*old_output])
            .collect();
        let ir_builder = std::mem::take(self.expr_store());
        ir_builder.into_ir_with_tracking(new_outputs, old_ir_tracking, old_id_to_new_id)
    }
}