use super::{
CountBuilder, ProofBuilder, ProofCounts, ProofPlan, ProvableQueryResult, QueryResult,
SumcheckMleEvaluations, SumcheckRandomScalars, VerificationBuilder,
};
use crate::{
base::{
bit::BitDistribution,
commitment::{Commitment, CommitmentEvaluationProof},
database::{CommitmentAccessor, DataAccessor},
math::log2_up,
polynomial::{compute_evaluation_vector, CompositePolynomialInfo},
proof::{Keccak256Transcript, ProofError, Transcript},
},
proof_primitive::sumcheck::SumcheckProof,
sql::proof::{QueryData, ResultBuilder},
};
use alloc::{vec, vec::Vec};
use bumpalo::Bump;
use core::cmp;
use num_traits::Zero;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct QueryProof<CP: CommitmentEvaluationProof> {
pub bit_distributions: Vec<BitDistribution>,
pub commitments: Vec<CP::Commitment>,
pub sumcheck_proof: SumcheckProof<CP::Scalar>,
pub pcs_proof_evaluations: Vec<CP::Scalar>,
pub evaluation_proof: CP,
}
impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
#[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
pub fn new(
expr: &(impl ProofPlan<CP::Commitment> + Serialize),
accessor: &impl DataAccessor<CP::Scalar>,
setup: &CP::ProverPublicSetup<'_>,
) -> (Self, ProvableQueryResult) {
let table_length = expr.get_length(accessor);
let num_sumcheck_variables = cmp::max(log2_up(table_length), 1);
let generator_offset = expr.get_offset(accessor);
assert!(num_sumcheck_variables > 0);
let alloc = Bump::new();
let mut result_builder = ResultBuilder::new(table_length);
let result_cols = expr.result_evaluate(&mut result_builder, &alloc, accessor);
let provable_result =
ProvableQueryResult::new(&result_builder.result_index_vector, &result_cols);
let mut transcript: Keccak256Transcript =
make_transcript(expr, &provable_result, table_length, generator_offset);
let post_result_challenges =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(result_builder.num_post_result_challenges())
.collect();
let mut builder =
ProofBuilder::new(table_length, num_sumcheck_variables, post_result_challenges);
expr.prover_evaluate(&mut builder, &alloc, accessor);
let num_sumcheck_variables = builder.num_sumcheck_variables();
let table_length = builder.table_length();
let commitments = builder.commit_intermediate_mles(generator_offset, setup);
extend_transcript(&mut transcript, &commitments, builder.bit_distributions());
let num_random_scalars = num_sumcheck_variables + builder.num_sumcheck_subpolynomials();
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(num_random_scalars)
.collect();
let poly = builder.make_sumcheck_polynomial(&SumcheckRandomScalars::new(
&random_scalars,
table_length,
num_sumcheck_variables,
));
let mut evaluation_point = vec![Zero::zero(); poly.num_variables];
let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly);
let mut evaluation_vec = vec![Zero::zero(); table_length];
compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
let pcs_proof_evaluations = builder.evaluate_pcs_proof_mles(&evaluation_vec);
transcript.extend_canonical_serialize_as_le(&pcs_proof_evaluations);
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(pcs_proof_evaluations.len())
.collect();
let folded_mle = builder.fold_pcs_proof_mles(&random_scalars);
let evaluation_proof = CP::new(
&mut transcript,
&folded_mle,
&evaluation_point,
generator_offset as u64,
setup,
);
let proof = Self {
bit_distributions: builder.bit_distributions().to_vec(),
commitments,
sumcheck_proof,
pcs_proof_evaluations,
evaluation_proof,
};
(proof, provable_result)
}
#[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
pub fn verify(
&self,
expr: &(impl ProofPlan<CP::Commitment> + Serialize),
accessor: &impl CommitmentAccessor<CP::Commitment>,
result: &ProvableQueryResult,
setup: &CP::VerifierPublicSetup<'_>,
) -> QueryResult<CP::Scalar> {
let table_length = expr.get_length(accessor);
let generator_offset = expr.get_offset(accessor);
let num_sumcheck_variables = cmp::max(log2_up(table_length), 1);
assert!(num_sumcheck_variables > 0);
for dist in self.bit_distributions.iter() {
if !dist.is_valid() {
Err(ProofError::VerificationError {
error: "invalid bit distributions",
})?;
}
}
let counts = {
let mut builder = CountBuilder::new(&self.bit_distributions);
expr.count(&mut builder, accessor)?;
builder.counts()
}?;
if !self.validate_sizes(&counts, result) {
Err(ProofError::VerificationError {
error: "invalid proof size",
})?;
}
let mut transcript: Keccak256Transcript =
make_transcript(expr, result, table_length, generator_offset);
let post_result_challenges =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(counts.post_result_challenges)
.collect();
extend_transcript(&mut transcript, &self.commitments, &self.bit_distributions);
let num_random_scalars = num_sumcheck_variables + counts.sumcheck_subpolynomials;
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(num_random_scalars)
.collect();
let sumcheck_random_scalars =
SumcheckRandomScalars::new(&random_scalars, table_length, num_sumcheck_variables);
let poly_info = CompositePolynomialInfo {
max_multiplicands: core::cmp::max(counts.sumcheck_max_multiplicands, 2),
num_variables: num_sumcheck_variables,
};
let subclaim = self.sumcheck_proof.verify_without_evaluation(
&mut transcript,
poly_info,
&Zero::zero(),
)?;
transcript.extend_canonical_serialize_as_le(&self.pcs_proof_evaluations);
let evaluation_random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(self.pcs_proof_evaluations.len())
.collect();
let column_result_fields = expr.get_column_result_fields();
let result_evaluations = result.evaluate(
&subclaim.evaluation_point,
table_length,
&column_result_fields[..],
)?;
let sumcheck_evaluations = SumcheckMleEvaluations::new(
table_length,
&subclaim.evaluation_point,
&sumcheck_random_scalars,
&self.pcs_proof_evaluations,
&result_evaluations,
result.indexes(),
);
let mut builder = VerificationBuilder::new(
generator_offset,
sumcheck_evaluations,
&self.bit_distributions,
&self.commitments,
sumcheck_random_scalars.subpolynomial_multipliers,
&evaluation_random_scalars,
post_result_challenges,
);
let owned_table_result = result.to_owned_table(&column_result_fields[..])?;
expr.verifier_evaluate(&mut builder, accessor, Some(&owned_table_result))?;
if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
Err(ProofError::VerificationError {
error: "sumcheck evaluation check failed",
})?;
}
let product = builder.folded_pcs_proof_evaluation();
self.evaluation_proof
.verify_batched_proof(
&mut transcript,
builder.pcs_proof_commitments(),
builder.inner_product_multipliers(),
&product,
&subclaim.evaluation_point,
generator_offset as u64,
table_length,
setup,
)
.map_err(|_e| ProofError::VerificationError {
error: "Inner product proof of MLE evaluations failed",
})?;
let verification_hash = transcript.challenge_as_le();
Ok(QueryData {
table: owned_table_result,
verification_hash,
})
}
fn validate_sizes(&self, counts: &ProofCounts, result: &ProvableQueryResult) -> bool {
result.num_columns() == counts.result_columns
&& self.commitments.len() == counts.intermediate_mles
&& self.pcs_proof_evaluations.len() == counts.intermediate_mles + counts.anchored_mles
}
}
fn make_transcript<C: Commitment, T: Transcript>(
expr: &(impl ProofPlan<C> + Serialize),
result: &ProvableQueryResult,
table_length: usize,
generator_offset: usize,
) -> T {
let mut transcript = T::new();
transcript.extend_serialize_as_le(result);
transcript.extend_serialize_as_le(expr);
transcript.extend_serialize_as_le(&table_length);
transcript.extend_serialize_as_le(&generator_offset);
transcript
}
fn extend_transcript<C: serde::Serialize>(
transcript: &mut impl Transcript,
commitments: &C,
bit_distributions: &[BitDistribution],
) {
transcript.extend_serialize_as_le(commitments);
transcript.extend_serialize_as_le(bit_distributions);
}