arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    network_content::NetworkContent,
    preprocess_info::PreprocessInfo,
    profile_summary::ProfileSummary,
    AsyncMPCCircuit,
};
use core_utils::circuit::{
    Batched,
    BitShareBinaryOp,
    BitShareUnaryOp,
    CircuitPreprocessing,
    FieldShareBinaryOp,
    FieldShareUnaryOp,
    Input,
    PointShareBinaryOp,
    PointShareUnaryOp,
    ShareOrPlaintext,
};
use indexmap::IndexMap;
use primitives::algebra::elliptic_curve::Curve25519Ristretto;

type Gate = core_utils::circuit::Gate<Curve25519Ristretto>;
type Label = core_utils::circuit::GateIndex;

fn gate_online_network_rounds(gate: &Gate) -> usize {
    match gate {
        Gate::Input { input_type } => match input_type {
            Input::SecretPlaintext { .. } => 1,
            _ => 0,
        },
        Gate::FieldShareUnaryOp { op, .. } => match op {
            FieldShareUnaryOp::Neg => 0,
            FieldShareUnaryOp::MulInverse => 2,
            FieldShareUnaryOp::Open => 1,
            FieldShareUnaryOp::IsZero => 2,
        },
        Gate::FieldShareBinaryOp { y_form, op, .. } => match op {
            FieldShareBinaryOp::Add => 0,
            FieldShareBinaryOp::Mul => match y_form {
                ShareOrPlaintext::Share => 1,
                ShareOrPlaintext::Plaintext => 0,
            },
        },
        Gate::BatchSummation { .. } => 0,
        Gate::BitShareUnaryOp { op, .. } => match op {
            BitShareUnaryOp::Not => 0,
            BitShareUnaryOp::Open => 1,
        },
        Gate::BitShareBinaryOp { y_form, op, .. } => match op {
            BitShareBinaryOp::Xor => 0,
            BitShareBinaryOp::Or | BitShareBinaryOp::And => match y_form {
                ShareOrPlaintext::Share => 1,
                ShareOrPlaintext::Plaintext => 0,
            },
        },
        Gate::PointShareUnaryOp { op, .. } => match op {
            PointShareUnaryOp::Neg => 0,
            PointShareUnaryOp::Open => 1,
            PointShareUnaryOp::IsZero => 2,
        },
        Gate::PointShareBinaryOp { y_form, op, .. } => match op {
            PointShareBinaryOp::Add => 0,
            PointShareBinaryOp::ScalarMul => match y_form {
                ShareOrPlaintext::Share => 1,
                ShareOrPlaintext::Plaintext => 0,
            },
        },
        Gate::FieldPlaintextUnaryOp { .. } => 0,
        Gate::FieldPlaintextBinaryOp { .. } => 0,
        Gate::BitPlaintextUnaryOp { .. } => 0,
        Gate::BitPlaintextBinaryOp { .. } => 0,
        Gate::DaBit { .. } => 0,
        Gate::GetDaBitFieldShare { .. } => 0,
        Gate::GetDaBitSharedBit { .. } => 0,
        Gate::BaseFieldPow { .. } => 2,
        Gate::BitPlaintextToField { .. } => 0,
        Gate::FieldPlaintextToBit { .. } => 0,
        Gate::BatchGetIndex { .. } => 0,
        Gate::CollectToBatch { .. } => 0,
        Gate::PointPlaintextUnaryOp { .. } => 0,
        Gate::PointPlaintextBinaryOp { .. } => 0,
        Gate::PointFromPlaintextExtendedEdwards { .. } => 0,
        Gate::PlaintextPointToExtendedEdwards { .. } => 0,
        Gate::PlaintextKeccakF1600 { .. } => 0,
        Gate::CompressPlaintextPoint { .. } => 0,
        Gate::KeyRecoveryPlaintextComputeErrors { .. } => 0,
    }
}

fn communication_depth_per_gate(circuit: &AsyncMPCCircuit) -> IndexMap<Label, usize> {
    let mut depth_per_gate = IndexMap::default();
    for (label, gate) in circuit.iter().enumerate() {
        let inputs = gate.get_gate_indices();
        let input_depth = inputs.iter().map(|x| depth_per_gate[x]).max().unwrap_or(0);
        let gate_depth = input_depth + gate_online_network_rounds(gate);
        depth_per_gate.insert(label as u32, gate_depth);
    }
    depth_per_gate
}

pub fn profile_circuit(circuit: &AsyncMPCCircuit, tracking: &[usize]) -> Vec<ProfileSummary> {
    let depth_per_gate = communication_depth_per_gate(circuit);
    let n_ops = circuit.ops_count() as usize;
    let mut max_depth = 0;
    let mut network_content = NetworkContent::default();
    let mut preprocessing = CircuitPreprocessing::default();
    let batched = circuit.determine_batched_gates();
    let mut res = Vec::with_capacity(tracking.len());
    for idx in 0..=n_ops {
        while res.len() < tracking.len() && idx >= tracking[res.len()] {
            res.push(ProfileSummary::new(
                max_depth,
                idx,
                network_content.network_size(),
                PreprocessInfo::from(&preprocessing).weight(),
            ));
        }
        if res.len() >= tracking.len() {
            break;
        }
        if idx >= n_ops {
            panic!(
                "problem here: {}, {}, {:?}",
                tracking.len(),
                res.len(),
                tracking
            );
        }
        max_depth = max_depth.max(depth_per_gate[idx]);
        let label = idx as u32;
        let batched = batched.get(&label).copied();
        network_content.add_gate(&circuit.ops[idx], batched.unwrap_or(1));
        let batched = match batched {
            Some(batched) => Batched::Yes(batched),
            None => Batched::No,
        };
        circuit.ops[idx].add_to_required_preprocessing(batched, &mut preprocessing);
    }
    res
}

pub fn get_circuit_depth(circuit: &AsyncMPCCircuit) -> usize {
    let depth_per_gate = communication_depth_per_gate(circuit);
    depth_per_gate.values().max().copied().unwrap_or(0)
}

#[allow(dead_code)]
pub fn explain_circuit_depth(circuit: &AsyncMPCCircuit) {
    let communication_depth_per_gate = communication_depth_per_gate(circuit);
    let Some((&max, _)) = communication_depth_per_gate.iter().max_by_key(|(_, b)| **b) else {
        return;
    };
    let mut last = max;
    let mut labels = Vec::new();
    loop {
        labels.push((last, communication_depth_per_gate[&last]));
        let children = circuit.ops[last as usize].get_gate_indices();
        let max_child = children
            .into_iter()
            .max_by_key(|c| communication_depth_per_gate[c]);
        match max_child {
            None => break,
            Some(a) => last = a,
        }
    }
    for (label, depth) in labels.into_iter().rev() {
        println!("{depth} - {label}: {:?}", circuit.ops[label as usize]);
    }
}