p3_sumcheck/layout/prover/suffix.rs
1//! Suffix-mode stacked-sumcheck prover.
2
3use alloc::vec;
4use alloc::vec::Vec;
5
6use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
7use p3_commit::Mmcs;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{ExtensionField, Field, TwoAdicField, dot_product};
10use p3_matrix::dense::DenseMatrix;
11use p3_multilinear_util::point::Point;
12use p3_multilinear_util::poly::Poly;
13use p3_multilinear_util::split_eq::SplitEq;
14
15use crate::commit::commit_base;
16use crate::lagrange::lagrange_weights_01inf_multi;
17use crate::layout::opening::{Opening, ProverMultiClaim, ProverVirtualClaim};
18use crate::layout::prover::Layout;
19use crate::layout::witness::{Table, TablePlacement};
20use crate::layout::{LayoutStrategy, Witness};
21use crate::product_polynomial::ProductPolynomial;
22use crate::strategy::{SumcheckProver, VariableOrder};
23use crate::svo::{SvoPoint, calculate_accumulators_batch};
24use crate::{Claim, SumcheckData, extrapolate_01inf};
25
26/// Stacked-sumcheck prover with suffix-first variable binding.
27///
28/// # Flow
29///
30/// - SVO accumulators are precomputed at claim-recording time.
31/// - Each preprocessing round reads its slice of those accumulators.
32/// - The residual product polynomial is built once, after all rounds.
33#[derive(Debug, Clone)]
34pub struct SuffixProver<F: Field, EF: ExtensionField<F>> {
35 /// Source tables behind the stacked polynomial.
36 pub(crate) tables: Vec<Table<F>>,
37 /// Per-table placement metadata inside the stacked polynomial.
38 pub(crate) placements: Vec<TablePlacement>,
39 /// Number of variables of the stacked polynomial.
40 pub(crate) num_variables: usize,
41 /// Number of preprocessing rounds consumed before residual sumcheck.
42 pub(crate) folding: usize,
43 /// Concrete claims recorded per source table (carries per-round SVO partials).
44 ///
45 /// # Invariants
46 ///
47 /// - Every opening stored here is tied to a concrete source column.
48 /// - Virtual openings never enter this map.
49 /// - Claims are appended in insertion order.
50 pub(crate) claim_map: Vec<Vec<ProverMultiClaim<F, EF>>>,
51 /// Virtual claims carrying precomputed SVO accumulators.
52 pub(crate) virtual_claims: Vec<ProverVirtualClaim<EF>>,
53}
54
55impl<F: TwoAdicField, EF: ExtensionField<F>> Layout<F, EF> for SuffixProver<F, EF> {
56 fn from_witness(witness: Witness<F>) -> Self {
57 // Move the witness fields out so the prover owns them outright.
58 // The stacked polynomial is intentionally discarded: every suffix-mode
59 // primitive walks the per-table data instead.
60 let parts = witness.into_parts();
61 // One claim list per source table; virtual claims live in their own bucket.
62 let num_tables = parts.tables.len();
63 Self {
64 tables: parts.tables,
65 placements: parts.placements,
66 num_variables: parts.num_variables,
67 folding: parts.folding,
68 claim_map: (0..num_tables).map(|_| Vec::new()).collect(),
69 virtual_claims: Vec::new(),
70 }
71 }
72
73 fn new_witness(tables: Vec<Table<F>>, folding: usize) -> Witness<F> {
74 Witness::new(tables, folding)
75 }
76
77 fn commit<Dft, MT, Challenger>(
78 dft: &Dft,
79 mmcs: &MT,
80 challenger: &mut Challenger,
81 witness: Witness<F>,
82 folding: usize,
83 starting_log_inv_rate: usize,
84 ) -> (Self, MT::Commitment, MT::ProverData<DenseMatrix<F>>)
85 where
86 Dft: TwoAdicSubgroupDft<F>,
87 MT: Mmcs<F>,
88 Challenger: CanObserve<MT::Commitment>,
89 {
90 let (root, prover_data) = commit_base(
91 Self::variable_order(),
92 dft,
93 mmcs,
94 challenger,
95 &witness.poly,
96 folding,
97 starting_log_inv_rate,
98 );
99
100 (Self::from_witness(witness), root, prover_data)
101 }
102
103 fn folding(&self) -> usize {
104 self.folding
105 }
106
107 /// Returns the number of variables of the stacked polynomial.
108 fn num_variables(&self) -> usize {
109 self.num_variables
110 }
111
112 /// Returns the number of variables of table `id`.
113 fn num_variables_table(&self, id: usize) -> usize {
114 self.tables[id].num_variables()
115 }
116
117 /// Records opening claims for the selected columns of `table_idx`.
118 ///
119 /// # Arguments
120 ///
121 /// - `table_idx` — source table index.
122 /// - `polys` — columns to open; must be non-empty.
123 /// - `challenger` — Fiat–Shamir transcript.
124 ///
125 /// # Fiat–Shamir
126 ///
127 /// - Samples the opening point internally from the challenger.
128 /// - Absorbs the evaluations into the transcript before returning.
129 /// - The verifier's `add_claim` performs the symmetric absorption.
130 ///
131 /// # Panics
132 ///
133 /// - Columns list must be non-empty.
134 #[tracing::instrument(skip_all)]
135 fn eval<Ch>(&mut self, table_idx: usize, polys: &[usize], challenger: &mut Ch) -> Vec<EF>
136 where
137 Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
138 {
139 // Precondition: opening nothing would silently push an empty ProverMultiClaim.
140 assert!(
141 !polys.is_empty(),
142 "opening schedule must name at least one column"
143 );
144
145 // Sample the local-frame opening point from the transcript.
146 let table = &self.tables[table_idx];
147 let point = Point::expand_from_univariate(
148 challenger.sample_algebra_element(),
149 table.num_variables(),
150 );
151
152 // Factorise the point with the suffix split; every selected column reuses it.
153 let point = SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Suffix);
154
155 // Evaluate each requested column and split into (opening, eval) in a single pass.
156 let (openings, evals): (Vec<_>, Vec<EF>) = polys
157 .iter()
158 .map(|&poly_idx| {
159 // Per-column eval plus the per-round partial-eval polynomials.
160 let (eval, partial_evals) = point.eval(table.poly(poly_idx));
161 // Wrap the outputs as a concrete opening on this column.
162 let opening = Opening {
163 poly_idx: Some(poly_idx),
164 eval,
165 data: partial_evals,
166 };
167 (opening, eval)
168 })
169 .unzip();
170
171 // Bind the evaluations into the transcript; the verifier absorbs the same bytes.
172 challenger.observe_algebra_slice(&evals);
173
174 // Store the batch with its shared SVO point.
175 self.claim_map[table_idx].push(ProverMultiClaim::new(point, openings));
176
177 evals
178 }
179
180 /// Samples a virtual evaluation on the full stacked polynomial.
181 ///
182 /// # Why heavier than prefix binding
183 ///
184 /// The stacked evaluation factors per column via the selector:
185 ///
186 /// ```text
187 /// stacked(point) = sum_{i} eq(selector_i, point_selector_part)
188 /// * col_i(point_local_part)
189 /// ```
190 ///
191 /// # Flow
192 ///
193 /// - Each column is evaluated at its local sub-point.
194 /// - Per-column partials are collected on the fly.
195 /// - Those partials feed the SVO accumulator batcher.
196 #[tracing::instrument(skip_all)]
197 fn add_virtual_eval<Ch>(&mut self, challenger: &mut Ch) -> EF
198 where
199 Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
200 {
201 // Sample a challenge point covering every stacked variable.
202 let point =
203 Point::expand_from_univariate(challenger.sample_algebra_element(), self.num_variables);
204
205 // Per-column accumulation state:
206 //
207 // - eval : running stacked evaluation.
208 // - openings: one virtual opening per column, carrying SVO partials.
209 // - weights : per-column selector-equality scalars.
210 let mut eval = EF::ZERO;
211 let mut openings = Vec::new();
212 let mut weights = Vec::new();
213
214 for placement in &self.placements {
215 for (poly_idx, selector) in placement.selectors().iter().enumerate() {
216 // Source column behind this slot.
217 let poly = self.tables[placement.idx()].poly(poly_idx);
218
219 // Split the challenge into (selector_bits, local_bits).
220 let (selector_part, local_part) = point.split_at(selector.num_variables());
221
222 // Scalar weight: eq(selector, selector_part) for this column.
223 let weight =
224 Point::eval_eq::<EF>(selector.point().as_slice(), selector_part.as_slice());
225
226 // Factorise the local part with the suffix split, then evaluate.
227 let local_svo =
228 SvoPoint::new_unpacked(self.folding, &local_part, VariableOrder::Suffix);
229 let (column_eval, partial_evals) = local_svo.eval(poly);
230
231 // Record a virtual opening (no source column tag) with partials.
232 let opening = Opening {
233 poly_idx: None,
234 eval: column_eval,
235 data: partial_evals,
236 };
237
238 // Add the weighted column evaluation into the stacked total.
239 eval += weight * column_eval;
240
241 // Stash opening and weight for the accumulator-batcher call.
242 openings.push(opening);
243 weights.push(weight);
244 }
245 }
246
247 // Batch every per-column opening into per-round SVO accumulators.
248 let accumulators = calculate_accumulators_batch(
249 &ProverMultiClaim::new(
250 SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Suffix),
251 openings,
252 ),
253 &weights,
254 );
255
256 // Debug-only consistency check:
257 //
258 // - hand-rolled weighted sum must equal the direct stacked evaluation.
259 // - accumulators batched per column must equal the single-opening batch.
260 #[cfg(debug_assertions)]
261 {
262 // Materialise the stacked polynomial with no challenges applied.
263 let poly = &self.compress_stacked(&Point::default());
264 // Check 1: weighted sum equals the direct evaluation.
265 assert_eq!(eval, poly.eval_base(&point));
266
267 // Build the reference opening by evaluating the materialised poly directly.
268 let ref_svo =
269 SvoPoint::<EF, EF>::new_unpacked(self.folding, &point, VariableOrder::Suffix);
270 let (ref_eval, ref_partials) = ref_svo.eval(poly);
271 let opening = Opening {
272 poly_idx: None,
273 eval: ref_eval,
274 data: ref_partials,
275 };
276 // Check 2: the reference evaluation matches the weighted one.
277 assert_eq!(eval, ref_eval);
278 // Check 3: accumulators from per-column batching match the single-opening batch.
279 assert_eq!(
280 accumulators,
281 calculate_accumulators_batch(
282 &ProverMultiClaim::new(
283 SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Suffix),
284 vec![opening],
285 ),
286 &[EF::ONE],
287 ),
288 );
289 }
290
291 // Commit the evaluation to the transcript and record the claim.
292 challenger.observe_algebra_element(eval);
293 self.virtual_claims.push(Claim {
294 point,
295 eval,
296 data: accumulators,
297 });
298
299 eval
300 }
301
302 /// Finalises SVO preprocessing and returns the residual sumcheck prover.
303 ///
304 /// # Returns
305 ///
306 /// - Residual sumcheck prover over the unpacked product polynomial.
307 /// - Folding challenges sampled during preprocessing.
308 ///
309 /// # Algorithm
310 ///
311 /// ```text
312 /// Phase | Action
313 /// ------+------------------------------------------------------------
314 /// 1 | Sample batching challenge a; flatten alphas by opening_idx.
315 /// 2 | Pre-batch per-claim accumulators with the a-powers.
316 /// 3 | Loop over preprocessing rounds:
317 /// a. (h(0), h(inf)) = dot(accumulators, Lagrange weights).
318 /// b. Sample challenge r; extrapolate the running sum.
319 /// 4 | Compose the residual product polynomial from compressed slots.
320 /// ```
321 #[tracing::instrument(skip_all)]
322 fn into_sumcheck<Ch>(
323 self,
324 sumcheck_data: &mut SumcheckData<F, EF>,
325 pow_bits: usize,
326 challenger: &mut Ch,
327 ) -> (SumcheckProver<F, EF>, Point<EF>)
328 where
329 Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
330 {
331 // Sanity: preprocessing cannot consume more rounds than the stacked arity.
332 assert!(self.folding <= self.num_variables);
333 let alpha: EF = challenger.sample_algebra_element();
334 let n_claims = self.num_claims();
335
336 // Stage A: batch per-claim accumulators using insertion-order alpha powers.
337 //
338 // - Iteration order is placement order, matching `sum` and `combine_eqs`.
339 // - Each claim consumes exactly `claim.len()` consecutive powers from
340 // the shared iterator, so the per-claim alpha vector is aligned with
341 // the claim's opening list by construction.
342 let mut alphas = alpha.powers();
343 let accumulators: Vec<_> = self
344 .placements
345 .iter()
346 .flat_map(|placement| self.claim_map[placement.idx()].iter())
347 .map(|claim| {
348 let per_claim: Vec<EF> = alphas.by_ref().take(claim.len()).collect();
349 calculate_accumulators_batch(claim, &per_claim)
350 })
351 .collect();
352
353 // Stage C: drive the preprocessing rounds from the accumulators.
354 let mut sum = self.sum(alpha);
355 let mut rs: Vec<EF> = vec![];
356
357 for round_idx in 0..self.folding {
358 // Lagrange weights at the challenges sampled so far.
359 let weights = lagrange_weights_01inf_multi(&rs);
360
361 // Round-coefficient identity (linearity of the dot product):
362 //
363 // c0 = sum_c dot(claim_c.accs[0], weights)
364 // + sum_v alpha_v * dot(virtual_v.accs[0], weights)
365 // c_inf = same with accs[1]
366 //
367 // - Concrete claims carry alpha pre-batched in stage B.
368 // - Virtual claims keep a separate scalar per claim.
369 // - No intermediate element-wise accumulator is needed.
370 let mut c0 = EF::ZERO;
371 let mut c_inf = EF::ZERO;
372
373 for accs in &accumulators {
374 c0 += dot_product::<EF, _, _>(
375 accs[round_idx][0].iter().copied(),
376 weights.iter().copied(),
377 );
378 c_inf += dot_product::<EF, _, _>(
379 accs[round_idx][1].iter().copied(),
380 weights.iter().copied(),
381 );
382 }
383
384 // Virtual-claim contributions: scale each claim's dot by its alpha power.
385 for (vc, alpha_i) in self
386 .virtual_claims
387 .iter()
388 .zip(alpha.shifted_powers(alpha.exp_u64(n_claims as u64)))
389 {
390 let vc_accs = &vc.data;
391 c0 += alpha_i
392 * dot_product::<EF, _, _>(
393 vc_accs[round_idx][0].iter().copied(),
394 weights.iter().copied(),
395 );
396 c_inf += alpha_i
397 * dot_product::<EF, _, _>(
398 vc_accs[round_idx][1].iter().copied(),
399 weights.iter().copied(),
400 );
401 }
402
403 // Observe coefficients, sample r, extrapolate the running sum.
404 let r = sumcheck_data.observe_and_sample(challenger, c0, c_inf, pow_bits);
405 sum = extrapolate_01inf(c0, sum - c0, c_inf, r);
406 rs.push(r);
407 }
408
409 // Stage D: materialise the residual product polynomial.
410 //
411 // - Suffix binding folds variables in reverse.
412 // - The residual poly therefore lives in the reversed-challenges frame.
413 let rs = Point::new(rs);
414 // Reverse the challenges before handing them to the compressors.
415 let reversed = rs.reversed();
416 // Factor 1 of the product: the compressed stacked poly at rs.
417 // No external scaling here; the plain path keeps the running sum unchanged.
418 let compressed = self.compress_stacked(&reversed);
419 // Factor 2 of the product: the batched equality-weight poly.
420 let weights = self.combine_eqs(&reversed, alpha);
421 // Pair them; the product polynomial drives the remaining rounds.
422 let poly = ProductPolynomial::new_unpacked(VariableOrder::Suffix, compressed, weights);
423 // Cross-check: the dot product of the two factors must equal the
424 // running sum accumulated across the preprocessing rounds.
425 debug_assert_eq!(poly.dot_product(), sum);
426
427 (SumcheckProver::new(poly, sum), rs)
428 }
429
430 /// Returns the total number of concrete openings recorded so far.
431 fn num_claims(&self) -> usize {
432 self.claim_map
433 .iter()
434 .flat_map(|claims| claims.iter().map(ProverMultiClaim::len))
435 .sum()
436 }
437
438 fn strategy() -> LayoutStrategy {
439 LayoutStrategy::new(false, VariableOrder::Suffix)
440 }
441}
442
443impl<F: TwoAdicField, EF: ExtensionField<F>> SuffixProver<F, EF> {
444 /// Computes the batched claimed sum from concrete and virtual openings.
445 ///
446 /// # Identity
447 ///
448 /// ```text
449 /// sum = sum_{i} alpha^i * eval_i
450 /// ```
451 ///
452 /// # Alpha ordering
453 ///
454 /// Powers of `alpha` are handed out in insertion order:
455 ///
456 /// - Outer: placements, in the order the witness laid them out.
457 /// - Middle: claims recorded against that placement's source table.
458 /// - Inner: openings inside each claim, in the order they were recorded.
459 ///
460 /// # Virtual claims
461 ///
462 /// - Virtual evaluations continue the same alpha sequence.
463 /// - They start at `alpha^n`, with `n` the total concrete opening count.
464 ///
465 /// # Verifier agreement
466 ///
467 /// The verifier walks its claim registry with the same three-loop order,
468 /// so both sides assign the same `alpha^i` to the same claim point.
469 pub(crate) fn sum(&self, alpha: EF) -> EF {
470 let mut sum = EF::ZERO;
471 let mut alphas = alpha.powers();
472
473 // Concrete openings: three loops, no filter.
474 for placement in &self.placements {
475 for claim in &self.claim_map[placement.idx()] {
476 for opening in claim.openings() {
477 sum += opening.eval() * alphas.next().unwrap();
478 }
479 }
480 }
481
482 // Virtual claims continue the alpha sequence right after the concrete ones.
483 sum += dot_product::<EF, _, _>(
484 self.virtual_claims.iter().map(Claim::eval),
485 alpha.shifted_powers(alpha.exp_u64(self.num_claims() as u64)),
486 );
487
488 sum
489 }
490
491 /// Compress every stacked-table slot by fixing the suffix challenges.
492 #[tracing::instrument(skip_all)]
493 pub(crate) fn compress_stacked(&self, rs: &Point<EF>) -> Poly<EF> {
494 self.compress_stacked_scaled(rs, EF::ONE)
495 }
496
497 /// Compress every stacked-table slot, folding `scale` into the equality table.
498 ///
499 /// ```text
500 /// out[slot, x_rest] = sum_{y in {0,1}^|r|} scale * eq(r, y) * col(x_rest, y)
501 /// ```
502 ///
503 /// One output slot per column.
504 /// Writes never overlap.
505 /// Output arity equals the stacked arity minus the challenge count.
506 ///
507 /// # Arguments
508 ///
509 /// - `rs` — suffix challenges already sampled.
510 /// - `scale` — extra factor folded into the equality table.
511 ///
512 /// # Why a scale parameter
513 ///
514 /// - A non-unit `scale` lets the caller absorb a combining challenge into
515 /// the residual factor without a second pass.
516 ///
517 /// # Panics
518 ///
519 /// - `scale` is zero: a zero scale silently zeroes the residual.
520 #[tracing::instrument(skip_all)]
521 pub(crate) fn compress_stacked_scaled(&self, rs: &Point<EF>, scale: EF) -> Poly<EF> {
522 assert!(rs.num_variables() <= self.num_variables);
523 assert!(scale != EF::ZERO, "compress scale must be non-zero");
524 // Output spans the residual stacked space.
525 // Size is 2^(num_variables - |rs|).
526 let mut out = Poly::<EF>::zero(self.num_variables - rs.num_variables());
527 // Bake the scalar into the prefix-half equality table.
528 // Each slot compression then returns scale * eq(r, y) * col(...) in one pass.
529 let rs = SplitEq::new_unpacked(rs, scale);
530
531 for placement in &self.placements {
532 for (poly_idx, selector) in placement.selectors().iter().enumerate() {
533 let poly = self.tables[placement.idx()].poly(poly_idx);
534 assert!(rs.num_variables() <= poly.num_variables());
535 // Slot start in the compressed output.
536 let off = selector.index() << (poly.num_variables() - rs.num_variables());
537 // Write this column's compression into its own slot.
538 rs.compress_suffix_into(
539 &mut out.as_mut_slice()
540 [off..off + (1 << (poly.num_variables() - rs.num_variables()))],
541 poly,
542 );
543 }
544 }
545 out
546 }
547
548 /// Builds the residual weight polynomial after the SVO rounds.
549 ///
550 /// # Contributions
551 ///
552 /// - Concrete claim: factored equality table scaled by
553 /// `alpha^i * eq(svo_part, rs)`, written into the owning slot only.
554 /// - Virtual claim: scaled equality table written across the full output.
555 #[tracing::instrument(skip_all)]
556 pub(crate) fn combine_eqs(&self, rs: &Point<EF>, alpha: EF) -> Poly<EF> {
557 // Preconditions: challenge count matches the folding depth.
558 assert_eq!(rs.num_variables(), self.folding);
559 // Output arity: stacked arity minus the folded challenges.
560 let mut out = Poly::<EF>::zero(self.num_variables - rs.num_variables());
561
562 let mut alphas = alpha.powers();
563
564 // Concrete claims: write each into the slot its column's selector addresses.
565 for placement in &self.placements {
566 let num_variables_table = self.num_variables_table(placement.idx());
567 let slot_size = 1usize << num_variables_table;
568 for claim in &self.claim_map[placement.idx()] {
569 for opening in claim.openings() {
570 // The opening's column tells us which selector picks the slot.
571 let col = opening.poly_idx().unwrap();
572 let off = placement.selectors()[col].index() << num_variables_table;
573 // Fold the scalar slot range down by the SVO depth.
574 let folded_range = (off >> self.folding)..((off + slot_size) >> self.folding);
575 claim.point().accumulate_into(
576 &mut out.as_mut_slice()[folded_range],
577 rs,
578 alphas.next().unwrap(),
579 );
580 }
581 }
582 }
583
584 // Virtual claims: span the full output; alpha continues after concrete ones.
585 let mut alpha_i = alpha.exp_u64(self.num_claims() as u64);
586 for claim in &self.virtual_claims {
587 // Split the claim point into (rest-of-space, svo-sub-point).
588 let (rest, svo) = claim
589 .point
590 .split_at(claim.point.num_variables() - rs.num_variables());
591 // Scalar weight: alpha^i times the equality between svo part and rs.
592 let scale = alpha_i * Point::eval_eq(svo.as_slice(), rs.as_slice());
593 // Contribute the scaled equality table across the whole output.
594 SplitEq::new_packed(&rest, scale).accumulate_into(out.as_mut_slice(), None);
595 // Advance alpha for the next virtual claim.
596 alpha_i *= alpha;
597 }
598
599 out
600 }
601}