use crate::{
network_content::NetworkContent,
preprocess_info::PreprocessInfo,
profile_summary::ProfileSummary,
AsyncMPCCircuit,
};
use core_utils::circuit::CircuitPreprocessing;
use indexmap::IndexMap;
use primitives::izip_eq;
type Label = core_utils::circuit::GateIndex;
fn communication_depth_per_gate(circuit: &AsyncMPCCircuit) -> IndexMap<Label, usize> {
let mut depth_per_gate = IndexMap::default();
for (label, gate) in izip_eq!(0..circuit.nb_gates(), circuit.iter_gates_ext()) {
depth_per_gate.insert(label, gate.level.comm_level());
}
depth_per_gate
}
pub fn profile_circuit(circuit: &AsyncMPCCircuit, tracking: &[usize]) -> Vec<ProfileSummary> {
let n_ops = circuit.nb_gates() as usize;
let mut max_depth = 0;
let mut network_content = NetworkContent::default();
let mut preprocessing = CircuitPreprocessing::default();
let mut res = Vec::with_capacity(tracking.len());
for idx in 0..=n_ops {
while res.len() < tracking.len() && idx >= tracking[res.len()] {
res.push(ProfileSummary::new(
max_depth,
idx,
network_content.network_size(),
PreprocessInfo::from(&preprocessing).weight(),
));
}
if res.len() >= tracking.len() {
break;
}
if idx >= n_ops {
panic!(
"problem here: {}, {}, {:?}",
tracking.len(),
res.len(),
tracking
);
}
let gate = circuit.gate_ext(idx as Label).expect("Gate not found.");
max_depth = max_depth.max(gate.level.comm_level());
network_content.add_gate(circuit, gate);
circuit.add_to_required_preprocessing(gate, &mut preprocessing);
}
res
}
pub fn get_circuit_depth(circuit: &AsyncMPCCircuit) -> usize {
let depth_per_gate = communication_depth_per_gate(circuit);
depth_per_gate.values().max().copied().unwrap_or(0)
}
#[allow(dead_code)]
pub fn explain_circuit_depth(circuit: &AsyncMPCCircuit) {
let communication_depth_per_gate = communication_depth_per_gate(circuit);
let Some((&max, _)) = communication_depth_per_gate.iter().max_by_key(|(_, b)| **b) else {
return;
};
let mut last = max;
let mut labels = Vec::new();
loop {
labels.push((last, communication_depth_per_gate[&last]));
let children = circuit
.gate(last as Label)
.expect("Gate not found")
.get_inputs();
let max_child = children
.into_iter()
.max_by_key(|c| communication_depth_per_gate[c]);
match max_child {
None => break,
Some(a) => last = a,
}
}
for (label, depth) in labels.into_iter().rev() {
println!("{depth} - {label}: {:?}", circuit.gate(label as Label));
}
}