Skip to main content

p3_sumcheck/layout/prover/
suffix.rs

1//! Suffix-mode stacked-sumcheck prover.
2
3use alloc::vec;
4use alloc::vec::Vec;
5
6use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
7use p3_commit::Mmcs;
8use p3_dft::TwoAdicSubgroupDft;
9use p3_field::{ExtensionField, Field, TwoAdicField, dot_product};
10use p3_matrix::dense::DenseMatrix;
11use p3_multilinear_util::point::Point;
12use p3_multilinear_util::poly::Poly;
13use p3_multilinear_util::split_eq::SplitEq;
14
15use crate::commit::commit_base;
16use crate::lagrange::lagrange_weights_01inf_multi;
17use crate::layout::opening::{Opening, ProverMultiClaim, ProverVirtualClaim};
18use crate::layout::prover::Layout;
19use crate::layout::witness::{Table, TablePlacement};
20use crate::layout::{LayoutStrategy, Witness};
21use crate::product_polynomial::ProductPolynomial;
22use crate::strategy::{SumcheckProver, VariableOrder};
23use crate::svo::{SvoPoint, calculate_accumulators_batch};
24use crate::{Claim, SumcheckData, extrapolate_01inf};
25
26/// Stacked-sumcheck prover with suffix-first variable binding.
27///
28/// # Flow
29///
30/// - SVO accumulators are precomputed at claim-recording time.
31/// - Each preprocessing round reads its slice of those accumulators.
32/// - The residual product polynomial is built once, after all rounds.
33#[derive(Debug, Clone)]
34pub struct SuffixProver<F: Field, EF: ExtensionField<F>> {
35    /// Source tables behind the stacked polynomial.
36    pub(crate) tables: Vec<Table<F>>,
37    /// Per-table placement metadata inside the stacked polynomial.
38    pub(crate) placements: Vec<TablePlacement>,
39    /// Number of variables of the stacked polynomial.
40    pub(crate) num_variables: usize,
41    /// Number of preprocessing rounds consumed before residual sumcheck.
42    pub(crate) folding: usize,
43    /// Concrete claims recorded per source table (carries per-round SVO partials).
44    ///
45    /// # Invariants
46    ///
47    /// - Every opening stored here is tied to a concrete source column.
48    /// - Virtual openings never enter this map.
49    /// - Claims are appended in insertion order.
50    pub(crate) claim_map: Vec<Vec<ProverMultiClaim<F, EF>>>,
51    /// Virtual claims carrying precomputed SVO accumulators.
52    pub(crate) virtual_claims: Vec<ProverVirtualClaim<EF>>,
53}
54
55impl<F: TwoAdicField, EF: ExtensionField<F>> Layout<F, EF> for SuffixProver<F, EF> {
56    fn from_witness(witness: Witness<F>) -> Self {
57        // Move the witness fields out so the prover owns them outright.
58        // The stacked polynomial is intentionally discarded: every suffix-mode
59        // primitive walks the per-table data instead.
60        let parts = witness.into_parts();
61        // One claim list per source table; virtual claims live in their own bucket.
62        let num_tables = parts.tables.len();
63        Self {
64            tables: parts.tables,
65            placements: parts.placements,
66            num_variables: parts.num_variables,
67            folding: parts.folding,
68            claim_map: (0..num_tables).map(|_| Vec::new()).collect(),
69            virtual_claims: Vec::new(),
70        }
71    }
72
73    fn new_witness(tables: Vec<Table<F>>, folding: usize) -> Witness<F> {
74        Witness::new(tables, folding)
75    }
76
77    fn commit<Dft, MT, Challenger>(
78        dft: &Dft,
79        mmcs: &MT,
80        challenger: &mut Challenger,
81        witness: Witness<F>,
82        folding: usize,
83        starting_log_inv_rate: usize,
84    ) -> (Self, MT::Commitment, MT::ProverData<DenseMatrix<F>>)
85    where
86        Dft: TwoAdicSubgroupDft<F>,
87        MT: Mmcs<F>,
88        Challenger: CanObserve<MT::Commitment>,
89    {
90        let (root, prover_data) = commit_base(
91            Self::variable_order(),
92            dft,
93            mmcs,
94            challenger,
95            &witness.poly,
96            folding,
97            starting_log_inv_rate,
98        );
99
100        (Self::from_witness(witness), root, prover_data)
101    }
102
103    fn folding(&self) -> usize {
104        self.folding
105    }
106
107    /// Returns the number of variables of the stacked polynomial.
108    fn num_variables(&self) -> usize {
109        self.num_variables
110    }
111
112    /// Returns the number of variables of table `id`.
113    fn num_variables_table(&self, id: usize) -> usize {
114        self.tables[id].num_variables()
115    }
116
117    /// Records opening claims for the selected columns of `table_idx`.
118    ///
119    /// # Arguments
120    ///
121    /// - `table_idx`  — source table index.
122    /// - `polys`      — columns to open; must be non-empty.
123    /// - `challenger` — Fiat–Shamir transcript.
124    ///
125    /// # Fiat–Shamir
126    ///
127    /// - Samples the opening point internally from the challenger.
128    /// - Absorbs the evaluations into the transcript before returning.
129    /// - The verifier's `add_claim` performs the symmetric absorption.
130    ///
131    /// # Panics
132    ///
133    /// - Columns list must be non-empty.
134    #[tracing::instrument(skip_all)]
135    fn eval<Ch>(&mut self, table_idx: usize, polys: &[usize], challenger: &mut Ch) -> Vec<EF>
136    where
137        Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
138    {
139        // Precondition: opening nothing would silently push an empty ProverMultiClaim.
140        assert!(
141            !polys.is_empty(),
142            "opening schedule must name at least one column"
143        );
144
145        // Sample the local-frame opening point from the transcript.
146        let table = &self.tables[table_idx];
147        let point = Point::expand_from_univariate(
148            challenger.sample_algebra_element(),
149            table.num_variables(),
150        );
151
152        // Factorise the point with the suffix split; every selected column reuses it.
153        let point = SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Suffix);
154
155        // Evaluate each requested column and split into (opening, eval) in a single pass.
156        let (openings, evals): (Vec<_>, Vec<EF>) = polys
157            .iter()
158            .map(|&poly_idx| {
159                // Per-column eval plus the per-round partial-eval polynomials.
160                let (eval, partial_evals) = point.eval(table.poly(poly_idx));
161                // Wrap the outputs as a concrete opening on this column.
162                let opening = Opening {
163                    poly_idx: Some(poly_idx),
164                    eval,
165                    data: partial_evals,
166                };
167                (opening, eval)
168            })
169            .unzip();
170
171        // Bind the evaluations into the transcript; the verifier absorbs the same bytes.
172        challenger.observe_algebra_slice(&evals);
173
174        // Store the batch with its shared SVO point.
175        self.claim_map[table_idx].push(ProverMultiClaim::new(point, openings));
176
177        evals
178    }
179
180    /// Samples a virtual evaluation on the full stacked polynomial.
181    ///
182    /// # Why heavier than prefix binding
183    ///
184    /// The stacked evaluation factors per column via the selector:
185    ///
186    /// ```text
187    ///     stacked(point) = sum_{i}  eq(selector_i, point_selector_part)
188    ///                               * col_i(point_local_part)
189    /// ```
190    ///
191    /// # Flow
192    ///
193    /// - Each column is evaluated at its local sub-point.
194    /// - Per-column partials are collected on the fly.
195    /// - Those partials feed the SVO accumulator batcher.
196    #[tracing::instrument(skip_all)]
197    fn add_virtual_eval<Ch>(&mut self, challenger: &mut Ch) -> EF
198    where
199        Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
200    {
201        // Sample a challenge point covering every stacked variable.
202        let point =
203            Point::expand_from_univariate(challenger.sample_algebra_element(), self.num_variables);
204
205        // Per-column accumulation state:
206        //
207        // - eval    : running stacked evaluation.
208        // - openings: one virtual opening per column, carrying SVO partials.
209        // - weights : per-column selector-equality scalars.
210        let mut eval = EF::ZERO;
211        let mut openings = Vec::new();
212        let mut weights = Vec::new();
213
214        for placement in &self.placements {
215            for (poly_idx, selector) in placement.selectors().iter().enumerate() {
216                // Source column behind this slot.
217                let poly = self.tables[placement.idx()].poly(poly_idx);
218
219                // Split the challenge into (selector_bits, local_bits).
220                let (selector_part, local_part) = point.split_at(selector.num_variables());
221
222                // Scalar weight: eq(selector, selector_part) for this column.
223                let weight =
224                    Point::eval_eq::<EF>(selector.point().as_slice(), selector_part.as_slice());
225
226                // Factorise the local part with the suffix split, then evaluate.
227                let local_svo =
228                    SvoPoint::new_unpacked(self.folding, &local_part, VariableOrder::Suffix);
229                let (column_eval, partial_evals) = local_svo.eval(poly);
230
231                // Record a virtual opening (no source column tag) with partials.
232                let opening = Opening {
233                    poly_idx: None,
234                    eval: column_eval,
235                    data: partial_evals,
236                };
237
238                // Add the weighted column evaluation into the stacked total.
239                eval += weight * column_eval;
240
241                // Stash opening and weight for the accumulator-batcher call.
242                openings.push(opening);
243                weights.push(weight);
244            }
245        }
246
247        // Batch every per-column opening into per-round SVO accumulators.
248        let accumulators = calculate_accumulators_batch(
249            &ProverMultiClaim::new(
250                SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Suffix),
251                openings,
252            ),
253            &weights,
254        );
255
256        // Debug-only consistency check:
257        //
258        // - hand-rolled weighted sum must equal the direct stacked evaluation.
259        // - accumulators batched per column must equal the single-opening batch.
260        #[cfg(debug_assertions)]
261        {
262            // Materialise the stacked polynomial with no challenges applied.
263            let poly = &self.compress_stacked(&Point::default());
264            // Check 1: weighted sum equals the direct evaluation.
265            assert_eq!(eval, poly.eval_base(&point));
266
267            // Build the reference opening by evaluating the materialised poly directly.
268            let ref_svo =
269                SvoPoint::<EF, EF>::new_unpacked(self.folding, &point, VariableOrder::Suffix);
270            let (ref_eval, ref_partials) = ref_svo.eval(poly);
271            let opening = Opening {
272                poly_idx: None,
273                eval: ref_eval,
274                data: ref_partials,
275            };
276            // Check 2: the reference evaluation matches the weighted one.
277            assert_eq!(eval, ref_eval);
278            // Check 3: accumulators from per-column batching match the single-opening batch.
279            assert_eq!(
280                accumulators,
281                calculate_accumulators_batch(
282                    &ProverMultiClaim::new(
283                        SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Suffix),
284                        vec![opening],
285                    ),
286                    &[EF::ONE],
287                ),
288            );
289        }
290
291        // Commit the evaluation to the transcript and record the claim.
292        challenger.observe_algebra_element(eval);
293        self.virtual_claims.push(Claim {
294            point,
295            eval,
296            data: accumulators,
297        });
298
299        eval
300    }
301
302    /// Finalises SVO preprocessing and returns the residual sumcheck prover.
303    ///
304    /// # Returns
305    ///
306    /// - Residual sumcheck prover over the unpacked product polynomial.
307    /// - Folding challenges sampled during preprocessing.
308    ///
309    /// # Algorithm
310    ///
311    /// ```text
312    ///     Phase | Action
313    ///     ------+------------------------------------------------------------
314    ///       1   | Sample batching challenge  a; flatten alphas by opening_idx.
315    ///       2   | Pre-batch per-claim accumulators with the a-powers.
316    ///       3   | Loop over preprocessing rounds:
317    ///               a. (h(0), h(inf)) = dot(accumulators, Lagrange weights).
318    ///               b. Sample challenge r; extrapolate the running sum.
319    ///       4   | Compose the residual product polynomial from compressed slots.
320    /// ```
321    #[tracing::instrument(skip_all)]
322    fn into_sumcheck<Ch>(
323        self,
324        sumcheck_data: &mut SumcheckData<F, EF>,
325        pow_bits: usize,
326        challenger: &mut Ch,
327    ) -> (SumcheckProver<F, EF>, Point<EF>)
328    where
329        Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
330    {
331        // Sanity: preprocessing cannot consume more rounds than the stacked arity.
332        assert!(self.folding <= self.num_variables);
333        let alpha: EF = challenger.sample_algebra_element();
334        let n_claims = self.num_claims();
335
336        // Stage A: batch per-claim accumulators using insertion-order alpha powers.
337        //
338        // - Iteration order is placement order, matching `sum` and `combine_eqs`.
339        // - Each claim consumes exactly `claim.len()` consecutive powers from
340        //   the shared iterator, so the per-claim alpha vector is aligned with
341        //   the claim's opening list by construction.
342        let mut alphas = alpha.powers();
343        let accumulators: Vec<_> = self
344            .placements
345            .iter()
346            .flat_map(|placement| self.claim_map[placement.idx()].iter())
347            .map(|claim| {
348                let per_claim: Vec<EF> = alphas.by_ref().take(claim.len()).collect();
349                calculate_accumulators_batch(claim, &per_claim)
350            })
351            .collect();
352
353        // Stage C: drive the preprocessing rounds from the accumulators.
354        let mut sum = self.sum(alpha);
355        let mut rs: Vec<EF> = vec![];
356
357        for round_idx in 0..self.folding {
358            // Lagrange weights at the challenges sampled so far.
359            let weights = lagrange_weights_01inf_multi(&rs);
360
361            // Round-coefficient identity (linearity of the dot product):
362            //
363            //     c0    = sum_c  dot(claim_c.accs[0], weights)
364            //           + sum_v  alpha_v * dot(virtual_v.accs[0], weights)
365            //     c_inf = same with accs[1]
366            //
367            // - Concrete claims carry alpha pre-batched in stage B.
368            // - Virtual claims keep a separate scalar per claim.
369            // - No intermediate element-wise accumulator is needed.
370            let mut c0 = EF::ZERO;
371            let mut c_inf = EF::ZERO;
372
373            for accs in &accumulators {
374                c0 += dot_product::<EF, _, _>(
375                    accs[round_idx][0].iter().copied(),
376                    weights.iter().copied(),
377                );
378                c_inf += dot_product::<EF, _, _>(
379                    accs[round_idx][1].iter().copied(),
380                    weights.iter().copied(),
381                );
382            }
383
384            // Virtual-claim contributions: scale each claim's dot by its alpha power.
385            for (vc, alpha_i) in self
386                .virtual_claims
387                .iter()
388                .zip(alpha.shifted_powers(alpha.exp_u64(n_claims as u64)))
389            {
390                let vc_accs = &vc.data;
391                c0 += alpha_i
392                    * dot_product::<EF, _, _>(
393                        vc_accs[round_idx][0].iter().copied(),
394                        weights.iter().copied(),
395                    );
396                c_inf += alpha_i
397                    * dot_product::<EF, _, _>(
398                        vc_accs[round_idx][1].iter().copied(),
399                        weights.iter().copied(),
400                    );
401            }
402
403            // Observe coefficients, sample r, extrapolate the running sum.
404            let r = sumcheck_data.observe_and_sample(challenger, c0, c_inf, pow_bits);
405            sum = extrapolate_01inf(c0, sum - c0, c_inf, r);
406            rs.push(r);
407        }
408
409        // Stage D: materialise the residual product polynomial.
410        //
411        // - Suffix binding folds variables in reverse.
412        // - The residual poly therefore lives in the reversed-challenges frame.
413        let rs = Point::new(rs);
414        // Reverse the challenges before handing them to the compressors.
415        let reversed = rs.reversed();
416        // Factor 1 of the product: the compressed stacked poly at rs.
417        // No external scaling here; the plain path keeps the running sum unchanged.
418        let compressed = self.compress_stacked(&reversed);
419        // Factor 2 of the product: the batched equality-weight poly.
420        let weights = self.combine_eqs(&reversed, alpha);
421        // Pair them; the product polynomial drives the remaining rounds.
422        let poly = ProductPolynomial::new_unpacked(VariableOrder::Suffix, compressed, weights);
423        // Cross-check: the dot product of the two factors must equal the
424        // running sum accumulated across the preprocessing rounds.
425        debug_assert_eq!(poly.dot_product(), sum);
426
427        (SumcheckProver::new(poly, sum), rs)
428    }
429
430    /// Returns the total number of concrete openings recorded so far.
431    fn num_claims(&self) -> usize {
432        self.claim_map
433            .iter()
434            .flat_map(|claims| claims.iter().map(ProverMultiClaim::len))
435            .sum()
436    }
437
438    fn strategy() -> LayoutStrategy {
439        LayoutStrategy::new(false, VariableOrder::Suffix)
440    }
441}
442
443impl<F: TwoAdicField, EF: ExtensionField<F>> SuffixProver<F, EF> {
444    /// Computes the batched claimed sum from concrete and virtual openings.
445    ///
446    /// # Identity
447    ///
448    /// ```text
449    ///     sum = sum_{i}  alpha^i * eval_i
450    /// ```
451    ///
452    /// # Alpha ordering
453    ///
454    /// Powers of `alpha` are handed out in insertion order:
455    ///
456    /// - Outer: placements, in the order the witness laid them out.
457    /// - Middle: claims recorded against that placement's source table.
458    /// - Inner: openings inside each claim, in the order they were recorded.
459    ///
460    /// # Virtual claims
461    ///
462    /// - Virtual evaluations continue the same alpha sequence.
463    /// - They start at `alpha^n`, with `n` the total concrete opening count.
464    ///
465    /// # Verifier agreement
466    ///
467    /// The verifier walks its claim registry with the same three-loop order,
468    /// so both sides assign the same `alpha^i` to the same claim point.
469    pub(crate) fn sum(&self, alpha: EF) -> EF {
470        let mut sum = EF::ZERO;
471        let mut alphas = alpha.powers();
472
473        // Concrete openings: three loops, no filter.
474        for placement in &self.placements {
475            for claim in &self.claim_map[placement.idx()] {
476                for opening in claim.openings() {
477                    sum += opening.eval() * alphas.next().unwrap();
478                }
479            }
480        }
481
482        // Virtual claims continue the alpha sequence right after the concrete ones.
483        sum += dot_product::<EF, _, _>(
484            self.virtual_claims.iter().map(Claim::eval),
485            alpha.shifted_powers(alpha.exp_u64(self.num_claims() as u64)),
486        );
487
488        sum
489    }
490
491    /// Compress every stacked-table slot by fixing the suffix challenges.
492    #[tracing::instrument(skip_all)]
493    pub(crate) fn compress_stacked(&self, rs: &Point<EF>) -> Poly<EF> {
494        self.compress_stacked_scaled(rs, EF::ONE)
495    }
496
497    /// Compress every stacked-table slot, folding `scale` into the equality table.
498    ///
499    /// ```text
500    ///     out[slot, x_rest] = sum_{y in {0,1}^|r|}  scale * eq(r, y) * col(x_rest, y)
501    /// ```
502    ///
503    /// One output slot per column.
504    /// Writes never overlap.
505    /// Output arity equals the stacked arity minus the challenge count.
506    ///
507    /// # Arguments
508    ///
509    /// - `rs` — suffix challenges already sampled.
510    /// - `scale` — extra factor folded into the equality table.
511    ///
512    /// # Why a scale parameter
513    ///
514    /// - A non-unit `scale` lets the caller absorb a combining challenge into
515    ///   the residual factor without a second pass.
516    ///
517    /// # Panics
518    ///
519    /// - `scale` is zero: a zero scale silently zeroes the residual.
520    #[tracing::instrument(skip_all)]
521    pub(crate) fn compress_stacked_scaled(&self, rs: &Point<EF>, scale: EF) -> Poly<EF> {
522        assert!(rs.num_variables() <= self.num_variables);
523        assert!(scale != EF::ZERO, "compress scale must be non-zero");
524        // Output spans the residual stacked space.
525        // Size is 2^(num_variables - |rs|).
526        let mut out = Poly::<EF>::zero(self.num_variables - rs.num_variables());
527        // Bake the scalar into the prefix-half equality table.
528        // Each slot compression then returns scale * eq(r, y) * col(...) in one pass.
529        let rs = SplitEq::new_unpacked(rs, scale);
530
531        for placement in &self.placements {
532            for (poly_idx, selector) in placement.selectors().iter().enumerate() {
533                let poly = self.tables[placement.idx()].poly(poly_idx);
534                assert!(rs.num_variables() <= poly.num_variables());
535                // Slot start in the compressed output.
536                let off = selector.index() << (poly.num_variables() - rs.num_variables());
537                // Write this column's compression into its own slot.
538                rs.compress_suffix_into(
539                    &mut out.as_mut_slice()
540                        [off..off + (1 << (poly.num_variables() - rs.num_variables()))],
541                    poly,
542                );
543            }
544        }
545        out
546    }
547
548    /// Builds the residual weight polynomial after the SVO rounds.
549    ///
550    /// # Contributions
551    ///
552    /// - Concrete claim: factored equality table scaled by
553    ///   `alpha^i * eq(svo_part, rs)`, written into the owning slot only.
554    /// - Virtual claim: scaled equality table written across the full output.
555    #[tracing::instrument(skip_all)]
556    pub(crate) fn combine_eqs(&self, rs: &Point<EF>, alpha: EF) -> Poly<EF> {
557        // Preconditions: challenge count matches the folding depth.
558        assert_eq!(rs.num_variables(), self.folding);
559        // Output arity: stacked arity minus the folded challenges.
560        let mut out = Poly::<EF>::zero(self.num_variables - rs.num_variables());
561
562        let mut alphas = alpha.powers();
563
564        // Concrete claims: write each into the slot its column's selector addresses.
565        for placement in &self.placements {
566            let num_variables_table = self.num_variables_table(placement.idx());
567            let slot_size = 1usize << num_variables_table;
568            for claim in &self.claim_map[placement.idx()] {
569                for opening in claim.openings() {
570                    // The opening's column tells us which selector picks the slot.
571                    let col = opening.poly_idx().unwrap();
572                    let off = placement.selectors()[col].index() << num_variables_table;
573                    // Fold the scalar slot range down by the SVO depth.
574                    let folded_range = (off >> self.folding)..((off + slot_size) >> self.folding);
575                    claim.point().accumulate_into(
576                        &mut out.as_mut_slice()[folded_range],
577                        rs,
578                        alphas.next().unwrap(),
579                    );
580                }
581            }
582        }
583
584        // Virtual claims: span the full output; alpha continues after concrete ones.
585        let mut alpha_i = alpha.exp_u64(self.num_claims() as u64);
586        for claim in &self.virtual_claims {
587            // Split the claim point into (rest-of-space, svo-sub-point).
588            let (rest, svo) = claim
589                .point
590                .split_at(claim.point.num_variables() - rs.num_variables());
591            // Scalar weight: alpha^i times the equality between svo part and rs.
592            let scale = alpha_i * Point::eval_eq(svo.as_slice(), rs.as_slice());
593            // Contribute the scaled equality table across the whole output.
594            SplitEq::new_packed(&rest, scale).accumulate_into(out.as_mut_slice(), None);
595            // Advance alpha for the next virtual claim.
596            alpha_i *= alpha;
597        }
598
599        out
600    }
601}