use std::{
collections::HashMap,
sync::{Arc, OnceLock, mpsc::Receiver},
};
use criterion::{Criterion, criterion_group, criterion_main};
use dashmap::DashMap;
use parasol_runtime::{
CircuitProcessor,
ComputeKey,
ComputeKeyNonFft,
DEFAULT_128,
Encryption,
Evaluation,
FAST_BIG_128,
L1GgswCiphertext,
L1GlweCiphertext,
Params,
SecretKey,
TURBO_CHUNGUS_128, fluent::{Bit, CiphertextOps, FheCircuitCtx, Muxable, UInt, UIntGraphNodes},
};
const PARAMS: &[Params] = &[DEFAULT_128, FAST_BIG_128, TURBO_CHUNGUS_128];
fn params_name(params: &Params) -> String {
let params_names = [
(DEFAULT_128, "DEFAULT_128".to_owned()),
(FAST_BIG_128, "FAST_BIG_128".to_owned()),
(TURBO_CHUNGUS_128, "TURBO_CHUNGUS_128".to_owned()),
]
.into_iter()
.collect::<HashMap<Params, String>>();
params_names.get(params).unwrap().to_owned()
}
fn make_computer(
params: &Params,
) -> (
Encryption,
Arc<SecretKey>,
CircuitProcessor,
Receiver<()>,
Evaluation,
) {
static SK: OnceLock<DashMap<Params, OnceLock<Arc<SecretKey>>>> = OnceLock::new();
static COMPUTE_KEY: OnceLock<DashMap<Params, OnceLock<Arc<ComputeKey>>>> = OnceLock::new();
let sk_cache = SK.get_or_init(DashMap::new);
let sk = sk_cache
.entry(params.clone())
.or_default()
.get_or_init(|| Arc::new(SecretKey::generate(params)))
.clone();
let ck_cache = COMPUTE_KEY.get_or_init(DashMap::new);
let compute_key = ck_cache
.entry(params.clone())
.or_default()
.get_or_init(|| {
let compute = ComputeKeyNonFft::generate(&sk, params);
Arc::new(compute.fft(params))
})
.clone();
let enc = Encryption::new(params);
let eval = Evaluation::new(compute_key.to_owned(), params, &enc);
let (uproc, fc) = CircuitProcessor::new(16384, None, &eval, &enc);
(enc, sk, uproc, fc, eval)
}
fn bench_binary_function<const N: usize, InCt, F>(crit: &mut Criterion, name: &str, op: F)
where
InCt: CiphertextOps,
F: Fn(
&FheCircuitCtx,
&UIntGraphNodes<N, L1GgswCiphertext>,
&UIntGraphNodes<N, L1GgswCiphertext>,
),
{
for p in PARAMS {
let (enc, sk, mut uproc, fc, _) = make_computer(p);
let ctx = FheCircuitCtx::new();
let a =
UInt::<N, InCt>::encrypt_secret(42 & ((0x1 << N) - 1), &enc, &sk).graph_inputs(&ctx);
let b =
UInt::<N, InCt>::encrypt_secret(35 & ((0x1 << N) - 1), &enc, &sk).graph_inputs(&ctx);
let a = a.convert::<L1GgswCiphertext>(&ctx).into();
let b = b.convert::<L1GgswCiphertext>(&ctx).into();
op(&ctx, &a, &b);
crit.bench_function(&format!("{name} params: {}", params_name(p)), |bench| {
bench.iter(|| {
uproc
.run_graph_blocking(&ctx.circuit.borrow(), &fc)
.unwrap();
});
});
}
}
fn bench_select_function<const N: usize, InCt>(crit: &mut Criterion, name: &str)
where
InCt: CiphertextOps + Muxable,
{
for p in PARAMS {
let (enc, sk, mut uproc, fc, _) = make_computer(p);
let ctx = FheCircuitCtx::new();
let selector = Bit::<InCt>::encrypt_secret(true, &enc, &sk)
.graph_input(&ctx)
.convert::<L1GgswCiphertext>(&ctx);
let a =
UInt::<N, InCt>::encrypt_secret(42 & ((0x1 << N) - 1), &enc, &sk).graph_inputs(&ctx);
let b =
UInt::<N, InCt>::encrypt_secret(35 & ((0x1 << N) - 1), &enc, &sk).graph_inputs(&ctx);
let a: UIntGraphNodes<N, InCt> = a.into();
let b: UIntGraphNodes<N, InCt> = b.into();
selector.select(&a, &b, &ctx);
crit.bench_function(&format!("{name} params: {}", params_name(p)), |bench| {
bench.iter(|| {
uproc
.run_graph_blocking(&ctx.circuit.borrow(), &fc)
.unwrap();
});
});
}
}
fn ops(c: &mut Criterion) {
fn run_benchmarks<const N: usize>(c: &mut Criterion) {
bench_binary_function::<N, L1GlweCiphertext, _>(
c,
&format!("add-{N}-glwe"),
|ctx, x, y| {
x.add::<L1GlweCiphertext>(y, ctx);
},
);
bench_binary_function::<N, L1GlweCiphertext, _>(c, &format!("gt-{N}-glwe"), |ctx, x, y| {
x.gt::<L1GlweCiphertext>(y, ctx);
});
bench_binary_function::<N, L1GlweCiphertext, _>(
c,
&format!("mul-{N}-glwe"),
|ctx, x, y| {
x.mul::<L1GlweCiphertext>(y, ctx);
},
);
bench_select_function::<N, L1GlweCiphertext>(c, &format!("select-{N}-glwe"));
}
run_benchmarks::<2>(c);
run_benchmarks::<4>(c);
run_benchmarks::<8>(c);
run_benchmarks::<16>(c);
run_benchmarks::<32>(c);
run_benchmarks::<64>(c);
}
criterion_group!(benches, ops);
criterion_main!(benches);