1use super::{
2 make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3 ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4 VerificationBuilderImpl,
5};
6use crate::{
7 base::{
8 bit::BitDistribution,
9 commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10 database::{
11 ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12 OwnedTable, Table, TableRef,
13 },
14 map::{IndexMap, IndexSet},
15 math::log2_up,
16 polynomial::{compute_evaluation_vector, MultilinearExtension},
17 proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18 },
19 proof_primitive::sumcheck::SumcheckProof,
20 utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29
30fn get_index_range<'a>(
35 accessor: &dyn MetadataAccessor,
36 table_refs: impl IntoIterator<Item = &'a TableRef>,
37) -> (usize, usize) {
38 table_refs
39 .into_iter()
40 .map(|table_ref| {
41 let length = accessor.get_length(table_ref);
42 let offset = accessor.get_offset(table_ref);
43 (offset, offset + length)
44 })
45 .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
46 .unwrap_or((0, 1))
48}
49
50#[derive(Clone, Serialize, Deserialize)]
51pub struct FirstRoundMessage<C> {
52 pub range_length: usize,
54 pub post_result_challenge_count: usize,
55 pub chi_evaluation_lengths: Vec<usize>,
57 pub rho_evaluation_lengths: Vec<usize>,
59 pub round_commitments: Vec<C>,
61}
62
63#[derive(Clone, Serialize, Deserialize)]
64pub struct FinalRoundMessage<C> {
65 pub subpolynomial_constraint_count: usize,
66 pub round_commitments: Vec<C>,
68 pub bit_distributions: Vec<BitDistribution>,
70}
71#[derive(Clone, Serialize, Deserialize)]
72pub struct QueryProofPCSProofEvaluations<S> {
73 pub first_round: Vec<S>,
75 pub column_ref: Vec<S>,
77 pub final_round: Vec<S>,
79}
80
81#[derive(Clone, Serialize, Deserialize)]
87pub struct QueryProof<CP: CommitmentEvaluationProof> {
88 pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
89 pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
90 pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
92 pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
93 pub(super) evaluation_proof: CP,
95}
96
97impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
98 #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
100 pub fn new(
101 expr: &(impl ProofPlan + Serialize),
102 accessor: &impl DataAccessor<CP::Scalar>,
103 setup: &CP::ProverPublicSetup<'_>,
104 params: &[LiteralValue],
105 ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
106 log::log_memory_usage("Start");
107
108 let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
109 let initial_range_length = (max_row_num - min_row_num).max(1);
110 let alloc = Bump::new();
111
112 let total_col_refs = expr.get_column_references();
113 let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
114 .get_table_references()
115 .into_iter()
116 .map(|table_ref| {
117 let idents: IndexSet<Ident> = total_col_refs
118 .iter()
119 .filter(|col_ref| col_ref.table_ref() == table_ref)
120 .map(ColumnRef::column_id)
121 .collect();
122 (table_ref.clone(), accessor.get_table(&table_ref, &idents))
123 })
124 .collect();
125
126 let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
128 let query_result =
129 expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
130 let owned_table_result = OwnedTable::from(&query_result);
131 let provable_result = query_result.into();
132 let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
133 let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
134
135 let range_length = first_round_builder.range_length();
136 let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
137 assert!(num_sumcheck_variables > 0);
138 let post_result_challenge_count = first_round_builder.num_post_result_challenges();
139
140 let first_round_commitments =
142 first_round_builder.commit_intermediate_mles(min_row_num, setup);
143
144 let mut transcript: Keccak256Transcript = Transcript::new();
146 transcript.challenge_as_le();
147 transcript.extend_serialize_as_le(expr);
148 transcript.challenge_as_le();
149 transcript.extend_serialize_as_le(&owned_table_result);
150 transcript.challenge_as_le();
151
152 for table in expr.get_table_references() {
153 let length = accessor.get_length(&table);
154 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
155 }
156 transcript.challenge_as_le();
157
158 for commitment in CP::Commitment::compute_commitments(
159 &expr
160 .get_column_references()
161 .into_iter()
162 .map(|col| {
163 CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
164 })
165 .collect_vec(),
166 min_row_num,
167 setup,
168 ) {
169 transcript.extend_serialize_as_le(&commitment);
170 }
171 transcript.challenge_as_le();
172
173 transcript.extend_serialize_as_le(&min_row_num);
174 transcript.challenge_as_le();
175
176 let first_round_message = FirstRoundMessage {
177 range_length,
178 chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
179 rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
180 post_result_challenge_count,
181 round_commitments: first_round_commitments,
182 };
183 transcript.extend_serialize_as_le(&first_round_message);
184
185 let post_result_challenges =
191 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
192 .take(post_result_challenge_count)
193 .collect();
194
195 let mut final_round_builder =
196 FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
197
198 expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
199
200 let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
201
202 let final_round_commitments =
204 final_round_builder.commit_intermediate_mles(min_row_num, setup);
205
206 let final_round_message = FinalRoundMessage {
207 subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
208 round_commitments: final_round_commitments,
209 bit_distributions: final_round_builder.bit_distributions().to_vec(),
210 };
211
212 transcript.challenge_as_le();
214 transcript.extend_serialize_as_le(&final_round_message);
215
216 let num_random_scalars =
218 num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
219 let random_scalars: Vec<_> =
220 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
221 .take(num_random_scalars)
222 .collect();
223 let state = make_sumcheck_prover_state(
224 final_round_builder.sumcheck_subpolynomials(),
225 num_sumcheck_variables,
226 &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
227 );
228 transcript.challenge_as_le();
229
230 let mut evaluation_point = vec![Zero::zero(); state.num_vars];
232 let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
233
234 let mut evaluation_vec = vec![Zero::zero(); range_length];
236 compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
237 let first_round_pcs_proof_evaluations =
238 first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
239 let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
240 .iter()
241 .map(|col_ref| {
242 accessor
243 .get_column(&col_ref.table_ref(), &col_ref.column_id())
244 .inner_product(&evaluation_vec)
245 })
246 .collect();
247 let final_round_pcs_proof_evaluations =
248 final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
249
250 let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
252 first_round: first_round_pcs_proof_evaluations,
253 column_ref: column_ref_pcs_proof_evaluations,
254 final_round: final_round_pcs_proof_evaluations,
255 };
256 transcript.extend_serialize_as_le(&pcs_proof_evaluations);
257
258 let random_scalars: Vec<_> =
261 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
262 .take(
263 pcs_proof_evaluations.first_round.len()
264 + pcs_proof_evaluations.column_ref.len()
265 + pcs_proof_evaluations.final_round.len(),
266 )
267 .collect();
268
269 let mut folded_mle = vec![Zero::zero(); range_length];
270 let column_ref_mles: Vec<_> = total_col_refs
271 .into_iter()
272 .map(|c| {
273 Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
274 as Box<dyn MultilinearExtension<_>>
275 })
276 .collect();
277 for (multiplier, evaluator) in random_scalars.iter().zip(
278 first_round_builder
279 .pcs_proof_mles()
280 .iter()
281 .chain(&column_ref_mles)
282 .chain(final_round_builder.pcs_proof_mles().iter()),
283 ) {
284 evaluator.mul_add(&mut folded_mle, multiplier);
285 }
286
287 let evaluation_proof = CP::new(
289 &mut transcript,
290 &folded_mle,
291 &evaluation_point,
292 min_row_num as u64,
293 setup,
294 );
295
296 let proof = Self {
297 first_round_message,
298 final_round_message,
299 sumcheck_proof,
300 pcs_proof_evaluations,
301 evaluation_proof,
302 };
303
304 log::log_memory_usage("End");
305
306 Ok((proof, provable_result))
307 }
308
309 #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
310 pub fn verify(
312 self,
313 expr: &(impl ProofPlan + Serialize),
314 accessor: &impl CommitmentAccessor<CP::Commitment>,
315 result: OwnedTable<CP::Scalar>,
316 setup: &CP::VerifierPublicSetup<'_>,
317 params: &[LiteralValue],
318 ) -> QueryResult<CP::Scalar> {
319 log::log_memory_usage("Start");
320
321 let table_refs = expr.get_table_references();
322 let (min_row_num, _) = get_index_range(accessor, &table_refs);
323 let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
324 assert!(num_sumcheck_variables > 0);
325
326 for dist in &self.final_round_message.bit_distributions {
328 if !dist.is_valid() {
329 Err(ProofError::VerificationError {
330 error: "invalid bit distributions",
331 })?;
332 } else if !dist.is_within_acceptable_range() {
333 Err(ProofError::VerificationError {
334 error: "bit distribution outside of acceptable range",
335 })?;
336 }
337 }
338
339 let column_references = expr.get_column_references();
340
341 let mut transcript: Keccak256Transcript = Transcript::new();
343 transcript.challenge_as_le();
344 transcript.extend_serialize_as_le(expr);
345 transcript.challenge_as_le();
346 transcript.extend_serialize_as_le(&result);
347 transcript.challenge_as_le();
348
349 for table in expr.get_table_references() {
350 let length = accessor.get_length(&table);
351 transcript.extend_serialize_as_le(&[0, 0, 0, length]);
352 }
353 transcript.challenge_as_le();
354
355 for commitment in expr
356 .get_column_references()
357 .into_iter()
358 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
359 {
360 transcript.extend_serialize_as_le(&commitment);
361 }
362 transcript.challenge_as_le();
363
364 transcript.extend_serialize_as_le(&min_row_num);
365 transcript.challenge_as_le();
366
367 transcript.extend_serialize_as_le(&self.first_round_message);
368
369 let post_result_challenges =
375 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
376 .take(self.first_round_message.post_result_challenge_count)
377 .collect();
378
379 transcript.challenge_as_le();
381 transcript.extend_serialize_as_le(&self.final_round_message);
382
383 let num_random_scalars =
385 num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
386 let random_scalars: Vec<_> =
387 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
388 .take(num_random_scalars)
389 .collect();
390 let sumcheck_random_scalars = SumcheckRandomScalars::new(
391 &random_scalars,
392 self.first_round_message.range_length,
393 num_sumcheck_variables,
394 );
395 transcript.challenge_as_le();
396
397 let subclaim = self.sumcheck_proof.verify_without_evaluation(
399 &mut transcript,
400 num_sumcheck_variables,
401 &Zero::zero(),
402 )?;
403
404 transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
406
407 let evaluation_random_scalars: Vec<_> =
410 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
411 .take(
412 self.pcs_proof_evaluations.first_round.len()
413 + self.pcs_proof_evaluations.column_ref.len()
414 + self.pcs_proof_evaluations.final_round.len(),
415 )
416 .collect();
417
418 let table_length_map = table_refs
420 .into_iter()
421 .map(|table_ref| {
422 let len = accessor.get_length(&table_ref);
423 (table_ref, len)
424 })
425 .collect::<IndexMap<TableRef, usize>>();
426
427 let chi_evaluation_lengths = table_length_map
428 .values()
429 .chain(self.first_round_message.chi_evaluation_lengths.iter())
430 .copied();
431
432 let sumcheck_evaluations = SumcheckMleEvaluations::new(
434 self.first_round_message.range_length,
435 chi_evaluation_lengths,
436 self.first_round_message.rho_evaluation_lengths.clone(),
437 &subclaim.evaluation_point,
438 &sumcheck_random_scalars,
439 &self.pcs_proof_evaluations.first_round,
440 &self.pcs_proof_evaluations.final_round,
441 );
442 let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
443 .into_iter()
444 .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
445 .collect();
446 let mut builder = VerificationBuilderImpl::new(
447 sumcheck_evaluations,
448 &self.final_round_message.bit_distributions,
449 sumcheck_random_scalars.subpolynomial_multipliers,
450 post_result_challenges,
451 self.first_round_message.chi_evaluation_lengths.clone(),
452 self.first_round_message.rho_evaluation_lengths.clone(),
453 subclaim.max_multiplicands,
454 );
455
456 let pcs_proof_commitments: Vec<_> = self
457 .first_round_message
458 .round_commitments
459 .iter()
460 .cloned()
461 .chain(
462 column_references
463 .iter()
464 .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
465 )
466 .chain(self.final_round_message.round_commitments.iter().cloned())
467 .collect();
468 let evaluation_accessor: IndexMap<_, _> = column_references
469 .into_iter()
470 .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
471 .chunk_by(|(r, _)| r.table_ref())
472 .into_iter()
473 .map(|(tr, g)| {
474 let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
475 (tr, im)
476 })
477 .collect();
478
479 let verifier_evaluations = expr.verifier_evaluate(
480 &mut builder,
481 &evaluation_accessor,
482 Some(&result),
483 &chi_eval_map,
484 params,
485 )?;
486 let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
488 if verifier_evaluations.column_evals() != result_evaluations {
490 Err(ProofError::VerificationError {
491 error: "result evaluation check failed",
492 })?;
493 }
494
495 if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
497 Err(ProofError::VerificationError {
498 error: "sumcheck evaluation check failed",
499 })?;
500 }
501
502 let pcs_proof_evaluations: Vec<_> = self
503 .pcs_proof_evaluations
504 .first_round
505 .iter()
506 .chain(self.pcs_proof_evaluations.column_ref.iter())
507 .chain(self.pcs_proof_evaluations.final_round.iter())
508 .copied()
509 .collect();
510
511 self.evaluation_proof
513 .verify_batched_proof(
514 &mut transcript,
515 &pcs_proof_commitments,
516 &evaluation_random_scalars,
517 &pcs_proof_evaluations,
518 &subclaim.evaluation_point,
519 min_row_num as u64,
520 self.first_round_message.range_length,
521 setup,
522 )
523 .map_err(|_e| ProofError::VerificationError {
524 error: "Inner product proof of MLE evaluations failed",
525 })?;
526
527 let verification_hash = transcript.challenge_as_le();
528
529 log::log_memory_usage("End");
530
531 Ok(QueryData {
532 table: result,
533 verification_hash,
534 })
535 }
536}