use field_cat::FieldBytes;
use crate::error::Error;
use crate::poly::NumVars;
use crate::transcript::Transcript;
use super::protocol::SumcheckProof;
struct VerifierState<F: field_cat::Field> {
current_claim: F,
transcript: Transcript,
challenges: Vec<F>,
}
pub fn sumcheck_verify<F: FieldBytes>(
proof: &SumcheckProof<F>,
claimed_sum: &F,
num_vars: NumVars,
transcript: Transcript,
) -> Result<(F, Vec<F>, Transcript), Error> {
if proof.round_polys().len() == num_vars.count() {
let initial = VerifierState {
current_claim: claimed_sum.clone(),
transcript,
challenges: Vec::with_capacity(num_vars.count()),
};
let final_state = proof
.round_polys()
.iter()
.try_fold(initial, |state, round_poly| {
let sum = round_poly.eval_zero().clone() + round_poly.eval_one().clone();
if sum == state.current_claim {
let transcript = state
.transcript
.absorb_field(round_poly.eval_zero())
.absorb_field(round_poly.eval_one());
let (challenge, transcript): (F, Transcript) =
transcript.squeeze_challenge()?;
let new_claim = round_poly.evaluate(&challenge);
let challenges = state
.challenges
.into_iter()
.chain(core::iter::once(challenge))
.collect();
Ok(VerifierState {
current_claim: new_claim,
transcript,
challenges,
})
} else {
Err(Error::SumcheckFinalMismatch)
}
})?;
Ok((
final_state.current_claim,
final_state.challenges,
final_state.transcript,
))
} else {
Err(Error::RoundCountMismatch {
expected: num_vars.count(),
actual: proof.round_polys().len(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::poly::MultilinearPoly;
use crate::sumcheck::protocol::SumcheckClaim;
use crate::sumcheck::prover::sumcheck_prove;
use field_cat::{F101, Field};
#[test]
fn prover_verifier_agree_on_zero_poly() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![F101::zero(), F101::zero()])?;
let claim = SumcheckClaim::new(poly.clone(), F101::zero());
let prover_transcript = Transcript::new(b"test");
let (proof, prover_challenges, _) = sumcheck_prove(&claim, prover_transcript)?;
let verifier_transcript = Transcript::new(b"test");
let (final_eval, verifier_challenges, _) =
sumcheck_verify(&proof, &F101::zero(), poly.num_vars(), verifier_transcript)?;
assert_eq!(prover_challenges, verifier_challenges);
let expected = poly.evaluate(&verifier_challenges)?;
assert_eq!(final_eval, expected);
Ok(())
}
#[test]
fn prover_verifier_agree_on_two_var() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![
F101::new(1),
F101::new(2),
F101::new(3),
F101::new(4),
])?;
let sum = poly.sum_over_boolean_hypercube();
let claim = SumcheckClaim::new(poly.clone(), sum);
let (proof, prover_challenges, _) = sumcheck_prove(&claim, Transcript::new(b"test"))?;
let (final_eval, verifier_challenges, _) =
sumcheck_verify(&proof, &sum, poly.num_vars(), Transcript::new(b"test"))?;
assert_eq!(prover_challenges, verifier_challenges);
let expected = poly.evaluate(&verifier_challenges)?;
assert_eq!(final_eval, expected);
Ok(())
}
#[test]
fn wrong_claimed_sum_rejected() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![
F101::new(1),
F101::new(2),
F101::new(3),
F101::new(4),
])?;
let claim = SumcheckClaim::new(poly.clone(), F101::new(10));
let (proof, _, _) = sumcheck_prove(&claim, Transcript::new(b"test"))?;
let result = sumcheck_verify(
&proof,
&F101::new(99),
poly.num_vars(),
Transcript::new(b"test"),
);
assert!(result.is_err());
Ok(())
}
}