use super::{
CountBuilder, FinalRoundBuilder, ProofCounts, ProofPlan, ProvableQueryResult, QueryResult,
SumcheckMleEvaluations, SumcheckRandomScalars, VerificationBuilder,
};
use crate::{
base::{
bit::BitDistribution,
commitment::CommitmentEvaluationProof,
database::{Column, CommitmentAccessor, DataAccessor, MetadataAccessor, TableRef},
math::log2_up,
polynomial::{compute_evaluation_vector, CompositePolynomialInfo},
proof::{Keccak256Transcript, ProofError, Transcript},
},
proof_primitive::sumcheck::SumcheckProof,
sql::proof::{FirstRoundBuilder, QueryData},
};
use alloc::{vec, vec::Vec};
use bumpalo::Bump;
use core::cmp;
use num_traits::Zero;
use serde::{Deserialize, Serialize};
fn get_index_range(
accessor: &dyn MetadataAccessor,
table_refs: impl IntoIterator<Item = TableRef>,
) -> (usize, usize) {
table_refs
.into_iter()
.map(|table_ref| {
let length = accessor.get_length(table_ref);
let offset = accessor.get_offset(table_ref);
(offset, offset + length)
})
.reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
.unwrap_or((0, 1))
}
#[derive(Clone, Serialize, Deserialize)]
pub struct QueryProof<CP: CommitmentEvaluationProof> {
pub bit_distributions: Vec<BitDistribution>,
pub commitments: Vec<CP::Commitment>,
pub sumcheck_proof: SumcheckProof<CP::Scalar>,
pub pcs_proof_evaluations: Vec<CP::Scalar>,
pub evaluation_proof: CP,
}
impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
#[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
pub fn new(
expr: &(impl ProofPlan + Serialize),
accessor: &impl DataAccessor<CP::Scalar>,
setup: &CP::ProverPublicSetup<'_>,
) -> (Self, ProvableQueryResult) {
let (min_row_num, max_row_num) = get_index_range(accessor, expr.get_table_references());
let range_length = max_row_num - min_row_num;
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);
let alloc = Bump::new();
let result_cols = expr.result_evaluate(range_length, &alloc, accessor);
let output_length = result_cols.first().map_or(0, Column::len);
let provable_result = ProvableQueryResult::new(output_length as u64, &result_cols);
let mut first_round_builder = FirstRoundBuilder::new();
expr.first_round_evaluate(&mut first_round_builder);
let mut transcript: Keccak256Transcript =
make_transcript(expr, &provable_result, range_length, min_row_num);
let post_result_challenges =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(first_round_builder.num_post_result_challenges())
.collect();
let mut builder =
FinalRoundBuilder::new(range_length, num_sumcheck_variables, post_result_challenges);
expr.final_round_evaluate(&mut builder, &alloc, accessor);
let num_sumcheck_variables = builder.num_sumcheck_variables();
let commitments = builder.commit_intermediate_mles(min_row_num, setup);
extend_transcript(&mut transcript, &commitments, builder.bit_distributions());
let num_random_scalars = num_sumcheck_variables + builder.num_sumcheck_subpolynomials();
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(num_random_scalars)
.collect();
let poly = builder.make_sumcheck_polynomial(&SumcheckRandomScalars::new(
&random_scalars,
range_length,
num_sumcheck_variables,
));
let mut evaluation_point = vec![Zero::zero(); poly.num_variables];
let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly);
let mut evaluation_vec = vec![Zero::zero(); range_length];
compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
let pcs_proof_evaluations = builder.evaluate_pcs_proof_mles(&evaluation_vec);
transcript.extend_canonical_serialize_as_le(&pcs_proof_evaluations);
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(pcs_proof_evaluations.len())
.collect();
let folded_mle = builder.fold_pcs_proof_mles(&random_scalars);
let evaluation_proof = CP::new(
&mut transcript,
&folded_mle,
&evaluation_point,
min_row_num as u64,
setup,
);
let proof = Self {
bit_distributions: builder.bit_distributions().to_vec(),
commitments,
sumcheck_proof,
pcs_proof_evaluations,
evaluation_proof,
};
(proof, provable_result)
}
#[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
pub fn verify(
&self,
expr: &(impl ProofPlan + Serialize),
accessor: &impl CommitmentAccessor<CP::Commitment>,
result: &ProvableQueryResult,
setup: &CP::VerifierPublicSetup<'_>,
) -> QueryResult<CP::Scalar> {
let (min_row_num, max_row_num) = get_index_range(accessor, expr.get_table_references());
let range_length = max_row_num - min_row_num;
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);
let output_length = result.table_length();
for dist in &self.bit_distributions {
if !dist.is_valid() {
Err(ProofError::VerificationError {
error: "invalid bit distributions",
})?;
}
}
let counts = {
let mut builder = CountBuilder::new(&self.bit_distributions);
expr.count(&mut builder)?;
builder.counts()
}?;
if !self.validate_sizes(&counts) {
Err(ProofError::VerificationError {
error: "invalid proof size",
})?;
}
let mut transcript: Keccak256Transcript =
make_transcript(expr, result, range_length, min_row_num);
let post_result_challenges =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(counts.post_result_challenges)
.collect();
extend_transcript(&mut transcript, &self.commitments, &self.bit_distributions);
let num_random_scalars = num_sumcheck_variables + counts.sumcheck_subpolynomials;
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(num_random_scalars)
.collect();
let sumcheck_random_scalars =
SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables);
let poly_info = CompositePolynomialInfo {
max_multiplicands: core::cmp::max(counts.sumcheck_max_multiplicands, 2),
num_variables: num_sumcheck_variables,
};
let subclaim = self.sumcheck_proof.verify_without_evaluation(
&mut transcript,
poly_info,
&Zero::zero(),
)?;
transcript.extend_canonical_serialize_as_le(&self.pcs_proof_evaluations);
let evaluation_random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(self.pcs_proof_evaluations.len())
.collect();
let column_result_fields = expr.get_column_result_fields();
let sumcheck_evaluations = SumcheckMleEvaluations::new(
range_length,
output_length,
&subclaim.evaluation_point,
&sumcheck_random_scalars,
&self.pcs_proof_evaluations,
);
let mut builder = VerificationBuilder::new(
min_row_num,
sumcheck_evaluations,
&self.bit_distributions,
&self.commitments,
sumcheck_random_scalars.subpolynomial_multipliers,
&evaluation_random_scalars,
post_result_challenges,
);
let owned_table_result = result.to_owned_table(&column_result_fields[..])?;
let verifier_evaluations =
expr.verifier_evaluate(&mut builder, accessor, Some(&owned_table_result))?;
let result_evaluations = result.evaluate(
&subclaim.evaluation_point,
output_length,
&column_result_fields[..],
)?;
if verifier_evaluations != result_evaluations {
Err(ProofError::VerificationError {
error: "result evaluation check failed",
})?;
}
if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
Err(ProofError::VerificationError {
error: "sumcheck evaluation check failed",
})?;
}
let product = builder.folded_pcs_proof_evaluation();
self.evaluation_proof
.verify_batched_proof(
&mut transcript,
builder.pcs_proof_commitments(),
builder.inner_product_multipliers(),
&product,
&subclaim.evaluation_point,
min_row_num as u64,
range_length,
setup,
)
.map_err(|_e| ProofError::VerificationError {
error: "Inner product proof of MLE evaluations failed",
})?;
let verification_hash = transcript.challenge_as_le();
Ok(QueryData {
table: owned_table_result,
verification_hash,
})
}
fn validate_sizes(&self, counts: &ProofCounts) -> bool {
self.commitments.len() == counts.intermediate_mles
&& self.pcs_proof_evaluations.len() == counts.intermediate_mles + counts.anchored_mles
}
}
fn make_transcript<T: Transcript>(
expr: &(impl ProofPlan + Serialize),
result: &ProvableQueryResult,
range_length: usize,
min_row_num: usize,
) -> T {
let mut transcript = T::new();
transcript.extend_serialize_as_le(result);
transcript.extend_serialize_as_le(expr);
transcript.extend_serialize_as_le(&range_length);
transcript.extend_serialize_as_le(&min_row_num);
transcript
}
fn extend_transcript<C: serde::Serialize>(
transcript: &mut impl Transcript,
commitments: &C,
bit_distributions: &[BitDistribution],
) {
transcript.extend_serialize_as_le(commitments);
transcript.extend_serialize_as_le(bit_distributions);
}