use std::{collections::HashMap, time::Instant};
use criterion::Criterion;
use garble_lang::{
CircuitKind, CompileOptions, compile_with_options,
literal::{Literal, VariantLiteral},
register_circuit::Circuit,
token::UnsignedNumType,
};
use polytune::{Error, channel, mpc};
use tracing::{error, info};
async fn simulate_mpc_async(
circuit: &Circuit,
inputs: &[&[bool]],
output_parties: &[usize],
) -> Result<Vec<bool>, Error> {
let p_eval = 0;
let channels = channel::SimpleChannel::channels(inputs.len());
let mut parties = channels.into_iter().zip(inputs).enumerate();
let Some((_, (eval_channel, inputs))) = parties.next() else {
return Ok(vec![]);
};
let mut computation: tokio::task::JoinSet<Vec<bool>> = tokio::task::JoinSet::new();
for (p_own, (channel, inputs)) in parties {
let circuit = circuit.clone();
let inputs = inputs.to_vec();
let output_parties = output_parties.to_vec();
computation.spawn(async move {
match mpc(
&channel,
&circuit,
&inputs,
p_eval,
p_own,
&output_parties,
None,
)
.await
{
Ok(res) => {
info!(
"Party {p_own} sent {:.2}MB of messages",
channel.bytes_sent() as f64 / 1024.0 / 1024.0
);
res
}
Err(e) => {
error!("SMPC protocol failed for party {p_own}: {:?}", e);
vec![]
}
}
});
}
let eval_result = mpc(
&eval_channel,
circuit,
inputs,
p_eval,
p_eval,
output_parties,
None,
)
.await;
match eval_result {
Err(e) => {
error!("SMPC protocol failed for Evaluator: {:?}", e);
Ok(vec![])
}
Ok(res) => {
let mut outputs = vec![res];
while let Some(output) = computation.join_next().await {
if let Ok(output) = output {
outputs.push(output);
}
}
outputs.retain(|o| !o.is_empty());
if !outputs.windows(2).all(|w| w[0] == w[1]) {
error!("The result does not match for all output parties: {outputs:?}");
}
let mb = eval_channel.bytes_sent() as f64 / 1024.0 / 1024.0;
info!("Party {p_eval} sent {mb:.2}MB of messages");
info!("MPC simulation finished successfully!");
Ok(outputs.pop().unwrap_or_default())
}
}
}
pub fn join_benchmark(c: &mut Criterion) {
let n_records = std::env::var("POLYTUNE_BENCH_JOIN_SIZE")
.map(|v| v.parse().expect("POLYTUNE_BENCH_JOIN_SIZE must be a u64"))
.unwrap_or(10);
let code = include_str!(".join.garble.rs");
let bench_id = format!("join {n_records} records");
c.bench_function(&bench_id, |b| {
b.to_async(tokio::runtime::Runtime::new().unwrap())
.iter(|| async {
let now = Instant::now();
info!("\n\nRUNNING MPC SIMULATION FOR {n_records} RECORDS:\n");
info!("Compiling circuit...");
let consts = HashMap::from_iter(vec![
(
"PARTY_0".into(),
HashMap::from_iter(vec![(
"ROWS".into(),
Literal::NumUnsigned(n_records, UnsignedNumType::Usize),
)]),
),
(
"PARTY_1".into(),
HashMap::from_iter(vec![(
"ROWS".into(),
Literal::NumUnsigned(n_records, UnsignedNumType::Usize),
)]),
),
]);
let prg = compile_with_options(
code,
CompileOptions {
circuit_kind: CircuitKind::Register,
consts,
..Default::default()
},
)
.expect("Circuit compilation failed");
let circuit = prg.circuit.unwrap_register_ref();
info!(
"Compiled circuit with total instructions: {}, AND ops {}",
circuit.insts.len(),
circuit.and_ops
);
let elapsed = now.elapsed();
info!(
"Compilation took {} minute(s), {} second(s)",
elapsed.as_secs() / 60,
elapsed.as_secs() % 60,
);
let id = Literal::ArrayRepeat(
Box::new(Literal::NumUnsigned(0, UnsignedNumType::U8)),
20,
);
let screening_status = Literal::Enum(
"ScreeningStatus".into(),
"Missing".into(),
VariantLiteral::Unit,
);
let rows0 = Literal::ArrayRepeat(
Box::new(Literal::Tuple(vec![id.clone(), screening_status])),
n_records as usize,
);
let rows1 = Literal::ArrayRepeat(
Box::new(Literal::Tuple(vec![
id.clone(),
Literal::NumUnsigned(0, UnsignedNumType::U8),
])),
n_records as usize,
);
let input0 = prg.literal_arg(0, rows0).unwrap().as_bits();
let input1 = prg.literal_arg(1, rows1).unwrap().as_bits();
let inputs = vec![input0.as_slice(), input1.as_slice()];
simulate_mpc_async(circuit, &inputs, &[0]).await.unwrap();
let elapsed = now.elapsed();
info!(
"MPC computation took {} minute(s), {} second(s)",
elapsed.as_secs() / 60,
elapsed.as_secs() % 60,
);
})
});
}