1use super::{
2 make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3 ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4 VerificationBuilderImpl,
5};
6use crate::{
7 base::{
8 bit::BitDistribution,
9 commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10 database::{
11 ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12 OwnedTable, Table, TableRef,
13 },
14 map::{IndexMap, IndexSet},
15 math::log2_up,
16 polynomial::{compute_evaluation_vector, MultilinearExtension},
17 proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18 },
19 proof_primitive::sumcheck::SumcheckProof,
20 utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29use tracing::{span, Level};
30
31const SETUP_HASH: [u8; 32] = [
32 0xe8, 0x84, 0x0d, 0x8a, 0x41, 0xce, 0x9d, 0x4e, 0x14, 0xe7, 0xba, 0x0e, 0x1b, 0x02, 0x32, 0x24,
33 0x75, 0x13, 0x61, 0x57, 0x73, 0x78, 0x29, 0x1f, 0xcd, 0x3f, 0x0f, 0x05, 0xf0, 0xf7, 0xe8, 0x75,
34]; fn get_index_range<'a>(
41 accessor: &dyn MetadataAccessor,
42 table_refs: impl IntoIterator<Item = &'a TableRef>,
43) -> (usize, usize) {
44 table_refs
45 .into_iter()
46 .map(|table_ref| {
47 let length = accessor.get_length(table_ref);
48 let offset = accessor.get_offset(table_ref);
49 (offset, offset + length)
50 })
51 .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
52 .unwrap_or((0, 1))
54}
55
56#[derive(Clone, Serialize, Deserialize)]
57pub struct FirstRoundMessage<C> {
58 pub range_length: usize,
60 pub post_result_challenge_count: usize,
61 pub chi_evaluation_lengths: Vec<usize>,
63 pub rho_evaluation_lengths: Vec<usize>,
65 pub round_commitments: Vec<C>,
67}
68
69#[derive(Clone, Serialize, Deserialize)]
70pub struct FinalRoundMessage<C> {
71 pub subpolynomial_constraint_count: usize,
72 pub round_commitments: Vec<C>,
74 pub bit_distributions: Vec<BitDistribution>,
76}
77#[derive(Clone, Serialize, Deserialize)]
78pub struct QueryProofPCSProofEvaluations<S> {
79 pub first_round: Vec<S>,
81 pub column_ref: Vec<S>,
83 pub final_round: Vec<S>,
85}
86
87#[derive(Clone, Serialize, Deserialize)]
93pub struct QueryProof<CP: CommitmentEvaluationProof> {
94 pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
95 pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
96 pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
98 pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
99 pub(super) evaluation_proof: CP,
101}
102
103impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
104 #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
106 #[expect(clippy::too_many_lines)]
107 pub fn new(
108 expr: &(impl ProofPlan + Serialize),
109 accessor: &impl DataAccessor<CP::Scalar>,
110 setup: &CP::ProverPublicSetup<'_>,
111 params: &[LiteralValue],
112 ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
113 log::log_memory_usage("Start");
114
115 let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
116 let initial_range_length = (max_row_num - min_row_num).max(1);
117 let alloc = Bump::new();
118
119 let total_col_refs = expr.get_column_references();
120 let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
121 .get_table_references()
122 .into_iter()
123 .map(|table_ref| {
124 let idents: IndexSet<Ident> = total_col_refs
125 .iter()
126 .filter(|col_ref| col_ref.table_ref() == table_ref)
127 .map(ColumnRef::column_id)
128 .collect();
129 (table_ref.clone(), accessor.get_table(&table_ref, &idents))
130 })
131 .collect();
132
133 let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
135 let query_result =
136 expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
137 let owned_table_result = OwnedTable::from(&query_result);
138 let provable_result = query_result.into();
139 let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
140 let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
141
142 let range_length = first_round_builder.range_length();
143 let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
144 assert!(num_sumcheck_variables > 0);
145 let post_result_challenge_count = first_round_builder.num_post_result_challenges();
146
147 let first_round_commitments =
149 first_round_builder.commit_intermediate_mles(min_row_num, setup);
150
151 let mut transcript: Keccak256Transcript = Transcript::new();
153 transcript.extend_as_le([SETUP_HASH]);
154 transcript.challenge_as_le();
155 transcript.extend_serialize_as_le(expr);
156 transcript.challenge_as_le();
157 transcript.extend_serialize_as_le(&owned_table_result);
158 transcript.challenge_as_le();
159
160 for table in expr.get_table_references() {
161 let length = accessor.get_length(&table);
162 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
163 }
164 transcript.challenge_as_le();
165
166 for commitment in CP::Commitment::compute_commitments(
167 &expr
168 .get_column_references()
169 .into_iter()
170 .map(|col| {
171 CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
172 })
173 .collect_vec(),
174 min_row_num,
175 setup,
176 ) {
177 transcript.extend_serialize_as_le(&commitment);
178 }
179 transcript.challenge_as_le();
180
181 transcript.extend_serialize_as_le(&min_row_num);
182 transcript.challenge_as_le();
183
184 let first_round_message = FirstRoundMessage {
185 range_length,
186 chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
187 rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
188 post_result_challenge_count,
189 round_commitments: first_round_commitments,
190 };
191 transcript.extend_serialize_as_le(&first_round_message);
192
193 let post_result_challenges =
199 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
200 .take(post_result_challenge_count)
201 .collect();
202
203 let mut final_round_builder =
204 FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
205
206 expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
207
208 let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
209
210 let final_round_commitments =
212 final_round_builder.commit_intermediate_mles(min_row_num, setup);
213
214 let final_round_message = FinalRoundMessage {
215 subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
216 round_commitments: final_round_commitments,
217 bit_distributions: final_round_builder.bit_distributions().to_vec(),
218 };
219
220 transcript.challenge_as_le();
222 transcript.extend_serialize_as_le(&final_round_message);
223
224 let num_random_scalars =
226 num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
227 let random_scalars: Vec<_> =
228 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
229 .take(num_random_scalars)
230 .collect();
231 let state = make_sumcheck_prover_state(
232 final_round_builder.sumcheck_subpolynomials(),
233 num_sumcheck_variables,
234 &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
235 );
236 transcript.challenge_as_le();
237
238 let span = span!(Level::DEBUG, "Sumcheck with initialization").entered();
240 let mut evaluation_point = vec![Zero::zero(); state.num_vars];
241 let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
242 span.exit();
243
244 let span = span!(Level::DEBUG, "initialize evaluation_vec").entered();
246 let mut evaluation_vec = vec![Zero::zero(); range_length];
247 span.exit();
248 compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
249 let first_round_pcs_proof_evaluations =
250 first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
251 let span = span!(Level::DEBUG, "initialize column_ref_pcs_proof_evaluations").entered();
252 let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
253 .iter()
254 .map(|col_ref| {
255 accessor
256 .get_column(&col_ref.table_ref(), &col_ref.column_id())
257 .inner_product(&evaluation_vec)
258 })
259 .collect();
260 span.exit();
261 let final_round_pcs_proof_evaluations =
262 final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
263
264 let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
266 first_round: first_round_pcs_proof_evaluations,
267 column_ref: column_ref_pcs_proof_evaluations,
268 final_round: final_round_pcs_proof_evaluations,
269 };
270 transcript.extend_serialize_as_le(&pcs_proof_evaluations);
271
272 let random_scalars: Vec<_> =
275 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
276 .take(
277 pcs_proof_evaluations.first_round.len()
278 + pcs_proof_evaluations.column_ref.len()
279 + pcs_proof_evaluations.final_round.len(),
280 )
281 .collect();
282
283 let column_ref_mles: Vec<_> = total_col_refs
284 .into_iter()
285 .map(|c| {
286 Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
287 as Box<dyn MultilinearExtension<_>>
288 })
289 .collect();
290
291 let span = span!(Level::DEBUG, "QueryProof get folded_mle").entered();
292 let mut folded_mle = vec![Zero::zero(); range_length];
293 for (multiplier, evaluator) in random_scalars.iter().zip(
294 first_round_builder
295 .pcs_proof_mles()
296 .iter()
297 .chain(&column_ref_mles)
298 .chain(final_round_builder.pcs_proof_mles().iter()),
299 ) {
300 evaluator.mul_add(&mut folded_mle, multiplier);
301 }
302 span.exit();
303
304 let evaluation_proof = CP::new(
306 &mut transcript,
307 &folded_mle,
308 &evaluation_point,
309 min_row_num as u64,
310 setup,
311 );
312
313 let proof = Self {
314 first_round_message,
315 final_round_message,
316 sumcheck_proof,
317 pcs_proof_evaluations,
318 evaluation_proof,
319 };
320
321 log::log_memory_usage("End");
322
323 Ok((proof, provable_result))
324 }
325
326 #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
327 #[expect(clippy::too_many_lines)]
328 pub fn verify(
330 self,
331 expr: &(impl ProofPlan + Serialize),
332 accessor: &impl CommitmentAccessor<CP::Commitment>,
333 result: OwnedTable<CP::Scalar>,
334 setup: &CP::VerifierPublicSetup<'_>,
335 params: &[LiteralValue],
336 ) -> QueryResult<CP::Scalar> {
337 log::log_memory_usage("Start");
338
339 let table_refs = expr.get_table_references();
340 let (min_row_num, _) = get_index_range(accessor, &table_refs);
341 let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
342 assert!(num_sumcheck_variables > 0);
343
344 for dist in &self.final_round_message.bit_distributions {
346 if !dist.is_valid() {
347 Err(ProofError::VerificationError {
348 error: "invalid bit distributions",
349 })?;
350 } else if !dist.is_within_acceptable_range() {
351 Err(ProofError::VerificationError {
352 error: "bit distribution outside of acceptable range",
353 })?;
354 }
355 }
356
357 let column_references = expr.get_column_references();
358
359 let mut transcript: Keccak256Transcript = Transcript::new();
361 transcript.extend_as_le([SETUP_HASH]);
362 transcript.challenge_as_le();
363 transcript.extend_serialize_as_le(expr);
364 transcript.challenge_as_le();
365 transcript.extend_serialize_as_le(&result);
366 transcript.challenge_as_le();
367
368 for table in expr.get_table_references() {
369 let length = accessor.get_length(&table);
370 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
371 }
372 transcript.challenge_as_le();
373
374 for commitment in expr
375 .get_column_references()
376 .into_iter()
377 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
378 {
379 transcript.extend_serialize_as_le(&commitment);
380 }
381 transcript.challenge_as_le();
382
383 transcript.extend_serialize_as_le(&min_row_num);
384 transcript.challenge_as_le();
385
386 transcript.extend_serialize_as_le(&self.first_round_message);
387
388 let post_result_challenges =
394 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
395 .take(self.first_round_message.post_result_challenge_count)
396 .collect();
397
398 transcript.challenge_as_le();
400 transcript.extend_serialize_as_le(&self.final_round_message);
401
402 let num_random_scalars =
404 num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
405 let random_scalars: Vec<_> =
406 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
407 .take(num_random_scalars)
408 .collect();
409 let sumcheck_random_scalars = SumcheckRandomScalars::new(
410 &random_scalars,
411 self.first_round_message.range_length,
412 num_sumcheck_variables,
413 );
414 transcript.challenge_as_le();
415
416 let subclaim = self.sumcheck_proof.verify_without_evaluation(
418 &mut transcript,
419 num_sumcheck_variables,
420 &Zero::zero(),
421 )?;
422
423 transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
425
426 let evaluation_random_scalars: Vec<_> =
429 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
430 .take(
431 self.pcs_proof_evaluations.first_round.len()
432 + self.pcs_proof_evaluations.column_ref.len()
433 + self.pcs_proof_evaluations.final_round.len(),
434 )
435 .collect();
436
437 let table_length_map = table_refs
439 .into_iter()
440 .map(|table_ref| {
441 let len = accessor.get_length(&table_ref);
442 (table_ref, len)
443 })
444 .collect::<IndexMap<TableRef, usize>>();
445
446 let chi_evaluation_lengths = table_length_map
447 .values()
448 .chain(self.first_round_message.chi_evaluation_lengths.iter())
449 .copied();
450
451 let sumcheck_evaluations = SumcheckMleEvaluations::new(
453 self.first_round_message.range_length,
454 chi_evaluation_lengths,
455 self.first_round_message.rho_evaluation_lengths.clone(),
456 &subclaim.evaluation_point,
457 &sumcheck_random_scalars,
458 &self.pcs_proof_evaluations.first_round,
459 &self.pcs_proof_evaluations.final_round,
460 );
461 let chi_eval_map: IndexMap<TableRef, (CP::Scalar, usize)> = table_length_map
462 .into_iter()
463 .map(|(table_ref, length)| {
464 (
465 table_ref,
466 (sumcheck_evaluations.chi_evaluations[&length], length),
467 )
468 })
469 .collect();
470 let mut builder = VerificationBuilderImpl::new(
471 sumcheck_evaluations,
472 &self.final_round_message.bit_distributions,
473 sumcheck_random_scalars.subpolynomial_multipliers,
474 post_result_challenges,
475 self.first_round_message.chi_evaluation_lengths.clone(),
476 self.first_round_message.rho_evaluation_lengths.clone(),
477 subclaim.max_multiplicands,
478 );
479
480 let pcs_proof_commitments: Vec<_> = self
481 .first_round_message
482 .round_commitments
483 .iter()
484 .cloned()
485 .chain(
486 column_references
487 .iter()
488 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
489 )
490 .chain(self.final_round_message.round_commitments.iter().cloned())
491 .collect();
492 let evaluation_accessor: IndexMap<_, _> = column_references
493 .into_iter()
494 .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
495 .chunk_by(|(r, _)| r.table_ref())
496 .into_iter()
497 .map(|(tr, g)| {
498 let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
499 (tr, im)
500 })
501 .collect();
502
503 let verifier_evaluations = expr.verifier_evaluate(
504 &mut builder,
505 &evaluation_accessor,
506 Some(&result),
507 &chi_eval_map,
508 params,
509 )?;
510 let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
512 if verifier_evaluations.column_evals() != result_evaluations {
514 Err(ProofError::VerificationError {
515 error: "result evaluation check failed",
516 })?;
517 }
518
519 if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
521 Err(ProofError::VerificationError {
522 error: "sumcheck evaluation check failed",
523 })?;
524 }
525
526 let pcs_proof_evaluations: Vec<_> = self
527 .pcs_proof_evaluations
528 .first_round
529 .iter()
530 .chain(self.pcs_proof_evaluations.column_ref.iter())
531 .chain(self.pcs_proof_evaluations.final_round.iter())
532 .copied()
533 .collect();
534
535 self.evaluation_proof
537 .verify_batched_proof(
538 &mut transcript,
539 &pcs_proof_commitments,
540 &evaluation_random_scalars,
541 &pcs_proof_evaluations,
542 &subclaim.evaluation_point,
543 min_row_num as u64,
544 self.first_round_message.range_length,
545 setup,
546 )
547 .map_err(|_e| ProofError::VerificationError {
548 error: "Inner product proof of MLE evaluations failed",
549 })?;
550
551 let verification_hash = transcript.challenge_as_le();
552
553 log::log_memory_usage("End");
554
555 Ok(QueryData {
556 table: result,
557 verification_hash,
558 })
559 }
560}