use crate::{
network_content::NetworkContent,
preprocess_info::PreprocessInfo,
profile_summary::ProfileSummary,
AsyncMPCCircuit,
};
use core_utils::circuit::{
Batched,
BitShareBinaryOp,
BitShareUnaryOp,
CircuitPreprocessing,
FieldShareBinaryOp,
FieldShareUnaryOp,
Input,
PointShareBinaryOp,
PointShareUnaryOp,
ShareOrPlaintext,
};
use indexmap::IndexMap;
use primitives::algebra::elliptic_curve::Curve25519Ristretto;
type Gate = core_utils::circuit::Gate<Curve25519Ristretto>;
type Label = core_utils::circuit::GateIndex;
fn gate_online_network_rounds(gate: &Gate) -> usize {
match gate {
Gate::Input { input_type } => match input_type {
Input::SecretPlaintext { .. } => 1,
_ => 0,
},
Gate::FieldShareUnaryOp { op, .. } => match op {
FieldShareUnaryOp::Neg => 0,
FieldShareUnaryOp::MulInverse => 2,
FieldShareUnaryOp::Open => 1,
FieldShareUnaryOp::IsZero => 2,
},
Gate::FieldShareBinaryOp { y_form, op, .. } => match op {
FieldShareBinaryOp::Add => 0,
FieldShareBinaryOp::Mul => match y_form {
ShareOrPlaintext::Share => 1,
ShareOrPlaintext::Plaintext => 0,
},
},
Gate::BatchSummation { .. } => 0,
Gate::BitShareUnaryOp { op, .. } => match op {
BitShareUnaryOp::Not => 0,
BitShareUnaryOp::Open => 1,
},
Gate::BitShareBinaryOp { y_form, op, .. } => match op {
BitShareBinaryOp::Xor => 0,
BitShareBinaryOp::Or | BitShareBinaryOp::And => match y_form {
ShareOrPlaintext::Share => 1,
ShareOrPlaintext::Plaintext => 0,
},
},
Gate::PointShareUnaryOp { op, .. } => match op {
PointShareUnaryOp::Neg => 0,
PointShareUnaryOp::Open => 1,
PointShareUnaryOp::IsZero => 2,
},
Gate::PointShareBinaryOp { y_form, op, .. } => match op {
PointShareBinaryOp::Add => 0,
PointShareBinaryOp::ScalarMul => match y_form {
ShareOrPlaintext::Share => 1,
ShareOrPlaintext::Plaintext => 0,
},
},
Gate::FieldPlaintextUnaryOp { .. } => 0,
Gate::FieldPlaintextBinaryOp { .. } => 0,
Gate::BitPlaintextUnaryOp { .. } => 0,
Gate::BitPlaintextBinaryOp { .. } => 0,
Gate::DaBit { .. } => 0,
Gate::GetDaBitFieldShare { .. } => 0,
Gate::GetDaBitSharedBit { .. } => 0,
Gate::BaseFieldPow { .. } => 2,
Gate::BitPlaintextToField { .. } => 0,
Gate::FieldPlaintextToBit { .. } => 0,
Gate::BatchGetIndex { .. } => 0,
Gate::CollectToBatch { .. } => 0,
Gate::PointPlaintextUnaryOp { .. } => 0,
Gate::PointPlaintextBinaryOp { .. } => 0,
Gate::PointFromPlaintextExtendedEdwards { .. } => 0,
Gate::PlaintextPointToExtendedEdwards { .. } => 0,
Gate::PlaintextKeccakF1600 { .. } => 0,
Gate::CompressPlaintextPoint { .. } => 0,
Gate::KeyRecoveryPlaintextComputeErrors { .. } => 0,
}
}
fn communication_depth_per_gate(circuit: &AsyncMPCCircuit) -> IndexMap<Label, usize> {
let mut depth_per_gate = IndexMap::default();
for (label, gate) in circuit.iter().enumerate() {
let inputs = gate.get_gate_indices();
let input_depth = inputs.iter().map(|x| depth_per_gate[x]).max().unwrap_or(0);
let gate_depth = input_depth + gate_online_network_rounds(gate);
depth_per_gate.insert(label as u32, gate_depth);
}
depth_per_gate
}
pub fn profile_circuit(circuit: &AsyncMPCCircuit, tracking: &[usize]) -> Vec<ProfileSummary> {
let depth_per_gate = communication_depth_per_gate(circuit);
let n_ops = circuit.ops_count() as usize;
let mut max_depth = 0;
let mut network_content = NetworkContent::default();
let mut preprocessing = CircuitPreprocessing::default();
let batched = circuit.determine_batched_gates();
let mut res = Vec::with_capacity(tracking.len());
for idx in 0..=n_ops {
while res.len() < tracking.len() && idx >= tracking[res.len()] {
res.push(ProfileSummary::new(
max_depth,
idx,
network_content.network_size(),
PreprocessInfo::from(&preprocessing).weight(),
));
}
if res.len() >= tracking.len() {
break;
}
if idx >= n_ops {
panic!(
"problem here: {}, {}, {:?}",
tracking.len(),
res.len(),
tracking
);
}
max_depth = max_depth.max(depth_per_gate[idx]);
let label = idx as u32;
let batched = batched.get(&label).copied();
network_content.add_gate(&circuit.ops[idx], batched.unwrap_or(1));
let batched = match batched {
Some(batched) => Batched::Yes(batched),
None => Batched::No,
};
circuit.ops[idx].add_to_required_preprocessing(batched, &mut preprocessing);
}
res
}
pub fn get_circuit_depth(circuit: &AsyncMPCCircuit) -> usize {
let depth_per_gate = communication_depth_per_gate(circuit);
depth_per_gate.values().max().copied().unwrap_or(0)
}
#[allow(dead_code)]
pub fn explain_circuit_depth(circuit: &AsyncMPCCircuit) {
let communication_depth_per_gate = communication_depth_per_gate(circuit);
let Some((&max, _)) = communication_depth_per_gate.iter().max_by_key(|(_, b)| **b) else {
return;
};
let mut last = max;
let mut labels = Vec::new();
loop {
labels.push((last, communication_depth_per_gate[&last]));
let children = circuit.ops[last as usize].get_gate_indices();
let max_child = children
.into_iter()
.max_by_key(|c| communication_depth_per_gate[c]);
match max_child {
None => break,
Some(a) => last = a,
}
}
for (label, depth) in labels.into_iter().rev() {
println!("{depth} - {label}: {:?}", circuit.ops[label as usize]);
}
}