proof_of_sql/sql/proof/
query_proof.rs

1use super::{
2    make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3    ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4    VerificationBuilderImpl,
5};
6use crate::{
7    base::{
8        bit::BitDistribution,
9        commitment::{Commitment, CommitmentEvaluationProof, CommittableColumn},
10        database::{
11            ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12            OwnedTable, Table, TableRef,
13        },
14        map::{IndexMap, IndexSet},
15        math::log2_up,
16        polynomial::{compute_evaluation_vector, MultilinearExtension},
17        proof::{Keccak256Transcript, PlaceholderResult, ProofError, Transcript},
18    },
19    proof_primitive::sumcheck::SumcheckProof,
20    utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use itertools::Itertools;
26use num_traits::Zero;
27use serde::{Deserialize, Serialize};
28
29/// Return the row number range of tables referenced in the Query
30///
31/// Basically we are looking for the smallest offset and the largest offset + length
32/// so that we have an index range of the table rows that the query is referencing.
33fn get_index_range<'a>(
34    accessor: &dyn MetadataAccessor,
35    table_refs: impl IntoIterator<Item = &'a TableRef>,
36) -> (usize, usize) {
37    table_refs
38        .into_iter()
39        .map(|table_ref| {
40            let length = accessor.get_length(table_ref);
41            let offset = accessor.get_offset(table_ref);
42            (offset, offset + length)
43        })
44        .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
45        // Only applies to `EmptyExec` where there are no tables
46        .unwrap_or((0, 1))
47}
48
49#[derive(Clone, Serialize, Deserialize)]
50pub struct FirstRoundMessage<C> {
51    /// Length of the range of generators we use
52    pub range_length: usize,
53    pub post_result_challenge_count: usize,
54    /// Chi evaluation lengths
55    pub chi_evaluation_lengths: Vec<usize>,
56    /// Rho evaluation lengths
57    pub rho_evaluation_lengths: Vec<usize>,
58    /// First Round Commitments
59    pub round_commitments: Vec<C>,
60}
61
62#[derive(Clone, Serialize, Deserialize)]
63pub struct FinalRoundMessage<C> {
64    pub subpolynomial_constraint_count: usize,
65    /// Final Round Commitments
66    pub round_commitments: Vec<C>,
67    /// Bit distributions
68    pub bit_distributions: Vec<BitDistribution>,
69}
70#[derive(Clone, Serialize, Deserialize)]
71pub struct QueryProofPCSProofEvaluations<S> {
72    /// MLEs used in first round sumcheck except for the result columns
73    pub first_round: Vec<S>,
74    /// evaluations of the columns referenced in the query
75    pub column_ref: Vec<S>,
76    /// MLEs used in final round sumcheck except for the result columns
77    pub final_round: Vec<S>,
78}
79
80/// The proof for a query.
81///
82/// Note: Because the class is deserialized from untrusted data, it
83/// cannot maintain any invariant on its data members; hence, they are
84/// all public so as to allow for easy manipulation for testing.
85#[derive(Clone, Serialize, Deserialize)]
86pub struct QueryProof<CP: CommitmentEvaluationProof> {
87    pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
88    pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
89    /// Sumcheck Proof
90    pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
91    pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
92    /// Inner product proof of the MLEs' evaluations
93    pub(super) evaluation_proof: CP,
94}
95
96impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
97    /// Create a new `QueryProof`.
98    #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
99    pub fn new(
100        expr: &(impl ProofPlan + Serialize),
101        accessor: &impl DataAccessor<CP::Scalar>,
102        setup: &CP::ProverPublicSetup<'_>,
103        params: &[LiteralValue],
104    ) -> PlaceholderResult<(Self, OwnedTable<CP::Scalar>)> {
105        log::log_memory_usage("Start");
106
107        let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
108        let initial_range_length = (max_row_num - min_row_num).max(1);
109        let alloc = Bump::new();
110
111        let total_col_refs = expr.get_column_references();
112        let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
113            .get_table_references()
114            .into_iter()
115            .map(|table_ref| {
116                let col_refs: IndexSet<ColumnRef> = total_col_refs
117                    .iter()
118                    .filter(|col_ref| col_ref.table_ref() == table_ref)
119                    .cloned()
120                    .collect();
121                (table_ref.clone(), accessor.get_table(table_ref, &col_refs))
122            })
123            .collect();
124
125        // Prover First Round: Evaluate the query && get the right number of post result challenges
126        let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
127        let query_result =
128            expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params)?;
129        let owned_table_result = OwnedTable::from(&query_result);
130        let provable_result = query_result.into();
131        let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
132        let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
133
134        let range_length = first_round_builder.range_length();
135        let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
136        assert!(num_sumcheck_variables > 0);
137        let post_result_challenge_count = first_round_builder.num_post_result_challenges();
138
139        // commit to any intermediate MLEs
140        let first_round_commitments =
141            first_round_builder.commit_intermediate_mles(min_row_num, setup);
142
143        // construct a transcript for the proof
144        let mut transcript: Keccak256Transcript = Transcript::new();
145        transcript.challenge_as_le();
146        transcript.extend_serialize_as_le(expr);
147        transcript.challenge_as_le();
148        transcript.extend_serialize_as_le(&owned_table_result);
149        transcript.challenge_as_le();
150
151        for table in expr.get_table_references() {
152            let length = accessor.get_length(&table);
153            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
154        }
155        transcript.challenge_as_le();
156
157        for commitment in CP::Commitment::compute_commitments(
158            &expr
159                .get_column_references()
160                .into_iter()
161                .map(|col| CommittableColumn::from(accessor.get_column(col)))
162                .collect_vec(),
163            min_row_num,
164            setup,
165        ) {
166            transcript.extend_serialize_as_le(&commitment);
167        }
168        transcript.challenge_as_le();
169
170        transcript.extend_serialize_as_le(&min_row_num);
171        transcript.challenge_as_le();
172
173        let first_round_message = FirstRoundMessage {
174            range_length,
175            chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
176            rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
177            post_result_challenge_count,
178            round_commitments: first_round_commitments,
179        };
180        transcript.extend_serialize_as_le(&first_round_message);
181
182        // These are the challenges that will be consumed by the proof
183        // Specifically, these are the challenges that the verifier sends to
184        // the prover after the prover sends the result, but before the prover
185        // send commitments to the intermediate witness columns.
186        // Note: the last challenge in the vec is the first one that is consumed.
187        let post_result_challenges =
188            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
189                .take(post_result_challenge_count)
190                .collect();
191
192        let mut final_round_builder =
193            FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
194
195        expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params)?;
196
197        let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
198
199        // commit to any intermediate MLEs
200        let final_round_commitments =
201            final_round_builder.commit_intermediate_mles(min_row_num, setup);
202
203        let final_round_message = FinalRoundMessage {
204            subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
205            round_commitments: final_round_commitments,
206            bit_distributions: final_round_builder.bit_distributions().to_vec(),
207        };
208
209        // add the commitments, bit distributions and chi evaluation lengths to the proof
210        transcript.challenge_as_le();
211        transcript.extend_serialize_as_le(&final_round_message);
212
213        // construct the sumcheck polynomial
214        let num_random_scalars =
215            num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
216        let random_scalars: Vec<_> =
217            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
218                .take(num_random_scalars)
219                .collect();
220        let state = make_sumcheck_prover_state(
221            final_round_builder.sumcheck_subpolynomials(),
222            num_sumcheck_variables,
223            &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
224        );
225        transcript.challenge_as_le();
226
227        // create the sumcheck proof -- this is the main part of proving a query
228        let mut evaluation_point = vec![Zero::zero(); state.num_vars];
229        let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
230
231        // evaluate the MLEs used in sumcheck except for the result columns
232        let mut evaluation_vec = vec![Zero::zero(); range_length];
233        compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
234        let first_round_pcs_proof_evaluations =
235            first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
236        let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
237            .iter()
238            .map(|col_ref| {
239                accessor
240                    .get_column(col_ref.clone())
241                    .inner_product(&evaluation_vec)
242            })
243            .collect();
244        let final_round_pcs_proof_evaluations =
245            final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
246
247        // commit to the MLE evaluations
248        let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
249            first_round: first_round_pcs_proof_evaluations,
250            column_ref: column_ref_pcs_proof_evaluations,
251            final_round: final_round_pcs_proof_evaluations,
252        };
253        transcript.extend_serialize_as_le(&pcs_proof_evaluations);
254
255        // fold together the pre result MLEs -- this will form the input to an inner product proof
256        // of their evaluations (fold in this context means create a random linear combination)
257        let random_scalars: Vec<_> =
258            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
259                .take(
260                    pcs_proof_evaluations.first_round.len()
261                        + pcs_proof_evaluations.column_ref.len()
262                        + pcs_proof_evaluations.final_round.len(),
263                )
264                .collect();
265
266        let mut folded_mle = vec![Zero::zero(); range_length];
267        let column_ref_mles: Vec<_> = total_col_refs
268            .into_iter()
269            .map(|c| Box::new(accessor.get_column(c)) as Box<dyn MultilinearExtension<_>>)
270            .collect();
271        for (multiplier, evaluator) in random_scalars.iter().zip(
272            first_round_builder
273                .pcs_proof_mles()
274                .iter()
275                .chain(&column_ref_mles)
276                .chain(final_round_builder.pcs_proof_mles().iter()),
277        ) {
278            evaluator.mul_add(&mut folded_mle, multiplier);
279        }
280
281        // finally, form the inner product proof of the MLEs' evaluations
282        let evaluation_proof = CP::new(
283            &mut transcript,
284            &folded_mle,
285            &evaluation_point,
286            min_row_num as u64,
287            setup,
288        );
289
290        let proof = Self {
291            first_round_message,
292            final_round_message,
293            sumcheck_proof,
294            pcs_proof_evaluations,
295            evaluation_proof,
296        };
297
298        log::log_memory_usage("End");
299
300        Ok((proof, provable_result))
301    }
302
303    #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
304    /// Verify a `QueryProof`. Note: This does NOT transform the result!
305    pub fn verify(
306        self,
307        expr: &(impl ProofPlan + Serialize),
308        accessor: &impl CommitmentAccessor<CP::Commitment>,
309        result: OwnedTable<CP::Scalar>,
310        setup: &CP::VerifierPublicSetup<'_>,
311        params: &[LiteralValue],
312    ) -> QueryResult<CP::Scalar> {
313        log::log_memory_usage("Start");
314
315        let table_refs = expr.get_table_references();
316        let (min_row_num, _) = get_index_range(accessor, &table_refs);
317        let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
318        assert!(num_sumcheck_variables > 0);
319
320        // validate bit decompositions
321        for dist in &self.final_round_message.bit_distributions {
322            if !dist.is_valid() {
323                Err(ProofError::VerificationError {
324                    error: "invalid bit distributions",
325                })?;
326            } else if !dist.is_within_acceptable_range() {
327                Err(ProofError::VerificationError {
328                    error: "bit distribution outside of acceptable range",
329                })?;
330            }
331        }
332
333        let column_references = expr.get_column_references();
334
335        // construct a transcript for the proof
336        let mut transcript: Keccak256Transcript = Transcript::new();
337        transcript.challenge_as_le();
338        transcript.extend_serialize_as_le(expr);
339        transcript.challenge_as_le();
340        transcript.extend_serialize_as_le(&result);
341        transcript.challenge_as_le();
342
343        for table in expr.get_table_references() {
344            let length = accessor.get_length(&table);
345            transcript.extend_serialize_as_le(&[0, 0, 0, length]);
346        }
347        transcript.challenge_as_le();
348
349        for commitment in expr
350            .get_column_references()
351            .into_iter()
352            .map(|col| accessor.get_commitment(col))
353        {
354            transcript.extend_serialize_as_le(&commitment);
355        }
356        transcript.challenge_as_le();
357
358        transcript.extend_serialize_as_le(&min_row_num);
359        transcript.challenge_as_le();
360
361        transcript.extend_serialize_as_le(&self.first_round_message);
362
363        // These are the challenges that will be consumed by the proof
364        // Specifically, these are the challenges that the verifier sends to
365        // the prover after the prover sends the result, but before the prover
366        // send commitments to the intermediate witness columns.
367        // Note: the last challenge in the vec is the first one that is consumed.
368        let post_result_challenges =
369            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
370                .take(self.first_round_message.post_result_challenge_count)
371                .collect();
372
373        // add the commitments and bit distributions to the proof
374        transcript.challenge_as_le();
375        transcript.extend_serialize_as_le(&self.final_round_message);
376
377        // draw the random scalars for sumcheck
378        let num_random_scalars =
379            num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
380        let random_scalars: Vec<_> =
381            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
382                .take(num_random_scalars)
383                .collect();
384        let sumcheck_random_scalars = SumcheckRandomScalars::new(
385            &random_scalars,
386            self.first_round_message.range_length,
387            num_sumcheck_variables,
388        );
389        transcript.challenge_as_le();
390
391        // verify sumcheck up to the evaluation check
392        let subclaim = self.sumcheck_proof.verify_without_evaluation(
393            &mut transcript,
394            num_sumcheck_variables,
395            &Zero::zero(),
396        )?;
397
398        // commit to mle evaluations
399        transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
400
401        // draw the random scalars for the evaluation proof
402        // (i.e. the folding/random linear combination of the pcs_proof_mles)
403        let evaluation_random_scalars: Vec<_> =
404            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
405                .take(
406                    self.pcs_proof_evaluations.first_round.len()
407                        + self.pcs_proof_evaluations.column_ref.len()
408                        + self.pcs_proof_evaluations.final_round.len(),
409                )
410                .collect();
411
412        // Always prepend input lengths to the chi evaluation lengths
413        let table_length_map = table_refs
414            .into_iter()
415            .map(|table_ref| {
416                let len = accessor.get_length(&table_ref);
417                (table_ref, len)
418            })
419            .collect::<IndexMap<TableRef, usize>>();
420
421        let chi_evaluation_lengths = table_length_map
422            .values()
423            .chain(self.first_round_message.chi_evaluation_lengths.iter())
424            .copied();
425
426        // pass over the provable AST to fill in the verification builder
427        let sumcheck_evaluations = SumcheckMleEvaluations::new(
428            self.first_round_message.range_length,
429            chi_evaluation_lengths,
430            self.first_round_message.rho_evaluation_lengths.clone(),
431            &subclaim.evaluation_point,
432            &sumcheck_random_scalars,
433            &self.pcs_proof_evaluations.first_round,
434            &self.pcs_proof_evaluations.final_round,
435        );
436        let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
437            .into_iter()
438            .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
439            .collect();
440        let mut builder = VerificationBuilderImpl::new(
441            sumcheck_evaluations,
442            &self.final_round_message.bit_distributions,
443            sumcheck_random_scalars.subpolynomial_multipliers,
444            post_result_challenges,
445            self.first_round_message.chi_evaluation_lengths.clone(),
446            self.first_round_message.rho_evaluation_lengths.clone(),
447            subclaim.max_multiplicands,
448        );
449
450        let pcs_proof_commitments: Vec<_> = self
451            .first_round_message
452            .round_commitments
453            .iter()
454            .cloned()
455            .chain(
456                column_references
457                    .iter()
458                    .map(|col| accessor.get_commitment(col.clone())),
459            )
460            .chain(self.final_round_message.round_commitments.iter().cloned())
461            .collect();
462        let evaluation_accessor: IndexMap<_, _> = column_references
463            .into_iter()
464            .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
465            .collect();
466
467        let verifier_evaluations = expr.verifier_evaluate(
468            &mut builder,
469            &evaluation_accessor,
470            Some(&result),
471            &chi_eval_map,
472            params,
473        )?;
474        // compute the evaluation of the result MLEs
475        let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
476        // check the evaluation of the result MLEs
477        if verifier_evaluations.column_evals() != result_evaluations {
478            Err(ProofError::VerificationError {
479                error: "result evaluation check failed",
480            })?;
481        }
482
483        // perform the evaluation check of the sumcheck polynomial
484        if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
485            Err(ProofError::VerificationError {
486                error: "sumcheck evaluation check failed",
487            })?;
488        }
489
490        let pcs_proof_evaluations: Vec<_> = self
491            .pcs_proof_evaluations
492            .first_round
493            .iter()
494            .chain(self.pcs_proof_evaluations.column_ref.iter())
495            .chain(self.pcs_proof_evaluations.final_round.iter())
496            .copied()
497            .collect();
498
499        // finally, check the MLE evaluations with the inner product proof
500        self.evaluation_proof
501            .verify_batched_proof(
502                &mut transcript,
503                &pcs_proof_commitments,
504                &evaluation_random_scalars,
505                &pcs_proof_evaluations,
506                &subclaim.evaluation_point,
507                min_row_num as u64,
508                self.first_round_message.range_length,
509                setup,
510            )
511            .map_err(|_e| ProofError::VerificationError {
512                error: "Inner product proof of MLE evaluations failed",
513            })?;
514
515        let verification_hash = transcript.challenge_as_le();
516
517        log::log_memory_usage("End");
518
519        Ok(QueryData {
520            table: result,
521            verification_hash,
522        })
523    }
524}