1use super::{
2 make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3 ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4 VerificationBuilderImpl,
5};
6use crate::{
7 base::{
8 bit::BitDistribution,
9 commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10 database::{
11 ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12 OwnedTable, Table, TableRef,
13 },
14 map::{IndexMap, IndexSet},
15 math::log2_up,
16 polynomial::{compute_evaluation_vector, MultilinearExtension},
17 proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18 },
19 proof_primitive::sumcheck::SumcheckProof,
20 utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28
29fn get_index_range<'a>(
34 accessor: &dyn MetadataAccessor,
35 table_refs: impl IntoIterator<Item = &'a TableRef>,
36) -> (usize, usize) {
37 table_refs
38 .into_iter()
39 .map(|table_ref| {
40 let length = accessor.get_length(table_ref);
41 let offset = accessor.get_offset(table_ref);
42 (offset, offset + length)
43 })
44 .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
45 .unwrap_or((0, 1))
47}
48
49#[derive(Clone, Serialize, Deserialize)]
50pub struct FirstRoundMessage<C> {
51 pub range_length: usize,
53 pub post_result_challenge_count: usize,
54 pub chi_evaluation_lengths: Vec<usize>,
56 pub rho_evaluation_lengths: Vec<usize>,
58 pub round_commitments: Vec<C>,
60}
61
62#[derive(Clone, Serialize, Deserialize)]
63pub struct FinalRoundMessage<C> {
64 pub subpolynomial_constraint_count: usize,
65 pub round_commitments: Vec<C>,
67 pub bit_distributions: Vec<BitDistribution>,
69}
70#[derive(Clone, Serialize, Deserialize)]
71pub struct QueryProofPCSProofEvaluations<S> {
72 pub first_round: Vec<S>,
74 pub column_ref: Vec<S>,
76 pub final_round: Vec<S>,
78}
79
80#[derive(Clone, Serialize, Deserialize)]
86pub struct QueryProof<CP: CommitmentEvaluationProof> {
87 pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
88 pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
89 pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
91 pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
92 pub(super) evaluation_proof: CP,
94}
95
96impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
97 #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
99 pub fn new(
100 expr: &(impl ProofPlan + Serialize),
101 accessor: &impl DataAccessor<CP::Scalar>,
102 setup: &CP::ProverPublicSetup<'_>,
103 params: &[LiteralValue],
104 ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
105 log::log_memory_usage("Start");
106
107 let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
108 let initial_range_length = (max_row_num - min_row_num).max(1);
109 let alloc = Bump::new();
110
111 let total_col_refs = expr.get_column_references();
112 let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
113 .get_table_references()
114 .into_iter()
115 .map(|table_ref| {
116 let col_refs: IndexSet<ColumnRef> = total_col_refs
117 .iter()
118 .filter(|col_ref| col_ref.table_ref() == table_ref)
119 .cloned()
120 .collect();
121 (table_ref.clone(), accessor.get_table(table_ref, &col_refs))
122 })
123 .collect();
124
125 let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
127 let query_result =
128 expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
129 let owned_table_result = OwnedTable::from(&query_result);
130 let provable_result = query_result.into();
131 let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
132 let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
133
134 let range_length = first_round_builder.range_length();
135 let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
136 assert!(num_sumcheck_variables > 0);
137 let post_result_challenge_count = first_round_builder.num_post_result_challenges();
138
139 let first_round_commitments =
141 first_round_builder.commit_intermediate_mles(min_row_num, setup);
142
143 let mut transcript: Keccak256Transcript = Transcript::new();
145 transcript.challenge_as_le();
146 transcript.extend_serialize_as_le(expr);
147 transcript.challenge_as_le();
148 transcript.extend_serialize_as_le(&owned_table_result);
149 transcript.challenge_as_le();
150
151 for table in expr.get_table_references() {
152 let length = accessor.get_length(&table);
153 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
154 }
155 transcript.challenge_as_le();
156
157 for commitment in CP::Commitment::compute_commitments(
158 &expr
159 .get_column_references()
160 .into_iter()
161 .map(|col| CommittableColumn::from(accessor.get_column(col)))
162 .collect_vec(),
163 min_row_num,
164 setup,
165 ) {
166 transcript.extend_serialize_as_le(&commitment);
167 }
168 transcript.challenge_as_le();
169
170 transcript.extend_serialize_as_le(&min_row_num);
171 transcript.challenge_as_le();
172
173 let first_round_message = FirstRoundMessage {
174 range_length,
175 chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
176 rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
177 post_result_challenge_count,
178 round_commitments: first_round_commitments,
179 };
180 transcript.extend_serialize_as_le(&first_round_message);
181
182 let post_result_challenges =
188 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
189 .take(post_result_challenge_count)
190 .collect();
191
192 let mut final_round_builder =
193 FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
194
195 expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
196
197 let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
198
199 let final_round_commitments =
201 final_round_builder.commit_intermediate_mles(min_row_num, setup);
202
203 let final_round_message = FinalRoundMessage {
204 subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
205 round_commitments: final_round_commitments,
206 bit_distributions: final_round_builder.bit_distributions().to_vec(),
207 };
208
209 transcript.challenge_as_le();
211 transcript.extend_serialize_as_le(&final_round_message);
212
213 let num_random_scalars =
215 num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
216 let random_scalars: Vec<_> =
217 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
218 .take(num_random_scalars)
219 .collect();
220 let state = make_sumcheck_prover_state(
221 final_round_builder.sumcheck_subpolynomials(),
222 num_sumcheck_variables,
223 &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
224 );
225 transcript.challenge_as_le();
226
227 let mut evaluation_point = vec![Zero::zero(); state.num_vars];
229 let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
230
231 let mut evaluation_vec = vec![Zero::zero(); range_length];
233 compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
234 let first_round_pcs_proof_evaluations =
235 first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
236 let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
237 .iter()
238 .map(|col_ref| {
239 accessor
240 .get_column(col_ref.clone())
241 .inner_product(&evaluation_vec)
242 })
243 .collect();
244 let final_round_pcs_proof_evaluations =
245 final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
246
247 let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
249 first_round: first_round_pcs_proof_evaluations,
250 column_ref: column_ref_pcs_proof_evaluations,
251 final_round: final_round_pcs_proof_evaluations,
252 };
253 transcript.extend_serialize_as_le(&pcs_proof_evaluations);
254
255 let random_scalars: Vec<_> =
258 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
259 .take(
260 pcs_proof_evaluations.first_round.len()
261 + pcs_proof_evaluations.column_ref.len()
262 + pcs_proof_evaluations.final_round.len(),
263 )
264 .collect();
265
266 let mut folded_mle = vec![Zero::zero(); range_length];
267 let column_ref_mles: Vec<_> = total_col_refs
268 .into_iter()
269 .map(|c| Box::new(accessor.get_column(c)) as Box<dyn MultilinearExtension<_>>)
270 .collect();
271 for (multiplier, evaluator) in random_scalars.iter().zip(
272 first_round_builder
273 .pcs_proof_mles()
274 .iter()
275 .chain(&column_ref_mles)
276 .chain(final_round_builder.pcs_proof_mles().iter()),
277 ) {
278 evaluator.mul_add(&mut folded_mle, multiplier);
279 }
280
281 let evaluation_proof = CP::new(
283 &mut transcript,
284 &folded_mle,
285 &evaluation_point,
286 min_row_num as u64,
287 setup,
288 );
289
290 let proof = Self {
291 first_round_message,
292 final_round_message,
293 sumcheck_proof,
294 pcs_proof_evaluations,
295 evaluation_proof,
296 };
297
298 log::log_memory_usage("End");
299
300 Ok((proof, provable_result))
301 }
302
303 #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
304 pub fn verify(
306 self,
307 expr: &(impl ProofPlan + Serialize),
308 accessor: &impl CommitmentAccessor<CP::Commitment>,
309 result: OwnedTable<CP::Scalar>,
310 setup: &CP::VerifierPublicSetup<'_>,
311 params: &[LiteralValue],
312 ) -> QueryResult<CP::Scalar> {
313 log::log_memory_usage("Start");
314
315 let table_refs = expr.get_table_references();
316 let (min_row_num, _) = get_index_range(accessor, &table_refs);
317 let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
318 assert!(num_sumcheck_variables > 0);
319
320 for dist in &self.final_round_message.bit_distributions {
322 if !dist.is_valid() {
323 Err(ProofError::VerificationError {
324 error: "invalid bit distributions",
325 })?;
326 } else if !dist.is_within_acceptable_range() {
327 Err(ProofError::VerificationError {
328 error: "bit distribution outside of acceptable range",
329 })?;
330 }
331 }
332
333 let column_references = expr.get_column_references();
334
335 let mut transcript: Keccak256Transcript = Transcript::new();
337 transcript.challenge_as_le();
338 transcript.extend_serialize_as_le(expr);
339 transcript.challenge_as_le();
340 transcript.extend_serialize_as_le(&result);
341 transcript.challenge_as_le();
342
343 for table in expr.get_table_references() {
344 let length = accessor.get_length(&table);
345 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
346 }
347 transcript.challenge_as_le();
348
349 for commitment in expr
350 .get_column_references()
351 .into_iter()
352 .map(|col| accessor.get_commitment(col))
353 {
354 transcript.extend_serialize_as_le(&commitment);
355 }
356 transcript.challenge_as_le();
357
358 transcript.extend_serialize_as_le(&min_row_num);
359 transcript.challenge_as_le();
360
361 transcript.extend_serialize_as_le(&self.first_round_message);
362
363 let post_result_challenges =
369 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
370 .take(self.first_round_message.post_result_challenge_count)
371 .collect();
372
373 transcript.challenge_as_le();
375 transcript.extend_serialize_as_le(&self.final_round_message);
376
377 let num_random_scalars =
379 num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
380 let random_scalars: Vec<_> =
381 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
382 .take(num_random_scalars)
383 .collect();
384 let sumcheck_random_scalars = SumcheckRandomScalars::new(
385 &random_scalars,
386 self.first_round_message.range_length,
387 num_sumcheck_variables,
388 );
389 transcript.challenge_as_le();
390
391 let subclaim = self.sumcheck_proof.verify_without_evaluation(
393 &mut transcript,
394 num_sumcheck_variables,
395 &Zero::zero(),
396 )?;
397
398 transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
400
401 let evaluation_random_scalars: Vec<_> =
404 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
405 .take(
406 self.pcs_proof_evaluations.first_round.len()
407 + self.pcs_proof_evaluations.column_ref.len()
408 + self.pcs_proof_evaluations.final_round.len(),
409 )
410 .collect();
411
412 let table_length_map = table_refs
414 .into_iter()
415 .map(|table_ref| {
416 let len = accessor.get_length(&table_ref);
417 (table_ref, len)
418 })
419 .collect::<IndexMap<TableRef, usize>>();
420
421 let chi_evaluation_lengths = table_length_map
422 .values()
423 .chain(self.first_round_message.chi_evaluation_lengths.iter())
424 .copied();
425
426 let sumcheck_evaluations = SumcheckMleEvaluations::new(
428 self.first_round_message.range_length,
429 chi_evaluation_lengths,
430 self.first_round_message.rho_evaluation_lengths.clone(),
431 &subclaim.evaluation_point,
432 &sumcheck_random_scalars,
433 &self.pcs_proof_evaluations.first_round,
434 &self.pcs_proof_evaluations.final_round,
435 );
436 let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
437 .into_iter()
438 .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
439 .collect();
440 let mut builder = VerificationBuilderImpl::new(
441 sumcheck_evaluations,
442 &self.final_round_message.bit_distributions,
443 sumcheck_random_scalars.subpolynomial_multipliers,
444 post_result_challenges,
445 self.first_round_message.chi_evaluation_lengths.clone(),
446 self.first_round_message.rho_evaluation_lengths.clone(),
447 subclaim.max_multiplicands,
448 );
449
450 let pcs_proof_commitments: Vec<_> = self
451 .first_round_message
452 .round_commitments
453 .iter()
454 .cloned()
455 .chain(
456 column_references
457 .iter()
458 .map(|col| accessor.get_commitment(col.clone())),
459 )
460 .chain(self.final_round_message.round_commitments.iter().cloned())
461 .collect();
462 let evaluation_accessor: IndexMap<_, _> = column_references
463 .into_iter()
464 .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
465 .collect();
466
467 let verifier_evaluations = expr.verifier_evaluate(
468 &mut builder,
469 &evaluation_accessor,
470 Some(&result),
471 &chi_eval_map,
472 params,
473 )?;
474 let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
476 if verifier_evaluations.column_evals() != result_evaluations {
478 Err(ProofError::VerificationError {
479 error: "result evaluation check failed",
480 })?;
481 }
482
483 if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
485 Err(ProofError::VerificationError {
486 error: "sumcheck evaluation check failed",
487 })?;
488 }
489
490 let pcs_proof_evaluations: Vec<_> = self
491 .pcs_proof_evaluations
492 .first_round
493 .iter()
494 .chain(self.pcs_proof_evaluations.column_ref.iter())
495 .chain(self.pcs_proof_evaluations.final_round.iter())
496 .copied()
497 .collect();
498
499 self.evaluation_proof
501 .verify_batched_proof(
502 &mut transcript,
503 &pcs_proof_commitments,
504 &evaluation_random_scalars,
505 &pcs_proof_evaluations,
506 &subclaim.evaluation_point,
507 min_row_num as u64,
508 self.first_round_message.range_length,
509 setup,
510 )
511 .map_err(|_e| ProofError::VerificationError {
512 error: "Inner product proof of MLE evaluations failed",
513 })?;
514
515 let verification_hash = transcript.challenge_as_le();
516
517 log::log_memory_usage("End");
518
519 Ok(QueryData {
520 table: result,
521 verification_hash,
522 })
523 }
524}