use field_cat::FieldBytes;
use crate::error::Error;
use crate::transcript::Transcript;
use super::protocol::{RoundPoly, SumcheckClaim, SumcheckProof};
struct ProverState<F: field_cat::Field> {
evals: Vec<F>,
transcript: Transcript,
round_polys: Vec<RoundPoly<F>>,
challenges: Vec<F>,
}
pub fn sumcheck_prove<F: FieldBytes>(
claim: &SumcheckClaim<F>,
transcript: Transcript,
) -> Result<(SumcheckProof<F>, Vec<F>, Transcript), Error> {
let num_rounds = claim.poly().num_vars().count();
let initial = ProverState {
evals: claim.poly().evals().to_vec(),
transcript,
round_polys: Vec::with_capacity(num_rounds),
challenges: Vec::with_capacity(num_rounds),
};
let final_state =
(0..num_rounds).try_fold(initial, |state, _| -> Result<ProverState<F>, Error> {
let half = state.evals.len() / 2;
let eval_zero = state.evals[..half]
.iter()
.cloned()
.fold(F::zero(), |acc, v| acc + v);
let eval_one = state.evals[half..]
.iter()
.cloned()
.fold(F::zero(), |acc, v| acc + v);
let round_poly = RoundPoly::new(eval_zero.clone(), eval_one.clone());
let transcript = state
.transcript
.absorb_field(&eval_zero)
.absorb_field(&eval_one);
let (challenge, transcript): (F, Transcript) = transcript.squeeze_challenge()?;
let new_evals: Vec<F> = (0..half)
.map(|j| {
let lo = state.evals[j].clone();
let hi = state.evals[j + half].clone();
lo * (F::one() - challenge.clone()) + hi * challenge.clone()
})
.collect();
let round_polys = state
.round_polys
.into_iter()
.chain(core::iter::once(round_poly))
.collect();
let challenges = state
.challenges
.into_iter()
.chain(core::iter::once(challenge))
.collect();
Ok(ProverState {
evals: new_evals,
transcript,
round_polys,
challenges,
})
})?;
Ok((
SumcheckProof::new(final_state.round_polys),
final_state.challenges,
final_state.transcript,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::poly::MultilinearPoly;
use field_cat::{F101, Field};
#[test]
fn zero_polynomial_sum_is_zero() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![F101::zero(), F101::zero()])?;
let claim = SumcheckClaim::new(poly, F101::zero());
let transcript = Transcript::new(b"test");
let (proof, challenges, _) = sumcheck_prove(&claim, transcript)?;
assert_eq!(proof.round_polys().len(), 1);
assert_eq!(challenges.len(), 1);
let rp = &proof.round_polys()[0];
assert_eq!(*rp.eval_zero() + *rp.eval_one(), F101::zero());
Ok(())
}
#[test]
fn constant_polynomial() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![F101::new(5), F101::new(5)])?;
let claim = SumcheckClaim::new(poly, F101::new(10));
let transcript = Transcript::new(b"test");
let (proof, _, _) = sumcheck_prove(&claim, transcript)?;
let rp = &proof.round_polys()[0];
assert_eq!(*rp.eval_zero() + *rp.eval_one(), F101::new(10));
Ok(())
}
#[test]
fn two_variable_polynomial() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![
F101::new(1),
F101::new(2),
F101::new(3),
F101::new(4),
])?;
let claim = SumcheckClaim::new(poly, F101::new(10));
let transcript = Transcript::new(b"test");
let (proof, challenges, _) = sumcheck_prove(&claim, transcript)?;
assert_eq!(proof.round_polys().len(), 2);
assert_eq!(challenges.len(), 2);
let rp0 = &proof.round_polys()[0];
assert_eq!(rp0.eval_zero().clone(), F101::new(3));
assert_eq!(rp0.eval_one().clone(), F101::new(7));
assert_eq!(*rp0.eval_zero() + *rp0.eval_one(), F101::new(10));
Ok(())
}
}