proof_of_sql/sql/proof/
query_proof.rs

1use super::{
2    make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3    ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4    VerificationBuilderImpl,
5};
6use crate::{
7    base::{
8        bit::BitDistribution,
9        commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10        database::{
11            ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12            OwnedTable, Table, TableRef,
13        },
14        map::{IndexMap, IndexSet},
15        math::log2_up,
16        polynomial::{compute_evaluation_vector, MultilinearExtension},
17        proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18    },
19    proof_primitive::sumcheck::SumcheckProof,
20    utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28use sqlparser::ast::Ident;
29use tracing::{span, Level};
30
31const SETUP_HASH: [u8; 32] = [
32    0xe8, 0x84, 0x0d, 0x8a, 0x41, 0xce, 0x9d, 0x4e, 0x14, 0xe7, 0xba, 0x0e, 0x1b, 0x02, 0x32, 0x24,
33    0x75, 0x13, 0x61, 0x57, 0x73, 0x78, 0x29, 0x1f, 0xcd, 0x3f, 0x0f, 0x05, 0xf0, 0xf7, 0xe8, 0x75,
34]; // TODO: make this different for each setup
35
36/// Return the row number range of tables referenced in the Query
37///
38/// Basically we are looking for the smallest offset and the largest offset + length
39/// so that we have an index range of the table rows that the query is referencing.
40fn get_index_range<'a>(
41    accessor: &dyn MetadataAccessor,
42    table_refs: impl IntoIterator<Item = &'a TableRef>,
43) -> (usize, usize) {
44    table_refs
45        .into_iter()
46        .map(|table_ref| {
47            let length = accessor.get_length(table_ref);
48            let offset = accessor.get_offset(table_ref);
49            (offset, offset + length)
50        })
51        .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
52        // Only applies to `EmptyExec` where there are no tables
53        .unwrap_or((0, 1))
54}
55
56#[derive(Clone, Serialize, Deserialize)]
57pub struct FirstRoundMessage<C> {
58    /// Length of the range of generators we use
59    pub range_length: usize,
60    pub post_result_challenge_count: usize,
61    /// Chi evaluation lengths
62    pub chi_evaluation_lengths: Vec<usize>,
63    /// Rho evaluation lengths
64    pub rho_evaluation_lengths: Vec<usize>,
65    /// First Round Commitments
66    pub round_commitments: Vec<C>,
67}
68
69#[derive(Clone, Serialize, Deserialize)]
70pub struct FinalRoundMessage<C> {
71    pub subpolynomial_constraint_count: usize,
72    /// Final Round Commitments
73    pub round_commitments: Vec<C>,
74    /// Bit distributions
75    pub bit_distributions: Vec<BitDistribution>,
76}
77#[derive(Clone, Serialize, Deserialize)]
78pub struct QueryProofPCSProofEvaluations<S> {
79    /// MLEs used in first round sumcheck except for the result columns
80    pub first_round: Vec<S>,
81    /// evaluations of the columns referenced in the query
82    pub column_ref: Vec<S>,
83    /// MLEs used in final round sumcheck except for the result columns
84    pub final_round: Vec<S>,
85}
86
87/// The proof for a query.
88///
89/// Note: Because the class is deserialized from untrusted data, it
90/// cannot maintain any invariant on its data members; hence, they are
91/// all public so as to allow for easy manipulation for testing.
92#[derive(Clone, Serialize, Deserialize)]
93pub struct QueryProof<CP: CommitmentEvaluationProof> {
94    pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
95    pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
96    /// Sumcheck Proof
97    pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
98    pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
99    /// Inner product proof of the MLEs' evaluations
100    pub(super) evaluation_proof: CP,
101}
102
103impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
104    /// Create a new `QueryProof`.
105    #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
106    #[expect(clippy::too_many_lines)]
107    pub fn new(
108        expr: &(impl ProofPlan + Serialize),
109        accessor: &impl DataAccessor<CP::Scalar>,
110        setup: &CP::ProverPublicSetup<'_>,
111        params: &[LiteralValue],
112    ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
113        log::log_memory_usage("Start");
114
115        let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
116        let initial_range_length = (max_row_num - min_row_num).max(1);
117        let alloc = Bump::new();
118
119        let total_col_refs = expr.get_column_references();
120        let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
121            .get_table_references()
122            .into_iter()
123            .map(|table_ref| {
124                let idents: IndexSet<Ident> = total_col_refs
125                    .iter()
126                    .filter(|col_ref| col_ref.table_ref() == table_ref)
127                    .map(ColumnRef::column_id)
128                    .collect();
129                (table_ref.clone(), accessor.get_table(&table_ref, &idents))
130            })
131            .collect();
132
133        // Prover First Round: Evaluate the query && get the right number of post result challenges
134        let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
135        let query_result =
136            expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
137        let owned_table_result = OwnedTable::from(&query_result);
138        let provable_result = query_result.into();
139        let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
140        let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
141
142        let range_length = first_round_builder.range_length();
143        let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
144        assert!(num_sumcheck_variables > 0);
145        let post_result_challenge_count = first_round_builder.num_post_result_challenges();
146
147        // commit to any intermediate MLEs
148        let first_round_commitments =
149            first_round_builder.commit_intermediate_mles(min_row_num, setup);
150
151        // construct a transcript for the proof
152        let mut transcript: Keccak256Transcript = Transcript::new();
153        transcript.extend_as_le([SETUP_HASH]);
154        transcript.challenge_as_le();
155        transcript.extend_serialize_as_le(expr);
156        transcript.challenge_as_le();
157        transcript.extend_serialize_as_le(&owned_table_result);
158        transcript.challenge_as_le();
159
160        for table in expr.get_table_references() {
161            let length = accessor.get_length(&table);
162            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
163        }
164        transcript.challenge_as_le();
165
166        for commitment in CP::Commitment::compute_commitments(
167            &expr
168                .get_column_references()
169                .into_iter()
170                .map(|col| {
171                    CommittableColumn::from(accessor.get_column(&col.table_ref(), &col.column_id()))
172                })
173                .collect_vec(),
174            min_row_num,
175            setup,
176        ) {
177            transcript.extend_serialize_as_le(&commitment);
178        }
179        transcript.challenge_as_le();
180
181        transcript.extend_serialize_as_le(&min_row_num);
182        transcript.challenge_as_le();
183
184        let first_round_message = FirstRoundMessage {
185            range_length,
186            chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
187            rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
188            post_result_challenge_count,
189            round_commitments: first_round_commitments,
190        };
191        transcript.extend_serialize_as_le(&first_round_message);
192
193        // These are the challenges that will be consumed by the proof
194        // Specifically, these are the challenges that the verifier sends to
195        // the prover after the prover sends the result, but before the prover
196        // send commitments to the intermediate witness columns.
197        // Note: the last challenge in the vec is the first one that is consumed.
198        let post_result_challenges =
199            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
200                .take(post_result_challenge_count)
201                .collect();
202
203        let mut final_round_builder =
204            FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
205
206        expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
207
208        let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
209
210        // commit to any intermediate MLEs
211        let final_round_commitments =
212            final_round_builder.commit_intermediate_mles(min_row_num, setup);
213
214        let final_round_message = FinalRoundMessage {
215            subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
216            round_commitments: final_round_commitments,
217            bit_distributions: final_round_builder.bit_distributions().to_vec(),
218        };
219
220        // add the commitments, bit distributions and chi evaluation lengths to the proof
221        transcript.challenge_as_le();
222        transcript.extend_serialize_as_le(&final_round_message);
223
224        // construct the sumcheck polynomial
225        let num_random_scalars =
226            num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
227        let random_scalars: Vec<_> =
228            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
229                .take(num_random_scalars)
230                .collect();
231        let state = make_sumcheck_prover_state(
232            final_round_builder.sumcheck_subpolynomials(),
233            num_sumcheck_variables,
234            &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
235        );
236        transcript.challenge_as_le();
237
238        // create the sumcheck proof -- this is the main part of proving a query
239        let span = span!(Level::DEBUG, "Sumcheck with initialization").entered();
240        let mut evaluation_point = vec![Zero::zero(); state.num_vars];
241        let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
242        span.exit();
243
244        // evaluate the MLEs used in sumcheck except for the result columns
245        let span = span!(Level::DEBUG, "initialize evaluation_vec").entered();
246        let mut evaluation_vec = vec![Zero::zero(); range_length];
247        span.exit();
248        compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
249        let first_round_pcs_proof_evaluations =
250            first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
251        let span = span!(Level::DEBUG, "initialize column_ref_pcs_proof_evaluations").entered();
252        let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
253            .iter()
254            .map(|col_ref| {
255                accessor
256                    .get_column(&col_ref.table_ref(), &col_ref.column_id())
257                    .inner_product(&evaluation_vec)
258            })
259            .collect();
260        span.exit();
261        let final_round_pcs_proof_evaluations =
262            final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
263
264        // commit to the MLE evaluations
265        let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
266            first_round: first_round_pcs_proof_evaluations,
267            column_ref: column_ref_pcs_proof_evaluations,
268            final_round: final_round_pcs_proof_evaluations,
269        };
270        transcript.extend_serialize_as_le(&pcs_proof_evaluations);
271
272        // fold together the pre result MLEs -- this will form the input to an inner product proof
273        // of their evaluations (fold in this context means create a random linear combination)
274        let random_scalars: Vec<_> =
275            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
276                .take(
277                    pcs_proof_evaluations.first_round.len()
278                        + pcs_proof_evaluations.column_ref.len()
279                        + pcs_proof_evaluations.final_round.len(),
280                )
281                .collect();
282
283        let column_ref_mles: Vec<_> = total_col_refs
284            .into_iter()
285            .map(|c| {
286                Box::new(accessor.get_column(&c.table_ref(), &c.column_id()))
287                    as Box<dyn MultilinearExtension<_>>
288            })
289            .collect();
290
291        let span = span!(Level::DEBUG, "QueryProof get folded_mle").entered();
292        let mut folded_mle = vec![Zero::zero(); range_length];
293        for (multiplier, evaluator) in random_scalars.iter().zip(
294            first_round_builder
295                .pcs_proof_mles()
296                .iter()
297                .chain(&column_ref_mles)
298                .chain(final_round_builder.pcs_proof_mles().iter()),
299        ) {
300            evaluator.mul_add(&mut folded_mle, multiplier);
301        }
302        span.exit();
303
304        // finally, form the inner product proof of the MLEs' evaluations
305        let evaluation_proof = CP::new(
306            &mut transcript,
307            &folded_mle,
308            &evaluation_point,
309            min_row_num as u64,
310            setup,
311        );
312
313        let proof = Self {
314            first_round_message,
315            final_round_message,
316            sumcheck_proof,
317            pcs_proof_evaluations,
318            evaluation_proof,
319        };
320
321        log::log_memory_usage("End");
322
323        Ok((proof, provable_result))
324    }
325
326    #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
327    #[expect(clippy::too_many_lines)]
328    /// Verify a `QueryProof`. Note: This does NOT transform the result!
329    pub fn verify(
330        self,
331        expr: &(impl ProofPlan + Serialize),
332        accessor: &impl CommitmentAccessor<CP::Commitment>,
333        result: OwnedTable<CP::Scalar>,
334        setup: &CP::VerifierPublicSetup<'_>,
335        params: &[LiteralValue],
336    ) -> QueryResult<CP::Scalar> {
337        log::log_memory_usage("Start");
338
339        let table_refs = expr.get_table_references();
340        let (min_row_num, _) = get_index_range(accessor, &table_refs);
341        let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
342        assert!(num_sumcheck_variables > 0);
343
344        // validate bit decompositions
345        for dist in &self.final_round_message.bit_distributions {
346            if !dist.is_valid() {
347                Err(ProofError::VerificationError {
348                    error: "invalid bit distributions",
349                })?;
350            } else if !dist.is_within_acceptable_range() {
351                Err(ProofError::VerificationError {
352                    error: "bit distribution outside of acceptable range",
353                })?;
354            }
355        }
356
357        let column_references = expr.get_column_references();
358
359        // construct a transcript for the proof
360        let mut transcript: Keccak256Transcript = Transcript::new();
361        transcript.extend_as_le([SETUP_HASH]);
362        transcript.challenge_as_le();
363        transcript.extend_serialize_as_le(expr);
364        transcript.challenge_as_le();
365        transcript.extend_serialize_as_le(&result);
366        transcript.challenge_as_le();
367
368        for table in expr.get_table_references() {
369            let length = accessor.get_length(&table);
370            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
371        }
372        transcript.challenge_as_le();
373
374        for commitment in expr
375            .get_column_references()
376            .into_iter()
377            .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id()))
378        {
379            transcript.extend_serialize_as_le(&commitment);
380        }
381        transcript.challenge_as_le();
382
383        transcript.extend_serialize_as_le(&min_row_num);
384        transcript.challenge_as_le();
385
386        transcript.extend_serialize_as_le(&self.first_round_message);
387
388        // These are the challenges that will be consumed by the proof
389        // Specifically, these are the challenges that the verifier sends to
390        // the prover after the prover sends the result, but before the prover
391        // send commitments to the intermediate witness columns.
392        // Note: the last challenge in the vec is the first one that is consumed.
393        let post_result_challenges =
394            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
395                .take(self.first_round_message.post_result_challenge_count)
396                .collect();
397
398        // add the commitments and bit distributions to the proof
399        transcript.challenge_as_le();
400        transcript.extend_serialize_as_le(&self.final_round_message);
401
402        // draw the random scalars for sumcheck
403        let num_random_scalars =
404            num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
405        let random_scalars: Vec<_> =
406            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
407                .take(num_random_scalars)
408                .collect();
409        let sumcheck_random_scalars = SumcheckRandomScalars::new(
410            &random_scalars,
411            self.first_round_message.range_length,
412            num_sumcheck_variables,
413        );
414        transcript.challenge_as_le();
415
416        // verify sumcheck up to the evaluation check
417        let subclaim = self.sumcheck_proof.verify_without_evaluation(
418            &mut transcript,
419            num_sumcheck_variables,
420            &Zero::zero(),
421        )?;
422
423        // commit to mle evaluations
424        transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
425
426        // draw the random scalars for the evaluation proof
427        // (i.e. the folding/random linear combination of the pcs_proof_mles)
428        let evaluation_random_scalars: Vec<_> =
429            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
430                .take(
431                    self.pcs_proof_evaluations.first_round.len()
432                        + self.pcs_proof_evaluations.column_ref.len()
433                        + self.pcs_proof_evaluations.final_round.len(),
434                )
435                .collect();
436
437        // Always prepend input lengths to the chi evaluation lengths
438        let table_length_map = table_refs
439            .into_iter()
440            .map(|table_ref| {
441                let len = accessor.get_length(&table_ref);
442                (table_ref, len)
443            })
444            .collect::<IndexMap<TableRef, usize>>();
445
446        let chi_evaluation_lengths = table_length_map
447            .values()
448            .chain(self.first_round_message.chi_evaluation_lengths.iter())
449            .copied();
450
451        // pass over the provable AST to fill in the verification builder
452        let sumcheck_evaluations = SumcheckMleEvaluations::new(
453            self.first_round_message.range_length,
454            chi_evaluation_lengths,
455            self.first_round_message.rho_evaluation_lengths.clone(),
456            &subclaim.evaluation_point,
457            &sumcheck_random_scalars,
458            &self.pcs_proof_evaluations.first_round,
459            &self.pcs_proof_evaluations.final_round,
460        );
461        let chi_eval_map: IndexMap<TableRef, (CP::Scalar, usize)> = table_length_map
462            .into_iter()
463            .map(|(table_ref, length)| {
464                (
465                    table_ref,
466                    (sumcheck_evaluations.chi_evaluations[&length], length),
467                )
468            })
469            .collect();
470        let mut builder = VerificationBuilderImpl::new(
471            sumcheck_evaluations,
472            &self.final_round_message.bit_distributions,
473            sumcheck_random_scalars.subpolynomial_multipliers,
474            post_result_challenges,
475            self.first_round_message.chi_evaluation_lengths.clone(),
476            self.first_round_message.rho_evaluation_lengths.clone(),
477            subclaim.max_multiplicands,
478        );
479
480        let pcs_proof_commitments: Vec<_> = self
481            .first_round_message
482            .round_commitments
483            .iter()
484            .cloned()
485            .chain(
486                column_references
487                    .iter()
488                    .map(|col| accessor.get_commitment(&col.table_ref(), &col.column_id())),
489            )
490            .chain(self.final_round_message.round_commitments.iter().cloned())
491            .collect();
492        let evaluation_accessor: IndexMap<_, _> = column_references
493            .into_iter()
494            .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
495            .chunk_by(|(r, _)| r.table_ref())
496            .into_iter()
497            .map(|(tr, g)| {
498                let im: IndexMap<_, _> = g.map(|(cr, eval)| (cr.column_id(), eval)).collect();
499                (tr, im)
500            })
501            .collect();
502
503        let verifier_evaluations = expr.verifier_evaluate(
504            &mut builder,
505            &evaluation_accessor,
506            Some(&result),
507            &chi_eval_map,
508            params,
509        )?;
510        // compute the evaluation of the result MLEs
511        let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
512        // check the evaluation of the result MLEs
513        if verifier_evaluations.column_evals() != result_evaluations {
514            Err(ProofError::VerificationError {
515                error: "result evaluation check failed",
516            })?;
517        }
518
519        // perform the evaluation check of the sumcheck polynomial
520        if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
521            Err(ProofError::VerificationError {
522                error: "sumcheck evaluation check failed",
523            })?;
524        }
525
526        let pcs_proof_evaluations: Vec<_> = self
527            .pcs_proof_evaluations
528            .first_round
529            .iter()
530            .chain(self.pcs_proof_evaluations.column_ref.iter())
531            .chain(self.pcs_proof_evaluations.final_round.iter())
532            .copied()
533            .collect();
534
535        // finally, check the MLE evaluations with the inner product proof
536        self.evaluation_proof
537            .verify_batched_proof(
538                &mut transcript,
539                &pcs_proof_commitments,
540                &evaluation_random_scalars,
541                &pcs_proof_evaluations,
542                &subclaim.evaluation_point,
543                min_row_num as u64,
544                self.first_round_message.range_length,
545                setup,
546            )
547            .map_err(|_e| ProofError::VerificationError {
548                error: "Inner product proof of MLE evaluations failed",
549            })?;
550
551        let verification_hash = transcript.challenge_as_le();
552
553        log::log_memory_usage("End");
554
555        Ok(QueryData {
556            table: result,
557            verification_hash,
558        })
559    }
560}