1use super::{
2 make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3 ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4 VerificationBuilderImpl,
5};
6use crate::{
7 base::{
8 bit::BitDistribution,
9 commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10 database::{
11 ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12 OwnedTable, Table, TableRef,
13 },
14 map::{IndexMap, IndexSet},
15 math::log2_up,
16 polynomial::{compute_evaluation_vector, MultilinearExtension},
17 proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18 },
19 proof_primitive::sumcheck::SumcheckProof,
20 utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29
30const SETUP_HASH: [u8; 32] = [
31 0xe8, 0x84, 0x0d, 0x8a, 0x41, 0xce, 0x9d, 0x4e, 0x14, 0xe7, 0xba, 0x0e, 0x1b, 0x02, 0x32, 0x24,
32 0x75, 0x13, 0x61, 0x57, 0x73, 0x78, 0x29, 0x1f, 0xcd, 0x3f, 0x0f, 0x05, 0xf0, 0xf7, 0xe8, 0x75,
33]; fn get_index_range<'a>(
40 accessor: &dyn MetadataAccessor,
41 table_refs: impl IntoIterator<Item = &'a TableRef>,
42) -> (usize, usize) {
43 table_refs
44 .into_iter()
45 .map(|table_ref| {
46 let length = accessor.get_length(table_ref);
47 let offset = accessor.get_offset(table_ref);
48 (offset, offset + length)
49 })
50 .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
51 .unwrap_or((0, 1))
53}
54
55#[derive(Clone, Serialize, Deserialize)]
56pub struct FirstRoundMessage<C> {
57 pub range_length: usize,
59 pub post_result_challenge_count: usize,
60 pub chi_evaluation_lengths: Vec<usize>,
62 pub rho_evaluation_lengths: Vec<usize>,
64 pub round_commitments: Vec<C>,
66}
67
68#[derive(Clone, Serialize, Deserialize)]
69pub struct FinalRoundMessage<C> {
70 pub subpolynomial_constraint_count: usize,
71 pub round_commitments: Vec<C>,
73 pub bit_distributions: Vec<BitDistribution>,
75}
76#[derive(Clone, Serialize, Deserialize)]
77pub struct QueryProofPCSProofEvaluations<S> {
78 pub first_round: Vec<S>,
80 pub column_ref: Vec<S>,
82 pub final_round: Vec<S>,
84}
85
86#[derive(Clone, Serialize, Deserialize)]
92pub struct QueryProof<CP: CommitmentEvaluationProof> {
93 pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
94 pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
95 pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
97 pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
98 pub(super) evaluation_proof: CP,
100}
101
102impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
103 #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
105 pub fn new(
106 expr: &(impl ProofPlan + Serialize),
107 accessor: &impl DataAccessor<CP::Scalar>,
108 setup: &CP::ProverPublicSetup<'_>,
109 params: &[LiteralValue],
110 ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
111 log::log_memory_usage("Start");
112
113 let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
114 let initial_range_length = (max_row_num - min_row_num).max(1);
115 let alloc = Bump::new();
116
117 let total_col_refs = expr.get_column_references();
118 let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
119 .get_table_references()
120 .into_iter()
121 .map(|table_ref| {
122 let idents: IndexSet<Ident> = total_col_refs
123 .iter()
124 .filter(|col_ref| col_ref.table_ref() == table_ref)
125 .map(ColumnRef::column_id)
126 .collect();
127 (table_ref.clone(), accessor.get_table(&table_ref, &idents))
128 })
129 .collect();
130
131 let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
133 let query_result =
134 expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
135 let owned_table_result = OwnedTable::from(&query_result);
136 let provable_result = query_result.into();
137 let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
138 let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
139
140 let range_length = first_round_builder.range_length();
141 let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
142 assert!(num_sumcheck_variables > 0);
143 let post_result_challenge_count = first_round_builder.num_post_result_challenges();
144
145 let first_round_commitments =
147 first_round_builder.commit_intermediate_mles(min_row_num, setup);
148
149 let mut transcript: Keccak256Transcript = Transcript::new();
151 transcript.extend_as_le([SETUP_HASH]);
152 transcript.challenge_as_le();
153 transcript.extend_serialize_as_le(expr);
154 transcript.challenge_as_le();
155 transcript.extend_serialize_as_le(&owned_table_result);
156 transcript.challenge_as_le();
157
158 for table in expr.get_table_references() {
159 let length = accessor.get_length(&table);
160 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
161 }
162 transcript.challenge_as_le();
163
164 for commitment in CP::Commitment::compute_commitments(
165 &expr
166 .get_column_references()
167 .into_iter()
168 .map(|col| {
169 CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
170 })
171 .collect_vec(),
172 min_row_num,
173 setup,
174 ) {
175 transcript.extend_serialize_as_le(&commitment);
176 }
177 transcript.challenge_as_le();
178
179 transcript.extend_serialize_as_le(&min_row_num);
180 transcript.challenge_as_le();
181
182 let first_round_message = FirstRoundMessage {
183 range_length,
184 chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
185 rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
186 post_result_challenge_count,
187 round_commitments: first_round_commitments,
188 };
189 transcript.extend_serialize_as_le(&first_round_message);
190
191 let post_result_challenges =
197 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
198 .take(post_result_challenge_count)
199 .collect();
200
201 let mut final_round_builder =
202 FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
203
204 expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
205
206 let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
207
208 let final_round_commitments =
210 final_round_builder.commit_intermediate_mles(min_row_num, setup);
211
212 let final_round_message = FinalRoundMessage {
213 subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
214 round_commitments: final_round_commitments,
215 bit_distributions: final_round_builder.bit_distributions().to_vec(),
216 };
217
218 transcript.challenge_as_le();
220 transcript.extend_serialize_as_le(&final_round_message);
221
222 let num_random_scalars =
224 num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
225 let random_scalars: Vec<_> =
226 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
227 .take(num_random_scalars)
228 .collect();
229 let state = make_sumcheck_prover_state(
230 final_round_builder.sumcheck_subpolynomials(),
231 num_sumcheck_variables,
232 &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
233 );
234 transcript.challenge_as_le();
235
236 let mut evaluation_point = vec![Zero::zero(); state.num_vars];
238 let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
239
240 let mut evaluation_vec = vec![Zero::zero(); range_length];
242 compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
243 let first_round_pcs_proof_evaluations =
244 first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
245 let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
246 .iter()
247 .map(|col_ref| {
248 accessor
249 .get_column(&col_ref.table_ref(), &col_ref.column_id())
250 .inner_product(&evaluation_vec)
251 })
252 .collect();
253 let final_round_pcs_proof_evaluations =
254 final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
255
256 let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
258 first_round: first_round_pcs_proof_evaluations,
259 column_ref: column_ref_pcs_proof_evaluations,
260 final_round: final_round_pcs_proof_evaluations,
261 };
262 transcript.extend_serialize_as_le(&pcs_proof_evaluations);
263
264 let random_scalars: Vec<_> =
267 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
268 .take(
269 pcs_proof_evaluations.first_round.len()
270 + pcs_proof_evaluations.column_ref.len()
271 + pcs_proof_evaluations.final_round.len(),
272 )
273 .collect();
274
275 let mut folded_mle = vec![Zero::zero(); range_length];
276 let column_ref_mles: Vec<_> = total_col_refs
277 .into_iter()
278 .map(|c| {
279 Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
280 as Box<dyn MultilinearExtension<_>>
281 })
282 .collect();
283 for (multiplier, evaluator) in random_scalars.iter().zip(
284 first_round_builder
285 .pcs_proof_mles()
286 .iter()
287 .chain(&column_ref_mles)
288 .chain(final_round_builder.pcs_proof_mles().iter()),
289 ) {
290 evaluator.mul_add(&mut folded_mle, multiplier);
291 }
292
293 let evaluation_proof = CP::new(
295 &mut transcript,
296 &folded_mle,
297 &evaluation_point,
298 min_row_num as u64,
299 setup,
300 );
301
302 let proof = Self {
303 first_round_message,
304 final_round_message,
305 sumcheck_proof,
306 pcs_proof_evaluations,
307 evaluation_proof,
308 };
309
310 log::log_memory_usage("End");
311
312 Ok((proof, provable_result))
313 }
314
315 #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
316 pub fn verify(
318 self,
319 expr: &(impl ProofPlan + Serialize),
320 accessor: &impl CommitmentAccessor<CP::Commitment>,
321 result: OwnedTable<CP::Scalar>,
322 setup: &CP::VerifierPublicSetup<'_>,
323 params: &[LiteralValue],
324 ) -> QueryResult<CP::Scalar> {
325 log::log_memory_usage("Start");
326
327 let table_refs = expr.get_table_references();
328 let (min_row_num, _) = get_index_range(accessor, &table_refs);
329 let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
330 assert!(num_sumcheck_variables > 0);
331
332 for dist in &self.final_round_message.bit_distributions {
334 if !dist.is_valid() {
335 Err(ProofError::VerificationError {
336 error: "invalid bit distributions",
337 })?;
338 } else if !dist.is_within_acceptable_range() {
339 Err(ProofError::VerificationError {
340 error: "bit distribution outside of acceptable range",
341 })?;
342 }
343 }
344
345 let column_references = expr.get_column_references();
346
347 let mut transcript: Keccak256Transcript = Transcript::new();
349 transcript.extend_as_le([SETUP_HASH]);
350 transcript.challenge_as_le();
351 transcript.extend_serialize_as_le(expr);
352 transcript.challenge_as_le();
353 transcript.extend_serialize_as_le(&result);
354 transcript.challenge_as_le();
355
356 for table in expr.get_table_references() {
357 let length = accessor.get_length(&table);
358 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
359 }
360 transcript.challenge_as_le();
361
362 for commitment in expr
363 .get_column_references()
364 .into_iter()
365 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
366 {
367 transcript.extend_serialize_as_le(&commitment);
368 }
369 transcript.challenge_as_le();
370
371 transcript.extend_serialize_as_le(&min_row_num);
372 transcript.challenge_as_le();
373
374 transcript.extend_serialize_as_le(&self.first_round_message);
375
376 let post_result_challenges =
382 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
383 .take(self.first_round_message.post_result_challenge_count)
384 .collect();
385
386 transcript.challenge_as_le();
388 transcript.extend_serialize_as_le(&self.final_round_message);
389
390 let num_random_scalars =
392 num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
393 let random_scalars: Vec<_> =
394 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
395 .take(num_random_scalars)
396 .collect();
397 let sumcheck_random_scalars = SumcheckRandomScalars::new(
398 &random_scalars,
399 self.first_round_message.range_length,
400 num_sumcheck_variables,
401 );
402 transcript.challenge_as_le();
403
404 let subclaim = self.sumcheck_proof.verify_without_evaluation(
406 &mut transcript,
407 num_sumcheck_variables,
408 &Zero::zero(),
409 )?;
410
411 transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
413
414 let evaluation_random_scalars: Vec<_> =
417 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
418 .take(
419 self.pcs_proof_evaluations.first_round.len()
420 + self.pcs_proof_evaluations.column_ref.len()
421 + self.pcs_proof_evaluations.final_round.len(),
422 )
423 .collect();
424
425 let table_length_map = table_refs
427 .into_iter()
428 .map(|table_ref| {
429 let len = accessor.get_length(&table_ref);
430 (table_ref, len)
431 })
432 .collect::<IndexMap<TableRef, usize>>();
433
434 let chi_evaluation_lengths = table_length_map
435 .values()
436 .chain(self.first_round_message.chi_evaluation_lengths.iter())
437 .copied();
438
439 let sumcheck_evaluations = SumcheckMleEvaluations::new(
441 self.first_round_message.range_length,
442 chi_evaluation_lengths,
443 self.first_round_message.rho_evaluation_lengths.clone(),
444 &subclaim.evaluation_point,
445 &sumcheck_random_scalars,
446 &self.pcs_proof_evaluations.first_round,
447 &self.pcs_proof_evaluations.final_round,
448 );
449 let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
450 .into_iter()
451 .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
452 .collect();
453 let mut builder = VerificationBuilderImpl::new(
454 sumcheck_evaluations,
455 &self.final_round_message.bit_distributions,
456 sumcheck_random_scalars.subpolynomial_multipliers,
457 post_result_challenges,
458 self.first_round_message.chi_evaluation_lengths.clone(),
459 self.first_round_message.rho_evaluation_lengths.clone(),
460 subclaim.max_multiplicands,
461 );
462
463 let pcs_proof_commitments: Vec<_> = self
464 .first_round_message
465 .round_commitments
466 .iter()
467 .cloned()
468 .chain(
469 column_references
470 .iter()
471 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
472 )
473 .chain(self.final_round_message.round_commitments.iter().cloned())
474 .collect();
475 let evaluation_accessor: IndexMap<_, _> = column_references
476 .into_iter()
477 .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
478 .chunk_by(|(r, _)| r.table_ref())
479 .into_iter()
480 .map(|(tr, g)| {
481 let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
482 (tr, im)
483 })
484 .collect();
485
486 let verifier_evaluations = expr.verifier_evaluate(
487 &mut builder,
488 &evaluation_accessor,
489 Some(&result),
490 &chi_eval_map,
491 params,
492 )?;
493 let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
495 if verifier_evaluations.column_evals() != result_evaluations {
497 Err(ProofError::VerificationError {
498 error: "result evaluation check failed",
499 })?;
500 }
501
502 if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
504 Err(ProofError::VerificationError {
505 error: "sumcheck evaluation check failed",
506 })?;
507 }
508
509 let pcs_proof_evaluations: Vec<_> = self
510 .pcs_proof_evaluations
511 .first_round
512 .iter()
513 .chain(self.pcs_proof_evaluations.column_ref.iter())
514 .chain(self.pcs_proof_evaluations.final_round.iter())
515 .copied()
516 .collect();
517
518 self.evaluation_proof
520 .verify_batched_proof(
521 &mut transcript,
522 &pcs_proof_commitments,
523 &evaluation_random_scalars,
524 &pcs_proof_evaluations,
525 &subclaim.evaluation_point,
526 min_row_num as u64,
527 self.first_round_message.range_length,
528 setup,
529 )
530 .map_err(|_e| ProofError::VerificationError {
531 error: "Inner product proof of MLE evaluations failed",
532 })?;
533
534 let verification_hash = transcript.challenge_as_le();
535
536 log::log_memory_usage("End");
537
538 Ok(QueryData {
539 table: result,
540 verification_hash,
541 })
542 }
543}