proof_of_sql/sql/proof/
query_proof.rs

1use super::{
2    make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3    ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4    VerificationBuilderImpl,
5};
6use crate::{
7    base::{
8        bit::BitDistribution,
9        commitment::CommitmentEvaluationProof,
10        database::{
11            ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12            OwnedTable, Table, TableRef,
13        },
14        map::{IndexMap, IndexSet},
15        math::log2_up,
16        polynomial::{compute_evaluation_vector, MultilinearExtension},
17        proof::{Keccak256Transcript, ProofError, Transcript},
18    },
19    proof_primitive::sumcheck::SumcheckProof,
20    utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use num_traits::Zero;
26use serde::{Deserialize, Serialize};
27
28/// Return the row number range of tables referenced in the Query
29///
30/// Basically we are looking for the smallest offset and the largest offset + length
31/// so that we have an index range of the table rows that the query is referencing.
32fn get_index_range<'a>(
33    accessor: &dyn MetadataAccessor,
34    table_refs: impl IntoIterator<Item = &'a TableRef>,
35) -> (usize, usize) {
36    table_refs
37        .into_iter()
38        .map(|table_ref| {
39            let length = accessor.get_length(table_ref);
40            let offset = accessor.get_offset(table_ref);
41            (offset, offset + length)
42        })
43        .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
44        // Only applies to `EmptyExec` where there are no tables
45        .unwrap_or((0, 1))
46}
47
48#[derive(Clone, Serialize, Deserialize)]
49pub struct FirstRoundMessage<C> {
50    /// Length of the range of generators we use
51    pub range_length: usize,
52    pub post_result_challenge_count: usize,
53    /// Chi evaluation lengths
54    pub chi_evaluation_lengths: Vec<usize>,
55    /// Rho evaluation lengths
56    pub rho_evaluation_lengths: Vec<usize>,
57    /// First Round Commitments
58    pub round_commitments: Vec<C>,
59}
60
61#[derive(Clone, Serialize, Deserialize)]
62pub struct FinalRoundMessage<C> {
63    pub subpolynomial_constraint_count: usize,
64    /// Final Round Commitments
65    pub round_commitments: Vec<C>,
66    /// Bit distributions
67    pub bit_distributions: Vec<BitDistribution>,
68}
69#[derive(Clone, Serialize, Deserialize)]
70pub struct QueryProofPCSProofEvaluations<S> {
71    /// MLEs used in first round sumcheck except for the result columns
72    pub first_round: Vec<S>,
73    /// evaluations of the columns referenced in the query
74    pub column_ref: Vec<S>,
75    /// MLEs used in final round sumcheck except for the result columns
76    pub final_round: Vec<S>,
77}
78
79/// The proof for a query.
80///
81/// Note: Because the class is deserialized from untrusted data, it
82/// cannot maintain any invariant on its data members; hence, they are
83/// all public so as to allow for easy manipulation for testing.
84#[derive(Clone, Serialize, Deserialize)]
85pub struct QueryProof<CP: CommitmentEvaluationProof> {
86    pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
87    pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
88    /// Sumcheck Proof
89    pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
90    pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
91    /// Inner product proof of the MLEs' evaluations
92    pub(super) evaluation_proof: CP,
93}
94
95impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
96    /// Create a new `QueryProof`.
97    #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
98    pub fn new(
99        expr: &(impl ProofPlan + Serialize),
100        accessor: &impl DataAccessor<CP::Scalar>,
101        setup: &CP::ProverPublicSetup<'_>,
102        params: &[LiteralValue],
103    ) -> (Self, OwnedTable<CP::Scalar>) {
104        log::log_memory_usage("Start");
105
106        let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
107        let initial_range_length = (max_row_num - min_row_num).max(1);
108        let alloc = Bump::new();
109
110        let total_col_refs = expr.get_column_references();
111        let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
112            .get_table_references()
113            .into_iter()
114            .map(|table_ref| {
115                let col_refs: IndexSet<ColumnRef> = total_col_refs
116                    .iter()
117                    .filter(|col_ref| col_ref.table_ref() == table_ref)
118                    .cloned()
119                    .collect();
120                (table_ref.clone(), accessor.get_table(table_ref, &col_refs))
121            })
122            .collect();
123
124        // Prover First Round: Evaluate the query && get the right number of post result challenges
125        let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
126        let query_result =
127            expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params);
128        let owned_table_result = OwnedTable::from(&query_result);
129        let provable_result = query_result.into();
130        let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
131        let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
132
133        let range_length = first_round_builder.range_length();
134        let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
135        assert!(num_sumcheck_variables > 0);
136        let post_result_challenge_count = first_round_builder.num_post_result_challenges();
137
138        // commit to any intermediate MLEs
139        let first_round_commitments =
140            first_round_builder.commit_intermediate_mles(min_row_num, setup);
141
142        // construct a transcript for the proof
143        let mut transcript: Keccak256Transcript = Transcript::new();
144        transcript.extend_serialize_as_le(expr);
145        transcript.extend_serialize_as_le(&owned_table_result);
146        transcript.extend_serialize_as_le(&min_row_num);
147        transcript.challenge_as_le();
148
149        let first_round_message = FirstRoundMessage {
150            range_length,
151            chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
152            rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
153            post_result_challenge_count,
154            round_commitments: first_round_commitments,
155        };
156        transcript.extend_serialize_as_le(&first_round_message);
157
158        // These are the challenges that will be consumed by the proof
159        // Specifically, these are the challenges that the verifier sends to
160        // the prover after the prover sends the result, but before the prover
161        // send commitments to the intermediate witness columns.
162        // Note: the last challenge in the vec is the first one that is consumed.
163        let post_result_challenges =
164            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
165                .take(post_result_challenge_count)
166                .collect();
167
168        let mut final_round_builder =
169            FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
170
171        expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params);
172
173        let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
174
175        // commit to any intermediate MLEs
176        let final_round_commitments =
177            final_round_builder.commit_intermediate_mles(min_row_num, setup);
178
179        let final_round_message = FinalRoundMessage {
180            subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
181            round_commitments: final_round_commitments,
182            bit_distributions: final_round_builder.bit_distributions().to_vec(),
183        };
184
185        // add the commitments, bit distributions and chi evaluation lengths to the proof
186        transcript.challenge_as_le();
187        transcript.extend_serialize_as_le(&final_round_message);
188
189        // construct the sumcheck polynomial
190        let num_random_scalars =
191            num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
192        let random_scalars: Vec<_> =
193            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
194                .take(num_random_scalars)
195                .collect();
196        let state = make_sumcheck_prover_state(
197            final_round_builder.sumcheck_subpolynomials(),
198            num_sumcheck_variables,
199            &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
200        );
201        transcript.challenge_as_le();
202
203        // create the sumcheck proof -- this is the main part of proving a query
204        let mut evaluation_point = vec![Zero::zero(); state.num_vars];
205        let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
206
207        // evaluate the MLEs used in sumcheck except for the result columns
208        let mut evaluation_vec = vec![Zero::zero(); range_length];
209        compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
210        let first_round_pcs_proof_evaluations =
211            first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
212        let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
213            .iter()
214            .map(|col_ref| {
215                accessor
216                    .get_column(col_ref.clone())
217                    .inner_product(&evaluation_vec)
218            })
219            .collect();
220        let final_round_pcs_proof_evaluations =
221            final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
222
223        // commit to the MLE evaluations
224        let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
225            first_round: first_round_pcs_proof_evaluations,
226            column_ref: column_ref_pcs_proof_evaluations,
227            final_round: final_round_pcs_proof_evaluations,
228        };
229        transcript.extend_serialize_as_le(&pcs_proof_evaluations);
230
231        // fold together the pre result MLEs -- this will form the input to an inner product proof
232        // of their evaluations (fold in this context means create a random linear combination)
233        let random_scalars: Vec<_> =
234            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
235                .take(
236                    pcs_proof_evaluations.first_round.len()
237                        + pcs_proof_evaluations.column_ref.len()
238                        + pcs_proof_evaluations.final_round.len(),
239                )
240                .collect();
241
242        let mut folded_mle = vec![Zero::zero(); range_length];
243        let column_ref_mles: Vec<_> = total_col_refs
244            .into_iter()
245            .map(|c| Box::new(accessor.get_column(c)) as Box<dyn MultilinearExtension<_>>)
246            .collect();
247        for (multiplier, evaluator) in random_scalars.iter().zip(
248            first_round_builder
249                .pcs_proof_mles()
250                .iter()
251                .chain(&column_ref_mles)
252                .chain(final_round_builder.pcs_proof_mles().iter()),
253        ) {
254            evaluator.mul_add(&mut folded_mle, multiplier);
255        }
256
257        // finally, form the inner product proof of the MLEs' evaluations
258        let evaluation_proof = CP::new(
259            &mut transcript,
260            &folded_mle,
261            &evaluation_point,
262            min_row_num as u64,
263            setup,
264        );
265
266        let proof = Self {
267            first_round_message,
268            final_round_message,
269            sumcheck_proof,
270            pcs_proof_evaluations,
271            evaluation_proof,
272        };
273
274        log::log_memory_usage("End");
275
276        (proof, provable_result)
277    }
278
279    #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
280    /// Verify a `QueryProof`. Note: This does NOT transform the result!
281    pub fn verify(
282        self,
283        expr: &(impl ProofPlan + Serialize),
284        accessor: &impl CommitmentAccessor<CP::Commitment>,
285        result: OwnedTable<CP::Scalar>,
286        setup: &CP::VerifierPublicSetup<'_>,
287        params: &[LiteralValue],
288    ) -> QueryResult<CP::Scalar> {
289        log::log_memory_usage("Start");
290
291        let table_refs = expr.get_table_references();
292        let (min_row_num, _) = get_index_range(accessor, &table_refs);
293        let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
294        assert!(num_sumcheck_variables > 0);
295
296        // validate bit decompositions
297        for dist in &self.final_round_message.bit_distributions {
298            if !dist.is_valid() {
299                Err(ProofError::VerificationError {
300                    error: "invalid bit distributions",
301                })?;
302            } else if !dist.is_within_acceptable_range() {
303                Err(ProofError::VerificationError {
304                    error: "bit distribution outside of acceptable range",
305                })?;
306            }
307        }
308
309        let column_references = expr.get_column_references();
310
311        // construct a transcript for the proof
312        let mut transcript: Keccak256Transcript = Transcript::new();
313        transcript.extend_serialize_as_le(expr);
314        transcript.extend_serialize_as_le(&result);
315        transcript.extend_serialize_as_le(&min_row_num);
316        transcript.challenge_as_le();
317
318        transcript.extend_serialize_as_le(&self.first_round_message);
319
320        // These are the challenges that will be consumed by the proof
321        // Specifically, these are the challenges that the verifier sends to
322        // the prover after the prover sends the result, but before the prover
323        // send commitments to the intermediate witness columns.
324        // Note: the last challenge in the vec is the first one that is consumed.
325        let post_result_challenges =
326            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
327                .take(self.first_round_message.post_result_challenge_count)
328                .collect();
329
330        // add the commitments and bit distributions to the proof
331        transcript.challenge_as_le();
332        transcript.extend_serialize_as_le(&self.final_round_message);
333
334        // draw the random scalars for sumcheck
335        let num_random_scalars =
336            num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
337        let random_scalars: Vec<_> =
338            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
339                .take(num_random_scalars)
340                .collect();
341        let sumcheck_random_scalars = SumcheckRandomScalars::new(
342            &random_scalars,
343            self.first_round_message.range_length,
344            num_sumcheck_variables,
345        );
346        transcript.challenge_as_le();
347
348        // verify sumcheck up to the evaluation check
349        let subclaim = self.sumcheck_proof.verify_without_evaluation(
350            &mut transcript,
351            num_sumcheck_variables,
352            &Zero::zero(),
353        )?;
354
355        // commit to mle evaluations
356        transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
357
358        // draw the random scalars for the evaluation proof
359        // (i.e. the folding/random linear combination of the pcs_proof_mles)
360        let evaluation_random_scalars: Vec<_> =
361            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
362                .take(
363                    self.pcs_proof_evaluations.first_round.len()
364                        + self.pcs_proof_evaluations.column_ref.len()
365                        + self.pcs_proof_evaluations.final_round.len(),
366                )
367                .collect();
368
369        // Always prepend input lengths to the chi evaluation lengths
370        let table_length_map = table_refs
371            .into_iter()
372            .map(|table_ref| {
373                let len = accessor.get_length(&table_ref);
374                (table_ref, len)
375            })
376            .collect::<IndexMap<TableRef, usize>>();
377
378        let chi_evaluation_lengths = table_length_map
379            .values()
380            .chain(self.first_round_message.chi_evaluation_lengths.iter())
381            .copied();
382
383        // pass over the provable AST to fill in the verification builder
384        let sumcheck_evaluations = SumcheckMleEvaluations::new(
385            self.first_round_message.range_length,
386            chi_evaluation_lengths,
387            self.first_round_message.rho_evaluation_lengths.clone(),
388            &subclaim.evaluation_point,
389            &sumcheck_random_scalars,
390            &self.pcs_proof_evaluations.first_round,
391            &self.pcs_proof_evaluations.final_round,
392        );
393        let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
394            .into_iter()
395            .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
396            .collect();
397        let mut builder = VerificationBuilderImpl::new(
398            sumcheck_evaluations,
399            &self.final_round_message.bit_distributions,
400            sumcheck_random_scalars.subpolynomial_multipliers,
401            post_result_challenges,
402            self.first_round_message.chi_evaluation_lengths.clone(),
403            self.first_round_message.rho_evaluation_lengths.clone(),
404            subclaim.max_multiplicands,
405        );
406
407        let pcs_proof_commitments: Vec<_> = self
408            .first_round_message
409            .round_commitments
410            .iter()
411            .cloned()
412            .chain(
413                column_references
414                    .iter()
415                    .map(|col| accessor.get_commitment(col.clone())),
416            )
417            .chain(self.final_round_message.round_commitments.iter().cloned())
418            .collect();
419        let evaluation_accessor: IndexMap<_, _> = column_references
420            .into_iter()
421            .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
422            .collect();
423
424        let verifier_evaluations = expr.verifier_evaluate(
425            &mut builder,
426            &evaluation_accessor,
427            Some(&result),
428            &chi_eval_map,
429            params,
430        )?;
431        // compute the evaluation of the result MLEs
432        let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
433        // check the evaluation of the result MLEs
434        if verifier_evaluations.column_evals() != result_evaluations {
435            Err(ProofError::VerificationError {
436                error: "result evaluation check failed",
437            })?;
438        }
439
440        // perform the evaluation check of the sumcheck polynomial
441        if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
442            Err(ProofError::VerificationError {
443                error: "sumcheck evaluation check failed",
444            })?;
445        }
446
447        let pcs_proof_evaluations: Vec<_> = self
448            .pcs_proof_evaluations
449            .first_round
450            .iter()
451            .chain(self.pcs_proof_evaluations.column_ref.iter())
452            .chain(self.pcs_proof_evaluations.final_round.iter())
453            .copied()
454            .collect();
455
456        // finally, check the MLE evaluations with the inner product proof
457        self.evaluation_proof
458            .verify_batched_proof(
459                &mut transcript,
460                &pcs_proof_commitments,
461                &evaluation_random_scalars,
462                &pcs_proof_evaluations,
463                &subclaim.evaluation_point,
464                min_row_num as u64,
465                self.first_round_message.range_length,
466                setup,
467            )
468            .map_err(|_e| ProofError::VerificationError {
469                error: "Inner product proof of MLE evaluations failed",
470            })?;
471
472        let verification_hash = transcript.challenge_as_le();
473
474        log::log_memory_usage("End");
475
476        Ok(QueryData {
477            table: result,
478            verification_hash,
479        })
480    }
481}