use crate::VerifierError;
use air::{
proof::{Queries, StarkProof, Table},
Air, EvaluationFrame,
};
use crypto::{BatchMerkleProof, ElementHasher, MerkleTree};
use fri::VerifierChannel as FriVerifierChannel;
use math::{FieldElement, StarkField};
use utils::{collections::Vec, string::ToString};
pub struct VerifierChannel<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> {
trace_roots: Vec<H::Digest>,
trace_queries: Option<TraceQueries<E, H>>,
constraint_root: H::Digest,
constraint_queries: Option<ConstraintQueries<E, H>>,
fri_roots: Option<Vec<H::Digest>>,
fri_layer_proofs: Vec<BatchMerkleProof<H>>,
fri_layer_queries: Vec<Vec<E>>,
fri_remainder: Option<Vec<E>>,
fri_num_partitions: usize,
ood_trace_frame: Option<TraceOodFrame<E>>,
ood_constraint_evaluations: Option<Vec<E>>,
pow_nonce: u64,
}
impl<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> VerifierChannel<E, H> {
pub fn new<A: Air<BaseField = E::BaseField>>(
air: &A,
proof: StarkProof,
) -> Result<Self, VerifierError> {
let StarkProof {
context,
commitments,
trace_queries,
constraint_queries,
ood_frame,
fri_proof,
pow_nonce,
} = proof;
if E::BaseField::get_modulus_le_bytes() != context.field_modulus_bytes() {
return Err(VerifierError::InconsistentBaseField);
}
let num_trace_segments = air.trace_layout().num_segments();
let main_trace_width = air.trace_layout().main_trace_width();
let aux_trace_width = air.trace_layout().aux_trace_width();
let lde_domain_size = air.lde_domain_size();
let fri_options = air.options().to_fri_options();
let (trace_roots, constraint_root, fri_roots) = commitments
.parse::<H>(
num_trace_segments,
fri_options.num_fri_layers(lde_domain_size),
)
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let trace_queries = TraceQueries::new(trace_queries, air)?;
let constraint_queries = ConstraintQueries::new(constraint_queries, air)?;
let fri_num_partitions = fri_proof.num_partitions();
let fri_remainder = fri_proof
.parse_remainder()
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let (fri_layer_queries, fri_layer_proofs) = fri_proof
.parse_layers::<H, E>(lde_domain_size, fri_options.folding_factor())
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let (ood_main_trace_frame, ood_aux_trace_frame, ood_constraint_evaluations) = ood_frame
.parse(main_trace_width, aux_trace_width, air.ce_blowup_factor())
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let ood_trace_frame = TraceOodFrame::new(ood_main_trace_frame, ood_aux_trace_frame);
Ok(VerifierChannel {
trace_roots,
trace_queries: Some(trace_queries),
constraint_root,
constraint_queries: Some(constraint_queries),
fri_roots: Some(fri_roots),
fri_layer_proofs,
fri_layer_queries,
fri_remainder: Some(fri_remainder),
fri_num_partitions,
ood_trace_frame: Some(ood_trace_frame),
ood_constraint_evaluations: Some(ood_constraint_evaluations),
pow_nonce,
})
}
pub fn read_trace_commitments(&self) -> &[H::Digest] {
&self.trace_roots
}
pub fn read_constraint_commitment(&self) -> H::Digest {
self.constraint_root
}
pub fn read_ood_trace_frame(&mut self) -> (EvaluationFrame<E>, Option<EvaluationFrame<E>>) {
let frame = self.ood_trace_frame.take().expect("already read");
(frame.main_frame, frame.aux_frame)
}
pub fn read_ood_constraint_evaluations(&mut self) -> Vec<E> {
self.ood_constraint_evaluations
.take()
.expect("already read")
}
pub fn read_pow_nonce(&self) -> u64 {
self.pow_nonce
}
#[allow(clippy::type_complexity)]
pub fn read_queried_trace_states(
&mut self,
positions: &[usize],
) -> Result<(Table<E::BaseField>, Option<Table<E>>), VerifierError> {
let queries = self.trace_queries.take().expect("already read");
for (root, proof) in self.trace_roots.iter().zip(queries.query_proofs.iter()) {
MerkleTree::verify_batch(root, positions, proof)
.map_err(|_| VerifierError::TraceQueryDoesNotMatchCommitment)?;
}
Ok((queries.main_states, queries.aux_states))
}
pub fn read_constraint_evaluations(
&mut self,
positions: &[usize],
) -> Result<Table<E>, VerifierError> {
let queries = self.constraint_queries.take().expect("already read");
MerkleTree::verify_batch(&self.constraint_root, positions, &queries.query_proofs)
.map_err(|_| VerifierError::ConstraintQueryDoesNotMatchCommitment)?;
Ok(queries.evaluations)
}
}
impl<E, H> FriVerifierChannel<E> for VerifierChannel<E, H>
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
{
type Hasher = H;
fn read_fri_num_partitions(&self) -> usize {
self.fri_num_partitions
}
fn read_fri_layer_commitments(&mut self) -> Vec<H::Digest> {
self.fri_roots.take().expect("already read")
}
fn take_next_fri_layer_proof(&mut self) -> BatchMerkleProof<H> {
self.fri_layer_proofs.remove(0)
}
fn take_next_fri_layer_queries(&mut self) -> Vec<E> {
self.fri_layer_queries.remove(0)
}
fn take_fri_remainder(&mut self) -> Vec<E> {
self.fri_remainder.take().expect("already read")
}
}
struct TraceQueries<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> {
query_proofs: Vec<BatchMerkleProof<H>>,
main_states: Table<E::BaseField>,
aux_states: Option<Table<E>>,
}
impl<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> TraceQueries<E, H> {
pub fn new<A: Air<BaseField = E::BaseField>>(
mut queries: Vec<Queries>,
air: &A,
) -> Result<Self, VerifierError> {
assert_eq!(
queries.len(),
air.trace_layout().num_segments(),
"expected {} trace segment queries, but received {}",
air.trace_layout().num_segments(),
queries.len()
);
let num_queries = air.options().num_queries();
let main_segment_width = air.trace_layout().main_trace_width();
let main_segment_queries = queries.remove(0);
let (main_segment_query_proofs, main_segment_states) = main_segment_queries
.parse::<H, E::BaseField>(air.lde_domain_size(), num_queries, main_segment_width)
.map_err(|err| {
VerifierError::ProofDeserializationError(format!(
"main trace segment query deserialization failed: {}",
err
))
})?;
let mut query_proofs = vec![main_segment_query_proofs];
let aux_trace_states = if air.trace_info().is_multi_segment() {
let mut aux_trace_states = Vec::new();
for (i, segment_queries) in queries.into_iter().enumerate() {
let segment_width = air.trace_layout().get_aux_segment_width(i);
let (segment_query_proof, segment_trace_states) = segment_queries
.parse::<H, E>(air.lde_domain_size(), num_queries, segment_width)
.map_err(|err| {
VerifierError::ProofDeserializationError(format!(
"auxiliary trace segment query deserialization failed: {}",
err
))
})?;
query_proofs.push(segment_query_proof);
aux_trace_states.push(segment_trace_states);
}
Some(Table::merge(aux_trace_states))
} else {
None
};
Ok(Self {
query_proofs,
main_states: main_segment_states,
aux_states: aux_trace_states,
})
}
}
struct ConstraintQueries<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> {
query_proofs: BatchMerkleProof<H>,
evaluations: Table<E>,
}
impl<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> ConstraintQueries<E, H> {
pub fn new<A: Air<BaseField = E::BaseField>>(
queries: Queries,
air: &A,
) -> Result<Self, VerifierError> {
let num_queries = air.options().num_queries();
let (query_proofs, evaluations) = queries
.parse::<H, E>(air.lde_domain_size(), num_queries, air.ce_blowup_factor())
.map_err(|err| {
VerifierError::ProofDeserializationError(format!(
"constraint evaluation query deserialization failed: {}",
err
))
})?;
Ok(Self {
query_proofs,
evaluations,
})
}
}
struct TraceOodFrame<E: FieldElement> {
main_frame: EvaluationFrame<E>,
aux_frame: Option<EvaluationFrame<E>>,
}
impl<E: FieldElement> TraceOodFrame<E> {
pub fn new(main_frame: EvaluationFrame<E>, aux_frame: Option<EvaluationFrame<E>>) -> Self {
Self {
main_frame,
aux_frame,
}
}
}