1use super::{
2 make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3 ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4 VerificationBuilderImpl,
5};
6use crate::{
7 base::{
8 bit::BitDistribution,
9 commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10 database::{
11 ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12 OwnedTable, Table, TableRef,
13 },
14 map::{IndexMap, IndexSet},
15 math::log2_up,
16 polynomial::{compute_evaluation_vector, MultilinearExtension},
17 proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18 },
19 proof_primitive::sumcheck::SumcheckProof,
20 utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29use tracing::{span, Level};
30
31const SETUP_HASH: [u8; 32] = [
32 0xe8, 0x84, 0x0d, 0x8a, 0x41, 0xce, 0x9d, 0x4e, 0x14, 0xe7, 0xba, 0x0e, 0x1b, 0x02, 0x32, 0x24,
33 0x75, 0x13, 0x61, 0x57, 0x73, 0x78, 0x29, 0x1f, 0xcd, 0x3f, 0x0f, 0x05, 0xf0, 0xf7, 0xe8, 0x75,
34]; fn get_index_range<'a>(
41 accessor: &dyn MetadataAccessor,
42 table_refs: impl IntoIterator<Item = &'a TableRef>,
43) -> (usize, usize) {
44 table_refs
45 .into_iter()
46 .map(|table_ref| {
47 let length = accessor.get_length(table_ref);
48 let offset = accessor.get_offset(table_ref);
49 (offset, offset + length)
50 })
51 .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
52 .unwrap_or((0, 1))
54}
55
56#[derive(Clone, Serialize, Deserialize)]
57pub struct FirstRoundMessage<C> {
58 pub range_length: usize,
60 pub post_result_challenge_count: usize,
61 pub chi_evaluation_lengths: Vec<usize>,
63 pub rho_evaluation_lengths: Vec<usize>,
65 pub round_commitments: Vec<C>,
67}
68
69#[derive(Clone, Serialize, Deserialize)]
70pub struct FinalRoundMessage<C> {
71 pub subpolynomial_constraint_count: usize,
72 pub round_commitments: Vec<C>,
74 pub bit_distributions: Vec<BitDistribution>,
76}
77#[derive(Clone, Serialize, Deserialize)]
78pub struct QueryProofPCSProofEvaluations<S> {
79 pub first_round: Vec<S>,
81 pub column_ref: Vec<S>,
83 pub final_round: Vec<S>,
85}
86
87#[derive(Clone, Serialize, Deserialize)]
93pub struct QueryProof<CP: CommitmentEvaluationProof> {
94 pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
95 pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
96 pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
98 pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
99 pub(super) evaluation_proof: CP,
101}
102
103impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
104 #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
106 pub fn new(
107 expr: &(impl ProofPlan + Serialize),
108 accessor: &impl DataAccessor<CP::Scalar>,
109 setup: &CP::ProverPublicSetup<'_>,
110 params: &[LiteralValue],
111 ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
112 log::log_memory_usage("Start");
113
114 let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
115 let initial_range_length = (max_row_num - min_row_num).max(1);
116 let alloc = Bump::new();
117
118 let total_col_refs = expr.get_column_references();
119 let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
120 .get_table_references()
121 .into_iter()
122 .map(|table_ref| {
123 let idents: IndexSet<Ident> = total_col_refs
124 .iter()
125 .filter(|col_ref| col_ref.table_ref() == table_ref)
126 .map(ColumnRef::column_id)
127 .collect();
128 (table_ref.clone(), accessor.get_table(&table_ref, &idents))
129 })
130 .collect();
131
132 let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
134 let query_result =
135 expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
136 let owned_table_result = OwnedTable::from(&query_result);
137 let provable_result = query_result.into();
138 let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
139 let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
140
141 let range_length = first_round_builder.range_length();
142 let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
143 assert!(num_sumcheck_variables > 0);
144 let post_result_challenge_count = first_round_builder.num_post_result_challenges();
145
146 let first_round_commitments =
148 first_round_builder.commit_intermediate_mles(min_row_num, setup);
149
150 let mut transcript: Keccak256Transcript = Transcript::new();
152 transcript.extend_as_le([SETUP_HASH]);
153 transcript.challenge_as_le();
154 transcript.extend_serialize_as_le(expr);
155 transcript.challenge_as_le();
156 transcript.extend_serialize_as_le(&owned_table_result);
157 transcript.challenge_as_le();
158
159 for table in expr.get_table_references() {
160 let length = accessor.get_length(&table);
161 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
162 }
163 transcript.challenge_as_le();
164
165 for commitment in CP::Commitment::compute_commitments(
166 &expr
167 .get_column_references()
168 .into_iter()
169 .map(|col| {
170 CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
171 })
172 .collect_vec(),
173 min_row_num,
174 setup,
175 ) {
176 transcript.extend_serialize_as_le(&commitment);
177 }
178 transcript.challenge_as_le();
179
180 transcript.extend_serialize_as_le(&min_row_num);
181 transcript.challenge_as_le();
182
183 let first_round_message = FirstRoundMessage {
184 range_length,
185 chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
186 rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
187 post_result_challenge_count,
188 round_commitments: first_round_commitments,
189 };
190 transcript.extend_serialize_as_le(&first_round_message);
191
192 let post_result_challenges =
198 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
199 .take(post_result_challenge_count)
200 .collect();
201
202 let mut final_round_builder =
203 FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
204
205 expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
206
207 let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
208
209 let final_round_commitments =
211 final_round_builder.commit_intermediate_mles(min_row_num, setup);
212
213 let final_round_message = FinalRoundMessage {
214 subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
215 round_commitments: final_round_commitments,
216 bit_distributions: final_round_builder.bit_distributions().to_vec(),
217 };
218
219 transcript.challenge_as_le();
221 transcript.extend_serialize_as_le(&final_round_message);
222
223 let num_random_scalars =
225 num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
226 let random_scalars: Vec<_> =
227 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
228 .take(num_random_scalars)
229 .collect();
230 let state = make_sumcheck_prover_state(
231 final_round_builder.sumcheck_subpolynomials(),
232 num_sumcheck_variables,
233 &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
234 );
235 transcript.challenge_as_le();
236
237 let mut evaluation_point = vec![Zero::zero(); state.num_vars];
239 let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
240
241 let mut evaluation_vec = vec![Zero::zero(); range_length];
243 compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
244 let first_round_pcs_proof_evaluations =
245 first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
246 let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
247 .iter()
248 .map(|col_ref| {
249 accessor
250 .get_column(&col_ref.table_ref(), &col_ref.column_id())
251 .inner_product(&evaluation_vec)
252 })
253 .collect();
254 let final_round_pcs_proof_evaluations =
255 final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
256
257 let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
259 first_round: first_round_pcs_proof_evaluations,
260 column_ref: column_ref_pcs_proof_evaluations,
261 final_round: final_round_pcs_proof_evaluations,
262 };
263 transcript.extend_serialize_as_le(&pcs_proof_evaluations);
264
265 let random_scalars: Vec<_> =
268 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
269 .take(
270 pcs_proof_evaluations.first_round.len()
271 + pcs_proof_evaluations.column_ref.len()
272 + pcs_proof_evaluations.final_round.len(),
273 )
274 .collect();
275
276 let column_ref_mles: Vec<_> = total_col_refs
277 .into_iter()
278 .map(|c| {
279 Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
280 as Box<dyn MultilinearExtension<_>>
281 })
282 .collect();
283
284 let span = span!(Level::DEBUG, "QueryProof get folded_mle").entered();
285 let mut folded_mle = vec![Zero::zero(); range_length];
286 for (multiplier, evaluator) in random_scalars.iter().zip(
287 first_round_builder
288 .pcs_proof_mles()
289 .iter()
290 .chain(&column_ref_mles)
291 .chain(final_round_builder.pcs_proof_mles().iter()),
292 ) {
293 evaluator.mul_add(&mut folded_mle, multiplier);
294 }
295 span.exit();
296
297 let evaluation_proof = CP::new(
299 &mut transcript,
300 &folded_mle,
301 &evaluation_point,
302 min_row_num as u64,
303 setup,
304 );
305
306 let proof = Self {
307 first_round_message,
308 final_round_message,
309 sumcheck_proof,
310 pcs_proof_evaluations,
311 evaluation_proof,
312 };
313
314 log::log_memory_usage("End");
315
316 Ok((proof, provable_result))
317 }
318
319 #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
320 pub fn verify(
322 self,
323 expr: &(impl ProofPlan + Serialize),
324 accessor: &impl CommitmentAccessor<CP::Commitment>,
325 result: OwnedTable<CP::Scalar>,
326 setup: &CP::VerifierPublicSetup<'_>,
327 params: &[LiteralValue],
328 ) -> QueryResult<CP::Scalar> {
329 log::log_memory_usage("Start");
330
331 let table_refs = expr.get_table_references();
332 let (min_row_num, _) = get_index_range(accessor, &table_refs);
333 let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
334 assert!(num_sumcheck_variables > 0);
335
336 for dist in &self.final_round_message.bit_distributions {
338 if !dist.is_valid() {
339 Err(ProofError::VerificationError {
340 error: "invalid bit distributions",
341 })?;
342 } else if !dist.is_within_acceptable_range() {
343 Err(ProofError::VerificationError {
344 error: "bit distribution outside of acceptable range",
345 })?;
346 }
347 }
348
349 let column_references = expr.get_column_references();
350
351 let mut transcript: Keccak256Transcript = Transcript::new();
353 transcript.extend_as_le([SETUP_HASH]);
354 transcript.challenge_as_le();
355 transcript.extend_serialize_as_le(expr);
356 transcript.challenge_as_le();
357 transcript.extend_serialize_as_le(&result);
358 transcript.challenge_as_le();
359
360 for table in expr.get_table_references() {
361 let length = accessor.get_length(&table);
362 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
363 }
364 transcript.challenge_as_le();
365
366 for commitment in expr
367 .get_column_references()
368 .into_iter()
369 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
370 {
371 transcript.extend_serialize_as_le(&commitment);
372 }
373 transcript.challenge_as_le();
374
375 transcript.extend_serialize_as_le(&min_row_num);
376 transcript.challenge_as_le();
377
378 transcript.extend_serialize_as_le(&self.first_round_message);
379
380 let post_result_challenges =
386 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
387 .take(self.first_round_message.post_result_challenge_count)
388 .collect();
389
390 transcript.challenge_as_le();
392 transcript.extend_serialize_as_le(&self.final_round_message);
393
394 let num_random_scalars =
396 num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
397 let random_scalars: Vec<_> =
398 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
399 .take(num_random_scalars)
400 .collect();
401 let sumcheck_random_scalars = SumcheckRandomScalars::new(
402 &random_scalars,
403 self.first_round_message.range_length,
404 num_sumcheck_variables,
405 );
406 transcript.challenge_as_le();
407
408 let subclaim = self.sumcheck_proof.verify_without_evaluation(
410 &mut transcript,
411 num_sumcheck_variables,
412 &Zero::zero(),
413 )?;
414
415 transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
417
418 let evaluation_random_scalars: Vec<_> =
421 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
422 .take(
423 self.pcs_proof_evaluations.first_round.len()
424 + self.pcs_proof_evaluations.column_ref.len()
425 + self.pcs_proof_evaluations.final_round.len(),
426 )
427 .collect();
428
429 let table_length_map = table_refs
431 .into_iter()
432 .map(|table_ref| {
433 let len = accessor.get_length(&table_ref);
434 (table_ref, len)
435 })
436 .collect::<IndexMap<TableRef, usize>>();
437
438 let chi_evaluation_lengths = table_length_map
439 .values()
440 .chain(self.first_round_message.chi_evaluation_lengths.iter())
441 .copied();
442
443 let sumcheck_evaluations = SumcheckMleEvaluations::new(
445 self.first_round_message.range_length,
446 chi_evaluation_lengths,
447 self.first_round_message.rho_evaluation_lengths.clone(),
448 &subclaim.evaluation_point,
449 &sumcheck_random_scalars,
450 &self.pcs_proof_evaluations.first_round,
451 &self.pcs_proof_evaluations.final_round,
452 );
453 let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
454 .into_iter()
455 .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
456 .collect();
457 let mut builder = VerificationBuilderImpl::new(
458 sumcheck_evaluations,
459 &self.final_round_message.bit_distributions,
460 sumcheck_random_scalars.subpolynomial_multipliers,
461 post_result_challenges,
462 self.first_round_message.chi_evaluation_lengths.clone(),
463 self.first_round_message.rho_evaluation_lengths.clone(),
464 subclaim.max_multiplicands,
465 );
466
467 let pcs_proof_commitments: Vec<_> = self
468 .first_round_message
469 .round_commitments
470 .iter()
471 .cloned()
472 .chain(
473 column_references
474 .iter()
475 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
476 )
477 .chain(self.final_round_message.round_commitments.iter().cloned())
478 .collect();
479 let evaluation_accessor: IndexMap<_, _> = column_references
480 .into_iter()
481 .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
482 .chunk_by(|(r, _)| r.table_ref())
483 .into_iter()
484 .map(|(tr, g)| {
485 let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
486 (tr, im)
487 })
488 .collect();
489
490 let verifier_evaluations = expr.verifier_evaluate(
491 &mut builder,
492 &evaluation_accessor,
493 Some(&result),
494 &chi_eval_map,
495 params,
496 )?;
497 let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
499 if verifier_evaluations.column_evals() != result_evaluations {
501 Err(ProofError::VerificationError {
502 error: "result evaluation check failed",
503 })?;
504 }
505
506 if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
508 Err(ProofError::VerificationError {
509 error: "sumcheck evaluation check failed",
510 })?;
511 }
512
513 let pcs_proof_evaluations: Vec<_> = self
514 .pcs_proof_evaluations
515 .first_round
516 .iter()
517 .chain(self.pcs_proof_evaluations.column_ref.iter())
518 .chain(self.pcs_proof_evaluations.final_round.iter())
519 .copied()
520 .collect();
521
522 self.evaluation_proof
524 .verify_batched_proof(
525 &mut transcript,
526 &pcs_proof_commitments,
527 &evaluation_random_scalars,
528 &pcs_proof_evaluations,
529 &subclaim.evaluation_point,
530 min_row_num as u64,
531 self.first_round_message.range_length,
532 setup,
533 )
534 .map_err(|_e| ProofError::VerificationError {
535 error: "Inner product proof of MLE evaluations failed",
536 })?;
537
538 let verification_hash = transcript.challenge_as_le();
539
540 log::log_memory_usage("End");
541
542 Ok(QueryData {
543 table: result,
544 verification_hash,
545 })
546 }
547}