#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::utils::partial_eval_multilinear;
use binary_fields::BinaryFieldElement;
#[derive(Debug, Clone, PartialEq)]
pub enum SumcheckError {
TranscriptExhausted,
SumMismatch,
ClaimMismatch,
NoRunningPoly,
NoPolyToGlue,
IncompleteEvaluation,
LengthMismatch,
}
#[derive(Debug, Clone)]
pub struct LinearPoly<F: BinaryFieldElement> {
b: F, c: F, }
impl<F: BinaryFieldElement> LinearPoly<F> {
pub fn new(b: F, c: F) -> Self {
Self { b, c }
}
pub fn eval(&self, r: F) -> F {
self.c.add(&self.b.mul(&r))
}
}
#[derive(Debug, Clone)]
pub struct QuadraticPoly<F: BinaryFieldElement> {
a: F,
b: F,
c: F,
}
impl<F: BinaryFieldElement> QuadraticPoly<F> {
pub fn new(a: F, b: F, c: F) -> Self {
Self { a, b, c }
}
pub fn eval_quadratic(&self, r: F) -> F {
self.a.mul(&r).mul(&r).add(&self.b.mul(&r)).add(&self.c)
}
}
pub fn linear_from_evals<F: BinaryFieldElement>(s0: F, s2: F) -> LinearPoly<F> {
LinearPoly::new(s0.add(&s2), s0)
}
pub fn quadratic_from_evals<F: BinaryFieldElement>(at0: F, at1: F, atx: F) -> QuadraticPoly<F> {
let x = F::from_bits(3);
let numerator = atx.add(&at0).add(&x.mul(&at1.add(&at0)));
let denominator = x.mul(&x).add(&x);
let a = numerator.mul(&denominator.inv());
let b = at1.add(&at0).add(&a);
QuadraticPoly { a, b, c: at0 }
}
pub fn fold_linear<F: BinaryFieldElement>(
p1: LinearPoly<F>,
p2: LinearPoly<F>,
alpha: F,
) -> LinearPoly<F> {
LinearPoly::new(p1.b.add(&alpha.mul(&p2.b)), p1.c.add(&alpha.mul(&p2.c)))
}
pub fn fold_quadratic<F: BinaryFieldElement>(
p1: QuadraticPoly<F>,
p2: QuadraticPoly<F>,
alpha: F,
) -> QuadraticPoly<F> {
QuadraticPoly::new(
p1.a.add(&alpha.mul(&p2.a)),
p1.b.add(&alpha.mul(&p2.b)),
p1.c.add(&alpha.mul(&p2.c)),
)
}
pub struct SumcheckVerifierInstance<F: BinaryFieldElement> {
basis_polys: Vec<Vec<F>>,
separation_challenges: Vec<F>,
pub sum: F,
transcript: Vec<(F, F, F)>,
ris: Vec<F>,
tr_reader: usize,
running_poly: Option<LinearPoly<F>>,
to_glue: Option<LinearPoly<F>>,
}
impl<F: BinaryFieldElement> SumcheckVerifierInstance<F> {
pub fn new(b1: Vec<F>, h1: F, transcript: Vec<(F, F, F)>) -> Self {
Self {
basis_polys: vec![b1],
separation_challenges: vec![F::one()],
sum: h1,
transcript,
ris: vec![],
tr_reader: 0,
running_poly: None,
to_glue: None,
}
}
fn read_tr(&mut self) -> Result<(F, F, F), SumcheckError> {
if self.tr_reader >= self.transcript.len() {
return Err(SumcheckError::TranscriptExhausted);
}
let (g0, g1, g2) = self.transcript[self.tr_reader];
self.tr_reader += 1;
Ok((g0, g1, g2))
}
pub fn fold(&mut self, r: F) -> Result<(F, F, F), SumcheckError> {
let (s0, s_total, s2) = self.read_tr()?;
if s_total != self.sum {
return Err(SumcheckError::SumMismatch);
}
let poly = linear_from_evals(s0, s2);
self.sum = poly.eval(r);
self.running_poly = Some(poly);
self.ris.push(r);
Ok((s0, s_total, s2))
}
pub fn introduce_new(&mut self, bi: Vec<F>, h: F) -> Result<(F, F, F), SumcheckError> {
let (s0, s_total, s2) = self.read_tr()?;
if s_total != h {
return Err(SumcheckError::ClaimMismatch);
}
self.basis_polys.push(bi);
self.to_glue = Some(linear_from_evals(s0, s2));
Ok((s0, s_total, s2))
}
pub fn glue(&mut self, alpha: F) -> Result<(), SumcheckError> {
if self.running_poly.is_none() {
return Err(SumcheckError::NoRunningPoly);
}
if self.to_glue.is_none() {
return Err(SumcheckError::NoPolyToGlue);
}
self.separation_challenges.push(alpha);
let running = self.running_poly.take().unwrap();
let to_glue = self.to_glue.take().unwrap();
self.running_poly = Some(fold_linear(running, to_glue, alpha));
Ok(())
}
fn evaluate_basis_polys(&mut self, r: F) -> Result<F, SumcheckError> {
self.ris.push(r);
let mut b0_copy = self.basis_polys[0].clone();
partial_eval_multilinear(&mut b0_copy, &self.ris);
if b0_copy.len() != 1 {
return Err(SumcheckError::IncompleteEvaluation);
}
let mut b_eval = b0_copy[0];
for i in 1..self.basis_polys.len() {
let n = self.basis_polys[i].len().ilog2() as usize;
let num_rs = self.ris.len();
let eval_pts = if num_rs >= n {
&self.ris[num_rs - n..]
} else {
&self.ris[..]
};
let mut bi_copy = self.basis_polys[i].clone();
partial_eval_multilinear(&mut bi_copy, eval_pts);
if bi_copy.len() != 1 {
return Err(SumcheckError::IncompleteEvaluation);
}
let bi_eval = bi_copy[0];
b_eval = b_eval.add(&self.separation_challenges[i].mul(&bi_eval));
}
Ok(b_eval)
}
pub fn verify(&mut self, r: F, f_eval: F) -> Result<bool, SumcheckError> {
if self.running_poly.is_none() {
return Ok(false);
}
let running_poly = self.running_poly.as_ref().unwrap().clone();
self.sum = running_poly.eval(r);
let basis_evals = self.evaluate_basis_polys(r)?;
Ok(f_eval.mul(&basis_evals) == self.sum)
}
fn evaluate_basis_polys_partially(&mut self, r: F, k: usize) -> Result<Vec<F>, SumcheckError> {
self.ris.push(r);
let mut b0_copy = self.basis_polys[0].clone();
partial_eval_multilinear(&mut b0_copy, &self.ris);
let mut acc = b0_copy;
for i in 1..self.basis_polys.len() {
let n = self.basis_polys[i].len().ilog2() as usize;
let num_rs = self.ris.len();
let eval_len = n.saturating_sub(k);
let eval_len = eval_len.min(num_rs);
let eval_pts = if eval_len > 0 {
&self.ris[num_rs - eval_len..]
} else {
&[]
};
let mut bi_copy = self.basis_polys[i].clone();
if !eval_pts.is_empty() {
partial_eval_multilinear(&mut bi_copy, eval_pts);
}
let alpha = self.separation_challenges[i];
if acc.len() != bi_copy.len() {
return Err(SumcheckError::LengthMismatch);
}
for j in 0..acc.len() {
acc[j] = acc[j].add(&alpha.mul(&bi_copy[j]));
}
}
Ok(acc)
}
pub fn verify_partial(&mut self, r: F, f_partial_eval: &[F]) -> Result<bool, SumcheckError> {
let k = f_partial_eval.len().ilog2() as usize;
if self.running_poly.is_none() {
return Ok(false);
}
self.sum = self.running_poly.as_ref().unwrap().eval(r);
let basis_evals = self.evaluate_basis_polys_partially(r, k)?;
if f_partial_eval.len() != basis_evals.len() {
return Ok(false);
}
let dot_product = f_partial_eval
.iter()
.zip(basis_evals.iter())
.fold(F::zero(), |acc, (&f_i, &b_i)| acc.add(&f_i.mul(&b_i)));
Ok(dot_product == self.sum)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ligerito_binary_fields::BinaryElem128;
#[test]
fn test_quadratic_eval() {
let poly = QuadraticPoly::new(
BinaryElem128::one(),
BinaryElem128::from(2),
BinaryElem128::from(3),
);
let val_at_0 = poly.eval_quadratic(BinaryElem128::zero());
assert_eq!(val_at_0, BinaryElem128::from(3));
}
#[test]
fn test_quadratic_from_evals() {
let at0 = BinaryElem128::from(1);
let at1 = BinaryElem128::from(2);
let at3 = BinaryElem128::from(4);
let poly = quadratic_from_evals(at0, at1, at3);
assert_eq!(poly.eval_quadratic(BinaryElem128::zero()), at0);
assert_eq!(poly.eval_quadratic(BinaryElem128::one()), at1);
assert_eq!(poly.eval_quadratic(BinaryElem128::from(3)), at3);
}
#[test]
fn test_fold_quadratic() {
let p1 = QuadraticPoly::new(
BinaryElem128::one(),
BinaryElem128::from(2),
BinaryElem128::from(3),
);
let p2 = QuadraticPoly::new(
BinaryElem128::from(4),
BinaryElem128::from(5),
BinaryElem128::from(6),
);
let alpha = BinaryElem128::from(7);
let folded = fold_quadratic(p1.clone(), p2.clone(), alpha);
let x = BinaryElem128::from(11);
let expected = p1.eval_quadratic(x).add(&alpha.mul(&p2.eval_quadratic(x)));
let actual = folded.eval_quadratic(x);
assert_eq!(actual, expected);
}
}