1use super::{
2 make_sumcheck_state::make_sumcheck_prover_state, FinalRoundBuilder, FirstRoundBuilder,
3 ProofPlan, QueryData, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars,
4 VerificationBuilderImpl,
5};
6use crate::{
7 base::{
8 bit::BitDistribution,
9 commitment::CommitmentEvaluationProof,
10 database::{
11 ColumnRef, CommitmentAccessor, DataAccessor, MetadataAccessor, OwnedTable, Table,
12 TableRef,
13 },
14 map::{IndexMap, IndexSet},
15 math::log2_up,
16 polynomial::{compute_evaluation_vector, MultilinearExtension},
17 proof::{Keccak256Transcript, ProofError, Transcript},
18 },
19 proof_primitive::sumcheck::SumcheckProof,
20 utils::log,
21};
22use alloc::{boxed::Box, vec, vec::Vec};
23use bumpalo::Bump;
24use core::cmp;
25use num_traits::Zero;
26use serde::{Deserialize, Serialize};
27
28fn get_index_range<'a>(
33 accessor: &dyn MetadataAccessor,
34 table_refs: impl IntoIterator<Item = &'a TableRef>,
35) -> (usize, usize) {
36 table_refs
37 .into_iter()
38 .map(|table_ref| {
39 let length = accessor.get_length(table_ref);
40 let offset = accessor.get_offset(table_ref);
41 (offset, offset + length)
42 })
43 .reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
44 .unwrap_or((0, 1))
46}
47
48#[derive(Clone, Serialize, Deserialize)]
49pub struct FirstRoundMessage<C> {
50 pub range_length: usize,
52 pub post_result_challenge_count: usize,
53 pub chi_evaluation_lengths: Vec<usize>,
55 pub rho_evaluation_lengths: Vec<usize>,
57 pub round_commitments: Vec<C>,
59}
60
61#[derive(Clone, Serialize, Deserialize)]
62pub struct FinalRoundMessage<C> {
63 pub subpolynomial_constraint_count: usize,
64 pub round_commitments: Vec<C>,
66 pub bit_distributions: Vec<BitDistribution>,
68}
69#[derive(Clone, Serialize, Deserialize)]
70pub struct QueryProofPCSProofEvaluations<S> {
71 pub first_round: Vec<S>,
73 pub column_ref: Vec<S>,
75 pub final_round: Vec<S>,
77}
78
79#[derive(Clone, Serialize, Deserialize)]
85pub struct QueryProof<CP: CommitmentEvaluationProof> {
86 pub(super) first_round_message: FirstRoundMessage<CP::Commitment>,
87 pub(super) final_round_message: FinalRoundMessage<CP::Commitment>,
88 pub(super) sumcheck_proof: SumcheckProof<CP::Scalar>,
90 pub(super) pcs_proof_evaluations: QueryProofPCSProofEvaluations<CP::Scalar>,
91 pub(super) evaluation_proof: CP,
93}
94
95impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
96 #[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
98 pub fn new(
99 expr: &(impl ProofPlan + Serialize),
100 accessor: &impl DataAccessor<CP::Scalar>,
101 setup: &CP::ProverPublicSetup<'_>,
102 ) -> (Self, OwnedTable<CP::Scalar>) {
103 log::log_memory_usage("Start");
104
105 let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
106 let initial_range_length = (max_row_num - min_row_num).max(1);
107 let alloc = Bump::new();
108
109 let total_col_refs = expr.get_column_references();
110 let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
111 .get_table_references()
112 .into_iter()
113 .map(|table_ref| {
114 let col_refs: IndexSet<ColumnRef> = total_col_refs
115 .iter()
116 .filter(|col_ref| col_ref.table_ref() == table_ref)
117 .cloned()
118 .collect();
119 (table_ref.clone(), accessor.get_table(table_ref, &col_refs))
120 })
121 .collect();
122
123 let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
125 let query_result = expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map);
126 let owned_table_result = OwnedTable::from(&query_result);
127 let provable_result = query_result.into();
128 let chi_evaluation_lengths = first_round_builder.chi_evaluation_lengths();
129 let rho_evaluation_lengths = first_round_builder.rho_evaluation_lengths();
130
131 let range_length = first_round_builder.range_length();
132 let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
133 assert!(num_sumcheck_variables > 0);
134 let post_result_challenge_count = first_round_builder.num_post_result_challenges();
135
136 let first_round_commitments =
138 first_round_builder.commit_intermediate_mles(min_row_num, setup);
139
140 let mut transcript: Keccak256Transcript = Transcript::new();
142 transcript.extend_serialize_as_le(expr);
143 transcript.extend_serialize_as_le(&owned_table_result);
144 transcript.extend_serialize_as_le(&min_row_num);
145 transcript.challenge_as_le();
146
147 let first_round_message = FirstRoundMessage {
148 range_length,
149 chi_evaluation_lengths: chi_evaluation_lengths.to_vec(),
150 rho_evaluation_lengths: rho_evaluation_lengths.to_vec(),
151 post_result_challenge_count,
152 round_commitments: first_round_commitments,
153 };
154 transcript.extend_serialize_as_le(&first_round_message);
155
156 let post_result_challenges =
162 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
163 .take(post_result_challenge_count)
164 .collect();
165
166 let mut final_round_builder =
167 FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
168
169 expr.final_round_evaluate(&mut final_round_builder, &alloc, &table_map);
170
171 let num_sumcheck_variables = final_round_builder.num_sumcheck_variables();
172
173 let final_round_commitments =
175 final_round_builder.commit_intermediate_mles(min_row_num, setup);
176
177 let final_round_message = FinalRoundMessage {
178 subpolynomial_constraint_count: final_round_builder.num_sumcheck_subpolynomials(),
179 round_commitments: final_round_commitments,
180 bit_distributions: final_round_builder.bit_distributions().to_vec(),
181 };
182
183 transcript.challenge_as_le();
185 transcript.extend_serialize_as_le(&final_round_message);
186
187 let num_random_scalars =
189 num_sumcheck_variables + final_round_message.subpolynomial_constraint_count;
190 let random_scalars: Vec<_> =
191 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
192 .take(num_random_scalars)
193 .collect();
194 let state = make_sumcheck_prover_state(
195 final_round_builder.sumcheck_subpolynomials(),
196 num_sumcheck_variables,
197 &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables),
198 );
199 transcript.challenge_as_le();
200
201 let mut evaluation_point = vec![Zero::zero(); state.num_vars];
203 let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, state);
204
205 let mut evaluation_vec = vec![Zero::zero(); range_length];
207 compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
208 let first_round_pcs_proof_evaluations =
209 first_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
210 let column_ref_pcs_proof_evaluations: Vec<_> = total_col_refs
211 .iter()
212 .map(|col_ref| {
213 accessor
214 .get_column(col_ref.clone())
215 .inner_product(&evaluation_vec)
216 })
217 .collect();
218 let final_round_pcs_proof_evaluations =
219 final_round_builder.evaluate_pcs_proof_mles(&evaluation_vec);
220
221 let pcs_proof_evaluations = QueryProofPCSProofEvaluations {
223 first_round: first_round_pcs_proof_evaluations,
224 column_ref: column_ref_pcs_proof_evaluations,
225 final_round: final_round_pcs_proof_evaluations,
226 };
227 transcript.extend_serialize_as_le(&pcs_proof_evaluations);
228
229 let random_scalars: Vec<_> =
232 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
233 .take(
234 pcs_proof_evaluations.first_round.len()
235 + pcs_proof_evaluations.column_ref.len()
236 + pcs_proof_evaluations.final_round.len(),
237 )
238 .collect();
239
240 let mut folded_mle = vec![Zero::zero(); range_length];
241 let column_ref_mles: Vec<_> = total_col_refs
242 .into_iter()
243 .map(|c| Box::new(accessor.get_column(c)) as Box<dyn MultilinearExtension<_>>)
244 .collect();
245 for (multiplier, evaluator) in random_scalars.iter().zip(
246 first_round_builder
247 .pcs_proof_mles()
248 .iter()
249 .chain(&column_ref_mles)
250 .chain(final_round_builder.pcs_proof_mles().iter()),
251 ) {
252 evaluator.mul_add(&mut folded_mle, multiplier);
253 }
254
255 let evaluation_proof = CP::new(
257 &mut transcript,
258 &folded_mle,
259 &evaluation_point,
260 min_row_num as u64,
261 setup,
262 );
263
264 let proof = Self {
265 first_round_message,
266 final_round_message,
267 sumcheck_proof,
268 pcs_proof_evaluations,
269 evaluation_proof,
270 };
271
272 log::log_memory_usage("End");
273
274 (proof, provable_result)
275 }
276
277 #[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
278 pub fn verify(
280 self,
281 expr: &(impl ProofPlan + Serialize),
282 accessor: &impl CommitmentAccessor<CP::Commitment>,
283 result: OwnedTable<CP::Scalar>,
284 setup: &CP::VerifierPublicSetup<'_>,
285 ) -> QueryResult<CP::Scalar> {
286 log::log_memory_usage("Start");
287
288 let table_refs = expr.get_table_references();
289 let (min_row_num, _) = get_index_range(accessor, &table_refs);
290 let num_sumcheck_variables = cmp::max(log2_up(self.first_round_message.range_length), 1);
291 assert!(num_sumcheck_variables > 0);
292
293 for dist in &self.final_round_message.bit_distributions {
295 if !dist.is_valid() {
296 Err(ProofError::VerificationError {
297 error: "invalid bit distributions",
298 })?;
299 } else if !dist.is_within_acceptable_range() {
300 Err(ProofError::VerificationError {
301 error: "bit distribution outside of acceptable range",
302 })?;
303 }
304 }
305
306 let column_references = expr.get_column_references();
307
308 let mut transcript: Keccak256Transcript = Transcript::new();
310 transcript.extend_serialize_as_le(expr);
311 transcript.extend_serialize_as_le(&result);
312 transcript.extend_serialize_as_le(&min_row_num);
313 transcript.challenge_as_le();
314
315 transcript.extend_serialize_as_le(&self.first_round_message);
316
317 let post_result_challenges =
323 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
324 .take(self.first_round_message.post_result_challenge_count)
325 .collect();
326
327 transcript.challenge_as_le();
329 transcript.extend_serialize_as_le(&self.final_round_message);
330
331 let num_random_scalars =
333 num_sumcheck_variables + self.final_round_message.subpolynomial_constraint_count;
334 let random_scalars: Vec<_> =
335 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
336 .take(num_random_scalars)
337 .collect();
338 let sumcheck_random_scalars = SumcheckRandomScalars::new(
339 &random_scalars,
340 self.first_round_message.range_length,
341 num_sumcheck_variables,
342 );
343 transcript.challenge_as_le();
344
345 let subclaim = self.sumcheck_proof.verify_without_evaluation(
347 &mut transcript,
348 num_sumcheck_variables,
349 &Zero::zero(),
350 )?;
351
352 transcript.extend_serialize_as_le(&self.pcs_proof_evaluations);
354
355 let evaluation_random_scalars: Vec<_> =
358 core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
359 .take(
360 self.pcs_proof_evaluations.first_round.len()
361 + self.pcs_proof_evaluations.column_ref.len()
362 + self.pcs_proof_evaluations.final_round.len(),
363 )
364 .collect();
365
366 let table_length_map = table_refs
368 .into_iter()
369 .map(|table_ref| {
370 let len = accessor.get_length(&table_ref);
371 (table_ref, len)
372 })
373 .collect::<IndexMap<TableRef, usize>>();
374
375 let chi_evaluation_lengths = table_length_map
376 .values()
377 .chain(self.first_round_message.chi_evaluation_lengths.iter())
378 .copied();
379
380 let sumcheck_evaluations = SumcheckMleEvaluations::new(
382 self.first_round_message.range_length,
383 chi_evaluation_lengths,
384 self.first_round_message.rho_evaluation_lengths.clone(),
385 &subclaim.evaluation_point,
386 &sumcheck_random_scalars,
387 &self.pcs_proof_evaluations.first_round,
388 &self.pcs_proof_evaluations.final_round,
389 );
390 let chi_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
391 .into_iter()
392 .map(|(table_ref, length)| (table_ref, sumcheck_evaluations.chi_evaluations[&length]))
393 .collect();
394 let mut builder = VerificationBuilderImpl::new(
395 sumcheck_evaluations,
396 &self.final_round_message.bit_distributions,
397 sumcheck_random_scalars.subpolynomial_multipliers,
398 post_result_challenges,
399 self.first_round_message.chi_evaluation_lengths.clone(),
400 self.first_round_message.rho_evaluation_lengths.clone(),
401 subclaim.max_multiplicands,
402 );
403
404 let pcs_proof_commitments: Vec<_> = self
405 .first_round_message
406 .round_commitments
407 .iter()
408 .cloned()
409 .chain(
410 column_references
411 .iter()
412 .map(|col| accessor.get_commitment(col.clone())),
413 )
414 .chain(self.final_round_message.round_commitments.iter().cloned())
415 .collect();
416 let evaluation_accessor: IndexMap<_, _> = column_references
417 .into_iter()
418 .zip(self.pcs_proof_evaluations.column_ref.iter().copied())
419 .collect();
420
421 let verifier_evaluations = expr.verifier_evaluate(
422 &mut builder,
423 &evaluation_accessor,
424 Some(&result),
425 &chi_eval_map,
426 )?;
427 let result_evaluations = result.mle_evaluations(&subclaim.evaluation_point);
429 if verifier_evaluations.column_evals() != result_evaluations {
431 Err(ProofError::VerificationError {
432 error: "result evaluation check failed",
433 })?;
434 }
435
436 if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
438 Err(ProofError::VerificationError {
439 error: "sumcheck evaluation check failed",
440 })?;
441 }
442
443 let pcs_proof_evaluations: Vec<_> = self
444 .pcs_proof_evaluations
445 .first_round
446 .iter()
447 .chain(self.pcs_proof_evaluations.column_ref.iter())
448 .chain(self.pcs_proof_evaluations.final_round.iter())
449 .copied()
450 .collect();
451
452 self.evaluation_proof
454 .verify_batched_proof(
455 &mut transcript,
456 &pcs_proof_commitments,
457 &evaluation_random_scalars,
458 &pcs_proof_evaluations,
459 &subclaim.evaluation_point,
460 min_row_num as u64,
461 self.first_round_message.range_length,
462 setup,
463 )
464 .map_err(|_e| ProofError::VerificationError {
465 error: "Inner product proof of MLE evaluations failed",
466 })?;
467
468 let verification_hash = transcript.challenge_as_le();
469
470 log::log_memory_usage("End");
471
472 Ok(QueryData {
473 table: result,
474 verification_hash,
475 })
476 }
477}