p3_sumcheck/layout/prover/
prefix.rs1use alloc::vec::Vec;
4
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_commit::Mmcs;
7use p3_dft::TwoAdicSubgroupDft;
8use p3_field::{ExtensionField, Field, TwoAdicField, dot_product};
9use p3_matrix::dense::DenseMatrix;
10use p3_multilinear_util::point::Point;
11use p3_multilinear_util::poly::Poly;
12use p3_multilinear_util::split_eq::SplitEq;
13
14use crate::commit::commit_base;
15use crate::lagrange::lagrange_weights_01inf_multi;
16use crate::layout::opening::Opening;
17use crate::layout::prover::Layout;
18use crate::layout::witness::{Table, TablePlacement};
19use crate::layout::{LayoutStrategy, ProverMultiClaim, ProverVirtualClaim, Witness};
20use crate::product_polynomial::ProductPolynomial;
21use crate::strategy::{SumcheckProver, VariableOrder};
22use crate::svo::{SvoPoint, calculate_accumulators_batch};
23use crate::{Claim, SumcheckData, extrapolate_01inf};
24
25#[derive(Debug, Clone)]
32pub struct PrefixProver<F: Field, EF: ExtensionField<F>> {
33 pub(crate) tables: Vec<Table<F>>,
35 pub(crate) placements: Vec<TablePlacement>,
37 pub(crate) num_variables: usize,
39 pub(crate) folding: usize,
41 pub(crate) poly: Poly<F>,
43 pub(crate) claim_map: Vec<Vec<ProverMultiClaim<F, EF>>>,
51 pub(crate) virtual_claims: Vec<ProverVirtualClaim<EF>>,
53}
54
55impl<F: TwoAdicField, EF: ExtensionField<F>> Layout<F, EF> for PrefixProver<F, EF> {
56 fn from_witness(witness: Witness<F>) -> Self {
57 let parts = witness.into_parts();
59 let num_tables = parts.tables.len();
61 Self {
62 tables: parts.tables,
63 placements: parts.placements,
64 num_variables: parts.num_variables,
65 folding: parts.folding,
66 poly: parts.poly,
67 claim_map: (0..num_tables).map(|_| Vec::new()).collect(),
68 virtual_claims: Vec::new(),
69 }
70 }
71
72 fn new_witness(tables: Vec<Table<F>>, folding: usize) -> Witness<F> {
73 Witness::new_interleaved(tables, folding)
74 }
75
76 fn commit<Dft, MT, Challenger>(
77 dft: &Dft,
78 mmcs: &MT,
79 challenger: &mut Challenger,
80 witness: Witness<F>,
81 folding: usize,
82 starting_log_inv_rate: usize,
83 ) -> (Self, MT::Commitment, MT::ProverData<DenseMatrix<F>>)
84 where
85 Dft: TwoAdicSubgroupDft<F>,
86 MT: Mmcs<F>,
87 Challenger: CanObserve<MT::Commitment>,
88 {
89 let (root, prover_data) = commit_base(
90 Self::variable_order(),
91 dft,
92 mmcs,
93 challenger,
94 &witness.poly,
95 folding,
96 starting_log_inv_rate,
97 );
98
99 (Self::from_witness(witness), root, prover_data)
100 }
101
102 fn folding(&self) -> usize {
103 self.folding
104 }
105
106 fn num_variables(&self) -> usize {
108 self.num_variables
109 }
110
111 fn num_variables_table(&self, id: usize) -> usize {
113 self.tables[id].num_variables()
114 }
115
116 #[tracing::instrument(skip_all)]
134 fn eval<Ch>(&mut self, table_idx: usize, polys: &[usize], challenger: &mut Ch) -> Vec<EF>
135 where
136 Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
137 {
138 assert!(
140 !polys.is_empty(),
141 "opening schedule must name at least one column"
142 );
143
144 let table = &self.tables[table_idx];
146 let point = Point::expand_from_univariate(
147 challenger.sample_algebra_element(),
148 table.num_variables(),
149 );
150
151 let point = SvoPoint::new_packed(self.folding, &point);
153
154 let (openings, evals): (Vec<_>, Vec<EF>) = polys
156 .iter()
157 .map(|&poly_idx| {
158 let (eval, partial_evals) = point.eval(table.poly(poly_idx));
159 let opening = Opening {
160 poly_idx: Some(poly_idx),
161 eval,
162 data: partial_evals,
163 };
164 (opening, eval)
165 })
166 .unzip();
167
168 challenger.observe_algebra_slice(&evals);
170
171 self.claim_map[table_idx].push(ProverMultiClaim::new(point, openings));
173
174 evals
175 }
176
177 #[tracing::instrument(skip_all)]
185 fn add_virtual_eval<Ch>(&mut self, challenger: &mut Ch) -> EF
186 where
187 Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
188 {
189 let point =
191 Point::expand_from_univariate(challenger.sample_algebra_element(), self.num_variables);
192
193 let mut eval = EF::ZERO;
194 let mut openings = Vec::new();
195 let mut weights = Vec::new();
196
197 for placement in &self.placements {
198 let table = &self.tables[placement.idx()];
199 for (poly_idx, selector) in placement.selectors().iter().enumerate() {
200 let poly = table.poly(poly_idx);
201
202 let (local_part, selector_part) = point.split_at(table.num_variables());
203
204 let weight =
205 Point::eval_eq::<EF>(selector.point().as_slice(), selector_part.as_slice());
206
207 let local_svo = SvoPoint::new_packed(self.folding, &local_part);
208 let (column_eval, partial_evals) = local_svo.eval(poly);
209
210 eval += weight * column_eval;
211 openings.push(Opening {
212 poly_idx: None,
213 eval: column_eval,
214 data: partial_evals,
215 });
216 weights.push(weight);
217 }
218 }
219
220 let accumulators = calculate_accumulators_batch(
221 &ProverMultiClaim::new(
222 SvoPoint::new_unpacked(self.folding, &point, VariableOrder::Prefix),
223 openings,
224 ),
225 &weights,
226 );
227
228 challenger.observe_algebra_element(eval);
230 self.virtual_claims.push(Claim {
231 point,
232 eval,
233 data: accumulators,
234 });
235
236 eval
237 }
238
239 #[tracing::instrument(skip_all)]
263 fn into_sumcheck<Ch>(
264 self,
265 sumcheck_data: &mut SumcheckData<F, EF>,
266 pow_bits: usize,
267 challenger: &mut Ch,
268 ) -> (SumcheckProver<F, EF>, Point<EF>)
269 where
270 Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
271 {
272 assert!(self.folding <= self.num_variables);
274
275 let alpha: EF = challenger.sample_algebra_element();
276 let n_claims = self.num_claims();
277
278 let mut alphas = alpha.powers();
279 let accumulators: Vec<_> = self
280 .placements
281 .iter()
282 .flat_map(|placement| self.claim_map[placement.idx()].iter())
283 .map(|claim| {
284 let per_claim: Vec<EF> = alphas.by_ref().take(claim.len()).collect();
285 calculate_accumulators_batch(claim, &per_claim)
286 })
287 .collect();
288
289 let mut sum = self.sum(alpha);
290 let mut rs = Vec::new();
291
292 for round_idx in 0..self.folding {
293 let weights = lagrange_weights_01inf_multi(&rs);
294
295 let mut c0 = EF::ZERO;
296 let mut c_inf = EF::ZERO;
297
298 for accs in &accumulators {
299 c0 += dot_product::<EF, _, _>(
300 accs[round_idx][0].iter().copied(),
301 weights.iter().copied(),
302 );
303 c_inf += dot_product::<EF, _, _>(
304 accs[round_idx][1].iter().copied(),
305 weights.iter().copied(),
306 );
307 }
308
309 for (vc, alpha_i) in self
310 .virtual_claims
311 .iter()
312 .zip(alpha.shifted_powers(alpha.exp_u64(n_claims as u64)))
313 {
314 let vc_accs = &vc.data;
315 c0 += alpha_i
316 * dot_product::<EF, _, _>(
317 vc_accs[round_idx][0].iter().copied(),
318 weights.iter().copied(),
319 );
320 c_inf += alpha_i
321 * dot_product::<EF, _, _>(
322 vc_accs[round_idx][1].iter().copied(),
323 weights.iter().copied(),
324 );
325 }
326
327 let r = sumcheck_data.observe_and_sample(challenger, c0, c_inf, pow_bits);
328 sum = extrapolate_01inf(c0, sum - c0, c_inf, r);
329 rs.push(r);
330 }
331
332 let rs = Point::new(rs);
333 let compressed = tracing::info_span!("compress_prefix_to_packed")
334 .in_scope(|| self.poly.compress_prefix_to_packed(&rs, EF::ONE));
335
336 let weights = self.combine_eqs(&rs, alpha).pack::<F, EF>();
337 let prod_poly =
338 ProductPolynomial::<F, EF>::new_packed(VariableOrder::Prefix, compressed, weights);
339 debug_assert_eq!(prod_poly.dot_product(), sum);
340
341 (SumcheckProver::new(prod_poly, sum), rs)
342 }
343
344 fn num_claims(&self) -> usize {
346 self.claim_map
347 .iter()
348 .flat_map(|claims| claims.iter().map(ProverMultiClaim::len))
349 .sum()
350 }
351
352 fn strategy() -> LayoutStrategy {
353 LayoutStrategy::new(true, VariableOrder::Prefix)
354 }
355}
356
357impl<F: TwoAdicField, EF: ExtensionField<F>> PrefixProver<F, EF> {
358 pub(crate) fn sum(&self, alpha: EF) -> EF {
384 let mut sum = EF::ZERO;
385 let mut alphas = alpha.powers();
386
387 for placement in &self.placements {
389 for claim in &self.claim_map[placement.idx()] {
390 for opening in claim.openings() {
391 sum += opening.eval() * alphas.next().unwrap();
392 }
393 }
394 }
395
396 sum += dot_product::<EF, _, _>(
398 self.virtual_claims.iter().map(Claim::eval),
399 alpha.shifted_powers(alpha.exp_u64(self.num_claims() as u64)),
400 );
401
402 sum
403 }
404
405 #[tracing::instrument(skip_all)]
407 pub(crate) fn combine_eqs(&self, rs: &Point<EF>, alpha: EF) -> Poly<EF> {
408 assert_eq!(rs.num_variables(), self.folding);
409 let mut out = Poly::<EF>::zero(self.num_variables - rs.num_variables());
410
411 let mut alphas = alpha.powers();
412
413 for placement in &self.placements {
414 let local_rest_variables =
415 self.num_variables_table(placement.idx()) - rs.num_variables();
416 for claim in &self.claim_map[placement.idx()] {
417 for opening in claim.openings() {
418 let col = opening.poly_idx().unwrap();
419 let selector = &placement.selectors()[col];
420 let mut local = Poly::<EF>::zero(local_rest_variables);
421 claim
422 .point()
423 .accumulate_into(local.as_mut_slice(), rs, alphas.next().unwrap());
424
425 for (local_idx, &value) in local.as_slice().iter().enumerate() {
426 let dst = (local_idx << selector.num_variables()) | selector.index();
427 out.as_mut_slice()[dst] += value;
428 }
429 }
430 }
431 }
432
433 let mut alpha_i = alpha.exp_u64(self.num_claims() as u64);
434 for claim in &self.virtual_claims {
435 let (svo, rest) = claim.point.split_at(rs.num_variables());
436 let scale = alpha_i * Point::eval_eq(svo.as_slice(), rs.as_slice());
437 SplitEq::new_unpacked(&rest, scale).accumulate_into(out.as_mut_slice(), None);
438 alpha_i *= alpha;
439 }
440
441 out
442 }
443}