proof_of_sql/sql/proof/
query_proof.rs

1use super::{
2    make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3    ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4    VerificationBuilderImpl,
5};
6use crate::{
7    base::{
8        bit::BitDistribution,
9        commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10        database::{
11            ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12            OwnedTable, Table, TableRef,
13        },
14        map::{IndexMap, IndexSet},
15        math::log2_up,
16        polynomial::{compute_evaluation_vector, MultilinearExtension},
17        proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18    },
19    proof_primitive::sumcheck::SumcheckProof,
20    utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29
30/// Return the row number range of tables referenced in the Query
31///
32/// Basically we are looking for the smallest offset and the largest offset + length
33/// so that we have an index range of the table rows that the query is referencing.
34fn get_index_range<'a>(
35    accessor: &dyn MetadataAccessor,
36    table_refs: impl IntoIterator<Item = &'a TableRef>,
37) -> (usize, usize) {
38    table_refs
39        .into_iter()
40        .map(|table_ref| {
41            let length = accessor.get_length(table_ref);
42            let offset = accessor.get_offset(table_ref);
43            (offset, offset + length)
44        })
45        .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
46        // Only applies to `EmptyExec` where there are no tables
47        .unwrap_or((0, 1))
48}
49
50#[derive(Clone, Serialize, Deserialize)]
51pub struct FirstRoundMessage<C> {
52    /// Length of the range of generators we use
53    pub range_length: usize,
54    pub post_result_challenge_count: usize,
55    /// Chi evaluation lengths
56    pub chi_evaluation_lengths: Vec<usize>,
57    /// Rho evaluation lengths
58    pub rho_evaluation_lengths: Vec<usize>,
59    /// First Round Commitments
60    pub round_commitments: Vec<C>,
61}
62
63#[derive(Clone, Serialize, Deserialize)]
64pub struct FinalRoundMessage<C> {
65    pub subpolynomial_constraint_count: usize,
66    /// Final Round Commitments
67    pub round_commitments: Vec<C>,
68    /// Bit distributions
69    pub bit_distributions: Vec<BitDistribution>,
70}
71#[derive(Clone, Serialize, Deserialize)]
72pub struct QueryProofPCSProofEvaluations<S> {
73    /// MLEs used in first round sumcheck except for the result columns
74    pub first_round: Vec<S>,
75    /// evaluations of the columns referenced in the query
76    pub column_ref: Vec<S>,
77    /// MLEs used in final round sumcheck except for the result columns
78    pub final_round: Vec<S>,
79}
80
81/// The proof for a query.
82///
83/// Note: Because the class is deserialized from untrusted data, it
84/// cannot maintain any invariant on its data members; hence, they are
85/// all public so as to allow for easy manipulation for testing.
86#[derive(Clone, Serialize, Deserialize)]
87pub struct QueryProof<CP: CommitmentEvaluationProof> {
88    pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
89    pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
90    /// Sumcheck Proof
91    pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
92    pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
93    /// Inner product proof of the MLEs' evaluations
94    pub(super) evaluation_proof: CP,
95}
96
97impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
98    /// Create a new `QueryProof`.
99    #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
100    pub fn new(
101        expr: &(impl ProofPlan + Serialize),
102        accessor: &impl DataAccessor<CP::Scalar>,
103        setup: &CP::ProverPublicSetup<'_>,
104        params: &[LiteralValue],
105    ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
106        log::log_memory_usage("Start");
107
108        let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
109        let initial_range_length = (max_row_num - min_row_num).max(1);
110        let alloc = Bump::new();
111
112        let total_col_refs = expr.get_column_references();
113        let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
114            .get_table_references()
115            .into_iter()
116            .map(|table_ref| {
117                let idents: IndexSet<Ident> = total_col_refs
118                    .iter()
119                    .filter(|col_ref| col_ref.table_ref() == table_ref)
120                    .map(ColumnRef::column_id)
121                    .collect();
122                (table_ref.clone(), accessor.get_table(&table_ref, &idents))
123            })
124            .collect();
125
126        // Prover First Round: Evaluate the query && get the right number of post result challenges
127        let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
128        let query_result =
129            expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
130        let owned_table_result = OwnedTable::from(&query_result);
131        let provable_result = query_result.into();
132        let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
133        let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
134
135        let range_length = first_round_builder.range_length();
136        let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
137        assert!(num_sumcheck_variables > 0);
138        let post_result_challenge_count = first_round_builder.num_post_result_challenges();
139
140        // commit to any intermediate MLEs
141        let first_round_commitments =
142            first_round_builder.commit_intermediate_mles(min_row_num, setup);
143
144        // construct a transcript for the proof
145        let mut transcript: Keccak256Transcript = Transcript::new();
146        transcript.challenge_as_le();
147        transcript.extend_serialize_as_le(expr);
148        transcript.challenge_as_le();
149        transcript.extend_serialize_as_le(&owned_table_result);
150        transcript.challenge_as_le();
151
152        for table in expr.get_table_references() {
153            let length = accessor.get_length(&table);
154            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
155        }
156        transcript.challenge_as_le();
157
158        for commitment in CP::Commitment::compute_commitments(
159            &expr
160                .get_column_references()
161                .into_iter()
162                .map(|col| {
163                    CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
164                })
165                .collect_vec(),
166            min_row_num,
167            setup,
168        ) {
169            transcript.extend_serialize_as_le(&commitment);
170        }
171        transcript.challenge_as_le();
172
173        transcript.extend_serialize_as_le(&min_row_num);
174        transcript.challenge_as_le();
175
176        let first_round_message = FirstRoundMessage {
177            range_length,
178            chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
179            rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
180            post_result_challenge_count,
181            round_commitments: first_round_commitments,
182        };
183        transcript.extend_serialize_as_le(&first_round_message);
184
185        // These are the challenges that will be consumed by the proof
186        // Specifically, these are the challenges that the verifier sends to
187        // the prover after the prover sends the result, but before the prover
188        // send commitments to the intermediate witness columns.
189        // Note: the last challenge in the vec is the first one that is consumed.
190        let post_result_challenges =
191            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
192                .take(post_result_challenge_count)
193                .collect();
194
195        let mut final_round_builder =
196            FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
197
198        expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
199
200        let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
201
202        // commit to any intermediate MLEs
203        let final_round_commitments =
204            final_round_builder.commit_intermediate_mles(min_row_num, setup);
205
206        let final_round_message = FinalRoundMessage {
207            subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
208            round_commitments: final_round_commitments,
209            bit_distributions: final_round_builder.bit_distributions().to_vec(),
210        };
211
212        // add the commitments, bit distributions and chi evaluation lengths to the proof
213        transcript.challenge_as_le();
214        transcript.extend_serialize_as_le(&final_round_message);
215
216        // construct the sumcheck polynomial
217        let num_random_scalars =
218            num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
219        let random_scalars: Vec<_> =
220            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
221                .take(num_random_scalars)
222                .collect();
223        let state = make_sumcheck_prover_state(
224            final_round_builder.sumcheck_subpolynomials(),
225            num_sumcheck_variables,
226            &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
227        );
228        transcript.challenge_as_le();
229
230        // create the sumcheck proof -- this is the main part of proving a query
231        let mut evaluation_point = vec![Zero::zero(); state.num_vars];
232        let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
233
234        // evaluate the MLEs used in sumcheck except for the result columns
235        let mut evaluation_vec = vec![Zero::zero(); range_length];
236        compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
237        let first_round_pcs_proof_evaluations =
238            first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
239        let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
240            .iter()
241            .map(|col_ref| {
242                accessor
243                    .get_column(&col_ref.table_ref(), &col_ref.column_id())
244                    .inner_product(&evaluation_vec)
245            })
246            .collect();
247        let final_round_pcs_proof_evaluations =
248            final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
249
250        // commit to the MLE evaluations
251        let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
252            first_round: first_round_pcs_proof_evaluations,
253            column_ref: column_ref_pcs_proof_evaluations,
254            final_round: final_round_pcs_proof_evaluations,
255        };
256        transcript.extend_serialize_as_le(&pcs_proof_evaluations);
257
258        // fold together the pre result MLEs -- this will form the input to an inner product proof
259        // of their evaluations (fold in this context means create a random linear combination)
260        let random_scalars: Vec<_> =
261            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
262                .take(
263                    pcs_proof_evaluations.first_round.len()
264                        + pcs_proof_evaluations.column_ref.len()
265                        + pcs_proof_evaluations.final_round.len(),
266                )
267                .collect();
268
269        let mut folded_mle = vec![Zero::zero(); range_length];
270        let column_ref_mles: Vec<_> = total_col_refs
271            .into_iter()
272            .map(|c| {
273                Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
274                    as Box<dyn MultilinearExtension<_>>
275            })
276            .collect();
277        for (multiplier, evaluator) in random_scalars.iter().zip(
278            first_round_builder
279                .pcs_proof_mles()
280                .iter()
281                .chain(&column_ref_mles)
282                .chain(final_round_builder.pcs_proof_mles().iter()),
283        ) {
284            evaluator.mul_add(&mut folded_mle, multiplier);
285        }
286
287        // finally, form the inner product proof of the MLEs' evaluations
288        let evaluation_proof = CP::new(
289            &mut transcript,
290            &folded_mle,
291            &evaluation_point,
292            min_row_num as u64,
293            setup,
294        );
295
296        let proof = Self {
297            first_round_message,
298            final_round_message,
299            sumcheck_proof,
300            pcs_proof_evaluations,
301            evaluation_proof,
302        };
303
304        log::log_memory_usage("End");
305
306        Ok((proof, provable_result))
307    }
308
309    #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
310    /// Verify a `QueryProof`. Note: This does NOT transform the result!
311    pub fn verify(
312        self,
313        expr: &(impl ProofPlan + Serialize),
314        accessor: &impl CommitmentAccessor<CP::Commitment>,
315        result: OwnedTable<CP::Scalar>,
316        setup: &CP::VerifierPublicSetup<'_>,
317        params: &[LiteralValue],
318    ) -> QueryResult<CP::Scalar> {
319        log::log_memory_usage("Start");
320
321        let table_refs = expr.get_table_references();
322        let (min_row_num, _) = get_index_range(accessor, &table_refs);
323        let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
324        assert!(num_sumcheck_variables > 0);
325
326        // validate bit decompositions
327        for dist in &self.final_round_message.bit_distributions {
328            if !dist.is_valid() {
329                Err(ProofError::VerificationError {
330                    error: "invalid bit distributions",
331                })?;
332            } else if !dist.is_within_acceptable_range() {
333                Err(ProofError::VerificationError {
334                    error: "bit distribution outside of acceptable range",
335                })?;
336            }
337        }
338
339        let column_references = expr.get_column_references();
340
341        // construct a transcript for the proof
342        let mut transcript: Keccak256Transcript = Transcript::new();
343        transcript.challenge_as_le();
344        transcript.extend_serialize_as_le(expr);
345        transcript.challenge_as_le();
346        transcript.extend_serialize_as_le(&result);
347        transcript.challenge_as_le();
348
349        for table in expr.get_table_references() {
350            let length = accessor.get_length(&table);
351            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
352        }
353        transcript.challenge_as_le();
354
355        for commitment in expr
356            .get_column_references()
357            .into_iter()
358            .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
359        {
360            transcript.extend_serialize_as_le(&commitment);
361        }
362        transcript.challenge_as_le();
363
364        transcript.extend_serialize_as_le(&min_row_num);
365        transcript.challenge_as_le();
366
367        transcript.extend_serialize_as_le(&self.first_round_message);
368
369        // These are the challenges that will be consumed by the proof
370        // Specifically, these are the challenges that the verifier sends to
371        // the prover after the prover sends the result, but before the prover
372        // send commitments to the intermediate witness columns.
373        // Note: the last challenge in the vec is the first one that is consumed.
374        let post_result_challenges =
375            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
376                .take(self.first_round_message.post_result_challenge_count)
377                .collect();
378
379        // add the commitments and bit distributions to the proof
380        transcript.challenge_as_le();
381        transcript.extend_serialize_as_le(&self.final_round_message);
382
383        // draw the random scalars for sumcheck
384        let num_random_scalars =
385            num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
386        let random_scalars: Vec<_> =
387            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
388                .take(num_random_scalars)
389                .collect();
390        let sumcheck_random_scalars = SumcheckRandomScalars::new(
391            &random_scalars,
392            self.first_round_message.range_length,
393            num_sumcheck_variables,
394        );
395        transcript.challenge_as_le();
396
397        // verify sumcheck up to the evaluation check
398        let subclaim = self.sumcheck_proof.verify_without_evaluation(
399            &mut transcript,
400            num_sumcheck_variables,
401            &Zero::zero(),
402        )?;
403
404        // commit to mle evaluations
405        transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
406
407        // draw the random scalars for the evaluation proof
408        // (i.e. the folding/random linear combination of the pcs_proof_mles)
409        let evaluation_random_scalars: Vec<_> =
410            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
411                .take(
412                    self.pcs_proof_evaluations.first_round.len()
413                        + self.pcs_proof_evaluations.column_ref.len()
414                        + self.pcs_proof_evaluations.final_round.len(),
415                )
416                .collect();
417
418        // Always prepend input lengths to the chi evaluation lengths
419        let table_length_map = table_refs
420            .into_iter()
421            .map(|table_ref| {
422                let len = accessor.get_length(&table_ref);
423                (table_ref, len)
424            })
425            .collect::<IndexMap<TableRef, usize>>();
426
427        let chi_evaluation_lengths = table_length_map
428            .values()
429            .chain(self.first_round_message.chi_evaluation_lengths.iter())
430            .copied();
431
432        // pass over the provable AST to fill in the verification builder
433        let sumcheck_evaluations = SumcheckMleEvaluations::new(
434            self.first_round_message.range_length,
435            chi_evaluation_lengths,
436            self.first_round_message.rho_evaluation_lengths.clone(),
437            &subclaim.evaluation_point,
438            &sumcheck_random_scalars,
439            &self.pcs_proof_evaluations.first_round,
440            &self.pcs_proof_evaluations.final_round,
441        );
442        let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
443            .into_iter()
444            .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
445            .collect();
446        let mut builder = VerificationBuilderImpl::new(
447            sumcheck_evaluations,
448            &self.final_round_message.bit_distributions,
449            sumcheck_random_scalars.subpolynomial_multipliers,
450            post_result_challenges,
451            self.first_round_message.chi_evaluation_lengths.clone(),
452            self.first_round_message.rho_evaluation_lengths.clone(),
453            subclaim.max_multiplicands,
454        );
455
456        let pcs_proof_commitments: Vec<_> = self
457            .first_round_message
458            .round_commitments
459            .iter()
460            .cloned()
461            .chain(
462                column_references
463                    .iter()
464                    .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
465            )
466            .chain(self.final_round_message.round_commitments.iter().cloned())
467            .collect();
468        let evaluation_accessor: IndexMap<_, _> = column_references
469            .into_iter()
470            .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
471            .chunk_by(|(r, _)| r.table_ref())
472            .into_iter()
473            .map(|(tr, g)| {
474                let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
475                (tr, im)
476            })
477            .collect();
478
479        let verifier_evaluations = expr.verifier_evaluate(
480            &mut builder,
481            &evaluation_accessor,
482            Some(&result),
483            &chi_eval_map,
484            params,
485        )?;
486        // compute the evaluation of the result MLEs
487        let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
488        // check the evaluation of the result MLEs
489        if verifier_evaluations.column_evals() != result_evaluations {
490            Err(ProofError::VerificationError {
491                error: "result evaluation check failed",
492            })?;
493        }
494
495        // perform the evaluation check of the sumcheck polynomial
496        if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
497            Err(ProofError::VerificationError {
498                error: "sumcheck evaluation check failed",
499            })?;
500        }
501
502        let pcs_proof_evaluations: Vec<_> = self
503            .pcs_proof_evaluations
504            .first_round
505            .iter()
506            .chain(self.pcs_proof_evaluations.column_ref.iter())
507            .chain(self.pcs_proof_evaluations.final_round.iter())
508            .copied()
509            .collect();
510
511        // finally, check the MLE evaluations with the inner product proof
512        self.evaluation_proof
513            .verify_batched_proof(
514                &mut transcript,
515                &pcs_proof_commitments,
516                &evaluation_random_scalars,
517                &pcs_proof_evaluations,
518                &subclaim.evaluation_point,
519                min_row_num as u64,
520                self.first_round_message.range_length,
521                setup,
522            )
523            .map_err(|_e| ProofError::VerificationError {
524                error: "Inner product proof of MLE evaluations failed",
525            })?;
526
527        let verification_hash = transcript.challenge_as_le();
528
529        log::log_memory_usage("End");
530
531        Ok(QueryData {
532            table: result,
533            verification_hash,
534        })
535    }
536}