use crate::{
data_structures::finalize,
ligero::ligero_commit,
sumcheck_polys::induce_sumcheck_poly,
transcript::{FiatShamir, Transcript},
utils::{eval_sk_at_vks, partial_eval_multilinear},
FinalLigeroProof, FinalizedLigeritoProof, LigeritoProof, ProverConfig,
RecursiveLigeroCommitment, RecursiveLigeroProof, SumcheckTranscript,
};
use binary_fields::BinaryFieldElement;
#[cfg(feature = "parallel")]
use crate::sumcheck_polys::induce_sumcheck_poly_parallel;
#[cfg(feature = "parallel")]
#[inline(always)]
fn induce_sumcheck_poly_auto<T, U>(
n: usize,
sks_vks: &[U],
opened_rows: &[Vec<U>],
v_challenges: &[U],
sorted_queries: &[usize],
alpha: U,
) -> (Vec<U>, U)
where
T: BinaryFieldElement + Send + Sync,
U: BinaryFieldElement + Send + Sync + From<T>,
{
induce_sumcheck_poly_parallel(n, sks_vks, opened_rows, v_challenges, sorted_queries, alpha)
}
#[cfg(not(feature = "parallel"))]
#[inline(always)]
fn induce_sumcheck_poly_auto<T, U>(
n: usize,
sks_vks: &[U],
opened_rows: &[Vec<U>],
v_challenges: &[U],
sorted_queries: &[usize],
alpha: U,
) -> (Vec<U>, U)
where
T: BinaryFieldElement,
U: BinaryFieldElement + From<T>,
{
induce_sumcheck_poly(n, sks_vks, opened_rows, v_challenges, sorted_queries, alpha)
}
fn prove_core<T, U>(
config: &ProverConfig<T, U>,
poly: &[T],
wtns_0: crate::data_structures::RecursiveLigeroWitness<T>,
cm_0: RecursiveLigeroCommitment,
fs: &mut impl Transcript,
) -> crate::Result<FinalizedLigeritoProof<T, U>>
where
T: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
U: BinaryFieldElement + Send + Sync + From<T> + bytemuck::Pod + 'static,
{
let mut proof = LigeritoProof::<T, U>::new();
proof.initial_ligero_cm = Some(cm_0);
let partial_evals_0: Vec<T> = (0..config.initial_k).map(|_| fs.get_challenge()).collect();
let mut f_evals = poly.to_vec();
partial_eval_multilinear(&mut f_evals, &partial_evals_0);
let partial_evals_0_u: Vec<U> = partial_evals_0.iter().map(|&x| U::from(x)).collect();
let f_evals_u: Vec<U> = f_evals.iter().map(|&x| U::from(x)).collect();
let wtns_1 = ligero_commit(
&f_evals_u,
config.dims[0].0,
config.dims[0].1,
&config.reed_solomon_codes[0],
);
let cm_1 = RecursiveLigeroCommitment {
root: wtns_1.tree.get_root(),
};
proof.recursive_commitments.push(cm_1.clone());
fs.absorb_root(&cm_1.root);
let rows = wtns_0.mat.len();
let queries = fs.get_distinct_queries(rows, config.num_queries); let alpha = fs.get_challenge::<U>();
let n = f_evals.len().trailing_zeros() as usize;
let sks_vks: Vec<T> = eval_sk_at_vks(1 << n);
let opened_rows: Vec<Vec<T>> = queries.iter().map(|&q| wtns_0.mat[q].clone()).collect();
let mtree_proof = wtns_0.tree.prove(&queries); proof.initial_ligero_proof = Some(RecursiveLigeroProof {
opened_rows: opened_rows.clone(),
merkle_proof: mtree_proof,
});
let (basis_poly, enforced_sum) = induce_sumcheck_poly(
n,
&sks_vks,
&opened_rows,
&partial_evals_0_u,
&queries,
alpha,
);
let mut sumcheck_transcript = vec![];
let mut current_poly = basis_poly;
let mut current_sum = enforced_sum;
fs.absorb_elem(current_sum);
let mut wtns_prev = wtns_1;
for i in 0..config.recursive_steps {
let mut rs = Vec::new();
for _ in 0..config.ks[i] {
let coeffs = compute_sumcheck_coefficients(¤t_poly);
sumcheck_transcript.push(coeffs);
let ri = fs.get_challenge::<U>();
rs.push(ri);
current_poly = fold_polynomial_with_challenge(¤t_poly, ri);
current_sum = evaluate_quadratic(coeffs, ri);
fs.absorb_elem(current_sum);
}
if i == config.recursive_steps - 1 {
fs.absorb_elems(¤t_poly);
let rows = wtns_prev.mat.len();
let queries = fs.get_distinct_queries(rows, config.num_queries);
let opened_rows: Vec<Vec<U>> =
queries.iter().map(|&q| wtns_prev.mat[q].clone()).collect();
let mtree_proof = wtns_prev.tree.prove(&queries);
proof.final_ligero_proof = Some(FinalLigeroProof {
yr: current_poly.clone(),
opened_rows,
merkle_proof: mtree_proof,
});
proof.sumcheck_transcript = Some(SumcheckTranscript {
transcript: sumcheck_transcript,
});
return finalize(proof);
}
let wtns_next = ligero_commit(
¤t_poly,
config.dims[i + 1].0,
config.dims[i + 1].1,
&config.reed_solomon_codes[i + 1],
);
let cm_next = RecursiveLigeroCommitment {
root: wtns_next.tree.get_root(),
};
proof.recursive_commitments.push(cm_next.clone());
fs.absorb_root(&cm_next.root);
let rows = wtns_prev.mat.len();
let queries = fs.get_distinct_queries(rows, config.num_queries); let alpha = fs.get_challenge::<U>();
let opened_rows: Vec<Vec<U>> = queries.iter().map(|&q| wtns_prev.mat[q].clone()).collect();
let mtree_proof = wtns_prev.tree.prove(&queries); proof.recursive_proofs.push(RecursiveLigeroProof {
opened_rows: opened_rows.clone(),
merkle_proof: mtree_proof,
});
let n = current_poly.len().trailing_zeros() as usize;
let sks_vks: Vec<U> = eval_sk_at_vks(1 << n);
let (basis_poly, enforced_sum) =
induce_sumcheck_poly_auto::<U, U>(n, &sks_vks, &opened_rows, &rs, &queries, alpha);
let glue_sum = current_sum.add(&enforced_sum);
fs.absorb_elem(glue_sum);
let beta = fs.get_challenge::<U>();
current_poly = glue_polynomials(¤t_poly, &basis_poly, beta);
current_sum = glue_sums(current_sum, enforced_sum, beta);
wtns_prev = wtns_next;
}
unreachable!("Should have returned in final round");
}
pub fn prove_with_transcript<T, U>(
config: &ProverConfig<T, U>,
poly: &[T],
mut fs: impl Transcript,
) -> crate::Result<FinalizedLigeritoProof<T, U>>
where
T: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
U: BinaryFieldElement + Send + Sync + From<T> + bytemuck::Pod + 'static,
{
config.validate()?;
let wtns_0 = ligero_commit(
poly,
config.initial_dims.0,
config.initial_dims.1,
&config.initial_reed_solomon,
);
let cm_0 = RecursiveLigeroCommitment {
root: wtns_0.tree.get_root(),
};
fs.absorb_root(&cm_0.root);
prove_core(config, poly, wtns_0, cm_0, &mut fs)
}
pub fn prove_with_evaluations<T, U>(
config: &ProverConfig<T, U>,
poly: &[T],
claims: &[crate::eval_proof::EvalClaim<T>],
mut fs: impl Transcript,
) -> crate::Result<FinalizedLigeritoProof<T, U>>
where
T: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
U: BinaryFieldElement + Send + Sync + From<T> + bytemuck::Pod + 'static,
{
config.validate()?;
let wtns_0 = ligero_commit(
poly,
config.initial_dims.0,
config.initial_dims.1,
&config.initial_reed_solomon,
);
let cm_0 = RecursiveLigeroCommitment {
root: wtns_0.tree.get_root(),
};
fs.absorb_root(&cm_0.root);
let n = poly.len().trailing_zeros() as usize;
let alphas: Vec<U> = (0..claims.len()).map(|_| fs.get_challenge()).collect();
let (eval_rounds, _, _) =
crate::eval_proof::eval_sumcheck_prove::<T, U>(poly, claims, &alphas, n, &mut fs);
let mut proof = prove_core(config, poly, wtns_0, cm_0, &mut fs)?;
proof.eval_rounds = eval_rounds;
Ok(proof)
}
pub fn prove<T, U>(
config: &ProverConfig<T, U>,
poly: &[T],
) -> crate::Result<FinalizedLigeritoProof<T, U>>
where
T: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
U: BinaryFieldElement + Send + Sync + From<T> + bytemuck::Pod + 'static,
{
#[cfg(feature = "transcript-merlin")]
let fs = FiatShamir::new_merlin();
#[cfg(not(feature = "transcript-merlin"))]
let fs = FiatShamir::new_sha256(0);
prove_with_transcript(config, poly, fs)
}
pub fn prove_sha256<T, U>(
config: &ProverConfig<T, U>,
poly: &[T],
) -> crate::Result<FinalizedLigeritoProof<T, U>>
where
T: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
U: BinaryFieldElement + Send + Sync + From<T> + bytemuck::Pod + 'static,
{
let fs = FiatShamir::new_sha256(1234);
prove_with_transcript(config, poly, fs)
}
#[cfg(feature = "transcript-blake2b")]
pub fn prove_blake2b<T, U>(
config: &ProverConfig<T, U>,
poly: &[T],
) -> crate::Result<FinalizedLigeritoProof<T, U>>
where
T: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
U: BinaryFieldElement + Send + Sync + From<T> + bytemuck::Pod + 'static,
{
let fs = FiatShamir::new_blake2b();
prove_with_transcript(config, poly, fs)
}
pub fn prove_debug<T, U>(
config: &ProverConfig<T, U>,
poly: &[T],
) -> crate::Result<FinalizedLigeritoProof<T, U>>
where
T: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
U: BinaryFieldElement + Send + Sync + From<T> + bytemuck::Pod + 'static,
{
println!("\n=== PROVER DEBUG ===");
#[cfg(feature = "transcript-merlin")]
let mut fs = FiatShamir::new_merlin();
#[cfg(not(feature = "transcript-merlin"))]
let mut fs = FiatShamir::new_sha256(0);
let mut proof = LigeritoProof::<T, U>::new();
println!("Creating initial commitment...");
let wtns_0 = ligero_commit(
poly,
config.initial_dims.0,
config.initial_dims.1,
&config.initial_reed_solomon,
);
let cm_0 = RecursiveLigeroCommitment {
root: wtns_0.tree.get_root(),
};
proof.initial_ligero_cm = Some(cm_0.clone());
fs.absorb_root(&cm_0.root);
println!("Initial commitment root: {:?}", cm_0.root);
let partial_evals_0: Vec<T> = (0..config.initial_k)
.map(|i| {
let challenge = fs.get_challenge();
println!("Initial challenge {}: {:?}", i, challenge);
challenge
})
.collect();
println!("\nPerforming partial evaluation...");
let mut f_evals = poly.to_vec();
partial_eval_multilinear(&mut f_evals, &partial_evals_0);
println!("Partial eval complete, new size: {}", f_evals.len());
let partial_evals_0_u: Vec<U> = partial_evals_0.iter().map(|&x| U::from(x)).collect();
let f_evals_u: Vec<U> = f_evals.iter().map(|&x| U::from(x)).collect();
println!("\nFirst recursive step...");
let wtns_1 = ligero_commit(
&f_evals_u,
config.dims[0].0,
config.dims[0].1,
&config.reed_solomon_codes[0],
);
let cm_1 = RecursiveLigeroCommitment {
root: wtns_1.tree.get_root(),
};
proof.recursive_commitments.push(cm_1.clone());
fs.absorb_root(&cm_1.root);
let rows = wtns_0.mat.len();
println!("\nSelecting queries from {} rows...", rows);
let queries = fs.get_distinct_queries(rows, config.num_queries);
println!(
"Selected queries (0-based): {:?}",
&queries[..queries.len().min(5)]
);
let alpha = fs.get_challenge::<U>();
println!("Alpha challenge: {:?}", alpha);
let n = f_evals.len().trailing_zeros() as usize;
println!("\nPreparing sumcheck, n = {}", n);
let sks_vks: Vec<T> = eval_sk_at_vks(1 << n);
let opened_rows: Vec<Vec<T>> = queries.iter().map(|&q| wtns_0.mat[q].clone()).collect();
let mtree_proof = wtns_0.tree.prove(&queries);
proof.initial_ligero_proof = Some(RecursiveLigeroProof {
opened_rows: opened_rows.clone(),
merkle_proof: mtree_proof,
});
println!("\nInducing sumcheck polynomial...");
let (basis_poly, enforced_sum) = induce_sumcheck_poly(
n,
&sks_vks,
&opened_rows,
&partial_evals_0_u,
&queries,
alpha,
);
println!("Enforced sum: {:?}", enforced_sum);
let mut sumcheck_transcript = vec![];
let mut current_poly = basis_poly;
let mut current_sum = enforced_sum;
fs.absorb_elem(current_sum);
let mut wtns_prev = wtns_1;
for i in 0..config.recursive_steps {
println!(
"\n--- Recursive step {}/{} ---",
i + 1,
config.recursive_steps
);
let mut rs = Vec::new();
for j in 0..config.ks[i] {
let coeffs = compute_sumcheck_coefficients(¤t_poly);
println!(" Round {}: coeffs = {:?}", j, coeffs);
sumcheck_transcript.push(coeffs);
let ri = fs.get_challenge::<U>();
println!(" Challenge: {:?}", ri);
rs.push(ri);
current_poly = fold_polynomial_with_challenge(¤t_poly, ri);
current_sum = evaluate_quadratic(coeffs, ri);
println!(" New sum: {:?}", current_sum);
fs.absorb_elem(current_sum);
}
if i == config.recursive_steps - 1 {
println!("\nFinal round - creating proof...");
fs.absorb_elems(¤t_poly);
let rows = wtns_prev.mat.len();
let queries = fs.get_distinct_queries(rows, config.num_queries);
let opened_rows: Vec<Vec<U>> =
queries.iter().map(|&q| wtns_prev.mat[q].clone()).collect();
let mtree_proof = wtns_prev.tree.prove(&queries);
proof.final_ligero_proof = Some(FinalLigeroProof {
yr: current_poly.clone(),
opened_rows,
merkle_proof: mtree_proof,
});
proof.sumcheck_transcript = Some(SumcheckTranscript {
transcript: sumcheck_transcript,
});
println!("Proof generation complete!");
return finalize(proof);
}
println!("\nContinuing recursion...");
let wtns_next = ligero_commit(
¤t_poly,
config.dims[i + 1].0,
config.dims[i + 1].1,
&config.reed_solomon_codes[i + 1],
);
let cm_next = RecursiveLigeroCommitment {
root: wtns_next.tree.get_root(),
};
proof.recursive_commitments.push(cm_next.clone());
fs.absorb_root(&cm_next.root);
let rows = wtns_prev.mat.len();
let queries = fs.get_distinct_queries(rows, config.num_queries);
let alpha = fs.get_challenge::<U>();
let opened_rows: Vec<Vec<U>> = queries.iter().map(|&q| wtns_prev.mat[q].clone()).collect();
let mtree_proof = wtns_prev.tree.prove(&queries);
proof.recursive_proofs.push(RecursiveLigeroProof {
opened_rows: opened_rows.clone(),
merkle_proof: mtree_proof,
});
let n = current_poly.len().trailing_zeros() as usize;
let sks_vks: Vec<U> = eval_sk_at_vks(1 << n);
println!("\nInducing next sumcheck polynomial...");
let (basis_poly, enforced_sum) =
induce_sumcheck_poly(n, &sks_vks, &opened_rows, &rs, &queries, alpha);
println!("Next enforced sum: {:?}", enforced_sum);
let glue_sum = current_sum.add(&enforced_sum);
fs.absorb_elem(glue_sum);
println!("Glue sum: {:?}", glue_sum);
let beta = fs.get_challenge::<U>();
println!("Beta challenge: {:?}", beta);
current_poly = glue_polynomials(¤t_poly, &basis_poly, beta);
current_sum = glue_sums(current_sum, enforced_sum, beta);
println!("Updated current sum: {:?}", current_sum);
wtns_prev = wtns_next;
}
unreachable!("Should have returned in final round");
}
fn compute_sumcheck_coefficients<F: BinaryFieldElement>(poly: &[F]) -> (F, F, F) {
let n = poly.len() / 2;
let mut s0 = F::zero();
let mut s1 = F::zero();
let mut s2 = F::zero();
for i in 0..n {
let p0 = poly[2 * i];
let p1 = poly[2 * i + 1];
s0 = s0.add(&p0);
s1 = s1.add(&p0.add(&p1));
s2 = s2.add(&p1);
}
(s0, s1, s2)
}
fn fold_polynomial_with_challenge<F: BinaryFieldElement>(poly: &[F], r: F) -> Vec<F> {
let n = poly.len() / 2;
let mut new_poly = vec![F::zero(); n];
for i in 0..n {
let p0 = poly[2 * i];
let p1 = poly[2 * i + 1];
new_poly[i] = p0.add(&r.mul(&p1.add(&p0)));
}
new_poly
}
fn evaluate_quadratic<F: BinaryFieldElement>(coeffs: (F, F, F), x: F) -> F {
let (s0, s1, _s2) = coeffs;
s0.add(&s1.mul(&x))
}
fn glue_polynomials<F: BinaryFieldElement>(f: &[F], g: &[F], beta: F) -> Vec<F> {
assert_eq!(f.len(), g.len());
f.iter()
.zip(g.iter())
.map(|(&fi, &gi)| fi.add(&beta.mul(&gi)))
.collect()
}
fn glue_sums<F: BinaryFieldElement>(sum_f: F, sum_g: F, beta: F) -> F {
sum_f.add(&beta.mul(&sum_g))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::configs::hardcoded_config_12;
use ligerito_binary_fields::{BinaryElem128, BinaryElem32};
use std::marker::PhantomData;
#[test]
fn test_evaluate_quadratic() {
let coeffs = (
BinaryElem32::from(1), BinaryElem32::from(3), BinaryElem32::from(2), );
let val0 = evaluate_quadratic(coeffs, BinaryElem32::zero());
assert_eq!(val0, BinaryElem32::from(1));
let val1 = evaluate_quadratic(coeffs, BinaryElem32::one());
assert_eq!(val1, BinaryElem32::from(2));
}
#[test]
fn test_glue_polynomials() {
let f = vec![BinaryElem32::from(1), BinaryElem32::from(2)];
let g = vec![BinaryElem32::from(3), BinaryElem32::from(4)];
let beta = BinaryElem32::from(5);
let result = glue_polynomials(&f, &g, beta);
assert_eq!(result.len(), 2);
assert_eq!(
result[0],
BinaryElem32::from(1).add(&beta.mul(&BinaryElem32::from(3)))
);
assert_eq!(
result[1],
BinaryElem32::from(2).add(&beta.mul(&BinaryElem32::from(4)))
);
}
#[test]
fn test_simple_prove() {
let config = hardcoded_config_12(PhantomData::<BinaryElem32>, PhantomData::<BinaryElem128>);
let poly = vec![BinaryElem32::one(); 1 << 12];
let proof = prove(&config, &poly);
assert!(proof.is_ok(), "Simple proof generation should succeed");
}
#[test]
fn test_sumcheck_consistency_in_prover() {
let config = hardcoded_config_12(PhantomData::<BinaryElem32>, PhantomData::<BinaryElem128>);
let poly = vec![BinaryElem32::zero(); 1 << 12];
let proof = prove(&config, &poly);
assert!(proof.is_ok(), "Zero polynomial proof should succeed");
let mut poly = vec![BinaryElem32::zero(); 1 << 12];
poly[0] = BinaryElem32::one();
poly[1] = BinaryElem32::from(2);
let proof = prove(&config, &poly);
assert!(proof.is_ok(), "Simple pattern proof should succeed");
}
#[test]
fn test_prove_with_evaluations() {
use crate::eval_proof::EvalClaim;
let config = hardcoded_config_12(PhantomData::<BinaryElem32>, PhantomData::<BinaryElem128>);
let verifier_config = crate::hardcoded_config_12_verifier();
let mut poly = vec![BinaryElem32::zero(); 1 << 12];
poly[0] = BinaryElem32::from(42);
poly[7] = BinaryElem32::from(99);
poly[100] = BinaryElem32::from(255);
let claims = vec![
EvalClaim {
index: 0,
value: BinaryElem32::from(42),
},
EvalClaim {
index: 7,
value: BinaryElem32::from(99),
},
EvalClaim {
index: 100,
value: BinaryElem32::from(255),
},
];
let fs = FiatShamir::new_sha256(0);
let proof = prove_with_evaluations(&config, &poly, &claims, fs);
assert!(proof.is_ok(), "prove_with_evaluations should succeed");
let proof = proof.unwrap();
assert_eq!(
proof.eval_rounds.len(),
12,
"should have 12 eval sumcheck rounds for 2^12 poly"
);
let fs = FiatShamir::new_sha256(0);
let result = crate::verifier::verify_with_evaluations::<BinaryElem32, BinaryElem128>(
&verifier_config,
&proof,
&claims,
fs,
);
assert!(result.is_ok(), "verify_with_evaluations should not error");
let result = result.unwrap();
assert!(result.is_some(), "eval sumcheck should pass");
let result = result.unwrap();
assert!(result.proximity_valid, "proximity test should pass");
}
#[test]
fn test_prove_with_evaluations_wrong_claim_fails() {
use crate::eval_proof::EvalClaim;
let config = hardcoded_config_12(PhantomData::<BinaryElem32>, PhantomData::<BinaryElem128>);
let verifier_config = crate::hardcoded_config_12_verifier();
let mut poly = vec![BinaryElem32::zero(); 1 << 12];
poly[5] = BinaryElem32::from(77);
let claims = vec![EvalClaim {
index: 5,
value: BinaryElem32::from(88),
}];
let fs = FiatShamir::new_sha256(0);
let proof = prove_with_evaluations(&config, &poly, &claims, fs).unwrap();
let fs = FiatShamir::new_sha256(0);
let result = crate::verifier::verify_with_evaluations::<BinaryElem32, BinaryElem128>(
&verifier_config,
&proof,
&claims,
fs,
)
.unwrap();
assert!(
result.is_none(),
"wrong eval claim should fail verification"
);
}
}