use std::{borrow::Cow, time::Instant};
use ark_ff::FftField;
use ark_std::rand::{distributions::Standard, prelude::Distribution, thread_rng, Rng};
use provekit_whir::{
algebra::{
fields::Field256,
linear_form::{Evaluate, LinearForm, MultilinearExtension},
},
hash::HASH_COUNTER,
parameters::ProtocolParameters,
transcript::{codecs::Empty, Codec, DomainSeparator, ProverState, VerifierState},
};
const NUM_POLYNOMIALS: usize = 40;
const NUM_VARIABLES: usize = 19; const NUM_COEFFS: usize = 1 << NUM_VARIABLES;
const NUM_EVAL_POINTS: usize = 1;
fn main() {
run_zk_batch::<Field256>();
}
fn run_zk_batch<F>()
where
F: FftField + Codec,
Standard: Distribution<F>,
{
use provekit_whir::protocols::whir_zk::Config;
let mut rng = thread_rng();
let security_level: usize = 128;
let pow_bits = 10;
let starting_rate = 1;
let initial_folding_factor = 4;
let folding_factor = 4;
let whir_params = ProtocolParameters {
security_level,
pow_bits,
initial_folding_factor,
folding_factor,
unique_decoding: false,
starting_log_inv_rate: starting_rate,
batch_size: NUM_POLYNOMIALS,
hash_id: provekit_whir::hash::BLAKE3,
};
let params = Config::<F>::new(NUM_VARIABLES, &whir_params);
println!("=========================================");
println!("zk_whir Batch Commit Example");
println!("=========================================");
println!("Polynomials: {NUM_POLYNOMIALS}");
println!("Variables: {NUM_VARIABLES} (size = {NUM_COEFFS})");
println!("Eval points: {NUM_EVAL_POINTS}");
println!("Security level: {security_level}");
println!("PoW bits: {pow_bits}");
println!("Field: Goldilocks^3 (192-bit)");
println!("Hash: Blake3");
println!("{params}");
let vectors: Vec<Vec<F>> = (0..NUM_POLYNOMIALS)
.map(|_| (0..NUM_COEFFS).map(|_| rng.gen()).collect())
.collect();
let vec_refs: Vec<&[F]> = vectors.iter().map(Vec::as_slice).collect();
let embedding = params.embedding();
let points: Vec<Vec<F>> = (0..NUM_EVAL_POINTS)
.map(|_| (0..NUM_VARIABLES).map(|_| rng.gen()).collect())
.collect();
let forms: Vec<MultilinearExtension<F>> = points
.iter()
.map(|p| MultilinearExtension::new(p.clone()))
.collect();
let evaluations: Vec<F> = forms
.iter()
.flat_map(|form| vec_refs.iter().map(|v| form.evaluate(embedding, v)))
.collect();
let prove_forms: Vec<Box<dyn LinearForm<F>>> = forms
.iter()
.map(|f| Box::new(f.clone()) as Box<dyn LinearForm<F>>)
.collect();
let ds = DomainSeparator::protocol(¶ms)
.session(&format!(
"zk_batch_commit_42x2^19 at {}:{}",
file!(),
line!()
))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
println!("-----------------------------------------");
println!("Committing to {NUM_POLYNOMIALS} polynomials...");
let commit_time = Instant::now();
let witness = params.commit(&mut prover_state, &vec_refs);
let commit_time = commit_time.elapsed();
println!("Commit time: {commit_time:.2?}");
println!("Proving...");
let prove_time = Instant::now();
let _ = params.prove(
&mut prover_state,
vectors
.iter()
.map(|v| Cow::Borrowed(v.as_slice()))
.collect(),
witness,
prove_forms,
Cow::Borrowed(&evaluations),
);
let prove_time = prove_time.elapsed();
println!("Prove time: {prove_time:.2?}");
let proof = prover_state.proof();
let proof_size = proof.narg_string.len() + proof.hints.len();
println!(
"Proof size: {:.1} KiB ({} bytes)",
proof_size as f64 / 1024.0,
proof_size,
);
println!("Total prover: {:.2?}", commit_time + prove_time,);
let verify_forms: Vec<&dyn LinearForm<F>> =
forms.iter().map(|f| f as &dyn LinearForm<F>).collect();
HASH_COUNTER.reset();
let reps = 10;
println!("-----------------------------------------");
println!("Verifying ({reps} repetitions)...");
let verify_time = Instant::now();
for _ in 0..reps {
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let commitment = params
.receive_commitments(&mut verifier_state)
.expect("receive commitments");
params
.verify(
&mut verifier_state,
&verify_forms,
&evaluations,
&commitment,
)
.expect("verification failed")
.verify(verify_forms.iter().copied())
.expect("final claim check failed");
}
let verify_elapsed = verify_time.elapsed();
println!(
"Verify time: {:.2?} (avg over {reps})",
verify_elapsed / reps as u32,
);
println!(
"Avg hashes: {:.1}k",
(HASH_COUNTER.get() as f64 / reps as f64) / 1000.0,
);
println!("=========================================");
println!("All verifications passed.");
}