proof_of_sql/sql/proof/
query_proof.rs

1use super::{
2    make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3    ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4    VerificationBuilderImpl,
5};
6use crate::{
7    base::{
8        bit::BitDistribution,
9        commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10        database::{
11            ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12            OwnedTable, Table, TableRef,
13        },
14        map::{IndexMap, IndexSet},
15        math::log2_up,
16        polynomial::{compute_evaluation_vector, MultilinearExtension},
17        proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18    },
19    proof_primitive::sumcheck::SumcheckProof,
20    utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29use tracing::{span, Level};
30
31const SETUP_HASH: [u8; 32] = [
32    0xe8, 0x84, 0x0d, 0x8a, 0x41, 0xce, 0x9d, 0x4e, 0x14, 0xe7, 0xba, 0x0e, 0x1b, 0x02, 0x32, 0x24,
33    0x75, 0x13, 0x61, 0x57, 0x73, 0x78, 0x29, 0x1f, 0xcd, 0x3f, 0x0f, 0x05, 0xf0, 0xf7, 0xe8, 0x75,
34]; // TODO: make this different for each setup
35
36/// Return the row number range of tables referenced in the Query
37///
38/// Basically we are looking for the smallest offset and the largest offset + length
39/// so that we have an index range of the table rows that the query is referencing.
40fn get_index_range<'a>(
41    accessor: &dyn MetadataAccessor,
42    table_refs: impl IntoIterator<Item = &'a TableRef>,
43) -> (usize, usize) {
44    table_refs
45        .into_iter()
46        .map(|table_ref| {
47            let length = accessor.get_length(table_ref);
48            let offset = accessor.get_offset(table_ref);
49            (offset, offset + length)
50        })
51        .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
52        // Only applies to `EmptyExec` where there are no tables
53        .unwrap_or((0, 1))
54}
55
56#[derive(Clone, Serialize, Deserialize)]
57pub struct FirstRoundMessage<C> {
58    /// Length of the range of generators we use
59    pub range_length: usize,
60    pub post_result_challenge_count: usize,
61    /// Chi evaluation lengths
62    pub chi_evaluation_lengths: Vec<usize>,
63    /// Rho evaluation lengths
64    pub rho_evaluation_lengths: Vec<usize>,
65    /// First Round Commitments
66    pub round_commitments: Vec<C>,
67}
68
69#[derive(Clone, Serialize, Deserialize)]
70pub struct FinalRoundMessage<C> {
71    pub subpolynomial_constraint_count: usize,
72    /// Final Round Commitments
73    pub round_commitments: Vec<C>,
74    /// Bit distributions
75    pub bit_distributions: Vec<BitDistribution>,
76}
77#[derive(Clone, Serialize, Deserialize)]
78pub struct QueryProofPCSProofEvaluations<S> {
79    /// MLEs used in first round sumcheck except for the result columns
80    pub first_round: Vec<S>,
81    /// evaluations of the columns referenced in the query
82    pub column_ref: Vec<S>,
83    /// MLEs used in final round sumcheck except for the result columns
84    pub final_round: Vec<S>,
85}
86
87/// The proof for a query.
88///
89/// Note: Because the class is deserialized from untrusted data, it
90/// cannot maintain any invariant on its data members; hence, they are
91/// all public so as to allow for easy manipulation for testing.
92#[derive(Clone, Serialize, Deserialize)]
93pub struct QueryProof<CP: CommitmentEvaluationProof> {
94    pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
95    pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
96    /// Sumcheck Proof
97    pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
98    pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
99    /// Inner product proof of the MLEs' evaluations
100    pub(super) evaluation_proof: CP,
101}
102
103impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
104    /// Create a new `QueryProof`.
105    #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
106    pub fn new(
107        expr: &(impl ProofPlan + Serialize),
108        accessor: &impl DataAccessor<CP::Scalar>,
109        setup: &CP::ProverPublicSetup<'_>,
110        params: &[LiteralValue],
111    ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
112        log::log_memory_usage("Start");
113
114        let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
115        let initial_range_length = (max_row_num - min_row_num).max(1);
116        let alloc = Bump::new();
117
118        let total_col_refs = expr.get_column_references();
119        let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
120            .get_table_references()
121            .into_iter()
122            .map(|table_ref| {
123                let idents: IndexSet<Ident> = total_col_refs
124                    .iter()
125                    .filter(|col_ref| col_ref.table_ref() == table_ref)
126                    .map(ColumnRef::column_id)
127                    .collect();
128                (table_ref.clone(), accessor.get_table(&table_ref, &idents))
129            })
130            .collect();
131
132        // Prover First Round: Evaluate the query && get the right number of post result challenges
133        let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
134        let query_result =
135            expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
136        let owned_table_result = OwnedTable::from(&query_result);
137        let provable_result = query_result.into();
138        let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
139        let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
140
141        let range_length = first_round_builder.range_length();
142        let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
143        assert!(num_sumcheck_variables > 0);
144        let post_result_challenge_count = first_round_builder.num_post_result_challenges();
145
146        // commit to any intermediate MLEs
147        let first_round_commitments =
148            first_round_builder.commit_intermediate_mles(min_row_num, setup);
149
150        // construct a transcript for the proof
151        let mut transcript: Keccak256Transcript = Transcript::new();
152        transcript.extend_as_le([SETUP_HASH]);
153        transcript.challenge_as_le();
154        transcript.extend_serialize_as_le(expr);
155        transcript.challenge_as_le();
156        transcript.extend_serialize_as_le(&owned_table_result);
157        transcript.challenge_as_le();
158
159        for table in expr.get_table_references() {
160            let length = accessor.get_length(&table);
161            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
162        }
163        transcript.challenge_as_le();
164
165        for commitment in CP::Commitment::compute_commitments(
166            &expr
167                .get_column_references()
168                .into_iter()
169                .map(|col| {
170                    CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
171                })
172                .collect_vec(),
173            min_row_num,
174            setup,
175        ) {
176            transcript.extend_serialize_as_le(&commitment);
177        }
178        transcript.challenge_as_le();
179
180        transcript.extend_serialize_as_le(&min_row_num);
181        transcript.challenge_as_le();
182
183        let first_round_message = FirstRoundMessage {
184            range_length,
185            chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
186            rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
187            post_result_challenge_count,
188            round_commitments: first_round_commitments,
189        };
190        transcript.extend_serialize_as_le(&first_round_message);
191
192        // These are the challenges that will be consumed by the proof
193        // Specifically, these are the challenges that the verifier sends to
194        // the prover after the prover sends the result, but before the prover
195        // send commitments to the intermediate witness columns.
196        // Note: the last challenge in the vec is the first one that is consumed.
197        let post_result_challenges =
198            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
199                .take(post_result_challenge_count)
200                .collect();
201
202        let mut final_round_builder =
203            FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
204
205        expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
206
207        let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
208
209        // commit to any intermediate MLEs
210        let final_round_commitments =
211            final_round_builder.commit_intermediate_mles(min_row_num, setup);
212
213        let final_round_message = FinalRoundMessage {
214            subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
215            round_commitments: final_round_commitments,
216            bit_distributions: final_round_builder.bit_distributions().to_vec(),
217        };
218
219        // add the commitments, bit distributions and chi evaluation lengths to the proof
220        transcript.challenge_as_le();
221        transcript.extend_serialize_as_le(&final_round_message);
222
223        // construct the sumcheck polynomial
224        let num_random_scalars =
225            num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
226        let random_scalars: Vec<_> =
227            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
228                .take(num_random_scalars)
229                .collect();
230        let state = make_sumcheck_prover_state(
231            final_round_builder.sumcheck_subpolynomials(),
232            num_sumcheck_variables,
233            &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
234        );
235        transcript.challenge_as_le();
236
237        // create the sumcheck proof -- this is the main part of proving a query
238        let mut evaluation_point = vec![Zero::zero(); state.num_vars];
239        let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
240
241        // evaluate the MLEs used in sumcheck except for the result columns
242        let mut evaluation_vec = vec![Zero::zero(); range_length];
243        compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
244        let first_round_pcs_proof_evaluations =
245            first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
246        let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
247            .iter()
248            .map(|col_ref| {
249                accessor
250                    .get_column(&col_ref.table_ref(), &col_ref.column_id())
251                    .inner_product(&evaluation_vec)
252            })
253            .collect();
254        let final_round_pcs_proof_evaluations =
255            final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
256
257        // commit to the MLE evaluations
258        let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
259            first_round: first_round_pcs_proof_evaluations,
260            column_ref: column_ref_pcs_proof_evaluations,
261            final_round: final_round_pcs_proof_evaluations,
262        };
263        transcript.extend_serialize_as_le(&pcs_proof_evaluations);
264
265        // fold together the pre result MLEs -- this will form the input to an inner product proof
266        // of their evaluations (fold in this context means create a random linear combination)
267        let random_scalars: Vec<_> =
268            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
269                .take(
270                    pcs_proof_evaluations.first_round.len()
271                        + pcs_proof_evaluations.column_ref.len()
272                        + pcs_proof_evaluations.final_round.len(),
273                )
274                .collect();
275
276        let column_ref_mles: Vec<_> = total_col_refs
277            .into_iter()
278            .map(|c| {
279                Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
280                    as Box<dyn MultilinearExtension<_>>
281            })
282            .collect();
283
284        let span = span!(Level::DEBUG, "QueryProof get folded_mle").entered();
285        let mut folded_mle = vec![Zero::zero(); range_length];
286        for (multiplier, evaluator) in random_scalars.iter().zip(
287            first_round_builder
288                .pcs_proof_mles()
289                .iter()
290                .chain(&column_ref_mles)
291                .chain(final_round_builder.pcs_proof_mles().iter()),
292        ) {
293            evaluator.mul_add(&mut folded_mle, multiplier);
294        }
295        span.exit();
296
297        // finally, form the inner product proof of the MLEs' evaluations
298        let evaluation_proof = CP::new(
299            &mut transcript,
300            &folded_mle,
301            &evaluation_point,
302            min_row_num as u64,
303            setup,
304        );
305
306        let proof = Self {
307            first_round_message,
308            final_round_message,
309            sumcheck_proof,
310            pcs_proof_evaluations,
311            evaluation_proof,
312        };
313
314        log::log_memory_usage("End");
315
316        Ok((proof, provable_result))
317    }
318
319    #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
320    /// Verify a `QueryProof`. Note: This does NOT transform the result!
321    pub fn verify(
322        self,
323        expr: &(impl ProofPlan + Serialize),
324        accessor: &impl CommitmentAccessor<CP::Commitment>,
325        result: OwnedTable<CP::Scalar>,
326        setup: &CP::VerifierPublicSetup<'_>,
327        params: &[LiteralValue],
328    ) -> QueryResult<CP::Scalar> {
329        log::log_memory_usage("Start");
330
331        let table_refs = expr.get_table_references();
332        let (min_row_num, _) = get_index_range(accessor, &table_refs);
333        let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
334        assert!(num_sumcheck_variables > 0);
335
336        // validate bit decompositions
337        for dist in &self.final_round_message.bit_distributions {
338            if !dist.is_valid() {
339                Err(ProofError::VerificationError {
340                    error: "invalid bit distributions",
341                })?;
342            } else if !dist.is_within_acceptable_range() {
343                Err(ProofError::VerificationError {
344                    error: "bit distribution outside of acceptable range",
345                })?;
346            }
347        }
348
349        let column_references = expr.get_column_references();
350
351        // construct a transcript for the proof
352        let mut transcript: Keccak256Transcript = Transcript::new();
353        transcript.extend_as_le([SETUP_HASH]);
354        transcript.challenge_as_le();
355        transcript.extend_serialize_as_le(expr);
356        transcript.challenge_as_le();
357        transcript.extend_serialize_as_le(&result);
358        transcript.challenge_as_le();
359
360        for table in expr.get_table_references() {
361            let length = accessor.get_length(&table);
362            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
363        }
364        transcript.challenge_as_le();
365
366        for commitment in expr
367            .get_column_references()
368            .into_iter()
369            .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
370        {
371            transcript.extend_serialize_as_le(&commitment);
372        }
373        transcript.challenge_as_le();
374
375        transcript.extend_serialize_as_le(&min_row_num);
376        transcript.challenge_as_le();
377
378        transcript.extend_serialize_as_le(&self.first_round_message);
379
380        // These are the challenges that will be consumed by the proof
381        // Specifically, these are the challenges that the verifier sends to
382        // the prover after the prover sends the result, but before the prover
383        // send commitments to the intermediate witness columns.
384        // Note: the last challenge in the vec is the first one that is consumed.
385        let post_result_challenges =
386            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
387                .take(self.first_round_message.post_result_challenge_count)
388                .collect();
389
390        // add the commitments and bit distributions to the proof
391        transcript.challenge_as_le();
392        transcript.extend_serialize_as_le(&self.final_round_message);
393
394        // draw the random scalars for sumcheck
395        let num_random_scalars =
396            num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
397        let random_scalars: Vec<_> =
398            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
399                .take(num_random_scalars)
400                .collect();
401        let sumcheck_random_scalars = SumcheckRandomScalars::new(
402            &random_scalars,
403            self.first_round_message.range_length,
404            num_sumcheck_variables,
405        );
406        transcript.challenge_as_le();
407
408        // verify sumcheck up to the evaluation check
409        let subclaim = self.sumcheck_proof.verify_without_evaluation(
410            &mut transcript,
411            num_sumcheck_variables,
412            &Zero::zero(),
413        )?;
414
415        // commit to mle evaluations
416        transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
417
418        // draw the random scalars for the evaluation proof
419        // (i.e. the folding/random linear combination of the pcs_proof_mles)
420        let evaluation_random_scalars: Vec<_> =
421            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
422                .take(
423                    self.pcs_proof_evaluations.first_round.len()
424                        + self.pcs_proof_evaluations.column_ref.len()
425                        + self.pcs_proof_evaluations.final_round.len(),
426                )
427                .collect();
428
429        // Always prepend input lengths to the chi evaluation lengths
430        let table_length_map = table_refs
431            .into_iter()
432            .map(|table_ref| {
433                let len = accessor.get_length(&table_ref);
434                (table_ref, len)
435            })
436            .collect::<IndexMap<TableRef, usize>>();
437
438        let chi_evaluation_lengths = table_length_map
439            .values()
440            .chain(self.first_round_message.chi_evaluation_lengths.iter())
441            .copied();
442
443        // pass over the provable AST to fill in the verification builder
444        let sumcheck_evaluations = SumcheckMleEvaluations::new(
445            self.first_round_message.range_length,
446            chi_evaluation_lengths,
447            self.first_round_message.rho_evaluation_lengths.clone(),
448            &subclaim.evaluation_point,
449            &sumcheck_random_scalars,
450            &self.pcs_proof_evaluations.first_round,
451            &self.pcs_proof_evaluations.final_round,
452        );
453        let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
454            .into_iter()
455            .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
456            .collect();
457        let mut builder = VerificationBuilderImpl::new(
458            sumcheck_evaluations,
459            &self.final_round_message.bit_distributions,
460            sumcheck_random_scalars.subpolynomial_multipliers,
461            post_result_challenges,
462            self.first_round_message.chi_evaluation_lengths.clone(),
463            self.first_round_message.rho_evaluation_lengths.clone(),
464            subclaim.max_multiplicands,
465        );
466
467        let pcs_proof_commitments: Vec<_> = self
468            .first_round_message
469            .round_commitments
470            .iter()
471            .cloned()
472            .chain(
473                column_references
474                    .iter()
475                    .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
476            )
477            .chain(self.final_round_message.round_commitments.iter().cloned())
478            .collect();
479        let evaluation_accessor: IndexMap<_, _> = column_references
480            .into_iter()
481            .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
482            .chunk_by(|(r, _)| r.table_ref())
483            .into_iter()
484            .map(|(tr, g)| {
485                let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
486                (tr, im)
487            })
488            .collect();
489
490        let verifier_evaluations = expr.verifier_evaluate(
491            &mut builder,
492            &evaluation_accessor,
493            Some(&result),
494            &chi_eval_map,
495            params,
496        )?;
497        // compute the evaluation of the result MLEs
498        let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
499        // check the evaluation of the result MLEs
500        if verifier_evaluations.column_evals() != result_evaluations {
501            Err(ProofError::VerificationError {
502                error: "result evaluation check failed",
503            })?;
504        }
505
506        // perform the evaluation check of the sumcheck polynomial
507        if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
508            Err(ProofError::VerificationError {
509                error: "sumcheck evaluation check failed",
510            })?;
511        }
512
513        let pcs_proof_evaluations: Vec<_> = self
514            .pcs_proof_evaluations
515            .first_round
516            .iter()
517            .chain(self.pcs_proof_evaluations.column_ref.iter())
518            .chain(self.pcs_proof_evaluations.final_round.iter())
519            .copied()
520            .collect();
521
522        // finally, check the MLE evaluations with the inner product proof
523        self.evaluation_proof
524            .verify_batched_proof(
525                &mut transcript,
526                &pcs_proof_commitments,
527                &evaluation_random_scalars,
528                &pcs_proof_evaluations,
529                &subclaim.evaluation_point,
530                min_row_num as u64,
531                self.first_round_message.range_length,
532                setup,
533            )
534            .map_err(|_e| ProofError::VerificationError {
535                error: "Inner product proof of MLE evaluations failed",
536            })?;
537
538        let verification_hash = transcript.challenge_as_le();
539
540        log::log_memory_usage("End");
541
542        Ok(QueryData {
543            table: result,
544            verification_hash,
545        })
546    }
547}