proof_of_sql/sql/proof/
query_proof.rs

1use super::{
2    make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3    ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4    VerificationBuilderImpl,
5};
6use crate::{
7    base::{
8        bit::BitDistribution,
9        commitment::CommitmentEvaluationProof,
10        database::{
11            ColumnRef, CommitmentAccessor, DataAccessor, MetadataAccessor, OwnedTable, Table,
12            TableRef,
13        },
14        map::{IndexMap, IndexSet},
15        math::log2_up,
16        polynomial::{compute_evaluation_vector, MultilinearExtension},
17        proof::{Keccak256Transcript, ProofError, Transcript},
18    },
19    proof_primitive::sumcheck::SumcheckProof,
20    utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use num_traits::Zero;
26use serde::{Deserialize, Serialize};
27
28/// Return the row number range of tables referenced in the Query
29///
30/// Basically we are looking for the smallest offset and the largest offset + length
31/// so that we have an index range of the table rows that the query is referencing.
32fn get_index_range<'a>(
33    accessor: &dyn MetadataAccessor,
34    table_refs: impl IntoIterator<Item = &'a TableRef>,
35) -> (usize, usize) {
36    table_refs
37        .into_iter()
38        .map(|table_ref| {
39            let length = accessor.get_length(table_ref);
40            let offset = accessor.get_offset(table_ref);
41            (offset, offset + length)
42        })
43        .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
44        // Only applies to `EmptyExec` where there are no tables
45        .unwrap_or((0, 1))
46}
47
48#[derive(Clone, Serialize, Deserialize)]
49pub struct FirstRoundMessage<C> {
50    /// Length of the range of generators we use
51    pub range_length: usize,
52    pub post_result_challenge_count: usize,
53    /// Chi evaluation lengths
54    pub chi_evaluation_lengths: Vec<usize>,
55    /// Rho evaluation lengths
56    pub rho_evaluation_lengths: Vec<usize>,
57    /// First Round Commitments
58    pub round_commitments: Vec<C>,
59}
60
61#[derive(Clone, Serialize, Deserialize)]
62pub struct FinalRoundMessage<C> {
63    pub subpolynomial_constraint_count: usize,
64    /// Final Round Commitments
65    pub round_commitments: Vec<C>,
66    /// Bit distributions
67    pub bit_distributions: Vec<BitDistribution>,
68}
69#[derive(Clone, Serialize, Deserialize)]
70pub struct QueryProofPCSProofEvaluations<S> {
71    /// MLEs used in first round sumcheck except for the result columns
72    pub first_round: Vec<S>,
73    /// evaluations of the columns referenced in the query
74    pub column_ref: Vec<S>,
75    /// MLEs used in final round sumcheck except for the result columns
76    pub final_round: Vec<S>,
77}
78
79/// The proof for a query.
80///
81/// Note: Because the class is deserialized from untrusted data, it
82/// cannot maintain any invariant on its data members; hence, they are
83/// all public so as to allow for easy manipulation for testing.
84#[derive(Clone, Serialize, Deserialize)]
85pub struct QueryProof<CP: CommitmentEvaluationProof> {
86    pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
87    pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
88    /// Sumcheck Proof
89    pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
90    pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
91    /// Inner product proof of the MLEs' evaluations
92    pub(super) evaluation_proof: CP,
93}
94
95impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
96    /// Create a new `QueryProof`.
97    #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
98    pub fn new(
99        expr: &(impl ProofPlan + Serialize),
100        accessor: &impl DataAccessor<CP::Scalar>,
101        setup: &CP::ProverPublicSetup<'_>,
102    ) -> (Self, OwnedTable<CP::Scalar>) {
103        log::log_memory_usage("Start");
104
105        let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
106        let initial_range_length = (max_row_num - min_row_num).max(1);
107        let alloc = Bump::new();
108
109        let total_col_refs = expr.get_column_references();
110        let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
111            .get_table_references()
112            .into_iter()
113            .map(|table_ref| {
114                let col_refs: IndexSet<ColumnRef> = total_col_refs
115                    .iter()
116                    .filter(|col_ref| col_ref.table_ref() == table_ref)
117                    .cloned()
118                    .collect();
119                (table_ref.clone(), accessor.get_table(table_ref, &col_refs))
120            })
121            .collect();
122
123        // Prover First Round: Evaluate the query && get the right number of post result challenges
124        let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
125        let query_result = expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map);
126        let owned_table_result = OwnedTable::from(&query_result);
127        let provable_result = query_result.into();
128        let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
129        let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
130
131        let range_length = first_round_builder.range_length();
132        let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
133        assert!(num_sumcheck_variables > 0);
134        let post_result_challenge_count = first_round_builder.num_post_result_challenges();
135
136        // commit to any intermediate MLEs
137        let first_round_commitments =
138            first_round_builder.commit_intermediate_mles(min_row_num, setup);
139
140        // construct a transcript for the proof
141        let mut transcript: Keccak256Transcript = Transcript::new();
142        transcript.extend_serialize_as_le(expr);
143        transcript.extend_serialize_as_le(&owned_table_result);
144        transcript.extend_serialize_as_le(&min_row_num);
145        transcript.challenge_as_le();
146
147        let first_round_message = FirstRoundMessage {
148            range_length,
149            chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
150            rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
151            post_result_challenge_count,
152            round_commitments: first_round_commitments,
153        };
154        transcript.extend_serialize_as_le(&first_round_message);
155
156        // These are the challenges that will be consumed by the proof
157        // Specifically, these are the challenges that the verifier sends to
158        // the prover after the prover sends the result, but before the prover
159        // send commitments to the intermediate witness columns.
160        // Note: the last challenge in the vec is the first one that is consumed.
161        let post_result_challenges =
162            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
163                .take(post_result_challenge_count)
164                .collect();
165
166        let mut final_round_builder =
167            FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
168
169        expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map);
170
171        let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
172
173        // commit to any intermediate MLEs
174        let final_round_commitments =
175            final_round_builder.commit_intermediate_mles(min_row_num, setup);
176
177        let final_round_message = FinalRoundMessage {
178            subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
179            round_commitments: final_round_commitments,
180            bit_distributions: final_round_builder.bit_distributions().to_vec(),
181        };
182
183        // add the commitments, bit distributions and chi evaluation lengths to the proof
184        transcript.challenge_as_le();
185        transcript.extend_serialize_as_le(&final_round_message);
186
187        // construct the sumcheck polynomial
188        let num_random_scalars =
189            num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
190        let random_scalars: Vec<_> =
191            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
192                .take(num_random_scalars)
193                .collect();
194        let state = make_sumcheck_prover_state(
195            final_round_builder.sumcheck_subpolynomials(),
196            num_sumcheck_variables,
197            &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
198        );
199        transcript.challenge_as_le();
200
201        // create the sumcheck proof -- this is the main part of proving a query
202        let mut evaluation_point = vec![Zero::zero(); state.num_vars];
203        let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
204
205        // evaluate the MLEs used in sumcheck except for the result columns
206        let mut evaluation_vec = vec![Zero::zero(); range_length];
207        compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
208        let first_round_pcs_proof_evaluations =
209            first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
210        let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
211            .iter()
212            .map(|col_ref| {
213                accessor
214                    .get_column(col_ref.clone())
215                    .inner_product(&evaluation_vec)
216            })
217            .collect();
218        let final_round_pcs_proof_evaluations =
219            final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
220
221        // commit to the MLE evaluations
222        let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
223            first_round: first_round_pcs_proof_evaluations,
224            column_ref: column_ref_pcs_proof_evaluations,
225            final_round: final_round_pcs_proof_evaluations,
226        };
227        transcript.extend_serialize_as_le(&pcs_proof_evaluations);
228
229        // fold together the pre result MLEs -- this will form the input to an inner product proof
230        // of their evaluations (fold in this context means create a random linear combination)
231        let random_scalars: Vec<_> =
232            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
233                .take(
234                    pcs_proof_evaluations.first_round.len()
235                        + pcs_proof_evaluations.column_ref.len()
236                        + pcs_proof_evaluations.final_round.len(),
237                )
238                .collect();
239
240        let mut folded_mle = vec![Zero::zero(); range_length];
241        let column_ref_mles: Vec<_> = total_col_refs
242            .into_iter()
243            .map(|c| Box::new(accessor.get_column(c)) as Box<dyn MultilinearExtension<_>>)
244            .collect();
245        for (multiplier, evaluator) in random_scalars.iter().zip(
246            first_round_builder
247                .pcs_proof_mles()
248                .iter()
249                .chain(&column_ref_mles)
250                .chain(final_round_builder.pcs_proof_mles().iter()),
251        ) {
252            evaluator.mul_add(&mut folded_mle, multiplier);
253        }
254
255        // finally, form the inner product proof of the MLEs' evaluations
256        let evaluation_proof = CP::new(
257            &mut transcript,
258            &folded_mle,
259            &evaluation_point,
260            min_row_num as u64,
261            setup,
262        );
263
264        let proof = Self {
265            first_round_message,
266            final_round_message,
267            sumcheck_proof,
268            pcs_proof_evaluations,
269            evaluation_proof,
270        };
271
272        log::log_memory_usage("End");
273
274        (proof, provable_result)
275    }
276
277    #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
278    /// Verify a `QueryProof`. Note: This does NOT transform the result!
279    pub fn verify(
280        self,
281        expr: &(impl ProofPlan + Serialize),
282        accessor: &impl CommitmentAccessor<CP::Commitment>,
283        result: OwnedTable<CP::Scalar>,
284        setup: &CP::VerifierPublicSetup<'_>,
285    ) -> QueryResult<CP::Scalar> {
286        log::log_memory_usage("Start");
287
288        let table_refs = expr.get_table_references();
289        let (min_row_num, _) = get_index_range(accessor, &table_refs);
290        let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
291        assert!(num_sumcheck_variables > 0);
292
293        // validate bit decompositions
294        for dist in &self.final_round_message.bit_distributions {
295            if !dist.is_valid() {
296                Err(ProofError::VerificationError {
297                    error: "invalid bit distributions",
298                })?;
299            } else if !dist.is_within_acceptable_range() {
300                Err(ProofError::VerificationError {
301                    error: "bit distribution outside of acceptable range",
302                })?;
303            }
304        }
305
306        let column_references = expr.get_column_references();
307
308        // construct a transcript for the proof
309        let mut transcript: Keccak256Transcript = Transcript::new();
310        transcript.extend_serialize_as_le(expr);
311        transcript.extend_serialize_as_le(&result);
312        transcript.extend_serialize_as_le(&min_row_num);
313        transcript.challenge_as_le();
314
315        transcript.extend_serialize_as_le(&self.first_round_message);
316
317        // These are the challenges that will be consumed by the proof
318        // Specifically, these are the challenges that the verifier sends to
319        // the prover after the prover sends the result, but before the prover
320        // send commitments to the intermediate witness columns.
321        // Note: the last challenge in the vec is the first one that is consumed.
322        let post_result_challenges =
323            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
324                .take(self.first_round_message.post_result_challenge_count)
325                .collect();
326
327        // add the commitments and bit distributions to the proof
328        transcript.challenge_as_le();
329        transcript.extend_serialize_as_le(&self.final_round_message);
330
331        // draw the random scalars for sumcheck
332        let num_random_scalars =
333            num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
334        let random_scalars: Vec<_> =
335            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
336                .take(num_random_scalars)
337                .collect();
338        let sumcheck_random_scalars = SumcheckRandomScalars::new(
339            &random_scalars,
340            self.first_round_message.range_length,
341            num_sumcheck_variables,
342        );
343        transcript.challenge_as_le();
344
345        // verify sumcheck up to the evaluation check
346        let subclaim = self.sumcheck_proof.verify_without_evaluation(
347            &mut transcript,
348            num_sumcheck_variables,
349            &Zero::zero(),
350        )?;
351
352        // commit to mle evaluations
353        transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
354
355        // draw the random scalars for the evaluation proof
356        // (i.e. the folding/random linear combination of the pcs_proof_mles)
357        let evaluation_random_scalars: Vec<_> =
358            core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
359                .take(
360                    self.pcs_proof_evaluations.first_round.len()
361                        + self.pcs_proof_evaluations.column_ref.len()
362                        + self.pcs_proof_evaluations.final_round.len(),
363                )
364                .collect();
365
366        // Always prepend input lengths to the chi evaluation lengths
367        let table_length_map = table_refs
368            .into_iter()
369            .map(|table_ref| {
370                let len = accessor.get_length(&table_ref);
371                (table_ref, len)
372            })
373            .collect::<IndexMap<TableRef, usize>>();
374
375        let chi_evaluation_lengths = table_length_map
376            .values()
377            .chain(self.first_round_message.chi_evaluation_lengths.iter())
378            .copied();
379
380        // pass over the provable AST to fill in the verification builder
381        let sumcheck_evaluations = SumcheckMleEvaluations::new(
382            self.first_round_message.range_length,
383            chi_evaluation_lengths,
384            self.first_round_message.rho_evaluation_lengths.clone(),
385            &subclaim.evaluation_point,
386            &sumcheck_random_scalars,
387            &self.pcs_proof_evaluations.first_round,
388            &self.pcs_proof_evaluations.final_round,
389        );
390        let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
391            .into_iter()
392            .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
393            .collect();
394        let mut builder = VerificationBuilderImpl::new(
395            sumcheck_evaluations,
396            &self.final_round_message.bit_distributions,
397            sumcheck_random_scalars.subpolynomial_multipliers,
398            post_result_challenges,
399            self.first_round_message.chi_evaluation_lengths.clone(),
400            self.first_round_message.rho_evaluation_lengths.clone(),
401            subclaim.max_multiplicands,
402        );
403
404        let pcs_proof_commitments: Vec<_> = self
405            .first_round_message
406            .round_commitments
407            .iter()
408            .cloned()
409            .chain(
410                column_references
411                    .iter()
412                    .map(|col| accessor.get_commitment(col.clone())),
413            )
414            .chain(self.final_round_message.round_commitments.iter().cloned())
415            .collect();
416        let evaluation_accessor: IndexMap<_, _> = column_references
417            .into_iter()
418            .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
419            .collect();
420
421        let verifier_evaluations = expr.verifier_evaluate(
422            &mut builder,
423            &evaluation_accessor,
424            Some(&result),
425            &chi_eval_map,
426        )?;
427        // compute the evaluation of the result MLEs
428        let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
429        // check the evaluation of the result MLEs
430        if verifier_evaluations.column_evals() != result_evaluations {
431            Err(ProofError::VerificationError {
432                error: "result evaluation check failed",
433            })?;
434        }
435
436        // perform the evaluation check of the sumcheck polynomial
437        if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
438            Err(ProofError::VerificationError {
439                error: "sumcheck evaluation check failed",
440            })?;
441        }
442
443        let pcs_proof_evaluations: Vec<_> = self
444            .pcs_proof_evaluations
445            .first_round
446            .iter()
447            .chain(self.pcs_proof_evaluations.column_ref.iter())
448            .chain(self.pcs_proof_evaluations.final_round.iter())
449            .copied()
450            .collect();
451
452        // finally, check the MLE evaluations with the inner product proof
453        self.evaluation_proof
454            .verify_batched_proof(
455                &mut transcript,
456                &pcs_proof_commitments,
457                &evaluation_random_scalars,
458                &pcs_proof_evaluations,
459                &subclaim.evaluation_point,
460                min_row_num as u64,
461                self.first_round_message.range_length,
462                setup,
463            )
464            .map_err(|_e| ProofError::VerificationError {
465                error: "Inner product proof of MLE evaluations failed",
466            })?;
467
468        let verification_hash = transcript.challenge_as_le();
469
470        log::log_memory_usage("End");
471
472        Ok(QueryData {
473            table: result,
474            verification_hash,
475        })
476    }
477}