use crate::{
errors::NovaError,
spartan::{
polys::{
multilinear::MultilinearPolynomial,
univariate::{CompressedUniPoly, UniPoly},
},
sumcheck::eq_sumcheck::EqSumCheckInstance,
},
traits::{Engine, TranscriptEngineTrait},
};
use ff::{Field, PrimeField};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
pub trait SumcheckEngine<E: Engine>: Send + Sync {
fn initial_claims(&self) -> Vec<E::Scalar>;
fn degree(&self) -> usize;
fn size(&self) -> usize;
fn evaluation_points(&mut self) -> Vec<Vec<E::Scalar>>;
fn bound(&mut self, r: &E::Scalar);
fn final_claims(&self) -> Vec<Vec<E::Scalar>>;
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct SumcheckProof<E: Engine> {
compressed_polys: Vec<CompressedUniPoly<E::Scalar>>,
}
impl<E: Engine> SumcheckProof<E> {
pub fn new(compressed_polys: Vec<CompressedUniPoly<E::Scalar>>) -> Self {
Self { compressed_polys }
}
#[inline]
pub fn update_claim(claim: E::Scalar, evals: &[E::Scalar; 3], r: &E::Scalar) -> E::Scalar {
let [e0, c3, em1] = *evals;
let e1 = claim - e0;
let half = E::Scalar::TWO_INV;
let a1 = (e1 - em1) * half - c3;
let a2 = (e1 + em1) * half - e0;
e0 + *r * (a1 + *r * (a2 + *r * c3))
}
pub fn verify(
&self,
claim: E::Scalar,
num_rounds: usize,
degree_bound: usize,
transcript: &mut E::TE,
) -> Result<(E::Scalar, Vec<E::Scalar>), NovaError> {
let mut e = claim;
let mut r: Vec<E::Scalar> = Vec::new();
if self.compressed_polys.len() != num_rounds {
return Err(NovaError::InvalidSumcheckProof);
}
for i in 0..self.compressed_polys.len() {
let poly = self.compressed_polys[i].decompress(&e);
if poly.degree() > degree_bound {
return Err(NovaError::InvalidSumcheckProof);
}
debug_assert_eq!(poly.eval_at_zero() + poly.eval_at_one(), e);
transcript.absorb(b"p", &poly);
let r_i = transcript.squeeze(b"c")?;
r.push(r_i);
e = poly.evaluate(&r_i);
}
Ok((e, r))
}
pub fn verify_batch(
&self,
claims: &[E::Scalar],
num_rounds: &[usize],
coeffs: &[E::Scalar],
degree_bound: usize,
transcript: &mut E::TE,
) -> Result<(E::Scalar, Vec<E::Scalar>), NovaError> {
let num_instances = claims.len();
assert_eq!(num_rounds.len(), num_instances);
assert_eq!(coeffs.len(), num_instances);
let num_rounds_max = *num_rounds.iter().max().unwrap();
let claim = claims
.iter()
.zip(num_rounds.iter())
.zip(coeffs.iter())
.map(|((claim, &nr), coeff)| {
let scale = E::Scalar::from(2u64).pow_vartime([(num_rounds_max - nr) as u64]);
*claim * scale * coeff
})
.sum();
self.verify(claim, num_rounds_max, degree_bound, transcript)
}
#[inline]
fn compute_eval_points_quad_prod(
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
) -> (E::Scalar, E::Scalar) {
let len = poly_A.len() / 2;
(0..len)
.into_par_iter()
.map(|i| {
let eval_point_0 = poly_A[i] * poly_B[i];
let poly_A_bound_coeff = poly_A[len + i] - poly_A[i];
let poly_B_bound_coeff = poly_B[len + i] - poly_B[i];
let bound_coeff = poly_A_bound_coeff * poly_B_bound_coeff;
(eval_point_0, bound_coeff)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
}
pub fn prove_quad_prod(
claim: &E::Scalar,
num_rounds: usize,
poly_A: &mut MultilinearPolynomial<E::Scalar>,
poly_B: &mut MultilinearPolynomial<E::Scalar>,
transcript: &mut E::TE,
) -> Result<(Self, Vec<E::Scalar>, Vec<E::Scalar>), NovaError> {
let mut r: Vec<E::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<E::Scalar>> = Vec::new();
let mut claim_per_round = *claim;
for _ in 0..num_rounds {
let poly = {
let (eval_point_0, bound_coeff) = Self::compute_eval_points_quad_prod(poly_A, poly_B);
let evals = vec![eval_point_0, claim_per_round - eval_point_0, bound_coeff];
UniPoly::from_evals_deg2(&evals)
};
transcript.absorb(b"p", &poly);
let r_i = transcript.squeeze(b"c")?;
r.push(r_i);
polys.push(poly.compress());
claim_per_round = poly.evaluate(&r_i);
rayon::join(
|| poly_A.bind_poly_var_top(&r_i),
|| poly_B.bind_poly_var_top(&r_i),
);
}
Ok((
SumcheckProof {
compressed_polys: polys,
},
r,
vec![poly_A[0], poly_B[0]],
))
}
pub fn prove_batch_eval(
claims: &[E::Scalar],
num_rounds: &[usize],
mut polys: Vec<MultilinearPolynomial<E::Scalar>>,
eq_points: Vec<Vec<E::Scalar>>,
coeffs: &[E::Scalar],
transcript: &mut E::TE,
) -> Result<(Self, Vec<E::Scalar>, Vec<E::Scalar>), NovaError> {
let num_claims = claims.len();
assert_eq!(num_rounds.len(), num_claims);
assert_eq!(polys.len(), num_claims);
assert_eq!(eq_points.len(), num_claims);
assert_eq!(coeffs.len(), num_claims);
for (i, &nr) in num_rounds.iter().enumerate() {
assert_eq!(polys[i].len(), 1 << nr, "poly size mismatch at index {i}");
assert_eq!(
eq_points[i].len(),
nr,
"eq_point length mismatch at index {i}"
);
}
let num_rounds_max = *num_rounds.iter().max().unwrap();
let mut eq_instances: Vec<EqSumCheckInstance<E>> =
eq_points.into_iter().map(EqSumCheckInstance::new).collect();
let mut running_claims: Vec<E::Scalar> = claims.to_vec();
let mut e: E::Scalar = claims
.iter()
.zip(num_rounds.iter())
.zip(coeffs.iter())
.map(|((claim, &nr), coeff)| {
let scale = E::Scalar::from(2u64).pow_vartime([(num_rounds_max - nr) as u64]);
*claim * scale * coeff
})
.sum();
let mut r: Vec<E::Scalar> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly<E::Scalar>> = Vec::new();
for current_round in 0..num_rounds_max {
let remaining_rounds = num_rounds_max - current_round;
let evals: Vec<[E::Scalar; 3]> = (0..num_claims)
.into_par_iter()
.map(|i| {
if remaining_rounds <= num_rounds[i] {
let (eval_0, _cubic_coeff, eval_m1) = eq_instances[i]
.evaluation_points_quadratic_with_one_input(&polys[i], running_claims[i]);
[eval_0, E::Scalar::ZERO, eval_m1]
} else {
let remaining_variables = remaining_rounds - num_rounds[i] - 1;
let scaled_claim =
E::Scalar::from(2u64).pow_vartime([remaining_variables as u64]) * claims[i];
[scaled_claim, E::Scalar::ZERO, scaled_claim]
}
})
.collect();
let evals_combined_0: E::Scalar = (0..num_claims).map(|i| evals[i][0] * coeffs[i]).sum();
let evals_combined_m1: E::Scalar = (0..num_claims).map(|i| evals[i][2] * coeffs[i]).sum();
let evals_combined_1 = e - evals_combined_0;
let quad_coeff =
(evals_combined_1 + evals_combined_m1 - evals_combined_0.double()) * E::Scalar::TWO_INV;
let uni_evals = vec![evals_combined_0, evals_combined_1, quad_coeff];
let poly = UniPoly::from_evals_deg2(&uni_evals);
transcript.absorb(b"p", &poly);
let r_i = transcript.squeeze(b"c")?;
r.push(r_i);
for i in 0..num_claims {
if remaining_rounds <= num_rounds[i] {
running_claims[i] = Self::update_claim(running_claims[i], &evals[i], &r_i);
polys[i].bind_poly_var_top(&r_i);
eq_instances[i].bound(&r_i);
}
}
e = poly.evaluate(&r_i);
quad_polys.push(poly.compress());
}
polys.iter().for_each(|p| assert_eq!(p.len(), 1));
let poly_finals: Vec<E::Scalar> = polys.into_iter().map(|poly| poly[0]).collect();
Ok((SumcheckProof::new(quad_polys), r, poly_finals))
}
#[inline]
pub fn compute_eval_points_linear(
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
) -> (E::Scalar, E::Scalar) {
let len = poly_A.len() / 2;
(0..len)
.into_par_iter()
.map(|i| {
let eval_point_0 = poly_A[i] - poly_B[i];
let poly_A_inf_point = poly_A[i] + poly_A[i] - poly_A[len + i];
let poly_B_inf_point = poly_B[i] + poly_B[i] - poly_B[len + i];
let eval_point_inf = poly_A_inf_point - poly_B_inf_point;
(eval_point_0, eval_point_inf)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
}
#[inline]
pub fn compute_eval_points_quadratic(
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
) -> (E::Scalar, E::Scalar) {
let len = poly_A.len() / 2;
(0..len)
.into_par_iter()
.map(|i| {
let eval_point_0 = poly_A[i] * poly_B[i];
let poly_A_inf_point = poly_A[i] + poly_A[i] - poly_A[len + i];
let poly_B_inf_point = poly_B[i] + poly_B[i] - poly_B[len + i];
let eval_point_inf = poly_A_inf_point * poly_B_inf_point;
(eval_point_0, eval_point_inf)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
}
#[inline]
pub fn compute_eval_points_cubic(
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
poly_C: &MultilinearPolynomial<E::Scalar>,
) -> (E::Scalar, E::Scalar, E::Scalar) {
let len = poly_A.len() / 2;
(0..len)
.into_par_iter()
.map(|i| {
let eval_point_0 = poly_A[i] * poly_B[i] * poly_C[i];
let d_a = poly_A[len + i] - poly_A[i];
let d_b = poly_B[len + i] - poly_B[i];
let d_c = poly_C[len + i] - poly_C[i];
let cubic_coeff = d_a * d_b * d_c;
let eval_point_inf = (poly_A[i] - d_a) * (poly_B[i] - d_b) * (poly_C[i] - d_c);
(eval_point_0, cubic_coeff, eval_point_inf)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
)
}
pub fn prove_cubic_with_three_inputs(
claim: &E::Scalar,
taus: Vec<E::Scalar>,
poly_A: &mut MultilinearPolynomial<E::Scalar>,
poly_B: &mut MultilinearPolynomial<E::Scalar>,
poly_C: &mut MultilinearPolynomial<E::Scalar>,
transcript: &mut E::TE,
) -> Result<(Self, Vec<E::Scalar>, Vec<E::Scalar>), NovaError> {
let mut r: Vec<E::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<E::Scalar>> = Vec::new();
let mut claim_per_round = *claim;
let num_rounds = taus.len();
let mut eq_instance = EqSumCheckInstance::<E>::new(taus);
for _ in 0..num_rounds {
let poly = {
let (eval_point_0, eval_point_bound_coeff, eval_point_inf) = eq_instance
.evaluation_points_cubic_with_three_inputs(poly_A, poly_B, poly_C, claim_per_round);
let evals = vec![
eval_point_0,
claim_per_round - eval_point_0,
eval_point_bound_coeff,
eval_point_inf,
];
UniPoly::from_evals_deg3(&evals)
};
transcript.absorb(b"p", &poly);
let r_i = transcript.squeeze(b"c")?;
r.push(r_i);
polys.push(poly.compress());
claim_per_round = poly.evaluate(&r_i);
rayon::join(
|| poly_A.bind_poly_var_top(&r_i),
|| poly_B.bind_poly_var_top(&r_i),
);
rayon::join(
|| poly_C.bind_poly_var_top(&r_i),
|| eq_instance.bound(&r_i),
);
}
Ok((
SumcheckProof {
compressed_polys: polys,
},
r,
vec![poly_A[0], poly_B[0], poly_C[0]],
))
}
pub fn prove_batched_cubic(
claim: &E::Scalar,
taus: Vec<E::Scalar>,
polys_A: &mut [MultilinearPolynomial<E::Scalar>],
polys_B: &mut [MultilinearPolynomial<E::Scalar>],
polys_C: &mut [MultilinearPolynomial<E::Scalar>],
alphas: &[E::Scalar],
transcript: &mut E::TE,
) -> Result<(Self, Vec<E::Scalar>, Vec<Vec<E::Scalar>>), NovaError> {
let k = polys_A.len();
if k == 0 {
return Err(NovaError::InvalidNumInstances);
}
assert_eq!(k, polys_B.len());
assert_eq!(k, polys_C.len());
assert_eq!(k, alphas.len());
let mut r: Vec<E::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<E::Scalar>> = Vec::new();
let mut claim_per_round = *claim;
let num_rounds = taus.len();
let mut eq_instance = EqSumCheckInstance::<E>::new(taus);
for _ in 0..num_rounds {
let poly = {
let (eval_point_0, eval_point_bound_coeff, eval_point_inf) = eq_instance
.evaluation_points_batched_cubic(polys_A, polys_B, polys_C, alphas, claim_per_round);
let evals = vec![
eval_point_0,
claim_per_round - eval_point_0,
eval_point_bound_coeff,
eval_point_inf,
];
UniPoly::from_evals_deg3(&evals)
};
transcript.absorb(b"p", &poly);
let r_i = transcript.squeeze(b"c")?;
r.push(r_i);
polys.push(poly.compress());
claim_per_round = poly.evaluate(&r_i);
polys_A
.par_iter_mut()
.chain(polys_B.par_iter_mut())
.chain(polys_C.par_iter_mut())
.for_each(|p| p.bind_poly_var_top(&r_i));
eq_instance.bound(&r_i);
}
let claims: Vec<Vec<E::Scalar>> = (0..k)
.map(|i| vec![polys_A[i][0], polys_B[i][0], polys_C[i][0]])
.collect();
Ok((
SumcheckProof {
compressed_polys: polys,
},
r,
claims,
))
}
}
pub mod eq_sumcheck {
use crate::{spartan::polys::multilinear::MultilinearPolynomial, traits::Engine};
use ff::Field;
use rayon::{iter::ZipEq, prelude::*, slice::Iter};
pub struct EqSumCheckInstance<E: Engine> {
init_num_vars: usize,
first_half: usize,
second_half: usize,
round: usize,
taus: Vec<E::Scalar>,
eval_eq_left: E::Scalar,
poly_eq_left: Vec<Vec<E::Scalar>>,
poly_eq_right: Vec<Vec<E::Scalar>>,
eq_tau_0_a_inf: Vec<(E::Scalar, E::Scalar, E::Scalar)>,
}
impl<E: Engine> EqSumCheckInstance<E> {
pub fn new(taus: Vec<E::Scalar>) -> Self {
let l = taus.len();
let first_half = l / 2;
let compute_eq_polynomials = |taus: Vec<&E::Scalar>| -> Vec<Vec<E::Scalar>> {
let len = taus.len();
let mut result = Vec::with_capacity(len + 1);
result.push(vec![E::Scalar::ONE]);
for i in 0..len {
let tau = taus[i];
let prev = &result[i];
let mut v_next = prev.to_vec();
v_next.par_extend(prev.par_iter().map(|v| *v * tau));
let (first, last) = v_next.split_at_mut(prev.len());
first.par_iter_mut().zip(last).for_each(|(a, b)| *a -= *b);
result.push(v_next);
}
result
};
let (left_taus, right_taus) = taus.split_at(first_half);
let left_taus = left_taus.iter().skip(1).rev().collect::<Vec<_>>();
let right_taus = right_taus.iter().rev().collect::<Vec<_>>();
let (poly_eq_left, poly_eq_right) = rayon::join(
|| compute_eq_polynomials(left_taus),
|| compute_eq_polynomials(right_taus),
);
let eq_tau_0_a_inf = taus
.par_iter()
.map(|tau| {
let one_minus_tau = E::Scalar::ONE - tau;
let two_tau_minus_one = *tau - one_minus_tau;
let two_minus_three_tau = one_minus_tau - two_tau_minus_one;
(one_minus_tau, two_tau_minus_one, two_minus_three_tau)
})
.collect::<Vec<_>>();
Self {
init_num_vars: l,
first_half,
second_half: l - first_half,
round: 1,
taus,
eval_eq_left: E::Scalar::ONE,
poly_eq_left,
poly_eq_right,
eq_tau_0_a_inf,
}
}
#[inline]
fn derive_from_claim_deg2(
&self,
t_0: E::Scalar,
t_inf: E::Scalar,
claim: E::Scalar,
) -> Option<(E::Scalar, E::Scalar, E::Scalar)> {
let p = self.eval_eq_left;
let (eq_0, eq_slope, eq_m1) = self.eq_tau_0_a_inf[self.round - 1];
let l_0_p = eq_0 * p;
let l_1_p = (eq_0 + eq_slope) * p;
let l_1_p_inv: Option<E::Scalar> = l_1_p.invert().into();
let l_1_p_inv = l_1_p_inv?;
let s_0 = l_0_p * t_0;
let s_1 = claim - s_0;
let t_1 = s_1 * l_1_p_inv;
let s_leading = eq_slope * p * t_inf;
let t_m1 = t_inf.double() + t_0.double() - t_1;
let s_m1 = eq_m1 * p * t_m1;
Some((s_0, s_leading, s_m1))
}
#[inline]
fn derive_from_claim_deg1(
&self,
t_0: E::Scalar,
claim: E::Scalar,
) -> Option<(E::Scalar, E::Scalar, E::Scalar)> {
let p = self.eval_eq_left;
let (eq_0, eq_slope, eq_m1) = self.eq_tau_0_a_inf[self.round - 1];
let l_0_p = eq_0 * p;
let l_1_p = (eq_0 + eq_slope) * p;
let l_1_p_inv: Option<E::Scalar> = l_1_p.invert().into();
let l_1_p_inv = l_1_p_inv?;
let s_0 = l_0_p * t_0;
let s_1 = claim - s_0;
let t_1 = s_1 * l_1_p_inv;
let s_leading = E::Scalar::ZERO;
let t_m1 = t_0.double() - t_1;
let s_m1 = eq_m1 * p * t_m1;
Some((s_0, s_leading, s_m1))
}
#[inline]
pub fn evaluation_points_batched_cubic(
&self,
polys_A: &[MultilinearPolynomial<E::Scalar>],
polys_B: &[MultilinearPolynomial<E::Scalar>],
polys_C: &[MultilinearPolynomial<E::Scalar>],
alphas: &[E::Scalar],
claim: E::Scalar,
) -> (E::Scalar, E::Scalar, E::Scalar) {
let k = polys_A.len();
assert!(k > 0);
assert_eq!(k, polys_B.len());
assert_eq!(k, polys_C.len());
assert_eq!(k, alphas.len());
assert_eq!(polys_A[0].len() % 2, 0);
let in_first_half = self.round < self.first_half;
let half_p = polys_A[0].Z.len() / 2;
let (t_0, t_inf) = if in_first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
(0..half_p)
.into_par_iter()
.map(|id| {
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
let mut sum_eval_0 = E::Scalar::ZERO;
let mut sum_q = E::Scalar::ZERO;
for i in 0..k {
let zero_a = polys_A[i].Z[id];
let one_a = polys_A[i].Z[id + half_p];
let zero_b = polys_B[i].Z[id];
let one_b = polys_B[i].Z[id + half_p];
let zero_c = polys_C[i].Z[id];
sum_eval_0 += alphas[i] * (zero_a * zero_b - zero_c);
sum_q += alphas[i] * (one_a - zero_a) * (one_b - zero_b);
}
(sum_eval_0 * factor, sum_q * factor)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
} else {
let poly_eq_right = self.poly_eq_right_last_half();
(0..half_p)
.into_par_iter()
.map(|id| {
let eq_r = poly_eq_right[id];
let mut sum_eval_0 = E::Scalar::ZERO;
let mut sum_q = E::Scalar::ZERO;
for i in 0..k {
let zero_a = polys_A[i].Z[id];
let one_a = polys_A[i].Z[id + half_p];
let zero_b = polys_B[i].Z[id];
let one_b = polys_B[i].Z[id + half_p];
let zero_c = polys_C[i].Z[id];
sum_eval_0 += alphas[i] * (zero_a * zero_b - zero_c);
sum_q += alphas[i] * (one_a - zero_a) * (one_b - zero_b);
}
(sum_eval_0 * eq_r, sum_q * eq_r)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
};
if let Some(result) = self.derive_from_claim_deg2(t_0, t_inf, claim) {
result
} else {
self.fallback_eval_inf_batched_cubic(t_0, t_inf, polys_A, polys_B, polys_C, alphas)
}
}
#[inline]
fn fallback_eval_inf_batched_cubic(
&self,
t_0: E::Scalar,
t_inf: E::Scalar,
polys_A: &[MultilinearPolynomial<E::Scalar>],
polys_B: &[MultilinearPolynomial<E::Scalar>],
polys_C: &[MultilinearPolynomial<E::Scalar>],
alphas: &[E::Scalar],
) -> (E::Scalar, E::Scalar, E::Scalar) {
let p = self.eval_eq_left;
let (eq_0, eq_slope, eq_m1) = self.eq_tau_0_a_inf[self.round - 1];
let k = polys_A.len();
let half_p = polys_A[0].Z.len() / 2;
let s_0 = eq_0 * p * t_0;
let s_leading = eq_slope * p * t_inf;
let t_m1 = if self.round < self.first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
(0..half_p)
.into_par_iter()
.map(|id| {
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
let mut sum = E::Scalar::ZERO;
for i in 0..k {
let m1_a = polys_A[i].Z[id].double() - polys_A[i].Z[id + half_p];
let m1_b = polys_B[i].Z[id].double() - polys_B[i].Z[id + half_p];
let m1_c = polys_C[i].Z[id].double() - polys_C[i].Z[id + half_p];
sum += alphas[i] * (m1_a * m1_b - m1_c);
}
sum * factor
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
} else {
let poly_eq_right = self.poly_eq_right_last_half();
(0..half_p)
.into_par_iter()
.map(|id| {
let eq_r = poly_eq_right[id];
let mut sum = E::Scalar::ZERO;
for i in 0..k {
let m1_a = polys_A[i].Z[id].double() - polys_A[i].Z[id + half_p];
let m1_b = polys_B[i].Z[id].double() - polys_B[i].Z[id + half_p];
let m1_c = polys_C[i].Z[id].double() - polys_C[i].Z[id + half_p];
sum += alphas[i] * (m1_a * m1_b - m1_c);
}
sum * eq_r
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
};
let s_m1 = eq_m1 * p * t_m1;
(s_0, s_leading, s_m1)
}
#[inline]
pub fn evaluation_points_cubic_with_three_inputs(
&self,
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
poly_C: &MultilinearPolynomial<E::Scalar>,
claim: E::Scalar,
) -> (E::Scalar, E::Scalar, E::Scalar) {
debug_assert_eq!(poly_A.len() % 2, 0);
let in_first_half = self.round < self.first_half;
let half_p = poly_A.Z.len() / 2;
let [zip_A, zip_B, zip_C] = split_and_zip([&poly_A.Z, &poly_B.Z, &poly_C.Z], half_p);
let (t_0, t_inf) = if in_first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
zip_A
.zip_eq(zip_B)
.zip_eq(zip_C)
.enumerate()
.map(|(id, ((a, b), c))| {
let (zero_a, one_a) = a;
let (zero_b, one_b) = b;
let (zero_c, _one_c) = c;
let eval_0 = *zero_a * *zero_b - *zero_c;
let q = (*one_a - *zero_a) * (*one_b - *zero_b);
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
(eval_0 * factor, q * factor)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
} else {
let poly_eq_right = self.poly_eq_right_last_half().par_iter();
zip_A
.zip_eq(zip_B)
.zip_eq(zip_C)
.zip_eq(poly_eq_right)
.map(|(((a, b), c), eq_r)| {
let (zero_a, one_a) = a;
let (zero_b, one_b) = b;
let (zero_c, _one_c) = c;
let eval_0 = *zero_a * *zero_b - *zero_c;
let q = (*one_a - *zero_a) * (*one_b - *zero_b);
(eval_0 * eq_r, q * eq_r)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
};
if let Some(result) = self.derive_from_claim_deg2(t_0, t_inf, claim) {
result
} else {
self.fallback_eval_inf_three_inputs(t_0, t_inf, poly_A, poly_B, poly_C)
}
}
#[inline]
pub fn evaluation_points_cubic_with_two_inputs(
&self,
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
claim: E::Scalar,
) -> (E::Scalar, E::Scalar, E::Scalar) {
debug_assert_eq!(poly_A.len() % 2, 0);
let in_first_half = self.round < self.first_half;
let half_p = poly_A.Z.len() / 2;
let [zip_A, zip_B] = split_and_zip([&poly_A.Z, &poly_B.Z], half_p);
let (t_0, t_inf) = if in_first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
zip_A
.zip_eq(zip_B)
.enumerate()
.map(|(id, (a, b))| {
let (zero_a, one_a) = a;
let (zero_b, one_b) = b;
let one = E::Scalar::ONE;
let eval_0 = *zero_a * *zero_b - one;
let q = (*one_a - *zero_a) * (*one_b - *zero_b);
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
(eval_0 * factor, q * factor)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
} else {
let poly_eq_right = self.poly_eq_right_last_half().par_iter();
zip_A
.zip_eq(zip_B)
.zip_eq(poly_eq_right)
.map(|((a, b), eq_r)| {
let (zero_a, one_a) = a;
let (zero_b, one_b) = b;
let one = E::Scalar::ONE;
let eval_0 = *zero_a * *zero_b - one;
let q = (*one_a - *zero_a) * (*one_b - *zero_b);
(eval_0 * eq_r, q * eq_r)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
};
if let Some(result) = self.derive_from_claim_deg2(t_0, t_inf, claim) {
result
} else {
self.fallback_eval_inf_two_inputs(t_0, t_inf, poly_A, poly_B)
}
}
#[inline]
pub fn evaluation_points_quadratic_with_one_input(
&self,
poly_A: &MultilinearPolynomial<E::Scalar>,
claim: E::Scalar,
) -> (E::Scalar, E::Scalar, E::Scalar) {
debug_assert_eq!(poly_A.len() % 2, 0);
let in_first_half = self.round < self.first_half;
let half_p = poly_A.Z.len() / 2;
let [zip_A] = split_and_zip([&poly_A.Z], half_p);
let t_0 = if in_first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
zip_A
.enumerate()
.map(|(id, a)| {
let (zero_a, _one_a) = a;
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
*zero_a * factor
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
} else {
let poly_eq_right = self.poly_eq_right_last_half().par_iter();
zip_A
.zip_eq(poly_eq_right)
.map(|(a, eq_r)| {
let (zero_a, _one_a) = a;
*zero_a * eq_r
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
};
if let Some(result) = self.derive_from_claim_deg1(t_0, claim) {
result
} else {
self.fallback_eval_inf_one_input(t_0, poly_A)
}
}
#[inline]
fn fallback_eval_inf_three_inputs(
&self,
t_0: E::Scalar,
t_inf: E::Scalar,
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
poly_C: &MultilinearPolynomial<E::Scalar>,
) -> (E::Scalar, E::Scalar, E::Scalar) {
let p = self.eval_eq_left;
let (eq_0, eq_slope, eq_m1) = self.eq_tau_0_a_inf[self.round - 1];
let s_0 = eq_0 * p * t_0;
let s_leading = eq_slope * p * t_inf;
let half_p = poly_A.Z.len() / 2;
let [zip_A, zip_B, zip_C] = split_and_zip([&poly_A.Z, &poly_B.Z, &poly_C.Z], half_p);
let t_m1 = if self.round < self.first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
zip_A
.zip_eq(zip_B)
.zip_eq(zip_C)
.enumerate()
.map(|(id, ((a, b), c))| {
let m1_a = a.0.double() - *a.1;
let m1_b = b.0.double() - *b.1;
let m1_c = c.0.double() - *c.1;
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
(m1_a * m1_b - m1_c) * factor
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
} else {
let poly_eq_right = self.poly_eq_right_last_half().par_iter();
zip_A
.zip_eq(zip_B)
.zip_eq(zip_C)
.zip_eq(poly_eq_right)
.map(|(((a, b), c), eq_r)| {
let m1_a = a.0.double() - *a.1;
let m1_b = b.0.double() - *b.1;
let m1_c = c.0.double() - *c.1;
(m1_a * m1_b - m1_c) * eq_r
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
};
let s_m1 = eq_m1 * p * t_m1;
(s_0, s_leading, s_m1)
}
#[inline]
fn fallback_eval_inf_two_inputs(
&self,
t_0: E::Scalar,
t_inf: E::Scalar,
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
) -> (E::Scalar, E::Scalar, E::Scalar) {
let p = self.eval_eq_left;
let (eq_0, eq_slope, eq_m1) = self.eq_tau_0_a_inf[self.round - 1];
let s_0 = eq_0 * p * t_0;
let s_leading = eq_slope * p * t_inf;
let half_p = poly_A.Z.len() / 2;
let [zip_A, zip_B] = split_and_zip([&poly_A.Z, &poly_B.Z], half_p);
let t_m1 = if self.round < self.first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
zip_A
.zip_eq(zip_B)
.enumerate()
.map(|(id, (a, b))| {
let m1_a = a.0.double() - *a.1;
let m1_b = b.0.double() - *b.1;
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
(m1_a * m1_b - E::Scalar::ONE) * factor
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
} else {
let poly_eq_right = self.poly_eq_right_last_half().par_iter();
zip_A
.zip_eq(zip_B)
.zip_eq(poly_eq_right)
.map(|((a, b), eq_r)| {
let m1_a = a.0.double() - *a.1;
let m1_b = b.0.double() - *b.1;
(m1_a * m1_b - E::Scalar::ONE) * eq_r
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
};
let s_m1 = eq_m1 * p * t_m1;
(s_0, s_leading, s_m1)
}
#[inline]
fn fallback_eval_inf_one_input(
&self,
t_0: E::Scalar,
poly_A: &MultilinearPolynomial<E::Scalar>,
) -> (E::Scalar, E::Scalar, E::Scalar) {
let p = self.eval_eq_left;
let (eq_0, _eq_slope, eq_m1) = self.eq_tau_0_a_inf[self.round - 1];
let s_0 = eq_0 * p * t_0;
let s_leading = E::Scalar::ZERO;
let half_p = poly_A.Z.len() / 2;
let [zip_A] = split_and_zip([&poly_A.Z], half_p);
let t_m1 = if self.round < self.first_half {
let (poly_eq_left, poly_eq_right, second_half, low_mask) = self.poly_eqs_first_half();
zip_A
.enumerate()
.map(|(id, a)| {
let m1_a = a.0.double() - *a.1;
let factor = poly_eq_left[id >> second_half] * poly_eq_right[id & low_mask];
m1_a * factor
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
} else {
let poly_eq_right = self.poly_eq_right_last_half().par_iter();
zip_A
.zip_eq(poly_eq_right)
.map(|(a, eq_r)| {
let m1_a = a.0.double() - *a.1;
m1_a * eq_r
})
.reduce(|| E::Scalar::ZERO, |a, b| a + b)
};
let s_m1 = eq_m1 * p * t_m1;
(s_0, s_leading, s_m1)
}
#[inline]
pub fn bound(&mut self, r: &E::Scalar) {
let tau = self.taus[self.round - 1];
self.eval_eq_left *= E::Scalar::ONE - tau - r + (*r * tau).double();
self.round += 1;
}
#[inline]
fn poly_eqs_first_half(&self) -> (&Vec<E::Scalar>, &Vec<E::Scalar>, usize, usize) {
let second_half = self.second_half;
let poly_eq_left = &self.poly_eq_left[self.first_half - self.round];
let poly_eq_right = &self.poly_eq_right[second_half];
debug_assert_eq!(poly_eq_right.len(), 1 << second_half);
(
poly_eq_left,
poly_eq_right,
second_half,
(1 << second_half) - 1,
)
}
#[inline]
fn poly_eq_right_last_half(&self) -> &Vec<E::Scalar> {
&self.poly_eq_right[self.init_num_vars - self.round]
}
}
#[inline]
fn split_and_zip<const N: usize, T: Sync>(
vec: [&[T]; N],
half_size: usize,
) -> [ZipEq<Iter<'_, T>, Iter<'_, T>>; N] {
std::array::from_fn(|i| {
let (left, right) = vec[i].split_at(half_size);
left.par_iter().zip_eq(right.par_iter())
})
}
}