use p3_challenger::{FieldChallenger, GrindingChallenger};
use p3_field::{Algebra, ExtensionField, Field, PrimeCharacteristicRing};
use p3_maybe_rayon::prelude::*;
use p3_multilinear_util::point::Point;
use p3_multilinear_util::poly::Poly;
use crate::constraints::Constraint;
use crate::product_polynomial::ProductPolynomial;
use crate::{SumcheckData, extrapolate_01inf};
const PAR_THRESHOLD: usize = 1 << 14;
const K: usize = 8;
#[inline(always)]
fn chunk_round_step<B, A>(e_lo: &[B; K], e_hi: &[B; K], w_lo: &[A; K], w_hi: &[A; K]) -> (A, A)
where
B: PrimeCharacteristicRing + Copy,
A: Algebra<B> + Copy,
{
let acc0 = A::mixed_dot_product::<K>(w_lo, e_lo);
let diffs_e: [B; K] = core::array::from_fn(|i| e_hi[i] - e_lo[i]);
let diffs_w: [A; K] = core::array::from_fn(|i| w_hi[i] - w_lo[i]);
let acc_inf = A::mixed_dot_product::<K>(&diffs_w, &diffs_e);
(acc0, acc_inf)
}
#[inline(always)]
fn round_step<B, A>((acc0, acc_inf): (A, A), e0: B, e1: B, w0: A, w1: A) -> (A, A)
where
B: PrimeCharacteristicRing + Copy,
A: Algebra<B> + Copy,
{
(acc0 + w0 * e0, acc_inf + (w1 - w0) * (e1 - e0))
}
#[inline(always)]
fn round_reduce<A: Copy + PrimeCharacteristicRing>(a: (A, A), b: (A, A)) -> (A, A) {
(a.0 + b.0, a.1 + b.1)
}
pub fn sumcheck_coefficients_prefix<B, A>(evals: &[B], weights: &[A]) -> (A, A)
where
B: PrimeCharacteristicRing + Copy + Send + Sync,
A: Algebra<B> + Copy + Send + Sync,
{
assert_eq!(evals.len(), weights.len());
assert!(evals.len().is_multiple_of(2));
let half = evals.len() / 2;
let (e_lo, e_hi) = evals.split_at(half);
let (w_lo, w_hi) = weights.split_at(half);
let body = (half / K) * K;
let (e_lo_main, e_lo_tail) = e_lo.split_at(body);
let (e_hi_main, e_hi_tail) = e_hi.split_at(body);
let (w_lo_main, w_lo_tail) = w_lo.split_at(body);
let (w_hi_main, w_hi_tail) = w_hi.split_at(body);
let main: (A, A) = if half > PAR_THRESHOLD {
e_lo_main
.par_chunks_exact(K)
.zip(e_hi_main.par_chunks_exact(K))
.zip(
w_lo_main
.par_chunks_exact(K)
.zip(w_hi_main.par_chunks_exact(K)),
)
.par_fold_reduce(
|| (A::ZERO, A::ZERO),
|acc, ((e_lo_c, e_hi_c), (w_lo_c, w_hi_c))| {
let chunk = chunk_round_step::<B, A>(
e_lo_c.try_into().unwrap(),
e_hi_c.try_into().unwrap(),
w_lo_c.try_into().unwrap(),
w_hi_c.try_into().unwrap(),
);
round_reduce(acc, chunk)
},
round_reduce,
)
} else {
e_lo_main
.chunks_exact(K)
.zip(e_hi_main.chunks_exact(K))
.zip(w_lo_main.chunks_exact(K).zip(w_hi_main.chunks_exact(K)))
.fold(
(A::ZERO, A::ZERO),
|acc, ((e_lo_c, e_hi_c), (w_lo_c, w_hi_c))| {
let chunk = chunk_round_step::<B, A>(
e_lo_c.try_into().unwrap(),
e_hi_c.try_into().unwrap(),
w_lo_c.try_into().unwrap(),
w_hi_c.try_into().unwrap(),
);
round_reduce(acc, chunk)
},
)
};
let tail = e_lo_tail
.iter()
.zip(e_hi_tail.iter())
.zip(w_lo_tail.iter().zip(w_hi_tail.iter()))
.fold((A::ZERO, A::ZERO), |acc, ((&e0, &e1), (&w0, &w1))| {
round_step(acc, e0, e1, w0, w1)
});
round_reduce(main, tail)
}
pub fn sumcheck_coefficients_suffix<B, A>(evals: &[B], weights: &[A]) -> (A, A)
where
B: PrimeCharacteristicRing + Copy + Send + Sync,
A: Algebra<B> + Copy + Send + Sync,
{
assert_eq!(evals.len(), weights.len());
assert!(evals.len().is_multiple_of(2));
let half = evals.len() / 2;
let body_pairs = (half / K) * K;
let body_elems = body_pairs * 2;
let (evals_main, evals_tail) = evals.split_at(body_elems);
let (weights_main, weights_tail) = weights.split_at(body_elems);
#[inline(always)]
fn gather_pairs<T: Copy>(chunk: &[T]) -> ([T; K], [T; K]) {
let lo: [T; K] = core::array::from_fn(|i| chunk[2 * i]);
let hi: [T; K] = core::array::from_fn(|i| chunk[2 * i + 1]);
(lo, hi)
}
let main: (A, A) = if evals.len() > PAR_THRESHOLD {
evals_main
.par_chunks_exact(2 * K)
.zip(weights_main.par_chunks_exact(2 * K))
.par_fold_reduce(
|| (A::ZERO, A::ZERO),
|acc, (e_chunk, w_chunk)| {
let (e_lo, e_hi) = gather_pairs::<B>(e_chunk);
let (w_lo, w_hi) = gather_pairs::<A>(w_chunk);
let chunk = chunk_round_step::<B, A>(&e_lo, &e_hi, &w_lo, &w_hi);
round_reduce(acc, chunk)
},
round_reduce,
)
} else {
evals_main
.chunks_exact(2 * K)
.zip(weights_main.chunks_exact(2 * K))
.fold((A::ZERO, A::ZERO), |acc, (e_chunk, w_chunk)| {
let (e_lo, e_hi) = gather_pairs::<B>(e_chunk);
let (w_lo, w_hi) = gather_pairs::<A>(w_chunk);
let chunk = chunk_round_step::<B, A>(&e_lo, &e_hi, &w_lo, &w_hi);
round_reduce(acc, chunk)
})
};
let tail = evals_tail
.chunks(2)
.zip(weights_tail.chunks(2))
.fold((A::ZERO, A::ZERO), |acc, (e, w)| {
round_step(acc, e[0], e[1], w[0], w[1])
});
round_reduce(main, tail)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VariableOrder {
Prefix,
Suffix,
}
impl VariableOrder {
pub fn sumcheck_coefficients<B, A>(self, evals: &[B], weights: &[A]) -> (A, A)
where
B: PrimeCharacteristicRing + Copy + Send + Sync,
A: Algebra<B> + Copy + Send + Sync,
{
match self {
Self::Prefix => sumcheck_coefficients_prefix(evals, weights),
Self::Suffix => sumcheck_coefficients_suffix(evals, weights),
}
}
pub fn fix_var<A, Ch>(self, poly: &mut Poly<A>, r: Ch)
where
A: Algebra<Ch> + Copy + Send + Sync,
Ch: Copy + Send + Sync,
{
match self {
Self::Prefix => poly.fix_prefix_var_mut(r),
Self::Suffix => poly.fix_suffix_var_mut(r),
}
}
pub fn eval_constraints_poly<F, EF>(
self,
constraints: &[Constraint<F, EF>],
challenge: &Point<EF>,
) -> EF
where
F: Field,
EF: ExtensionField<F>,
{
let reversed = challenge.reversed();
constraints
.iter()
.map(|constraint| {
let local_challenge = match self {
Self::Prefix => reversed
.get_subpoint_over_range(..constraint.num_variables())
.reversed(),
Self::Suffix => reversed.get_subpoint_over_range(..constraint.num_variables()),
};
let eq_contrib = constraint
.iter_eqs()
.map(|(point, coeff)| coeff * point.eq_poly(&local_challenge))
.sum::<EF>();
let sel_contrib = constraint
.iter_sels()
.map(|(&var, coeff)| coeff * local_challenge.select_poly(var))
.sum::<EF>();
eq_contrib + sel_contrib
})
.sum()
}
}
#[derive(Debug, Clone)]
pub struct SumcheckProver<F: Field, EF: ExtensionField<F>> {
poly: ProductPolynomial<F, EF>,
sum: EF,
}
impl<F: Field, EF: ExtensionField<F>> SumcheckProver<F, EF> {
pub fn new(poly: ProductPolynomial<F, EF>, sum: EF) -> Self {
debug_assert_eq!(poly.dot_product(), sum);
Self { poly, sum }
}
pub const fn claimed_sum(&self) -> EF {
self.sum
}
pub fn num_variables(&self) -> usize {
self.poly.num_variables()
}
#[tracing::instrument(skip_all)]
pub fn evals(&self) -> Poly<EF> {
self.poly.evals()
}
pub fn eval(&self, point: &Point<EF>) -> EF {
self.poly.eval(point)
}
pub(crate) fn round_coefficients(&self) -> (EF, EF) {
self.poly.round_coefficients()
}
pub(crate) fn fold_round_with_coefficients(&mut self, c0: EF, c_inf: EF, gamma: EF) {
self.sum = extrapolate_01inf(c0, self.sum - c0, c_inf, gamma);
self.poly.fold_round(gamma);
debug_assert_eq!(self.sum, self.poly.dot_product());
}
pub(crate) fn scale_weights_and_claim(&mut self, scale: EF) {
self.poly.scale_weights(scale);
self.sum *= scale;
}
pub fn weights(&self) -> Poly<EF> {
self.poly.weights()
}
pub fn accumulate_claim(&mut self, weights_delta: &[EF], sum_delta: EF) {
self.poly.accumulate_weights(weights_delta);
self.sum += sum_delta;
debug_assert_eq!(self.sum, self.poly.dot_product());
}
#[tracing::instrument(skip_all)]
pub fn compute_sumcheck_polynomials<Challenger>(
&mut self,
sumcheck_data: &mut SumcheckData<F, EF>,
challenger: &mut Challenger,
folding_factor: usize,
pow_bits: usize,
constraint: Option<Constraint<F, EF>>,
) -> Point<EF>
where
Challenger: FieldChallenger<F> + GrindingChallenger<Witness = F>,
{
if let Some(constraint) = constraint {
self.poly.combine(&mut self.sum, &constraint);
}
let res = (0..folding_factor)
.map(|_| {
self.poly
.round(sumcheck_data, challenger, &mut self.sum, pow_bits)
})
.collect();
Point::new(res)
}
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use p3_field::extension::BinomialExtensionField;
use p3_multilinear_util::point::Point;
use p3_multilinear_util::poly::Poly;
use proptest::prelude::*;
use rand::rngs::SmallRng;
use rand::{RngExt, SeedableRng};
use super::VariableOrder;
use crate::constraints::Constraint;
use crate::constraints::statement::{EqStatement, SelectStatement};
type F = BabyBear;
type EF = BinomialExtensionField<BabyBear, 4>;
fn eval_constraints_poly_reference(
order: VariableOrder,
constraints: &[Constraint<F, EF>],
challenge: &Point<EF>,
) -> EF {
constraints
.iter()
.map(|constraint| {
let mut combined = Poly::zero(constraint.num_variables());
let mut eval = EF::ZERO;
constraint.combine(&mut combined, &mut eval);
let point = match order {
VariableOrder::Prefix => challenge
.reversed()
.get_subpoint_over_range(..constraint.num_variables())
.reversed(),
VariableOrder::Suffix => challenge
.reversed()
.get_subpoint_over_range(..constraint.num_variables()),
};
combined.eval_ext::<F>(&point)
})
.sum()
}
fn random_constraints(
rng: &mut SmallRng,
num_variables: usize,
rounds: usize,
) -> Vec<Constraint<F, EF>> {
(0..rounds)
.map(|_| {
let num_variables = rng.random_range(1..=num_variables);
let gamma = rng.random();
let mut eq_statement = EqStatement::initialize(num_variables);
(0..rng.random_range(0..=3)).for_each(|_| {
eq_statement
.add_evaluated_constraint(Point::rand(rng, num_variables), rng.random());
});
let mut sel_statement = SelectStatement::<F, EF>::initialize(num_variables);
(0..rng.random_range(0..=3))
.for_each(|_| sel_statement.add_constraint(rng.random(), rng.random()));
Constraint::new(gamma, eq_statement, sel_statement)
})
.collect()
}
#[test]
fn test_eval_constraints_poly_prefix() {
let mut rng = SmallRng::seed_from_u64(0);
let constraints = random_constraints(&mut rng, 20, 6);
let challenge = Point::rand(&mut rng, 20);
let got = VariableOrder::Prefix.eval_constraints_poly(&constraints, &challenge);
let expected =
eval_constraints_poly_reference(VariableOrder::Prefix, &constraints, &challenge);
assert_eq!(got, expected);
}
#[test]
fn test_eval_constraints_poly_suffix() {
let mut rng = SmallRng::seed_from_u64(1);
let constraints = random_constraints(&mut rng, 20, 6);
let challenge = Point::rand(&mut rng, 20);
let got = VariableOrder::Suffix.eval_constraints_poly(&constraints, &challenge);
let expected =
eval_constraints_poly_reference(VariableOrder::Suffix, &constraints, &challenge);
assert_eq!(got, expected);
}
proptest! {
#[test]
fn prop_eval_constraints_poly_matches_reference(
total_num_variables in 2usize..=20,
rounds in 1usize..=8,
seed in any::<u64>(),
) {
let mut rng = SmallRng::seed_from_u64(seed);
let constraints = random_constraints(&mut rng, total_num_variables, rounds);
let challenge = Point::rand(&mut rng, total_num_variables);
prop_assert_eq!(
VariableOrder::Prefix.eval_constraints_poly(&constraints, &challenge),
eval_constraints_poly_reference(VariableOrder::Prefix, &constraints, &challenge),
);
prop_assert_eq!(
VariableOrder::Suffix.eval_constraints_poly(&constraints, &challenge),
eval_constraints_poly_reference(VariableOrder::Suffix, &constraints, &challenge),
);
}
}
}