use std::fmt;
use ark_ff::FftField;
use serde::{Deserialize, Serialize};
use crate::algebra::embedding::Embedding;
mod committer;
mod prover;
mod utils;
mod verifier;
pub use self::{committer::Witness, verifier::Commitments};
use crate::{
algebra::embedding::Identity,
bits::Bits,
parameters::ProtocolParameters,
protocols::{irs_commit, whir},
};
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
#[serde(bound = "")]
pub struct Config<F: FftField> {
pub blinded_polynomial: whir::Config<Identity<F>>,
pub blinding_polynomial: whir::Config<Identity<F>>,
}
impl<F: FftField> fmt::Display for Config<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "zkWHIR 2.0 — Alternative Randomness Sampling")?;
writeln!(f, "Blinded polynomial instance:")?;
write!(f, "{}", self.blinded_polynomial)?;
writeln!(f, "Blinding polynomial instance:")?;
write!(f, "{}", self.blinding_polynomial)
}
}
impl<F: FftField> Config<F> {
pub fn check_max_pow_bits(&self, bits: Bits) -> bool {
self.blinded_polynomial.check_max_pow_bits(bits)
}
#[cfg(test)]
pub(crate) fn disable_pow(&mut self) {
self.blinded_polynomial.disable_pow();
self.blinding_polynomial.disable_pow();
}
pub fn embedding(&self) -> &Identity<F> {
self.blinded_polynomial.embedding()
}
pub fn num_witness_variables(&self) -> usize {
self.blinded_polynomial.initial_num_variables()
}
pub fn security_levels(&self, num_vectors: usize, num_linear_forms: usize) -> (f64, f64) {
let num_blinding_vecs = self.blinding_polynomial.initial_committer.num_vectors;
let blinded_sec = self
.blinded_polynomial
.security_level(num_vectors, num_linear_forms);
let blinding_sec = self
.blinding_polynomial
.security_level(num_blinding_vecs, num_blinding_vecs);
(blinded_sec, blinding_sec)
}
pub fn new(num_variables_main: usize, params: &ProtocolParameters) -> Self {
assert!(
!params.unique_decoding,
"zkWHIR 2.0 requires list decoding (unique_decoding must be false). \
The protocol relies on OOD queries in Step 5 for blinding claim \
generation; unique decoding sets out_domain_samples = 0, making \
Commitment::num_vectors() undefined."
);
let blinded_config: whir::Config<Identity<F>> =
whir::Config::new(1 << num_variables_main, params);
assert!(
!blinded_config.round_configs.is_empty(),
"zkWHIR 2.0 requires at least one WHIR round \
(num_variables_main too small for folding_factor)"
);
let witness_sec = params.security_level.saturating_sub(params.pow_bits) as f64;
let blinding_sec = params.security_level as f64;
let witness_t_delta = blinded_config.final_sumcheck.initial_size;
let mut witness_leak = InstanceLeak::new(
params,
witness_sec,
&blinded_config.initial_committer,
num_variables_main,
);
witness_leak.t_delta = witness_t_delta;
let blinding_leak = InstanceLeak::new(
params,
blinding_sec,
&blinded_config.initial_committer,
num_variables_main,
);
let ell = ell_from_q_ub(query_upper_bound(&witness_leak, &blinding_leak));
assert!(
ell + 1 < num_variables_main,
"blinding variables ell+1={} must be < mu={num_variables_main}",
ell + 1
);
assert!(
ell >= params.initial_folding_factor,
"ell={ell} must be >= initial_folding_factor={} \
(parameters too aggressive for ZK sizing)",
params.initial_folding_factor
);
let nu = num_variables_main / ell;
let blinding_params = ProtocolParameters {
batch_size: params.batch_size + nu,
..*params
};
Self {
blinded_polynomial: blinded_config,
blinding_polynomial: whir::Config::new(1 << (ell + 1), &blinding_params),
}
}
}
const MAX_SUMCHECK_DEGREE: usize = 3;
struct InstanceLeak {
k: usize,
mu: usize,
d: usize,
q_delta: usize,
stir_delta: usize,
ood_delta: usize,
t_delta: usize,
}
impl InstanceLeak {
fn new<M>(
params: &ProtocolParameters,
security_target: f64,
irs_config: &irs_commit::Config<M>,
num_variables: usize,
) -> Self
where
M: Embedding,
M::Source: FftField,
M::Target: FftField,
{
#[allow(clippy::cast_possible_wrap)]
let rate = 0.5_f64.powi(params.starting_log_inv_rate as i32);
let (q, stir) = Self::query_counts(
params.unique_decoding,
security_target,
rate,
params.initial_folding_factor,
);
Self {
k: 1 << params.initial_folding_factor,
mu: num_variables,
d: MAX_SUMCHECK_DEGREE,
q_delta: q,
stir_delta: stir,
ood_delta: irs_config.out_domain_samples,
t_delta: q, }
}
const fn leak(&self) -> usize {
self.k
.saturating_mul(self.q_delta.saturating_add(self.stir_delta))
.saturating_add(self.t_delta)
.saturating_add(self.ood_delta)
.saturating_add(self.d.saturating_add(1).saturating_mul(self.mu))
}
#[allow(clippy::cast_sign_loss)]
fn query_counts(
unique_decoding: bool,
security_target: f64,
rate: f64,
folding_factor: usize,
) -> (usize, usize) {
let q = irs_commit::num_in_domain_queries(unique_decoding, security_target, rate);
let s = folding_factor as f64;
let slack = irs_commit::johnson_slack(unique_decoding, rate);
let per_sample = if unique_decoding {
f64::midpoint(1., rate)
} else {
rate.sqrt() + slack
};
let stir = (security_target / (s + (-per_sample.log2()))).ceil() as usize;
(q, stir)
}
}
const fn query_upper_bound(witness: &InstanceLeak, blinding: &InstanceLeak) -> usize {
witness.leak().saturating_add(blinding.leak())
}
const fn ell_from_q_ub(q_ub: usize) -> usize {
assert!(q_ub > 0, "query upper bound must be positive");
(usize::BITS - q_ub.leading_zeros()) as usize
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
use ark_ff::{AdditiveGroup, Field};
use super::Config;
use crate::{
algebra::{
fields::Field64,
linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension},
MultilinearPoint,
},
hash,
parameters::ProtocolParameters,
transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState},
};
type F = Field64;
const TEST_NUM_VARIABLES: usize = 12;
const TEST_NUM_COEFFS: usize = 1 << TEST_NUM_VARIABLES;
fn make_test_config() -> Config<F> {
make_test_config_batch(1)
}
fn to_prove_forms(
forms: &[Box<dyn LinearForm<F>>],
size: usize,
) -> Vec<Box<dyn LinearForm<F>>> {
forms
.iter()
.map(|f| {
let mut cv = vec![F::ZERO; size];
f.accumulate(&mut cv, F::ONE);
Box::new(Covector::new(cv)) as Box<dyn LinearForm<F>>
})
.collect()
}
#[allow(clippy::needless_pass_by_value)]
fn prove_and_verify(
config: &Config<F>,
vectors: Vec<Vec<F>>,
forms: Vec<Box<dyn LinearForm<F>>>,
evaluations: &[F],
) {
let prove_forms = to_prove_forms(forms.as_slice(), vectors[0].len());
let ds = DomainSeparator::protocol(config)
.session(&format!("zk2-pv {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let poly_refs: Vec<&[F]> = vectors.iter().map(|v| v.as_slice()).collect();
let witness = config.commit(&mut prover_state, &poly_refs);
config.prove(
&mut prover_state,
vectors.into_iter().map(Cow::Owned).collect(),
witness,
prove_forms,
Cow::Borrowed(evaluations),
);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let commitments = config
.receive_commitments(&mut verifier_state)
.expect("receive_commitments failed");
let weight_refs: Vec<&dyn LinearForm<F>> = forms
.iter()
.map(|f| f.as_ref() as &dyn LinearForm<F>)
.collect();
config
.verify(&mut verifier_state, &weight_refs, evaluations, &commitments)
.expect("verification failed")
.verify(weight_refs)
.expect("blinded polynomial final claim check failed");
}
#[test]
fn test_zk_prove_verify_single_point() {
let mut rng = ark_std::test_rng();
let config = make_test_config();
let vector = vec![F::ONE; TEST_NUM_COEFFS];
let point = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let form = MultilinearExtension { point: point.0 };
let evaluation = form.evaluate(config.embedding(), &vector);
prove_and_verify(&config, vec![vector], vec![Box::new(form)], &[evaluation]);
}
#[test]
fn test_zk_prove_verify_multiple_points() {
let mut rng = ark_std::test_rng();
let config = make_test_config();
let vector: Vec<F> = (0..TEST_NUM_COEFFS)
.map(|i| F::from(i as u64 + 1))
.collect();
let p0 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let p1 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let f0 = MultilinearExtension { point: p0.0 };
let f1 = MultilinearExtension { point: p1.0 };
let embedding = config.embedding();
let eval0 = f0.evaluate(embedding, &vector);
let eval1 = f1.evaluate(embedding, &vector);
prove_and_verify(
&config,
vec![vector],
vec![Box::new(f0), Box::new(f1)],
&[eval0, eval1],
);
}
#[test]
fn test_zk_prove_verify_with_covector() {
let mut rng = ark_std::test_rng();
let config = make_test_config();
let vector = vec![F::ONE; TEST_NUM_COEFFS];
let point = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let mle_form = MultilinearExtension { point: point.0 };
let embedding = config.embedding();
let mle_eval = mle_form.evaluate(embedding, &vector);
let cov = Covector::new((0..TEST_NUM_COEFFS).map(|i| F::from(i as u64)).collect());
let cov_eval = cov.evaluate(embedding, &vector);
prove_and_verify(
&config,
vec![vector],
vec![Box::new(mle_form), Box::new(cov)],
&[mle_eval, cov_eval],
);
}
fn make_test_config_batch(batch_size: usize) -> Config<F> {
let whir_params = ProtocolParameters {
unique_decoding: false,
security_level: 16,
pow_bits: 0,
initial_folding_factor: 2,
folding_factor: 2,
starting_log_inv_rate: 1,
batch_size,
hash_id: hash::SHA2,
};
let mut config = Config::new(TEST_NUM_VARIABLES, &whir_params);
config.disable_pow();
config
}
#[test]
fn test_zk_prove_verify_nonzero_rem() {
const NUM_VARS: usize = 14;
const NUM_COEFFS: usize = 1 << NUM_VARS;
let mut rng = ark_std::test_rng();
let whir_params = ProtocolParameters {
unique_decoding: false,
security_level: 16,
pow_bits: 0,
initial_folding_factor: 2,
folding_factor: 2,
starting_log_inv_rate: 1,
batch_size: 1,
hash_id: hash::SHA2,
};
let mut config = Config::new(NUM_VARS, &whir_params);
config.disable_pow();
let ell = config.blinding_polynomial.initial_num_variables() - 1;
assert_ne!(NUM_VARS % ell, 0, "test requires non-zero rem");
let vector = vec![F::ONE; NUM_COEFFS];
let point = MultilinearPoint::rand(&mut rng, NUM_VARS);
let form = MultilinearExtension { point: point.0 };
let evaluation = form.evaluate(config.embedding(), &vector);
prove_and_verify(&config, vec![vector], vec![Box::new(form)], &[evaluation]);
}
#[test]
fn test_zk_prove_verify_multi_vector() {
let mut rng = ark_std::test_rng();
let config = make_test_config_batch(2);
let v0: Vec<F> = (0..TEST_NUM_COEFFS)
.map(|i| F::from(i as u64 + 1))
.collect();
let v1: Vec<F> = (0..TEST_NUM_COEFFS)
.map(|i| F::from(i as u64 * 3 + 7))
.collect();
let p0 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let f0 = MultilinearExtension { point: p0.0 };
let embedding = config.embedding();
let eval_0_0 = f0.evaluate(embedding, &v0);
let eval_0_1 = f0.evaluate(embedding, &v1);
prove_and_verify(
&config,
vec![v0, v1],
vec![Box::new(f0)],
&[eval_0_0, eval_0_1],
);
}
#[test]
fn test_zk_prove_verify_multi_vector_multi_form() {
let mut rng = ark_std::test_rng();
let config = make_test_config_batch(2);
let v0: Vec<F> = (0..TEST_NUM_COEFFS)
.map(|i| F::from(i as u64 + 1))
.collect();
let v1: Vec<F> = (0..TEST_NUM_COEFFS)
.map(|i| F::from(i as u64 * 3 + 7))
.collect();
let p0 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let p1 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let f0 = MultilinearExtension { point: p0.0 };
let f1 = MultilinearExtension { point: p1.0 };
let embedding = config.embedding();
let eval_0_0 = f0.evaluate(embedding, &v0);
let eval_0_1 = f0.evaluate(embedding, &v1);
let eval_1_0 = f1.evaluate(embedding, &v0);
let eval_1_1 = f1.evaluate(embedding, &v1);
prove_and_verify(
&config,
vec![v0, v1],
vec![Box::new(f0), Box::new(f1)],
&[eval_0_0, eval_0_1, eval_1_0, eval_1_1],
);
}
#[test]
fn test_zk_rejects_wrong_evaluations() {
let mut rng = ark_std::test_rng();
let config = make_test_config();
let vector: Vec<F> = (0..TEST_NUM_COEFFS)
.map(|i| F::from(i as u64 + 1))
.collect();
let p0 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let p1 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let f0 = MultilinearExtension { point: p0.0 };
let f1 = MultilinearExtension { point: p1.0 };
let embedding = config.embedding();
let evaluations = vec![
f0.evaluate(embedding, &vector),
f1.evaluate(embedding, &vector),
];
let forms: Vec<Box<dyn LinearForm<F>>> = vec![Box::new(f0), Box::new(f1)];
let prove_forms = to_prove_forms(&forms, vector.len());
let ds = DomainSeparator::protocol(&config)
.session(&format!("zk2-wrong-eval {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let witness = config.commit(&mut prover_state, &[&vector]);
config.prove(
&mut prover_state,
vec![Cow::Owned(vector)],
witness,
prove_forms,
Cow::Borrowed(&evaluations),
);
let proof = prover_state.proof();
let mut wrong_evaluations = evaluations;
wrong_evaluations[0] += F::ONE;
let weight_refs: Vec<&dyn LinearForm<F>> = forms
.iter()
.map(|f| f.as_ref() as &dyn LinearForm<F>)
.collect();
let verify_outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let commitments = config
.receive_commitments(&mut verifier_state)
.expect("receive_commitments");
config
.verify(
&mut verifier_state,
&weight_refs,
&wrong_evaluations,
&commitments,
)?
.verify(weight_refs.iter().copied())
}));
if let Ok(result) = verify_outcome {
assert!(
result.is_err(),
"verification should reject wrong public evaluations"
);
}
}
#[test]
fn test_zk_rejects_tampered_proof() {
let mut rng = ark_std::test_rng();
let config = make_test_config();
let vector: Vec<F> = (0..TEST_NUM_COEFFS)
.map(|i| F::from(i as u64 + 1))
.collect();
let p0 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let p1 = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let f0 = MultilinearExtension { point: p0.0 };
let f1 = MultilinearExtension { point: p1.0 };
let embedding = config.embedding();
let evaluations = vec![
f0.evaluate(embedding, &vector),
f1.evaluate(embedding, &vector),
];
let forms: Vec<Box<dyn LinearForm<F>>> = vec![Box::new(f0), Box::new(f1)];
let prove_forms = to_prove_forms(&forms, vector.len());
let ds = DomainSeparator::protocol(&config)
.session(&format!("zk2-tamper {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let witness = config.commit(&mut prover_state, &[&vector]);
config.prove(
&mut prover_state,
vec![Cow::Owned(vector)],
witness,
prove_forms,
Cow::Borrowed(&evaluations),
);
let mut tampered_proof = prover_state.proof();
if let Some(last) = tampered_proof.narg_string.last_mut() {
*last ^= 1;
} else {
panic!("expected non-empty proof transcript");
}
let weight_refs: Vec<&dyn LinearForm<F>> = forms
.iter()
.map(|f| f.as_ref() as &dyn LinearForm<F>)
.collect();
let verify_outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut verifier_state = VerifierState::new_std(&ds, &tampered_proof);
let commitments = config
.receive_commitments(&mut verifier_state)
.expect("receive_commitments");
config
.verify(
&mut verifier_state,
&weight_refs,
&evaluations,
&commitments,
)?
.verify(weight_refs.iter().copied())
}));
if let Ok(result) = verify_outcome {
assert!(
result.is_err(),
"verification should reject tampered proof bytes"
);
}
}
#[test]
fn test_zk_malicious_prover_wrong_evaluation() {
let mut rng = ark_std::test_rng();
let config = make_test_config();
let vector = vec![F::ONE; TEST_NUM_COEFFS];
let point = MultilinearPoint::rand(&mut rng, TEST_NUM_VARIABLES);
let form = MultilinearExtension { point: point.0 };
let correct_evaluation = form.evaluate(config.embedding(), &vector);
let wrong_evaluation = correct_evaluation + F::from(42u64);
let forms: Vec<Box<dyn LinearForm<F>>> = vec![Box::new(form)];
let prove_forms = to_prove_forms(&forms, vector.len());
let weight_refs: Vec<&dyn LinearForm<F>> = forms
.iter()
.map(|f| f.as_ref() as &dyn LinearForm<F>)
.collect();
let ds = DomainSeparator::protocol(&config)
.session(&format!("zk2-malicious {}:{}", file!(), line!()))
.instance(&Empty);
let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut prover_state = ProverState::new_std(&ds);
let witness = config.commit(&mut prover_state, &[&vector]);
config.prove(
&mut prover_state,
vec![Cow::Borrowed(vector.as_slice())],
witness,
prove_forms,
Cow::Owned(vec![wrong_evaluation]),
);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let commitments = config
.receive_commitments(&mut verifier_state)
.expect("receive_commitments");
config
.verify(
&mut verifier_state,
&weight_refs,
&[wrong_evaluation],
&commitments,
)?
.verify(weight_refs.iter().copied())
}));
if let Ok(result) = outcome {
assert!(
result.is_err(),
"SOUNDNESS BUG: verifier accepted wrong evaluation from malicious prover \
(correct={correct_evaluation:?}, claimed={wrong_evaluation:?})"
);
}
}
#[test]
#[should_panic(expected = "zkWHIR 2.0 requires list decoding")]
fn test_zk_unique_decoding_unsupported() {
let whir_params = ProtocolParameters {
unique_decoding: true,
security_level: 32,
pow_bits: 0,
initial_folding_factor: 2,
folding_factor: 2,
starting_log_inv_rate: 1,
batch_size: 1,
hash_id: hash::SHA2,
};
Config::<F>::new(TEST_NUM_VARIABLES, &whir_params);
}
#[test]
fn test_zk_prove_verify_zero_rem() {
const NUM_VARS: usize = 20;
const NUM_COEFFS: usize = 1 << NUM_VARS;
let mut rng = ark_std::test_rng();
let whir_params = ProtocolParameters {
unique_decoding: false,
security_level: 16,
pow_bits: 0,
initial_folding_factor: 2,
folding_factor: 2,
starting_log_inv_rate: 1,
batch_size: 1,
hash_id: hash::SHA2,
};
let mut config = Config::new(NUM_VARS, &whir_params);
config.disable_pow();
let ell = config.blinding_polynomial.initial_num_variables() - 1;
assert_eq!(NUM_VARS % ell, 0, "test requires rem == 0");
let vector = vec![F::ONE; NUM_COEFFS];
let point = MultilinearPoint::rand(&mut rng, NUM_VARS);
let form = MultilinearExtension { point: point.0 };
let evaluation = form.evaluate(config.embedding(), &vector);
prove_and_verify(&config, vec![vector], vec![Box::new(form)], &[evaluation]);
}
}