1use super::{
2 make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3 ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4 VerificationBuilderImpl,
5};
6use crate::{
7 base::{
8 bit::BitDistribution,
9 commitment::CommitmentEvaluationProof,
10 database::{
11 ColumnRef, CommitmentAccessor, DataAccessor, LiteralValue, MetadataAccessor,
12 OwnedTable, Table, TableRef,
13 },
14 map::{IndexMap, IndexSet},
15 math::log2_up,
16 polynomial::{compute_evaluation_vector, MultilinearExtension},
17 proof::{Keccak256Transcript, ProofError, Transcript},
18 },
19 proof_primitive::sumcheck::SumcheckProof,
20 utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use num_traits::Zero;
26use serde::{Deserialize, Serialize};
27
28fn get_index_range<'a>(
33 accessor: &dyn MetadataAccessor,
34 table_refs: impl IntoIterator<Item = &'a TableRef>,
35) -> (usize, usize) {
36 table_refs
37 .into_iter()
38 .map(|table_ref| {
39 let length = accessor.get_length(table_ref);
40 let offset = accessor.get_offset(table_ref);
41 (offset, offset + length)
42 })
43 .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
44 .unwrap_or((0, 1))
46}
47
48#[derive(Clone, Serialize, Deserialize)]
49pub struct FirstRoundMessage<C> {
50 pub range_length: usize,
52 pub post_result_challenge_count: usize,
53 pub chi_evaluation_lengths: Vec<usize>,
55 pub rho_evaluation_lengths: Vec<usize>,
57 pub round_commitments: Vec<C>,
59}
60
61#[derive(Clone, Serialize, Deserialize)]
62pub struct FinalRoundMessage<C> {
63 pub subpolynomial_constraint_count: usize,
64 pub round_commitments: Vec<C>,
66 pub bit_distributions: Vec<BitDistribution>,
68}
69#[derive(Clone, Serialize, Deserialize)]
70pub struct QueryProofPCSProofEvaluations<S> {
71 pub first_round: Vec<S>,
73 pub column_ref: Vec<S>,
75 pub final_round: Vec<S>,
77}
78
79#[derive(Clone, Serialize, Deserialize)]
85pub struct QueryProof<CP: CommitmentEvaluationProof> {
86 pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
87 pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
88 pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
90 pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
91 pub(super) evaluation_proof: CP,
93}
94
95impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
96 #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
98 pub fn new(
99 expr: &(impl ProofPlan + Serialize),
100 accessor: &impl DataAccessor<CP::Scalar>,
101 setup: &CP::ProverPublicSetup<'_>,
102 params: &[LiteralValue],
103 ) -> (Self, OwnedTable<CP::Scalar>) {
104 log::log_memory_usage("Start");
105
106 let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
107 let initial_range_length = (max_row_num - min_row_num).max(1);
108 let alloc = Bump::new();
109
110 let total_col_refs = expr.get_column_references();
111 let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
112 .get_table_references()
113 .into_iter()
114 .map(|table_ref| {
115 let col_refs: IndexSet<ColumnRef> = total_col_refs
116 .iter()
117 .filter(|col_ref| col_ref.table_ref() == table_ref)
118 .cloned()
119 .collect();
120 (table_ref.clone(), accessor.get_table(table_ref, &col_refs))
121 })
122 .collect();
123
124 let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
126 let query_result =
127 expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map, params);
128 let owned_table_result = OwnedTable::from(&query_result);
129 let provable_result = query_result.into();
130 let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
131 let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
132
133 let range_length = first_round_builder.range_length();
134 let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
135 assert!(num_sumcheck_variables > 0);
136 let post_result_challenge_count = first_round_builder.num_post_result_challenges();
137
138 let first_round_commitments =
140 first_round_builder.commit_intermediate_mles(min_row_num, setup);
141
142 let mut transcript: Keccak256Transcript = Transcript::new();
144 transcript.extend_serialize_as_le(expr);
145 transcript.extend_serialize_as_le(&owned_table_result);
146 transcript.extend_serialize_as_le(&min_row_num);
147 transcript.challenge_as_le();
148
149 let first_round_message = FirstRoundMessage {
150 range_length,
151 chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
152 rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
153 post_result_challenge_count,
154 round_commitments: first_round_commitments,
155 };
156 transcript.extend_serialize_as_le(&first_round_message);
157
158 let post_result_challenges =
164 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
165 .take(post_result_challenge_count)
166 .collect();
167
168 let mut final_round_builder =
169 FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
170
171 expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map, params);
172
173 let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
174
175 let final_round_commitments =
177 final_round_builder.commit_intermediate_mles(min_row_num, setup);
178
179 let final_round_message = FinalRoundMessage {
180 subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
181 round_commitments: final_round_commitments,
182 bit_distributions: final_round_builder.bit_distributions().to_vec(),
183 };
184
185 transcript.challenge_as_le();
187 transcript.extend_serialize_as_le(&final_round_message);
188
189 let num_random_scalars =
191 num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
192 let random_scalars: Vec<_> =
193 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
194 .take(num_random_scalars)
195 .collect();
196 let state = make_sumcheck_prover_state(
197 final_round_builder.sumcheck_subpolynomials(),
198 num_sumcheck_variables,
199 &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
200 );
201 transcript.challenge_as_le();
202
203 let mut evaluation_point = vec![Zero::zero(); state.num_vars];
205 let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
206
207 let mut evaluation_vec = vec![Zero::zero(); range_length];
209 compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
210 let first_round_pcs_proof_evaluations =
211 first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
212 let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
213 .iter()
214 .map(|col_ref| {
215 accessor
216 .get_column(col_ref.clone())
217 .inner_product(&evaluation_vec)
218 })
219 .collect();
220 let final_round_pcs_proof_evaluations =
221 final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
222
223 let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
225 first_round: first_round_pcs_proof_evaluations,
226 column_ref: column_ref_pcs_proof_evaluations,
227 final_round: final_round_pcs_proof_evaluations,
228 };
229 transcript.extend_serialize_as_le(&pcs_proof_evaluations);
230
231 let random_scalars: Vec<_> =
234 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
235 .take(
236 pcs_proof_evaluations.first_round.len()
237 + pcs_proof_evaluations.column_ref.len()
238 + pcs_proof_evaluations.final_round.len(),
239 )
240 .collect();
241
242 let mut folded_mle = vec![Zero::zero(); range_length];
243 let column_ref_mles: Vec<_> = total_col_refs
244 .into_iter()
245 .map(|c| Box::new(accessor.get_column(c)) as Box<dyn MultilinearExtension<_>>)
246 .collect();
247 for (multiplier, evaluator) in random_scalars.iter().zip(
248 first_round_builder
249 .pcs_proof_mles()
250 .iter()
251 .chain(&column_ref_mles)
252 .chain(final_round_builder.pcs_proof_mles().iter()),
253 ) {
254 evaluator.mul_add(&mut folded_mle, multiplier);
255 }
256
257 let evaluation_proof = CP::new(
259 &mut transcript,
260 &folded_mle,
261 &evaluation_point,
262 min_row_num as u64,
263 setup,
264 );
265
266 let proof = Self {
267 first_round_message,
268 final_round_message,
269 sumcheck_proof,
270 pcs_proof_evaluations,
271 evaluation_proof,
272 };
273
274 log::log_memory_usage("End");
275
276 (proof, provable_result)
277 }
278
279 #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
280 pub fn verify(
282 self,
283 expr: &(impl ProofPlan + Serialize),
284 accessor: &impl CommitmentAccessor<CP::Commitment>,
285 result: OwnedTable<CP::Scalar>,
286 setup: &CP::VerifierPublicSetup<'_>,
287 params: &[LiteralValue],
288 ) -> QueryResult<CP::Scalar> {
289 log::log_memory_usage("Start");
290
291 let table_refs = expr.get_table_references();
292 let (min_row_num, _) = get_index_range(accessor, &table_refs);
293 let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
294 assert!(num_sumcheck_variables > 0);
295
296 for dist in &self.final_round_message.bit_distributions {
298 if !dist.is_valid() {
299 Err(ProofError::VerificationError {
300 error: "invalid bit distributions",
301 })?;
302 } else if !dist.is_within_acceptable_range() {
303 Err(ProofError::VerificationError {
304 error: "bit distribution outside of acceptable range",
305 })?;
306 }
307 }
308
309 let column_references = expr.get_column_references();
310
311 let mut transcript: Keccak256Transcript = Transcript::new();
313 transcript.extend_serialize_as_le(expr);
314 transcript.extend_serialize_as_le(&result);
315 transcript.extend_serialize_as_le(&min_row_num);
316 transcript.challenge_as_le();
317
318 transcript.extend_serialize_as_le(&self.first_round_message);
319
320 let post_result_challenges =
326 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
327 .take(self.first_round_message.post_result_challenge_count)
328 .collect();
329
330 transcript.challenge_as_le();
332 transcript.extend_serialize_as_le(&self.final_round_message);
333
334 let num_random_scalars =
336 num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
337 let random_scalars: Vec<_> =
338 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
339 .take(num_random_scalars)
340 .collect();
341 let sumcheck_random_scalars = SumcheckRandomScalars::new(
342 &random_scalars,
343 self.first_round_message.range_length,
344 num_sumcheck_variables,
345 );
346 transcript.challenge_as_le();
347
348 let subclaim = self.sumcheck_proof.verify_without_evaluation(
350 &mut transcript,
351 num_sumcheck_variables,
352 &Zero::zero(),
353 )?;
354
355 transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
357
358 let evaluation_random_scalars: Vec<_> =
361 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
362 .take(
363 self.pcs_proof_evaluations.first_round.len()
364 + self.pcs_proof_evaluations.column_ref.len()
365 + self.pcs_proof_evaluations.final_round.len(),
366 )
367 .collect();
368
369 let table_length_map = table_refs
371 .into_iter()
372 .map(|table_ref| {
373 let len = accessor.get_length(&table_ref);
374 (table_ref, len)
375 })
376 .collect::<IndexMap<TableRef, usize>>();
377
378 let chi_evaluation_lengths = table_length_map
379 .values()
380 .chain(self.first_round_message.chi_evaluation_lengths.iter())
381 .copied();
382
383 let sumcheck_evaluations = SumcheckMleEvaluations::new(
385 self.first_round_message.range_length,
386 chi_evaluation_lengths,
387 self.first_round_message.rho_evaluation_lengths.clone(),
388 &subclaim.evaluation_point,
389 &sumcheck_random_scalars,
390 &self.pcs_proof_evaluations.first_round,
391 &self.pcs_proof_evaluations.final_round,
392 );
393 let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
394 .into_iter()
395 .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
396 .collect();
397 let mut builder = VerificationBuilderImpl::new(
398 sumcheck_evaluations,
399 &self.final_round_message.bit_distributions,
400 sumcheck_random_scalars.subpolynomial_multipliers,
401 post_result_challenges,
402 self.first_round_message.chi_evaluation_lengths.clone(),
403 self.first_round_message.rho_evaluation_lengths.clone(),
404 subclaim.max_multiplicands,
405 );
406
407 let pcs_proof_commitments: Vec<_> = self
408 .first_round_message
409 .round_commitments
410 .iter()
411 .cloned()
412 .chain(
413 column_references
414 .iter()
415 .map(|col| accessor.get_commitment(col.clone())),
416 )
417 .chain(self.final_round_message.round_commitments.iter().cloned())
418 .collect();
419 let evaluation_accessor: IndexMap<_, _> = column_references
420 .into_iter()
421 .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
422 .collect();
423
424 let verifier_evaluations = expr.verifier_evaluate(
425 &mut builder,
426 &evaluation_accessor,
427 Some(&result),
428 &chi_eval_map,
429 params,
430 )?;
431 let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
433 if verifier_evaluations.column_evals() != result_evaluations {
435 Err(ProofError::VerificationError {
436 error: "result evaluation check failed",
437 })?;
438 }
439
440 if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
442 Err(ProofError::VerificationError {
443 error: "sumcheck evaluation check failed",
444 })?;
445 }
446
447 let pcs_proof_evaluations: Vec<_> = self
448 .pcs_proof_evaluations
449 .first_round
450 .iter()
451 .chain(self.pcs_proof_evaluations.column_ref.iter())
452 .chain(self.pcs_proof_evaluations.final_round.iter())
453 .copied()
454 .collect();
455
456 self.evaluation_proof
458 .verify_batched_proof(
459 &mut transcript,
460 &pcs_proof_commitments,
461 &evaluation_random_scalars,
462 &pcs_proof_evaluations,
463 &subclaim.evaluation_point,
464 min_row_num as u64,
465 self.first_round_message.range_length,
466 setup,
467 )
468 .map_err(|_e| ProofError::VerificationError {
469 error: "Inner product proof of MLE evaluations failed",
470 })?;
471
472 let verification_hash = transcript.challenge_as_le();
473
474 log::log_memory_usage("End");
475
476 Ok(QueryData {
477 table: result,
478 verification_hash,
479 })
480 }
481}