proof_of_sql/sql/proof/
query_proof.rs

1use super::{
2    make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3    ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4    VerificationBuilderImpl,
5};
6use crate::{
7    base::{
8        bit::BitDistribution,
9        commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10        database::{
11            ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12            OwnedTable, Table, TableRef,
13        },
14        map::{IndexMap, IndexSet},
15        math::log2_up,
16        polynomial::{compute_evaluation_vector, MultilinearExtension},
17        proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18    },
19    proof_primitive::sumcheck::SumcheckProof,
20    utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29
30const SETUP_HASH: [u8; 32] = [
31    0xe8, 0x84, 0x0d, 0x8a, 0x41, 0xce, 0x9d, 0x4e, 0x14, 0xe7, 0xba, 0x0e, 0x1b, 0x02, 0x32, 0x24,
32    0x75, 0x13, 0x61, 0x57, 0x73, 0x78, 0x29, 0x1f, 0xcd, 0x3f, 0x0f, 0x05, 0xf0, 0xf7, 0xe8, 0x75,
33]; // TODO: make this different for each setup
34
35/// Return the row number range of tables referenced in the Query
36///
37/// Basically we are looking for the smallest offset and the largest offset + length
38/// so that we have an index range of the table rows that the query is referencing.
39fn get_index_range<'a>(
40    accessor: &dyn MetadataAccessor,
41    table_refs: impl IntoIterator<Item = &'a TableRef>,
42) -> (usize, usize) {
43    table_refs
44        .into_iter()
45        .map(|table_ref| {
46            let length = accessor.get_length(table_ref);
47            let offset = accessor.get_offset(table_ref);
48            (offset, offset + length)
49        })
50        .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
51        // Only applies to `EmptyExec` where there are no tables
52        .unwrap_or((0, 1))
53}
54
55#[derive(Clone, Serialize, Deserialize)]
56pub struct FirstRoundMessage<C> {
57    /// Length of the range of generators we use
58    pub range_length: usize,
59    pub post_result_challenge_count: usize,
60    /// Chi evaluation lengths
61    pub chi_evaluation_lengths: Vec<usize>,
62    /// Rho evaluation lengths
63    pub rho_evaluation_lengths: Vec<usize>,
64    /// First Round Commitments
65    pub round_commitments: Vec<C>,
66}
67
68#[derive(Clone, Serialize, Deserialize)]
69pub struct FinalRoundMessage<C> {
70    pub subpolynomial_constraint_count: usize,
71    /// Final Round Commitments
72    pub round_commitments: Vec<C>,
73    /// Bit distributions
74    pub bit_distributions: Vec<BitDistribution>,
75}
76#[derive(Clone, Serialize, Deserialize)]
77pub struct QueryProofPCSProofEvaluations<S> {
78    /// MLEs used in first round sumcheck except for the result columns
79    pub first_round: Vec<S>,
80    /// evaluations of the columns referenced in the query
81    pub column_ref: Vec<S>,
82    /// MLEs used in final round sumcheck except for the result columns
83    pub final_round: Vec<S>,
84}
85
86/// The proof for a query.
87///
88/// Note: Because the class is deserialized from untrusted data, it
89/// cannot maintain any invariant on its data members; hence, they are
90/// all public so as to allow for easy manipulation for testing.
91#[derive(Clone, Serialize, Deserialize)]
92pub struct QueryProof<CP: CommitmentEvaluationProof> {
93    pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
94    pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
95    /// Sumcheck Proof
96    pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
97    pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
98    /// Inner product proof of the MLEs' evaluations
99    pub(super) evaluation_proof: CP,
100}
101
102impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
103    /// Create a new `QueryProof`.
104    #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
105    pub fn new(
106        expr: &(impl ProofPlan + Serialize),
107        accessor: &impl DataAccessor<CP::Scalar>,
108        setup: &CP::ProverPublicSetup<'_>,
109        params: &[LiteralValue],
110    ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
111        log::log_memory_usage("Start");
112
113        let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
114        let initial_range_length = (max_row_num - min_row_num).max(1);
115        let alloc = Bump::new();
116
117        let total_col_refs = expr.get_column_references();
118        let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
119            .get_table_references()
120            .into_iter()
121            .map(|table_ref| {
122                let idents: IndexSet<Ident> = total_col_refs
123                    .iter()
124                    .filter(|col_ref| col_ref.table_ref() == table_ref)
125                    .map(ColumnRef::column_id)
126                    .collect();
127                (table_ref.clone(), accessor.get_table(&table_ref, &idents))
128            })
129            .collect();
130
131        // Prover First Round: Evaluate the query && get the right number of post result challenges
132        let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
133        let query_result =
134            expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
135        let owned_table_result = OwnedTable::from(&query_result);
136        let provable_result = query_result.into();
137        let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
138        let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
139
140        let range_length = first_round_builder.range_length();
141        let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
142        assert!(num_sumcheck_variables > 0);
143        let post_result_challenge_count = first_round_builder.num_post_result_challenges();
144
145        // commit to any intermediate MLEs
146        let first_round_commitments =
147            first_round_builder.commit_intermediate_mles(min_row_num, setup);
148
149        // construct a transcript for the proof
150        let mut transcript: Keccak256Transcript = Transcript::new();
151        transcript.extend_as_le([SETUP_HASH]);
152        transcript.challenge_as_le();
153        transcript.extend_serialize_as_le(expr);
154        transcript.challenge_as_le();
155        transcript.extend_serialize_as_le(&owned_table_result);
156        transcript.challenge_as_le();
157
158        for table in expr.get_table_references() {
159            let length = accessor.get_length(&table);
160            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
161        }
162        transcript.challenge_as_le();
163
164        for commitment in CP::Commitment::compute_commitments(
165            &expr
166                .get_column_references()
167                .into_iter()
168                .map(|col| {
169                    CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
170                })
171                .collect_vec(),
172            min_row_num,
173            setup,
174        ) {
175            transcript.extend_serialize_as_le(&commitment);
176        }
177        transcript.challenge_as_le();
178
179        transcript.extend_serialize_as_le(&min_row_num);
180        transcript.challenge_as_le();
181
182        let first_round_message = FirstRoundMessage {
183            range_length,
184            chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
185            rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
186            post_result_challenge_count,
187            round_commitments: first_round_commitments,
188        };
189        transcript.extend_serialize_as_le(&first_round_message);
190
191        // These are the challenges that will be consumed by the proof
192        // Specifically, these are the challenges that the verifier sends to
193        // the prover after the prover sends the result, but before the prover
194        // send commitments to the intermediate witness columns.
195        // Note: the last challenge in the vec is the first one that is consumed.
196        let post_result_challenges =
197            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
198                .take(post_result_challenge_count)
199                .collect();
200
201        let mut final_round_builder =
202            FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
203
204        expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
205
206        let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
207
208        // commit to any intermediate MLEs
209        let final_round_commitments =
210            final_round_builder.commit_intermediate_mles(min_row_num, setup);
211
212        let final_round_message = FinalRoundMessage {
213            subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
214            round_commitments: final_round_commitments,
215            bit_distributions: final_round_builder.bit_distributions().to_vec(),
216        };
217
218        // add the commitments, bit distributions and chi evaluation lengths to the proof
219        transcript.challenge_as_le();
220        transcript.extend_serialize_as_le(&final_round_message);
221
222        // construct the sumcheck polynomial
223        let num_random_scalars =
224            num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
225        let random_scalars: Vec<_> =
226            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
227                .take(num_random_scalars)
228                .collect();
229        let state = make_sumcheck_prover_state(
230            final_round_builder.sumcheck_subpolynomials(),
231            num_sumcheck_variables,
232            &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
233        );
234        transcript.challenge_as_le();
235
236        // create the sumcheck proof -- this is the main part of proving a query
237        let mut evaluation_point = vec![Zero::zero(); state.num_vars];
238        let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
239
240        // evaluate the MLEs used in sumcheck except for the result columns
241        let mut evaluation_vec = vec![Zero::zero(); range_length];
242        compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
243        let first_round_pcs_proof_evaluations =
244            first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
245        let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
246            .iter()
247            .map(|col_ref| {
248                accessor
249                    .get_column(&col_ref.table_ref(), &col_ref.column_id())
250                    .inner_product(&evaluation_vec)
251            })
252            .collect();
253        let final_round_pcs_proof_evaluations =
254            final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
255
256        // commit to the MLE evaluations
257        let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
258            first_round: first_round_pcs_proof_evaluations,
259            column_ref: column_ref_pcs_proof_evaluations,
260            final_round: final_round_pcs_proof_evaluations,
261        };
262        transcript.extend_serialize_as_le(&pcs_proof_evaluations);
263
264        // fold together the pre result MLEs -- this will form the input to an inner product proof
265        // of their evaluations (fold in this context means create a random linear combination)
266        let random_scalars: Vec<_> =
267            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
268                .take(
269                    pcs_proof_evaluations.first_round.len()
270                        + pcs_proof_evaluations.column_ref.len()
271                        + pcs_proof_evaluations.final_round.len(),
272                )
273                .collect();
274
275        let mut folded_mle = vec![Zero::zero(); range_length];
276        let column_ref_mles: Vec<_> = total_col_refs
277            .into_iter()
278            .map(|c| {
279                Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
280                    as Box<dyn MultilinearExtension<_>>
281            })
282            .collect();
283        for (multiplier, evaluator) in random_scalars.iter().zip(
284            first_round_builder
285                .pcs_proof_mles()
286                .iter()
287                .chain(&column_ref_mles)
288                .chain(final_round_builder.pcs_proof_mles().iter()),
289        ) {
290            evaluator.mul_add(&mut folded_mle, multiplier);
291        }
292
293        // finally, form the inner product proof of the MLEs' evaluations
294        let evaluation_proof = CP::new(
295            &mut transcript,
296            &folded_mle,
297            &evaluation_point,
298            min_row_num as u64,
299            setup,
300        );
301
302        let proof = Self {
303            first_round_message,
304            final_round_message,
305            sumcheck_proof,
306            pcs_proof_evaluations,
307            evaluation_proof,
308        };
309
310        log::log_memory_usage("End");
311
312        Ok((proof, provable_result))
313    }
314
315    #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
316    /// Verify a `QueryProof`. Note: This does NOT transform the result!
317    pub fn verify(
318        self,
319        expr: &(impl ProofPlan + Serialize),
320        accessor: &impl CommitmentAccessor<CP::Commitment>,
321        result: OwnedTable<CP::Scalar>,
322        setup: &CP::VerifierPublicSetup<'_>,
323        params: &[LiteralValue],
324    ) -> QueryResult<CP::Scalar> {
325        log::log_memory_usage("Start");
326
327        let table_refs = expr.get_table_references();
328        let (min_row_num, _) = get_index_range(accessor, &table_refs);
329        let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
330        assert!(num_sumcheck_variables > 0);
331
332        // validate bit decompositions
333        for dist in &self.final_round_message.bit_distributions {
334            if !dist.is_valid() {
335                Err(ProofError::VerificationError {
336                    error: "invalid bit distributions",
337                })?;
338            } else if !dist.is_within_acceptable_range() {
339                Err(ProofError::VerificationError {
340                    error: "bit distribution outside of acceptable range",
341                })?;
342            }
343        }
344
345        let column_references = expr.get_column_references();
346
347        // construct a transcript for the proof
348        let mut transcript: Keccak256Transcript = Transcript::new();
349        transcript.extend_as_le([SETUP_HASH]);
350        transcript.challenge_as_le();
351        transcript.extend_serialize_as_le(expr);
352        transcript.challenge_as_le();
353        transcript.extend_serialize_as_le(&result);
354        transcript.challenge_as_le();
355
356        for table in expr.get_table_references() {
357            let length = accessor.get_length(&table);
358            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
359        }
360        transcript.challenge_as_le();
361
362        for commitment in expr
363            .get_column_references()
364            .into_iter()
365            .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
366        {
367            transcript.extend_serialize_as_le(&commitment);
368        }
369        transcript.challenge_as_le();
370
371        transcript.extend_serialize_as_le(&min_row_num);
372        transcript.challenge_as_le();
373
374        transcript.extend_serialize_as_le(&self.first_round_message);
375
376        // These are the challenges that will be consumed by the proof
377        // Specifically, these are the challenges that the verifier sends to
378        // the prover after the prover sends the result, but before the prover
379        // send commitments to the intermediate witness columns.
380        // Note: the last challenge in the vec is the first one that is consumed.
381        let post_result_challenges =
382            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
383                .take(self.first_round_message.post_result_challenge_count)
384                .collect();
385
386        // add the commitments and bit distributions to the proof
387        transcript.challenge_as_le();
388        transcript.extend_serialize_as_le(&self.final_round_message);
389
390        // draw the random scalars for sumcheck
391        let num_random_scalars =
392            num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
393        let random_scalars: Vec<_> =
394            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
395                .take(num_random_scalars)
396                .collect();
397        let sumcheck_random_scalars = SumcheckRandomScalars::new(
398            &random_scalars,
399            self.first_round_message.range_length,
400            num_sumcheck_variables,
401        );
402        transcript.challenge_as_le();
403
404        // verify sumcheck up to the evaluation check
405        let subclaim = self.sumcheck_proof.verify_without_evaluation(
406            &mut transcript,
407            num_sumcheck_variables,
408            &Zero::zero(),
409        )?;
410
411        // commit to mle evaluations
412        transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
413
414        // draw the random scalars for the evaluation proof
415        // (i.e. the folding/random linear combination of the pcs_proof_mles)
416        let evaluation_random_scalars: Vec<_> =
417            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
418                .take(
419                    self.pcs_proof_evaluations.first_round.len()
420                        + self.pcs_proof_evaluations.column_ref.len()
421                        + self.pcs_proof_evaluations.final_round.len(),
422                )
423                .collect();
424
425        // Always prepend input lengths to the chi evaluation lengths
426        let table_length_map = table_refs
427            .into_iter()
428            .map(|table_ref| {
429                let len = accessor.get_length(&table_ref);
430                (table_ref, len)
431            })
432            .collect::<IndexMap<TableRef, usize>>();
433
434        let chi_evaluation_lengths = table_length_map
435            .values()
436            .chain(self.first_round_message.chi_evaluation_lengths.iter())
437            .copied();
438
439        // pass over the provable AST to fill in the verification builder
440        let sumcheck_evaluations = SumcheckMleEvaluations::new(
441            self.first_round_message.range_length,
442            chi_evaluation_lengths,
443            self.first_round_message.rho_evaluation_lengths.clone(),
444            &subclaim.evaluation_point,
445            &sumcheck_random_scalars,
446            &self.pcs_proof_evaluations.first_round,
447            &self.pcs_proof_evaluations.final_round,
448        );
449        let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
450            .into_iter()
451            .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
452            .collect();
453        let mut builder = VerificationBuilderImpl::new(
454            sumcheck_evaluations,
455            &self.final_round_message.bit_distributions,
456            sumcheck_random_scalars.subpolynomial_multipliers,
457            post_result_challenges,
458            self.first_round_message.chi_evaluation_lengths.clone(),
459            self.first_round_message.rho_evaluation_lengths.clone(),
460            subclaim.max_multiplicands,
461        );
462
463        let pcs_proof_commitments: Vec<_> = self
464            .first_round_message
465            .round_commitments
466            .iter()
467            .cloned()
468            .chain(
469                column_references
470                    .iter()
471                    .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
472            )
473            .chain(self.final_round_message.round_commitments.iter().cloned())
474            .collect();
475        let evaluation_accessor: IndexMap<_, _> = column_references
476            .into_iter()
477            .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
478            .chunk_by(|(r, _)| r.table_ref())
479            .into_iter()
480            .map(|(tr, g)| {
481                let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
482                (tr, im)
483            })
484            .collect();
485
486        let verifier_evaluations = expr.verifier_evaluate(
487            &mut builder,
488            &evaluation_accessor,
489            Some(&result),
490            &chi_eval_map,
491            params,
492        )?;
493        // compute the evaluation of the result MLEs
494        let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
495        // check the evaluation of the result MLEs
496        if verifier_evaluations.column_evals() != result_evaluations {
497            Err(ProofError::VerificationError {
498                error: "result evaluation check failed",
499            })?;
500        }
501
502        // perform the evaluation check of the sumcheck polynomial
503        if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
504            Err(ProofError::VerificationError {
505                error: "sumcheck evaluation check failed",
506            })?;
507        }
508
509        let pcs_proof_evaluations: Vec<_> = self
510            .pcs_proof_evaluations
511            .first_round
512            .iter()
513            .chain(self.pcs_proof_evaluations.column_ref.iter())
514            .chain(self.pcs_proof_evaluations.final_round.iter())
515            .copied()
516            .collect();
517
518        // finally, check the MLE evaluations with the inner product proof
519        self.evaluation_proof
520            .verify_batched_proof(
521                &mut transcript,
522                &pcs_proof_commitments,
523                &evaluation_random_scalars,
524                &pcs_proof_evaluations,
525                &subclaim.evaluation_point,
526                min_row_num as u64,
527                self.first_round_message.range_length,
528                setup,
529            )
530            .map_err(|_e| ProofError::VerificationError {
531                error: "Inner product proof of MLE evaluations failed",
532            })?;
533
534        let verification_hash = transcript.challenge_as_le();
535
536        log::log_memory_usage("End");
537
538        Ok(QueryData {
539            table: result,
540            verification_hash,
541        })
542    }
543}