1use p3_challenger::{FieldChallenger, GrindingChallenger};
10use p3_field::{Algebra, ExtensionField, Field, PrimeCharacteristicRing};
11use p3_maybe_rayon::prelude::*;
12use p3_multilinear_util::point::Point;
13use p3_multilinear_util::poly::Poly;
14
15use crate::constraints::Constraint;
16use crate::product_polynomial::ProductPolynomial;
17use crate::{SumcheckData, extrapolate_01inf};
18
19const PAR_THRESHOLD: usize = 1 << 14;
26
27const K: usize = 8;
36
37#[inline(always)]
52fn chunk_round_step<B, A>(e_lo: &[B; K], e_hi: &[B; K], w_lo: &[A; K], w_hi: &[A; K]) -> (A, A)
53where
54 B: PrimeCharacteristicRing + Copy,
55 A: Algebra<B> + Copy,
56{
57 let acc0 = A::mixed_dot_product::<K>(w_lo, e_lo);
59
60 let diffs_e: [B; K] = core::array::from_fn(|i| e_hi[i] - e_lo[i]);
63 let diffs_w: [A; K] = core::array::from_fn(|i| w_hi[i] - w_lo[i]);
64
65 let acc_inf = A::mixed_dot_product::<K>(&diffs_w, &diffs_e);
67
68 (acc0, acc_inf)
69}
70
71#[inline(always)]
73fn round_step<B, A>((acc0, acc_inf): (A, A), e0: B, e1: B, w0: A, w1: A) -> (A, A)
74where
75 B: PrimeCharacteristicRing + Copy,
76 A: Algebra<B> + Copy,
77{
78 (acc0 + w0 * e0, acc_inf + (w1 - w0) * (e1 - e0))
79}
80
81#[inline(always)]
83fn round_reduce<A: Copy + PrimeCharacteristicRing>(a: (A, A), b: (A, A)) -> (A, A) {
84 (a.0 + b.0, a.1 + b.1)
85}
86
87pub fn sumcheck_coefficients_prefix<B, A>(evals: &[B], weights: &[A]) -> (A, A)
105where
106 B: PrimeCharacteristicRing + Copy + Send + Sync,
107 A: Algebra<B> + Copy + Send + Sync,
108{
109 assert_eq!(evals.len(), weights.len());
111 assert!(evals.len().is_multiple_of(2));
112 let half = evals.len() / 2;
113 let (e_lo, e_hi) = evals.split_at(half);
114 let (w_lo, w_hi) = weights.split_at(half);
115
116 let body = (half / K) * K;
117 let (e_lo_main, e_lo_tail) = e_lo.split_at(body);
118 let (e_hi_main, e_hi_tail) = e_hi.split_at(body);
119 let (w_lo_main, w_lo_tail) = w_lo.split_at(body);
120 let (w_hi_main, w_hi_tail) = w_hi.split_at(body);
121
122 let main: (A, A) = if half > PAR_THRESHOLD {
124 e_lo_main
125 .par_chunks_exact(K)
126 .zip(e_hi_main.par_chunks_exact(K))
127 .zip(
128 w_lo_main
129 .par_chunks_exact(K)
130 .zip(w_hi_main.par_chunks_exact(K)),
131 )
132 .par_fold_reduce(
133 || (A::ZERO, A::ZERO),
134 |acc, ((e_lo_c, e_hi_c), (w_lo_c, w_hi_c))| {
135 let chunk = chunk_round_step::<B, A>(
136 e_lo_c.try_into().unwrap(),
137 e_hi_c.try_into().unwrap(),
138 w_lo_c.try_into().unwrap(),
139 w_hi_c.try_into().unwrap(),
140 );
141 round_reduce(acc, chunk)
142 },
143 round_reduce,
144 )
145 } else {
146 e_lo_main
147 .chunks_exact(K)
148 .zip(e_hi_main.chunks_exact(K))
149 .zip(w_lo_main.chunks_exact(K).zip(w_hi_main.chunks_exact(K)))
150 .fold(
151 (A::ZERO, A::ZERO),
152 |acc, ((e_lo_c, e_hi_c), (w_lo_c, w_hi_c))| {
153 let chunk = chunk_round_step::<B, A>(
154 e_lo_c.try_into().unwrap(),
155 e_hi_c.try_into().unwrap(),
156 w_lo_c.try_into().unwrap(),
157 w_hi_c.try_into().unwrap(),
158 );
159 round_reduce(acc, chunk)
160 },
161 )
162 };
163
164 let tail = e_lo_tail
166 .iter()
167 .zip(e_hi_tail.iter())
168 .zip(w_lo_tail.iter().zip(w_hi_tail.iter()))
169 .fold((A::ZERO, A::ZERO), |acc, ((&e0, &e1), (&w0, &w1))| {
170 round_step(acc, e0, e1, w0, w1)
171 });
172
173 round_reduce(main, tail)
174}
175
176pub fn sumcheck_coefficients_suffix<B, A>(evals: &[B], weights: &[A]) -> (A, A)
195where
196 B: PrimeCharacteristicRing + Copy + Send + Sync,
197 A: Algebra<B> + Copy + Send + Sync,
198{
199 assert_eq!(evals.len(), weights.len());
201 assert!(evals.len().is_multiple_of(2));
202
203 let half = evals.len() / 2;
204 let body_pairs = (half / K) * K;
206 let body_elems = body_pairs * 2;
207 let (evals_main, evals_tail) = evals.split_at(body_elems);
208 let (weights_main, weights_tail) = weights.split_at(body_elems);
209
210 #[inline(always)]
211 fn gather_pairs<T: Copy>(chunk: &[T]) -> ([T; K], [T; K]) {
212 let lo: [T; K] = core::array::from_fn(|i| chunk[2 * i]);
214 let hi: [T; K] = core::array::from_fn(|i| chunk[2 * i + 1]);
215 (lo, hi)
216 }
217
218 let main: (A, A) = if evals.len() > PAR_THRESHOLD {
219 evals_main
220 .par_chunks_exact(2 * K)
221 .zip(weights_main.par_chunks_exact(2 * K))
222 .par_fold_reduce(
223 || (A::ZERO, A::ZERO),
224 |acc, (e_chunk, w_chunk)| {
225 let (e_lo, e_hi) = gather_pairs::<B>(e_chunk);
226 let (w_lo, w_hi) = gather_pairs::<A>(w_chunk);
227 let chunk = chunk_round_step::<B, A>(&e_lo, &e_hi, &w_lo, &w_hi);
228 round_reduce(acc, chunk)
229 },
230 round_reduce,
231 )
232 } else {
233 evals_main
234 .chunks_exact(2 * K)
235 .zip(weights_main.chunks_exact(2 * K))
236 .fold((A::ZERO, A::ZERO), |acc, (e_chunk, w_chunk)| {
237 let (e_lo, e_hi) = gather_pairs::<B>(e_chunk);
238 let (w_lo, w_hi) = gather_pairs::<A>(w_chunk);
239 let chunk = chunk_round_step::<B, A>(&e_lo, &e_hi, &w_lo, &w_hi);
240 round_reduce(acc, chunk)
241 })
242 };
243
244 let tail = evals_tail
246 .chunks(2)
247 .zip(weights_tail.chunks(2))
248 .fold((A::ZERO, A::ZERO), |acc, (e, w)| {
249 round_step(acc, e[0], e[1], w[0], w[1])
250 });
251
252 round_reduce(main, tail)
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266pub enum VariableOrder {
267 Prefix,
269 Suffix,
271}
272
273impl VariableOrder {
274 pub fn sumcheck_coefficients<B, A>(self, evals: &[B], weights: &[A]) -> (A, A)
276 where
277 B: PrimeCharacteristicRing + Copy + Send + Sync,
278 A: Algebra<B> + Copy + Send + Sync,
279 {
280 match self {
281 Self::Prefix => sumcheck_coefficients_prefix(evals, weights),
282 Self::Suffix => sumcheck_coefficients_suffix(evals, weights),
283 }
284 }
285
286 pub fn fix_var<A, Ch>(self, poly: &mut Poly<A>, r: Ch)
288 where
289 A: Algebra<Ch> + Copy + Send + Sync,
290 Ch: Copy + Send + Sync,
291 {
292 match self {
293 Self::Prefix => poly.fix_prefix_var_mut(r),
294 Self::Suffix => poly.fix_suffix_var_mut(r),
295 }
296 }
297
298 pub fn eval_constraints_poly<F, EF>(
307 self,
308 constraints: &[Constraint<F, EF>],
309 challenge: &Point<EF>,
310 ) -> EF
311 where
312 F: Field,
313 EF: ExtensionField<F>,
314 {
315 let reversed = challenge.reversed();
317
318 constraints
319 .iter()
320 .map(|constraint| {
321 let local_challenge = match self {
323 Self::Prefix => reversed
324 .get_subpoint_over_range(..constraint.num_variables())
325 .reversed(),
326 Self::Suffix => reversed.get_subpoint_over_range(..constraint.num_variables()),
327 };
328
329 let eq_contrib = constraint
331 .iter_eqs()
332 .map(|(point, coeff)| coeff * point.eq_poly(&local_challenge))
333 .sum::<EF>();
334 let sel_contrib = constraint
336 .iter_sels()
337 .map(|(&var, coeff)| coeff * local_challenge.select_poly(var))
338 .sum::<EF>();
339 eq_contrib + sel_contrib
340 })
341 .sum()
342 }
343}
344
345#[derive(Debug, Clone)]
358pub struct SumcheckProver<F: Field, EF: ExtensionField<F>> {
359 poly: ProductPolynomial<F, EF>,
361 sum: EF,
363}
364
365impl<F: Field, EF: ExtensionField<F>> SumcheckProver<F, EF> {
366 pub fn new(poly: ProductPolynomial<F, EF>, sum: EF) -> Self {
368 debug_assert_eq!(poly.dot_product(), sum);
370 Self { poly, sum }
371 }
372
373 pub const fn claimed_sum(&self) -> EF {
375 self.sum
376 }
377
378 pub fn num_variables(&self) -> usize {
380 self.poly.num_variables()
381 }
382
383 #[tracing::instrument(skip_all)]
385 pub fn evals(&self) -> Poly<EF> {
386 self.poly.evals()
387 }
388
389 pub fn eval(&self, point: &Point<EF>) -> EF {
391 self.poly.eval(point)
392 }
393
394 pub(crate) fn round_coefficients(&self) -> (EF, EF) {
397 self.poly.round_coefficients()
398 }
399
400 pub(crate) fn fold_round_with_coefficients(&mut self, c0: EF, c_inf: EF, gamma: EF) {
403 self.sum = extrapolate_01inf(c0, self.sum - c0, c_inf, gamma);
404 self.poly.fold_round(gamma);
405 debug_assert_eq!(self.sum, self.poly.dot_product());
406 }
407
408 pub(crate) fn scale_weights_and_claim(&mut self, scale: EF) {
413 self.poly.scale_weights(scale);
414 self.sum *= scale;
415 }
416
417 pub fn weights(&self) -> Poly<EF> {
419 self.poly.weights()
420 }
421
422 pub fn accumulate_claim(&mut self, weights_delta: &[EF], sum_delta: EF) {
429 self.poly.accumulate_weights(weights_delta);
430 self.sum += sum_delta;
431 debug_assert_eq!(self.sum, self.poly.dot_product());
432 }
433
434 #[tracing::instrument(skip_all)]
451 pub fn compute_sumcheck_polynomials<Challenger>(
452 &mut self,
453 sumcheck_data: &mut SumcheckData<F, EF>,
454 challenger: &mut Challenger,
455 folding_factor: usize,
456 pow_bits: usize,
457 constraint: Option<Constraint<F, EF>>,
458 ) -> Point<EF>
459 where
460 Challenger: FieldChallenger<F> + GrindingChallenger<Witness = F>,
461 {
462 if let Some(constraint) = constraint {
464 self.poly.combine(&mut self.sum, &constraint);
465 }
466
467 let res = (0..folding_factor)
469 .map(|_| {
470 self.poly
471 .round(sumcheck_data, challenger, &mut self.sum, pow_bits)
472 })
473 .collect();
474
475 Point::new(res)
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use alloc::vec::Vec;
482
483 use p3_baby_bear::BabyBear;
484 use p3_field::PrimeCharacteristicRing;
485 use p3_field::extension::BinomialExtensionField;
486 use p3_multilinear_util::point::Point;
487 use p3_multilinear_util::poly::Poly;
488 use proptest::prelude::*;
489 use rand::rngs::SmallRng;
490 use rand::{RngExt, SeedableRng};
491
492 use super::VariableOrder;
493 use crate::constraints::Constraint;
494 use crate::constraints::statement::{EqStatement, SelectStatement};
495
496 type F = BabyBear;
497 type EF = BinomialExtensionField<BabyBear, 4>;
498
499 fn eval_constraints_poly_reference(
502 order: VariableOrder,
503 constraints: &[Constraint<F, EF>],
504 challenge: &Point<EF>,
505 ) -> EF {
506 constraints
507 .iter()
508 .map(|constraint| {
509 let mut combined = Poly::zero(constraint.num_variables());
511 let mut eval = EF::ZERO;
512 constraint.combine(&mut combined, &mut eval);
513
514 let point = match order {
516 VariableOrder::Prefix => challenge
517 .reversed()
518 .get_subpoint_over_range(..constraint.num_variables())
519 .reversed(),
520 VariableOrder::Suffix => challenge
521 .reversed()
522 .get_subpoint_over_range(..constraint.num_variables()),
523 };
524
525 combined.eval_ext::<F>(&point)
526 })
527 .sum()
528 }
529
530 fn random_constraints(
532 rng: &mut SmallRng,
533 num_variables: usize,
534 rounds: usize,
535 ) -> Vec<Constraint<F, EF>> {
536 (0..rounds)
537 .map(|_| {
538 let num_variables = rng.random_range(1..=num_variables);
539 let gamma = rng.random();
540
541 let mut eq_statement = EqStatement::initialize(num_variables);
543 (0..rng.random_range(0..=3)).for_each(|_| {
544 eq_statement
545 .add_evaluated_constraint(Point::rand(rng, num_variables), rng.random());
546 });
547
548 let mut sel_statement = SelectStatement::<F, EF>::initialize(num_variables);
550 (0..rng.random_range(0..=3))
551 .for_each(|_| sel_statement.add_constraint(rng.random(), rng.random()));
552
553 Constraint::new(gamma, eq_statement, sel_statement)
554 })
555 .collect()
556 }
557
558 #[test]
559 fn test_eval_constraints_poly_prefix() {
560 let mut rng = SmallRng::seed_from_u64(0);
562 let constraints = random_constraints(&mut rng, 20, 6);
563 let challenge = Point::rand(&mut rng, 20);
564
565 let got = VariableOrder::Prefix.eval_constraints_poly(&constraints, &challenge);
567 let expected =
568 eval_constraints_poly_reference(VariableOrder::Prefix, &constraints, &challenge);
569 assert_eq!(got, expected);
570 }
571
572 #[test]
573 fn test_eval_constraints_poly_suffix() {
574 let mut rng = SmallRng::seed_from_u64(1);
576 let constraints = random_constraints(&mut rng, 20, 6);
577 let challenge = Point::rand(&mut rng, 20);
578
579 let got = VariableOrder::Suffix.eval_constraints_poly(&constraints, &challenge);
581 let expected =
582 eval_constraints_poly_reference(VariableOrder::Suffix, &constraints, &challenge);
583 assert_eq!(got, expected);
584 }
585
586 proptest! {
587 #[test]
591 fn prop_eval_constraints_poly_matches_reference(
592 total_num_variables in 2usize..=20,
593 rounds in 1usize..=8,
594 seed in any::<u64>(),
595 ) {
596 let mut rng = SmallRng::seed_from_u64(seed);
597 let constraints = random_constraints(&mut rng, total_num_variables, rounds);
598 let challenge = Point::rand(&mut rng, total_num_variables);
599
600 prop_assert_eq!(
601 VariableOrder::Prefix.eval_constraints_poly(&constraints, &challenge),
602 eval_constraints_poly_reference(VariableOrder::Prefix, &constraints, &challenge),
603 );
604 prop_assert_eq!(
605 VariableOrder::Suffix.eval_constraints_poly(&constraints, &challenge),
606 eval_constraints_poly_reference(VariableOrder::Suffix, &constraints, &challenge),
607 );
608 }
609 }
610}