#![allow(type_alias_bounds)]
mod config;
mod prover;
pub(crate) mod rounds;
mod verifier;
use std::fmt::Debug;
use ark_ff::{FftField, Field};
use ark_std::rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
#[cfg(feature = "tracing")]
use tracing::instrument;
use crate::{
algebra::{
embedding::{Embedding, Identity},
linear_form::LinearForm,
},
hash::Hash,
protocols::{irs_commit, proof_of_work, sumcheck},
transcript::{
Codec, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, VerifierState,
},
utils::zip_strict,
verify,
};
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
#[serde(bound = "M: Embedding, M::Source: FftField, M::Target: FftField")]
pub struct Config<M>
where
M: Embedding,
M::Source: FftField,
M::Target: FftField,
{
pub initial_committer: irs_commit::Config<M>,
pub initial_sumcheck: sumcheck::Config<M::Target>,
pub initial_skip_pow: proof_of_work::Config,
pub round_configs: Vec<RoundConfig<M::Target>>,
pub final_sumcheck: sumcheck::Config<M::Target>,
pub final_pow: proof_of_work::Config,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(bound = "F: FftField")]
pub struct RoundConfig<F>
where
F: FftField,
{
pub irs_committer: irs_commit::Config<Identity<F>>,
pub sumcheck: sumcheck::Config<F>,
pub pow: proof_of_work::Config,
}
pub type Witness<F: FftField, M: Embedding<Target = F>> = irs_commit::Witness<M::Source, F>;
pub type Commitment<F: Field> = irs_commit::Commitment<F>;
#[must_use = "The final claim must be checked if there where any linear forms."]
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct FinalClaim<F: Field> {
pub evaluation_point: Vec<F>,
pub rlc_coefficients: Vec<F>,
pub linear_form_rlc: F,
}
impl<F: Field> FinalClaim<F> {
pub fn verify<'a>(
&'a self,
linear_forms: impl IntoIterator<Item = &'a dyn LinearForm<F>>,
) -> VerificationResult<()> {
let rlc = zip_strict(&self.rlc_coefficients, linear_forms)
.map(|(&c, l)| c * l.mle_evaluate(&self.evaluation_point))
.sum::<F>();
verify!(rlc == self.linear_form_rlc);
Ok(())
}
}
impl<M> Config<M>
where
M: Embedding,
M::Source: FftField,
M::Target: FftField,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, fields(size = vectors.first().unwrap().len())))]
pub fn commit<H, R>(
&self,
prover_state: &mut ProverState<H, R>,
vectors: &[&[M::Source]],
) -> Witness<M::Target, M>
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
M::Target: Codec<[H::U]>,
Hash: ProverMessage<[H::U]>,
{
self.initial_committer.commit(prover_state, vectors)
}
pub fn receive_commitment<H>(
&self,
verifier_state: &mut VerifierState<H>,
) -> VerificationResult<Commitment<M::Target>>
where
H: DuplexSpongeInterface,
M::Target: Codec<[H::U]>,
Hash: ProverMessage<[H::U]>,
{
self.initial_committer.receive_commitment(verifier_state)
}
#[cfg(test)]
pub(crate) fn disable_pow(&mut self) {
self.initial_sumcheck.round_pow.threshold = u64::MAX;
self.initial_skip_pow.threshold = u64::MAX;
for round in &mut self.round_configs {
round.sumcheck.round_pow.threshold = u64::MAX;
round.pow.threshold = u64::MAX;
}
self.final_sumcheck.round_pow.threshold = u64::MAX;
self.final_pow.threshold = u64::MAX;
}
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
use ark_ff::{Field, UniformRand};
use super::*;
use crate::{
algebra::{
embedding::Basefield,
fields::{Field64, Field64_3},
linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension},
MultilinearPoint,
},
hash,
parameters::ProtocolParameters,
transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState},
utils::test_serde,
};
type F = Field64;
type EF = Field64_3;
fn build_prove_forms<F: Field>(
points: &[MultilinearPoint<F>],
num_variables: usize,
include_covector: bool,
) -> Vec<Box<dyn LinearForm<F>>> {
let mut forms: Vec<Box<dyn LinearForm<F>>> = Vec::new();
for point in points {
forms.push(Box::new(MultilinearExtension {
point: point.0.clone(),
}));
}
if include_covector {
forms.push(Box::new(Covector {
vector: (0..1 << num_variables).map(F::from).collect(),
}));
}
forms
}
fn make_whir_things(
num_variables: usize,
initial_folding_factor: usize,
folding_factor: usize,
num_points: usize,
unique_decoding: bool,
pow_bits: usize,
) {
let num_coeffs = 1 << num_variables;
let mut rng = ark_std::test_rng();
let whir_params = ProtocolParameters {
security_level: 32,
pow_bits,
initial_folding_factor,
folding_factor,
unique_decoding,
starting_log_inv_rate: 1,
batch_size: 1,
hash_id: hash::SHA2,
};
let mut params = Config::<Basefield<EF>>::new(1 << num_variables, &whir_params);
params.disable_pow();
eprintln!("{params}");
test_serde(¶ms);
let vector = vec![F::ONE; num_coeffs];
let points: Vec<_> = (0..num_points)
.map(|_| MultilinearPoint::rand(&mut rng, num_variables))
.collect();
let mut linear_forms: Vec<Box<dyn LinearForm<EF>>> = Vec::new();
let mut evaluations = Vec::new();
for point in &points {
let linear_form = MultilinearExtension {
point: point.0.clone(),
};
evaluations.push(linear_form.evaluate(params.embedding(), &vector));
linear_forms.push(Box::new(linear_form));
}
let covector = Covector {
vector: (0..1 << num_variables).map(EF::from).collect(),
};
let sum = covector.evaluate(params.embedding(), &vector);
linear_forms.push(Box::new(covector));
evaluations.push(sum);
let ds = DomainSeparator::protocol(¶ms)
.session(&format!("Test at {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let witness = params.commit(&mut prover_state, &[&vector]);
let prove_linear_forms = build_prove_forms(&points, num_variables, true);
let _ = params.prove(
&mut prover_state,
vec![Cow::from(vector)],
vec![Cow::Owned(witness)],
prove_linear_forms,
Cow::Borrowed(evaluations.as_slice()),
);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let commitment = params.receive_commitment(&mut verifier_state).unwrap();
let final_claim = params
.verify(&mut verifier_state, &[&commitment], &evaluations)
.unwrap();
final_claim
.verify(
linear_forms
.iter()
.map(|l| l.as_ref() as &dyn LinearForm<EF>),
)
.unwrap();
}
#[test]
fn test_whir_1() {
for folding_factor in [1, 2, 3, 4] {
let num_variables = folding_factor..=3 * folding_factor;
for num_variable in num_variables {
for num_points in [0, 1, 2] {
for unique_decoding in [true, false] {
for pow_bits in [0, 5, 10] {
eprintln!();
dbg!(
folding_factor,
num_variable,
num_points,
unique_decoding,
pow_bits
);
make_whir_things(
num_variable,
folding_factor,
folding_factor,
num_points,
unique_decoding,
pow_bits,
);
}
}
}
}
}
}
#[test]
fn test_fail() {
make_whir_things(3, 2, 2, 0, false, 0);
}
#[test]
fn test_whir_mixed_folding_factors() {
let folding_factors = [1, 2, 3, 4];
let num_points = [0, 1, 2];
for initial_folding_factor in folding_factors {
for folding_factor in folding_factors {
if initial_folding_factor == folding_factor {
continue;
}
let n = std::cmp::max(initial_folding_factor, folding_factor);
let num_variables = n..=3 * n;
for num_variable in num_variables {
for num_points in num_points {
eprintln!();
dbg!(
initial_folding_factor,
folding_factor,
num_variable,
num_points,
);
make_whir_things(
num_variable,
initial_folding_factor,
folding_factor,
num_points,
false,
5,
);
}
}
}
}
}
fn make_whir_batch_things(
num_variables: usize,
initial_folding_factor: usize,
folding_factor: usize,
num_points_per_poly: usize,
num_vectors: usize,
unique_decoding: bool,
pow_bits: usize,
) {
let num_coeffs = 1 << num_variables;
let mut rng = ark_std::test_rng();
let whir_params = ProtocolParameters {
security_level: 32,
pow_bits,
initial_folding_factor,
folding_factor,
unique_decoding,
starting_log_inv_rate: 1,
batch_size: 1,
hash_id: hash::SHA2,
};
let mut params = Config::new(1 << num_variables, &whir_params);
params.disable_pow();
eprintln!("{params}");
let vectors: Vec<_> = (0..num_vectors)
.map(|i| {
vec![F::from((i + 1) as u64); num_coeffs]
})
.collect();
let vec_refs = vectors.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
let points: Vec<_> = (0..num_points_per_poly)
.map(|_| MultilinearPoint::rand(&mut rng, num_variables))
.collect();
let mut linear_forms: Vec<Box<dyn Evaluate<Basefield<EF>>>> = Vec::new();
for point in &points {
linear_forms.push(Box::new(MultilinearExtension {
point: point.0.clone(),
}));
}
linear_forms.push(Box::new(Covector {
vector: ((0..1 << num_variables).map(EF::from).collect()),
}));
let evaluations = linear_forms
.iter()
.flat_map(|linear_form| {
vec_refs
.iter()
.map(|vec| linear_form.evaluate(params.embedding(), vec))
})
.collect::<Vec<_>>();
let ds = DomainSeparator::protocol(¶ms)
.session(&format!("Test at {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let mut witnesses = Vec::new();
for &vec in &vec_refs {
let witness = params.commit(&mut prover_state, &[vec]);
witnesses.push(witness);
}
let prove_linear_forms = build_prove_forms(&points, num_variables, true);
let _ = params.prove(
&mut prover_state,
vectors
.iter()
.map(|v| Cow::Borrowed(v.as_slice()))
.collect(),
witnesses.into_iter().map(Cow::Owned).collect(),
prove_linear_forms,
Cow::Borrowed(evaluations.as_slice()),
);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let mut commitments = Vec::new();
for _ in 0..num_vectors {
let commitment = params.receive_commitment(&mut verifier_state).unwrap();
commitments.push(commitment);
}
let commitment_refs = commitments.iter().collect::<Vec<_>>();
let final_claim = params
.verify(&mut verifier_state, &commitment_refs, &evaluations)
.unwrap();
final_claim
.verify(
linear_forms
.iter()
.map(|l| l.as_ref() as &dyn LinearForm<EF>),
)
.unwrap();
}
#[test]
fn test_whir_batch_1() {
let folding_factors = [1, 2, 3, 4];
let num_polynomials = [2, 3, 4];
let num_points = [0, 1, 2];
for initial_folding_factor in folding_factors {
for folding_factor in folding_factors {
let n = std::cmp::max(initial_folding_factor, folding_factor);
for num_variables in (initial_folding_factor + folding_factor)..=3 * n {
for num_polys in num_polynomials {
for num_points_per_poly in num_points {
eprintln!();
dbg!(
initial_folding_factor,
folding_factor,
num_variables,
num_polys,
num_points_per_poly,
);
make_whir_batch_things(
num_variables,
initial_folding_factor,
folding_factor,
num_points_per_poly,
num_polys,
false,
0, );
}
}
}
}
}
}
#[test]
fn test_whir_batch_single_polynomial() {
make_whir_batch_things(
6, 2, 2, 2, 1, false, 0,
);
}
#[test]
#[cfg_attr(feature = "verifier_panics", should_panic)]
#[cfg_attr(
debug_assertions,
ignore = "debug_assert in prover panics on intentionally invalid input"
)]
fn test_whir_batch_rejects_invalid_constraint() {
let num_variables = 4;
let initial_folding_factor = 2;
let folding_factor = 2;
let num_polynomials = 2;
let num_coeffs = 1 << num_variables;
let mut rng = ark_std::test_rng();
let whir_params = ProtocolParameters {
security_level: 32,
pow_bits: 0,
initial_folding_factor,
folding_factor,
unique_decoding: false,
starting_log_inv_rate: 1,
batch_size: 1,
hash_id: hash::SHA2,
};
let mut params = Config::<Basefield<EF>>::new(1 << num_variables, &whir_params);
params.disable_pow();
let vec1 = vec![F::ONE; num_coeffs];
let vec2 = vec![F::from(2u64); num_coeffs];
let vec_wrong = vec![F::from(999u64); num_coeffs];
let constraint_points: Vec<_> = (0..2)
.map(|_| MultilinearPoint::rand(&mut rng, num_variables))
.collect();
let linear_forms: [Box<dyn Evaluate<Basefield<EF>>>; 2] = [
Box::new(MultilinearExtension {
point: constraint_points[0].0.clone(),
}),
Box::new(MultilinearExtension {
point: constraint_points[1].0.clone(),
}),
];
let evaluations = linear_forms
.iter()
.flat_map(|weights| {
[&vec1, &vec_wrong].map(|v| weights.evaluate(params.embedding(), v))
})
.collect::<Vec<_>>();
let ds = DomainSeparator::protocol(¶ms)
.session(&format!("Test at {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let witness1 = params.commit(&mut prover_state, &[&vec1]);
let witness2 = params.commit(&mut prover_state, &[&vec2]);
let prove_linear_forms = build_prove_forms(&constraint_points, num_variables, false);
let _ = params.prove(
&mut prover_state,
vec![Cow::Borrowed(vec1.as_slice()), Cow::from(vec_wrong)],
vec![Cow::Owned(witness1), Cow::Owned(witness2)],
prove_linear_forms,
Cow::Borrowed(evaluations.as_slice()),
);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let mut commitments = Vec::new();
for _ in 0..num_polynomials {
let parsed_commitment = params.receive_commitment(&mut verifier_state).unwrap();
commitments.push(parsed_commitment);
}
let final_claim = params
.verify(
&mut verifier_state,
&[&commitments[0], &commitments[1]],
&evaluations,
)
.unwrap();
let verifier_result = final_claim.verify(
linear_forms
.iter()
.map(|l| l.as_ref() as &dyn LinearForm<EF>),
);
assert!(
verifier_result.is_err(),
"Verifier should reject mismatched polynomial"
);
}
#[allow(clippy::too_many_arguments)]
fn make_whir_batch_with_batch_size(
num_variables: usize,
initial_folding_factor: usize,
folding_factor: usize,
num_points_per_poly: usize,
num_witnesses: usize,
batch_size: usize,
unique_decoding: bool,
pow_bits: usize,
) {
let num_coeffs = 1 << num_variables;
let mut rng = ark_std::test_rng();
let whir_params = ProtocolParameters {
security_level: 32,
pow_bits,
initial_folding_factor,
folding_factor,
unique_decoding,
starting_log_inv_rate: 1,
batch_size, hash_id: hash::SHA2,
};
let mut params = Config::<Basefield<EF>>::new(1 << num_variables, &whir_params);
params.disable_pow();
let all_vectors: Vec<Vec<F>> = (0..num_witnesses * batch_size)
.map(|i| vec![F::from((i + 1) as u64); num_coeffs])
.collect::<Vec<_>>();
let vec_refs = all_vectors.iter().map(|p| p.as_slice()).collect::<Vec<_>>();
let points: Vec<_> = (0..num_points_per_poly)
.map(|_| MultilinearPoint::rand(&mut rng, num_variables))
.collect();
let mut linear_forms: Vec<Box<dyn Evaluate<Basefield<EF>>>> = Vec::new();
for point in &points {
linear_forms.push(Box::new(MultilinearExtension {
point: point.0.clone(),
}));
}
linear_forms.push(Box::new(Covector {
vector: (0..1 << num_variables).map(EF::from).collect(),
}));
let evaluations = linear_forms
.iter()
.flat_map(|linear_form| {
vec_refs
.iter()
.map(|vec| linear_form.evaluate(params.embedding(), vec))
})
.collect::<Vec<_>>();
let ds = DomainSeparator::protocol(¶ms)
.session(&format!("Test at {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let mut witnesses = Vec::new();
for witness_polys in vec_refs.chunks(batch_size) {
let witness = params.commit(&mut prover_state, witness_polys);
witnesses.push(witness);
}
let prove_linear_forms = build_prove_forms(&points, num_variables, true);
let _ = params.prove(
&mut prover_state,
all_vectors
.iter()
.map(|v| Cow::Borrowed(v.as_slice()))
.collect(),
witnesses.into_iter().map(Cow::Owned).collect(),
prove_linear_forms,
Cow::Borrowed(evaluations.as_slice()),
);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let mut commitments = Vec::new();
for _ in 0..num_witnesses {
let commitment = params.receive_commitment(&mut verifier_state).unwrap();
commitments.push(commitment);
}
let commitment_refs = commitments.iter().collect::<Vec<_>>();
let final_claim = params
.verify(&mut verifier_state, &commitment_refs, &evaluations)
.unwrap();
final_claim
.verify(
linear_forms
.iter()
.map(|l| l.as_ref() as &dyn LinearForm<EF>),
)
.unwrap();
}
#[test]
fn test_whir_batch_with_batch_size_2() {
let batch_sizes = [2, 3];
let num_witnesses = [2, 3];
let folding_factors = [2, 3];
for batch_size in batch_sizes {
for num_witness in num_witnesses {
for folding_factor in folding_factors {
make_whir_batch_with_batch_size(
folding_factor * 2, folding_factor,
folding_factor,
1, num_witness,
batch_size,
false,
0, );
}
}
}
}
fn random_vector(num_coefficients: usize) -> Vec<F> {
let mut store = Vec::<F>::with_capacity(num_coefficients);
let mut rng = ark_std::rand::thread_rng();
(0..num_coefficients).for_each(|_| store.push(F::rand(&mut rng)));
store
}
fn make_batched_whir_things(
batch_size: usize,
num_variables: usize,
initial_folding_factor: usize,
folding_factor: usize,
num_points: usize,
unique_decoding: bool,
pow_bits: usize,
) {
eprintln!("\n---------------------");
eprintln!("Test parameters: ");
eprintln!(" num_vectors : {batch_size}");
eprintln!(" num_variables : {num_variables}");
eprintln!(" initial_folding : {initial_folding_factor}");
eprintln!(" folding_factor : {folding_factor}");
eprintln!(" num_points : {num_points:?}");
eprintln!(" unique_decoding : {unique_decoding:?}");
eprintln!(" pow_bits : {pow_bits}");
let num_coeffs = 1 << num_variables;
let mut rng = ark_std::test_rng();
let whir_params = ProtocolParameters {
security_level: 32,
pow_bits,
initial_folding_factor,
folding_factor,
unique_decoding,
starting_log_inv_rate: 1,
batch_size,
hash_id: hash::SHA2,
};
let mut params = Config::new(1 << num_variables, &whir_params);
params.disable_pow();
let vectors: Vec<Vec<F>> = (0..batch_size).map(|_| random_vector(num_coeffs)).collect();
let vec_refs = vectors.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
let points: Vec<_> = (0..num_points)
.map(|_| MultilinearPoint::rand(&mut rng, num_variables))
.collect();
let ds = DomainSeparator::protocol(¶ms)
.session(&format!("Test at {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let batched_witness = params.commit(&mut prover_state, &vec_refs);
let mut linear_forms: Vec<Box<dyn Evaluate<Basefield<F>>>> = Vec::new();
for point in &points {
linear_forms.push(Box::new(MultilinearExtension {
point: point.0.clone(),
}));
}
linear_forms.push(Box::new(Covector {
vector: (0..1 << num_variables).map(F::from).collect(),
}));
let values = linear_forms
.iter()
.flat_map(|linear_form| {
vec_refs
.iter()
.map(|vec| linear_form.evaluate(params.embedding(), vec))
})
.collect::<Vec<_>>();
let prove_linear_forms = build_prove_forms(&points, num_variables, true);
let weights_dyn_refs = linear_forms
.iter()
.map(|w| w.as_ref() as &dyn LinearForm<F>)
.collect::<Vec<_>>();
let _ = params.prove(
&mut prover_state,
vectors
.iter()
.map(|v| Cow::Borrowed(v.as_slice()))
.collect(),
vec![Cow::Owned(batched_witness)],
prove_linear_forms,
Cow::Borrowed(values.as_slice()),
);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let commitment = params.receive_commitment(&mut verifier_state).unwrap();
params
.verify(&mut verifier_state, &[&commitment], &values)
.unwrap()
.verify(weights_dyn_refs)
.unwrap();
}
#[test]
fn test_batched_whir() {
let folding_factors = [1, 4];
let unique_decoding_options = [false, true];
let num_points = [0, 2];
let pow_bits = [0, 10];
for folding_factor in folding_factors {
let num_variables = (2 * folding_factor)..=3 * folding_factor;
for num_variable in num_variables {
for num_points in num_points {
for unique_decoding in unique_decoding_options {
for pow_bits in pow_bits {
for batch_size in 1..=4 {
make_batched_whir_things(
batch_size,
num_variable,
folding_factor,
folding_factor,
num_points,
unique_decoding,
pow_bits,
);
}
}
}
}
}
}
}
}