use super::{
CountBuilder, FinalRoundBuilder, ProofCounts, ProofPlan, ProvableQueryResult, QueryResult,
SumcheckMleEvaluations, SumcheckRandomScalars, VerificationBuilder,
};
use crate::{
base::{
bit::BitDistribution,
commitment::CommitmentEvaluationProof,
database::{
ColumnRef, CommitmentAccessor, DataAccessor, MetadataAccessor, Table, TableRef,
},
map::{IndexMap, IndexSet},
math::log2_up,
polynomial::{compute_evaluation_vector, CompositePolynomialInfo},
proof::{Keccak256Transcript, ProofError, Transcript},
},
proof_primitive::sumcheck::SumcheckProof,
sql::proof::{FirstRoundBuilder, QueryData},
};
use alloc::{vec, vec::Vec};
use bumpalo::Bump;
use core::cmp;
use num_traits::Zero;
use serde::{Deserialize, Serialize};
fn get_index_range<'a>(
accessor: &dyn MetadataAccessor,
table_refs: impl IntoIterator<Item = &'a TableRef>,
) -> (usize, usize) {
table_refs
.into_iter()
.map(|table_ref| {
let length = accessor.get_length(*table_ref);
let offset = accessor.get_offset(*table_ref);
(offset, offset + length)
})
.reduce(|(min_start, max_end), (start, end)| (min_start.min(start), max_end.max(end)))
.unwrap_or((0, 1))
}
#[derive(Clone, Serialize, Deserialize)]
pub struct QueryProof<CP: CommitmentEvaluationProof> {
pub bit_distributions: Vec<BitDistribution>,
pub one_evaluation_lengths: Vec<usize>,
pub commitments: Vec<CP::Commitment>,
pub sumcheck_proof: SumcheckProof<CP::Scalar>,
pub pcs_proof_evaluations: Vec<CP::Scalar>,
pub evaluation_proof: CP,
pub range_length: usize,
}
impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
#[tracing::instrument(name = "QueryProof::new", level = "debug", skip_all)]
pub fn new(
expr: &(impl ProofPlan + Serialize),
accessor: &impl DataAccessor<CP::Scalar>,
setup: &CP::ProverPublicSetup<'_>,
) -> (Self, ProvableQueryResult) {
let (min_row_num, max_row_num) = get_index_range(accessor, &expr.get_table_references());
let initial_range_length = max_row_num - min_row_num;
let alloc = Bump::new();
let total_col_refs = expr.get_column_references();
let table_map: IndexMap<TableRef, Table<CP::Scalar>> = expr
.get_table_references()
.into_iter()
.map(|table_ref| {
let col_refs: IndexSet<ColumnRef> = total_col_refs
.iter()
.filter(|col_ref| col_ref.table_ref() == table_ref)
.copied()
.collect();
(table_ref, accessor.get_table(table_ref, &col_refs))
})
.collect();
let (query_result, one_evaluation_lengths) = expr.result_evaluate(&alloc, &table_map);
let provable_result = query_result.into();
let mut first_round_builder = FirstRoundBuilder::new();
expr.first_round_evaluate(&mut first_round_builder);
let range_length = one_evaluation_lengths
.iter()
.copied()
.chain(core::iter::once(initial_range_length))
.max()
.expect("Will always have at least one element"); let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);
let mut transcript: Keccak256Transcript = make_transcript(
expr,
&provable_result,
range_length,
min_row_num,
&one_evaluation_lengths,
);
let post_result_challenges =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(first_round_builder.num_post_result_challenges())
.collect();
let mut builder = FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);
for col_ref in total_col_refs {
builder.produce_anchored_mle(accessor.get_column(col_ref));
}
expr.final_round_evaluate(&mut builder, &alloc, &table_map);
let num_sumcheck_variables = builder.num_sumcheck_variables();
let commitments = builder.commit_intermediate_mles(min_row_num, setup);
extend_transcript(&mut transcript, &commitments, builder.bit_distributions());
let num_random_scalars = num_sumcheck_variables + builder.num_sumcheck_subpolynomials();
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(num_random_scalars)
.collect();
let poly = builder.make_sumcheck_polynomial(&SumcheckRandomScalars::new(
&random_scalars,
range_length,
num_sumcheck_variables,
));
let mut evaluation_point = vec![Zero::zero(); poly.num_variables];
let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly);
let mut evaluation_vec = vec![Zero::zero(); range_length];
compute_evaluation_vector(&mut evaluation_vec, &evaluation_point);
let pcs_proof_evaluations = builder.evaluate_pcs_proof_mles(&evaluation_vec);
transcript.extend_canonical_serialize_as_le(&pcs_proof_evaluations);
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(pcs_proof_evaluations.len())
.collect();
assert_eq!(random_scalars.len(), builder.pcs_proof_mles().len());
let mut folded_mle = vec![Zero::zero(); range_length];
for (multiplier, evaluator) in random_scalars.iter().zip(builder.pcs_proof_mles().iter()) {
evaluator.mul_add(&mut folded_mle, multiplier);
}
let evaluation_proof = CP::new(
&mut transcript,
&folded_mle,
&evaluation_point,
min_row_num as u64,
setup,
);
let proof = Self {
bit_distributions: builder.bit_distributions().to_vec(),
one_evaluation_lengths,
commitments,
sumcheck_proof,
pcs_proof_evaluations,
evaluation_proof,
range_length,
};
(proof, provable_result)
}
#[tracing::instrument(name = "QueryProof::verify", level = "debug", skip_all, err)]
pub fn verify(
&self,
expr: &(impl ProofPlan + Serialize),
accessor: &impl CommitmentAccessor<CP::Commitment>,
result: &ProvableQueryResult,
setup: &CP::VerifierPublicSetup<'_>,
) -> QueryResult<CP::Scalar> {
let owned_table_result = result.to_owned_table(&expr.get_column_result_fields())?;
let table_refs = expr.get_table_references();
let (min_row_num, _) = get_index_range(accessor, &table_refs);
let num_sumcheck_variables = cmp::max(log2_up(self.range_length), 1);
assert!(num_sumcheck_variables > 0);
for dist in &self.bit_distributions {
if !dist.is_valid() {
Err(ProofError::VerificationError {
error: "invalid bit distributions",
})?;
}
}
let column_references = expr.get_column_references();
let mut builder = CountBuilder::new(&self.bit_distributions);
builder.count_anchored_mles(column_references.len());
expr.count(&mut builder)?;
let counts = builder.counts()?;
if !self.validate_sizes(&counts) {
Err(ProofError::VerificationError {
error: "invalid proof size",
})?;
}
let mut transcript: Keccak256Transcript = make_transcript(
expr,
result,
self.range_length,
min_row_num,
&self.one_evaluation_lengths,
);
let post_result_challenges =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(counts.post_result_challenges)
.collect();
extend_transcript(&mut transcript, &self.commitments, &self.bit_distributions);
let num_random_scalars = num_sumcheck_variables + counts.sumcheck_subpolynomials;
let random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(num_random_scalars)
.collect();
let sumcheck_random_scalars =
SumcheckRandomScalars::new(&random_scalars, self.range_length, num_sumcheck_variables);
let poly_info = CompositePolynomialInfo {
max_multiplicands: core::cmp::max(counts.sumcheck_max_multiplicands, 2),
num_variables: num_sumcheck_variables,
};
let subclaim = self.sumcheck_proof.verify_without_evaluation(
&mut transcript,
poly_info,
&Zero::zero(),
)?;
transcript.extend_canonical_serialize_as_le(&self.pcs_proof_evaluations);
let evaluation_random_scalars: Vec<_> =
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(self.pcs_proof_evaluations.len())
.collect();
let table_length_map = table_refs
.iter()
.map(|table_ref| (table_ref, accessor.get_length(*table_ref)))
.collect::<IndexMap<_, _>>();
let one_evaluation_lengths = table_length_map
.values()
.chain(self.one_evaluation_lengths.iter())
.copied();
let sumcheck_evaluations = SumcheckMleEvaluations::new(
self.range_length,
one_evaluation_lengths,
&subclaim.evaluation_point,
&sumcheck_random_scalars,
&self.pcs_proof_evaluations,
);
let one_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
.iter()
.map(|(table_ref, length)| (**table_ref, sumcheck_evaluations.one_evaluations[length]))
.collect();
let mut builder = VerificationBuilder::new(
min_row_num,
sumcheck_evaluations,
&self.bit_distributions,
sumcheck_random_scalars.subpolynomial_multipliers,
&evaluation_random_scalars,
post_result_challenges,
self.one_evaluation_lengths.clone(),
);
let pcs_proof_commitments: Vec<_> = column_references
.iter()
.map(|col| accessor.get_commitment(*col))
.chain(self.commitments.iter().cloned())
.collect();
let evaluation_accessor: IndexMap<_, _> = column_references
.into_iter()
.map(|col| (col, builder.consume_anchored_mle()))
.collect();
let verifier_evaluations = expr.verifier_evaluate(
&mut builder,
&evaluation_accessor,
Some(&owned_table_result),
&one_eval_map,
)?;
let result_evaluations = owned_table_result.mle_evaluations(&subclaim.evaluation_point);
if verifier_evaluations.column_evals() != result_evaluations {
Err(ProofError::VerificationError {
error: "result evaluation check failed",
})?;
}
if builder.sumcheck_evaluation() != subclaim.expected_evaluation {
Err(ProofError::VerificationError {
error: "sumcheck evaluation check failed",
})?;
}
let product = builder.folded_pcs_proof_evaluation();
self.evaluation_proof
.verify_batched_proof(
&mut transcript,
&pcs_proof_commitments,
builder.inner_product_multipliers(),
&product,
&subclaim.evaluation_point,
min_row_num as u64,
self.range_length,
setup,
)
.map_err(|_e| ProofError::VerificationError {
error: "Inner product proof of MLE evaluations failed",
})?;
let verification_hash = transcript.challenge_as_le();
Ok(QueryData {
table: owned_table_result,
verification_hash,
})
}
fn validate_sizes(&self, counts: &ProofCounts) -> bool {
self.commitments.len() == counts.intermediate_mles
&& self.pcs_proof_evaluations.len() == counts.intermediate_mles + counts.anchored_mles
}
}
fn make_transcript<T: Transcript>(
expr: &(impl ProofPlan + Serialize),
result: &ProvableQueryResult,
range_length: usize,
min_row_num: usize,
one_evaluation_lengths: &[usize],
) -> T {
let mut transcript = T::new();
transcript.extend_serialize_as_le(result);
transcript.extend_serialize_as_le(expr);
transcript.extend_serialize_as_le(&range_length);
transcript.extend_serialize_as_le(&min_row_num);
transcript.extend_serialize_as_le(one_evaluation_lengths);
transcript
}
fn extend_transcript<C: serde::Serialize>(
transcript: &mut impl Transcript,
commitments: &C,
bit_distributions: &[BitDistribution],
) {
transcript.extend_serialize_as_le(commitments);
transcript.extend_serialize_as_le(bit_distributions);
}