use std::{borrow::Cow, time::Instant};
use ark_ff::FftField;
use ark_std::rand::{distributions::Standard, prelude::Distribution, thread_rng, Rng};
use provekit_whir::{
algebra::{
embedding::{Embedding, Identity},
fields::Field256,
linear_form::{Evaluate, LinearForm, MultilinearExtension},
},
parameters::ProtocolParameters,
transcript::{codecs::Empty, Codec, DomainSeparator, ProverState, VerifierState},
};
const NUM_POLYNOMIALS: usize = 44;
const NUM_VARIABLES: usize = 19;
const NUM_COEFFS: usize = 1 << NUM_VARIABLES;
fn main() {
println!("=== PARAMETER SWEEP: 44 poly × 2^19, BN254, 128-bit security ===\n");
let mut rng = thread_rng();
let vectors: Vec<Vec<Field256>> = (0..NUM_POLYNOMIALS)
.map(|_| (0..NUM_COEFFS).map(|_| rng.gen()).collect())
.collect();
println!("--- NON-ZK ---");
println!("{:<45} {:>10} {:>10} {:>10} {:>10}",
"Config", "Commit", "Prove", "Total", "Proof KiB");
for &pow in &[10, 15, 20, 25] {
for &rate in &[1, 2] {
for &ff in &[3, 4] {
let label = format!("pow={pow}, rate={rate}, ff={ff}");
bench_nonzk(&label, &vectors, pow, rate, ff, ff);
}
}
}
println!("\n--- ZK ---");
println!("{:<45} {:>10} {:>10} {:>10} {:>10}",
"Config", "Commit", "Prove", "Total", "Proof KiB");
for &pow in &[10, 15, 20, 25] {
for &rate in &[1, 2] {
for &ff in &[3, 4] {
let label = format!("pow={pow}, rate={rate}, ff={ff}");
bench_zk(&label, &vectors, pow, rate, ff, ff);
}
}
}
}
fn bench_nonzk(label: &str, vectors: &[Vec<Field256>], pow: usize, rate: usize, iff: usize, ff: usize) {
use provekit_whir::protocols::whir::Config;
let whir_params = ProtocolParameters {
security_level: 128,
pow_bits: pow,
initial_folding_factor: iff,
folding_factor: ff,
unique_decoding: false,
starting_log_inv_rate: rate,
batch_size: NUM_POLYNOMIALS,
hash_id: provekit_whir::hash::BLAKE3,
};
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let params = Config::<Identity<Field256>>::new(1 << NUM_VARIABLES, &whir_params);
let ds = DomainSeparator::protocol(¶ms)
.session(&String::from("sweep"))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let vec_refs: Vec<&[Field256]> = vectors.iter().map(Vec::as_slice).collect();
let points: Vec<Vec<Field256>> = (0..1)
.map(|_| (0..NUM_VARIABLES).map(|_| thread_rng().gen()).collect())
.collect();
let forms: Vec<MultilinearExtension<Field256>> = points
.iter()
.map(|p| MultilinearExtension::new(p.clone()))
.collect();
let evaluations: Vec<Field256> = forms
.iter()
.flat_map(|form| vec_refs.iter().map(|v| form.evaluate(params.embedding(), v)))
.collect();
let prove_forms: Vec<Box<dyn LinearForm<Field256>>> = forms
.iter()
.map(|f| Box::new(f.clone()) as Box<dyn LinearForm<Field256>>)
.collect();
let t_commit = Instant::now();
let witness = params.commit(&mut prover_state, &vec_refs);
let commit = t_commit.elapsed();
let t_prove = Instant::now();
let _ = params.prove(
&mut prover_state,
vectors.iter().map(|v| Cow::Borrowed(v.as_slice())).collect(),
vec![Cow::Owned(witness)],
prove_forms,
Cow::Borrowed(&evaluations),
);
let prove = t_prove.elapsed();
let proof = prover_state.proof();
let proof_size = (proof.narg_string.len() + proof.hints.len()) as f64 / 1024.0;
(commit, prove, proof_size)
}));
match result {
Ok((commit, prove, proof_kb)) => {
let total = commit + prove;
println!("{label:<45} {:>9.2?} {:>9.2?} {:>9.2?} {:>9.1}",
commit, prove, total, proof_kb);
}
Err(_) => println!("{label:<45} FAILED"),
}
}
fn bench_zk(label: &str, vectors: &[Vec<Field256>], pow: usize, rate: usize, iff: usize, ff: usize) {
use provekit_whir::protocols::whir_zk::Config;
let whir_params = ProtocolParameters {
security_level: 128,
pow_bits: pow,
initial_folding_factor: iff,
folding_factor: ff,
unique_decoding: false,
starting_log_inv_rate: rate,
batch_size: NUM_POLYNOMIALS,
hash_id: provekit_whir::hash::BLAKE3,
};
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let params = Config::<Field256>::new(NUM_VARIABLES, &whir_params);
let ds = DomainSeparator::protocol(¶ms)
.session(&String::from("sweep"))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let vec_refs: Vec<&[Field256]> = vectors.iter().map(Vec::as_slice).collect();
let embedding = params.embedding();
let points: Vec<Vec<Field256>> = (0..1)
.map(|_| (0..NUM_VARIABLES).map(|_| thread_rng().gen()).collect())
.collect();
let forms: Vec<MultilinearExtension<Field256>> = points
.iter()
.map(|p| MultilinearExtension::new(p.clone()))
.collect();
let evaluations: Vec<Field256> = forms
.iter()
.flat_map(|form| vec_refs.iter().map(|v| form.evaluate(embedding, v)))
.collect();
let prove_forms: Vec<Box<dyn LinearForm<Field256>>> = forms
.iter()
.map(|f| Box::new(f.clone()) as Box<dyn LinearForm<Field256>>)
.collect();
let t_commit = Instant::now();
let witness = params.commit(&mut prover_state, &vec_refs);
let commit = t_commit.elapsed();
let t_prove = Instant::now();
let _ = params.prove(
&mut prover_state,
vectors.iter().map(|v| Cow::Borrowed(v.as_slice())).collect(),
witness,
prove_forms,
Cow::Borrowed(&evaluations),
);
let prove = t_prove.elapsed();
let proof = prover_state.proof();
let proof_size = (proof.narg_string.len() + proof.hints.len()) as f64 / 1024.0;
(commit, prove, proof_size)
}));
match result {
Ok((commit, prove, proof_kb)) => {
let total = commit + prove;
println!("{label:<45} {:>9.2?} {:>9.2?} {:>9.2?} {:>9.1}",
commit, prove, total, proof_kb);
}
Err(_) => println!("{label:<45} FAILED"),
}
}