Skip to main content

p3_sumcheck/layout/prover/
prefix.rs

1//! Prefix-mode stacked-sumcheck prover.
2
3use alloc::vec::Vec;
4
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_commit::Mmcs;
7use p3_dft::TwoAdicSubgroupDft;
8use p3_field::{ExtensionField, Field, TwoAdicField, dot_product};
9use p3_matrix::dense::DenseMatrix;
10use p3_multilinear_util::point::Point;
11use p3_multilinear_util::poly::Poly;
12use p3_multilinear_util::split_eq::SplitEq;
13
14use crate::commit::commit_base;
15use crate::lagrange::lagrange_weights_01inf_multi;
16use crate::layout::opening::Opening;
17use crate::layout::prover::Layout;
18use crate::layout::witness::{Table, TablePlacement};
19use crate::layout::{LayoutStrategy, ProverMultiClaim, ProverVirtualClaim, Witness};
20use crate::product_polynomial::ProductPolynomial;
21use crate::strategy::{SumcheckProver, VariableOrder};
22use crate::svo::{SvoPoint, calculate_accumulators_batch};
23use crate::{Claim, SumcheckData, extrapolate_01inf};
24
25/// Stacked-sumcheck prover with prefix-first variable binding.
26///
27/// # Flow
28///
29/// - Round one runs in SIMD-packed form.
30/// - Every later round runs on the residual product polynomial.
31#[derive(Debug, Clone)]
32pub struct PrefixProver<F: Field, EF: ExtensionField<F>> {
33    /// Source tables behind the stacked polynomial.
34    pub(crate) tables: Vec<Table<F>>,
35    /// Per-table placement metadata inside the stacked polynomial.
36    pub(crate) placements: Vec<TablePlacement>,
37    /// Number of variables of the stacked polynomial.
38    pub(crate) num_variables: usize,
39    /// Number of preprocessing rounds consumed before residual sumcheck.
40    pub(crate) folding: usize,
41    /// Stacked committed polynomial.
42    pub(crate) poly: Poly<F>,
43    /// Concrete claims recorded per source table.
44    ///
45    /// # Invariants
46    ///
47    /// - Every opening stored here is tied to a concrete source column.
48    /// - Virtual openings never enter this map.
49    /// - Claims are appended in insertion order.
50    pub(crate) claim_map: Vec<Vec<ProverMultiClaim<F, EF>>>,
51    /// Virtual claims sampled directly on the stacked polynomial.
52    pub(crate) virtual_claims: Vec<ProverVirtualClaim<EF>>,
53}
54
55impl<F: TwoAdicField, EF: ExtensionField<F>> Layout<F, EF> for PrefixProver<F, EF> {
56    fn from_witness(witness: Witness<F>) -> Self {
57        // Move the witness fields out so the prover owns them outright.
58        let parts = witness.into_parts();
59        // One claim list per source table; virtual claims live in their own bucket.
60        let num_tables = parts.tables.len();
61        Self {
62            tables: parts.tables,
63            placements: parts.placements,
64            num_variables: parts.num_variables,
65            folding: parts.folding,
66            poly: parts.poly,
67            claim_map: (0..num_tables).map(|_| Vec::new()).collect(),
68            virtual_claims: Vec::new(),
69        }
70    }
71
72    fn new_witness(tables: Vec<Table<F>>, folding: usize) -> Witness<F> {
73        Witness::new_interleaved(tables, folding)
74    }
75
76    fn commit<Dft, MT, Challenger>(
77        dft: &Dft,
78        mmcs: &MT,
79        challenger: &mut Challenger,
80        witness: Witness<F>,
81        folding: usize,
82        starting_log_inv_rate: usize,
83    ) -> (Self, MT::Commitment, MT::ProverData<DenseMatrix<F>>)
84    where
85        Dft: TwoAdicSubgroupDft<F>,
86        MT: Mmcs<F>,
87        Challenger: CanObserve<MT::Commitment>,
88    {
89        let (root, prover_data) = commit_base(
90            Self::variable_order(),
91            dft,
92            mmcs,
93            challenger,
94            &witness.poly,
95            folding,
96            starting_log_inv_rate,
97        );
98
99        (Self::from_witness(witness), root, prover_data)
100    }
101
102    fn folding(&self) -> usize {
103        self.folding
104    }
105
106    /// Returns the number of variables of the stacked polynomial.
107    fn num_variables(&self) -> usize {
108        self.num_variables
109    }
110
111    /// Returns the number of variables of table `id`.
112    fn num_variables_table(&self, id: usize) -> usize {
113        self.tables[id].num_variables()
114    }
115
116    /// Records opening claims for the selected columns of `table_idx`.
117    ///
118    /// # Arguments
119    ///
120    /// - `table_idx`  — source table index.
121    /// - `polys`      — columns to open; must be non-empty.
122    /// - `challenger` — Fiat–Shamir transcript.
123    ///
124    /// # Fiat–Shamir
125    ///
126    /// - Samples the opening point internally from the challenger.
127    /// - Absorbs the evaluations into the transcript before returning.
128    /// - The verifier's `add_claim` performs the symmetric absorption.
129    ///
130    /// # Panics
131    ///
132    /// - Columns list must be non-empty.
133    #[tracing::instrument(skip_all)]
134    fn eval<Ch>(&mut self, table_idx: usize, polys: &[usize], challenger: &mut Ch) -> Vec<EF>
135    where
136        Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
137    {
138        // Precondition: opening nothing would silently push an empty ProverMultiClaim.
139        assert!(
140            !polys.is_empty(),
141            "opening schedule must name at least one column"
142        );
143
144        // Sample the local-frame opening point from the transcript.
145        let table = &self.tables[table_idx];
146        let point = Point::expand_from_univariate(
147            challenger.sample_algebra_element(),
148            table.num_variables(),
149        );
150
151        // Factorise the point once; every selected column reuses it.
152        let point = SvoPoint::new_packed(self.folding, &point);
153
154        // Evaluate each column at the SVO point; split into (opening, eval).
155        let (openings, evals): (Vec<_>, Vec<EF>) = polys
156            .iter()
157            .map(|&poly_idx| {
158                let (eval, partial_evals) = point.eval(table.poly(poly_idx));
159                let opening = Opening {
160                    poly_idx: Some(poly_idx),
161                    eval,
162                    data: partial_evals,
163                };
164                (opening, eval)
165            })
166            .unzip();
167
168        // Bind the evaluations into the transcript; the verifier absorbs the same bytes.
169        challenger.observe_algebra_slice(&evals);
170
171        // Store the batch for the later sumcheck reduction.
172        self.claim_map[table_idx].push(ProverMultiClaim::new(point, openings));
173
174        evals
175    }
176
177    /// Samples a virtual evaluation on the full stacked polynomial.
178    ///
179    /// # Why
180    ///
181    /// The WHIR protocol occasionally pins the stacked polynomial at a fresh
182    /// random point for soundness amplification. Prefix mode evaluates the
183    /// stacked polynomial directly — no per-column weighting needed.
184    #[tracing::instrument(skip_all)]
185    fn add_virtual_eval<Ch>(&mut self, challenger: &mut Ch) -> EF
186    where
187        Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
188    {
189        // Sample a challenge point covering every stacked variable.
190        let point =
191            Point::expand_from_univariate(challenger.sample_algebra_element(), self.num_variables);
192
193        let mut eval = EF::ZERO;
194        let mut openings = Vec::new();
195        let mut weights = Vec::new();
196
197        for placement in &self.placements {
198            let table = &self.tables[placement.idx()];
199            for (poly_idx, selector) in placement.selectors().iter().enumerate() {
200                let poly = table.poly(poly_idx);
201
202                let (local_part, selector_part) = point.split_at(table.num_variables());
203
204                let weight =
205                    Point::eval_eq::<EF>(selector.point().as_slice(), selector_part.as_slice());
206
207                let local_svo = SvoPoint::new_packed(self.folding, &local_part);
208                let (column_eval, partial_evals) = local_svo.eval(poly);
209
210                eval += weight * column_eval;
211                openings.push(Opening {
212                    poly_idx: None,
213                    eval: column_eval,
214                    data: partial_evals,
215                });
216                weights.push(weight);
217            }
218        }
219
220        let accumulators = calculate_accumulators_batch(
221            &ProverMultiClaim::new(
222                SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Prefix),
223                openings,
224            ),
225            &weights,
226        );
227
228        // Commit the evaluation to the transcript.
229        challenger.observe_algebra_element(eval);
230        self.virtual_claims.push(Claim {
231            point,
232            eval,
233            data: accumulators,
234        });
235
236        eval
237    }
238
239    /// Finalises preprocessing and returns the residual sumcheck prover.
240    ///
241    /// # Returns
242    ///
243    /// - Residual sumcheck prover over the packed product polynomial.
244    /// - Folding challenges sampled during preprocessing.
245    ///
246    /// # Algorithm
247    ///
248    /// ```text
249    ///     Phase | Action
250    ///     ------+-----------------------------------------------
251    ///       1   | Sample the batching challenge  a.
252    ///       2   | running sum  = sum_{i}  a^i * eval_i.
253    ///       3   | weight poly  = sum_{i}  a^i * eq(z_i, X).
254    ///       4   | Fold round 1 in SIMD-packed arithmetic.
255    ///       5   | Drive rounds 2..folding on the product polynomial.
256    /// ```
257    ///
258    /// # Precondition
259    ///
260    /// - Each table's arity is at least  log_2(W), with W the packing width.
261    /// - Guarantees every per-slot packed accumulation spans a whole packed element.
262    #[tracing::instrument(skip_all)]
263    fn into_sumcheck<Ch>(
264        self,
265        sumcheck_data: &mut SumcheckData<F, EF>,
266        pow_bits: usize,
267        challenger: &mut Ch,
268    ) -> (SumcheckProver<F, EF>, Point<EF>)
269    where
270        Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
271    {
272        // Sanity: preprocessing cannot consume more rounds than the stacked arity.
273        assert!(self.folding <= self.num_variables);
274
275        let alpha: EF = challenger.sample_algebra_element();
276        let n_claims = self.num_claims();
277
278        let mut alphas = alpha.powers();
279        let accumulators: Vec<_> = self
280            .placements
281            .iter()
282            .flat_map(|placement| self.claim_map[placement.idx()].iter())
283            .map(|claim| {
284                let per_claim: Vec<EF> = alphas.by_ref().take(claim.len()).collect();
285                calculate_accumulators_batch(claim, &per_claim)
286            })
287            .collect();
288
289        let mut sum = self.sum(alpha);
290        let mut rs = Vec::new();
291
292        for round_idx in 0..self.folding {
293            let weights = lagrange_weights_01inf_multi(&rs);
294
295            let mut c0 = EF::ZERO;
296            let mut c_inf = EF::ZERO;
297
298            for accs in &accumulators {
299                c0 += dot_product::<EF, _, _>(
300                    accs[round_idx][0].iter().copied(),
301                    weights.iter().copied(),
302                );
303                c_inf += dot_product::<EF, _, _>(
304                    accs[round_idx][1].iter().copied(),
305                    weights.iter().copied(),
306                );
307            }
308
309            for (vc, alpha_i) in self
310                .virtual_claims
311                .iter()
312                .zip(alpha.shifted_powers(alpha.exp_u64(n_claims as u64)))
313            {
314                let vc_accs = &vc.data;
315                c0 += alpha_i
316                    * dot_product::<EF, _, _>(
317                        vc_accs[round_idx][0].iter().copied(),
318                        weights.iter().copied(),
319                    );
320                c_inf += alpha_i
321                    * dot_product::<EF, _, _>(
322                        vc_accs[round_idx][1].iter().copied(),
323                        weights.iter().copied(),
324                    );
325            }
326
327            let r = sumcheck_data.observe_and_sample(challenger, c0, c_inf, pow_bits);
328            sum = extrapolate_01inf(c0, sum - c0, c_inf, r);
329            rs.push(r);
330        }
331
332        let rs = Point::new(rs);
333        let compressed = tracing::info_span!("compress_prefix_to_packed")
334            .in_scope(|| self.poly.compress_prefix_to_packed(&rs, EF::ONE));
335
336        let weights = self.combine_eqs(&rs, alpha).pack::<F, EF>();
337        let prod_poly =
338            ProductPolynomial::<F, EF>::new_packed(VariableOrder::Prefix, compressed, weights);
339        debug_assert_eq!(prod_poly.dot_product(), sum);
340
341        (SumcheckProver::new(prod_poly, sum), rs)
342    }
343
344    /// Returns the total number of concrete openings recorded so far.
345    fn num_claims(&self) -> usize {
346        self.claim_map
347            .iter()
348            .flat_map(|claims| claims.iter().map(ProverMultiClaim::len))
349            .sum()
350    }
351
352    fn strategy() -> LayoutStrategy {
353        LayoutStrategy::new(true, VariableOrder::Prefix)
354    }
355}
356
357impl<F: TwoAdicField, EF: ExtensionField<F>> PrefixProver<F, EF> {
358    /// Computes the batched claimed sum from concrete and virtual openings.
359    ///
360    /// # Identity
361    ///
362    /// ```text
363    ///     sum = sum_{i}  alpha^i * eval_i
364    /// ```
365    ///
366    /// # Alpha ordering
367    ///
368    /// Powers of `alpha` are handed out in insertion order:
369    ///
370    /// - Outer: placements, in the order the witness laid them out.
371    /// - Middle: claims recorded against that placement's source table.
372    /// - Inner: openings inside each claim, in the order they were recorded.
373    ///
374    /// # Virtual claims
375    ///
376    /// - Virtual evaluations continue the same alpha sequence.
377    /// - They start at `alpha^n`, with `n` the total concrete opening count.
378    ///
379    /// # Verifier agreement
380    ///
381    /// The verifier walks its claim registry with the same three-loop order,
382    /// so both sides assign the same `alpha^i` to the same claim point.
383    pub(crate) fn sum(&self, alpha: EF) -> EF {
384        let mut sum = EF::ZERO;
385        let mut alphas = alpha.powers();
386
387        // Concrete openings: three loops, no filter.
388        for placement in &self.placements {
389            for claim in &self.claim_map[placement.idx()] {
390                for opening in claim.openings() {
391                    sum += opening.eval() * alphas.next().unwrap();
392                }
393            }
394        }
395
396        // Virtual claims continue the alpha sequence right after the concrete ones.
397        sum += dot_product::<EF, _, _>(
398            self.virtual_claims.iter().map(Claim::eval),
399            alpha.shifted_powers(alpha.exp_u64(self.num_claims() as u64)),
400        );
401
402        sum
403    }
404
405    /// Builds the residual equality-weight polynomial after the prefix SVO rounds.
406    #[tracing::instrument(skip_all)]
407    pub(crate) fn combine_eqs(&self, rs: &Point<EF>, alpha: EF) -> Poly<EF> {
408        assert_eq!(rs.num_variables(), self.folding);
409        let mut out = Poly::<EF>::zero(self.num_variables - rs.num_variables());
410
411        let mut alphas = alpha.powers();
412
413        for placement in &self.placements {
414            let local_rest_variables =
415                self.num_variables_table(placement.idx()) - rs.num_variables();
416            for claim in &self.claim_map[placement.idx()] {
417                for opening in claim.openings() {
418                    let col = opening.poly_idx().unwrap();
419                    let selector = &placement.selectors()[col];
420                    let mut local = Poly::<EF>::zero(local_rest_variables);
421                    claim
422                        .point()
423                        .accumulate_into(local.as_mut_slice(), rs, alphas.next().unwrap());
424
425                    for (local_idx, &value) in local.as_slice().iter().enumerate() {
426                        let dst = (local_idx << selector.num_variables()) | selector.index();
427                        out.as_mut_slice()[dst] += value;
428                    }
429                }
430            }
431        }
432
433        let mut alpha_i = alpha.exp_u64(self.num_claims() as u64);
434        for claim in &self.virtual_claims {
435            let (svo, rest) = claim.point.split_at(rs.num_variables());
436            let scale = alpha_i * Point::eval_eq(svo.as_slice(), rs.as_slice());
437            SplitEq::new_unpacked(&rest, scale).accumulate_into(out.as_mut_slice(), None);
438            alpha_i *= alpha;
439        }
440
441        out
442    }
443}