use alloc::{string::ToString, vec::Vec};
use core::marker::PhantomData;
use air::{
proof::{Proof, Queries, QuotientOodFrame, Table, TraceOodFrame},
Air,
};
use crypto::{ElementHasher, VectorCommitment};
use fri::VerifierChannel as FriVerifierChannel;
use math::{FieldElement, StarkField};
use crate::VerifierError;
pub struct VerifierChannel<
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
> {
trace_commitments: Vec<H::Digest>,
trace_queries: Option<TraceQueries<E, H, V>>,
constraint_commitment: H::Digest,
constraint_queries: Option<ConstraintQueries<E, H, V>>,
partition_size_main: usize,
partition_size_aux: usize,
partition_size_constraint: usize,
fri_commitments: Option<Vec<H::Digest>>,
fri_layer_proofs: Vec<V::MultiProof>,
fri_layer_queries: Vec<Vec<E>>,
fri_remainder: Option<Vec<E>>,
fri_num_partitions: usize,
ood_trace_frame: Option<TraceOodFrame<E>>,
ood_constraint_evaluations: Option<QuotientOodFrame<E>>,
pow_nonce: u64,
}
impl<E, H, V> VerifierChannel<E, H, V>
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
pub fn new<A: Air<BaseField = E::BaseField>>(
air: &A,
proof: Proof,
) -> Result<Self, VerifierError> {
let Proof {
context,
num_unique_queries,
commitments,
trace_queries,
constraint_queries,
ood_frame,
fri_proof,
pow_nonce,
} = proof;
if E::BaseField::get_modulus_le_bytes() != context.field_modulus_bytes() {
return Err(VerifierError::InconsistentBaseField);
}
let constraint_frame_width = air.context().num_constraint_composition_columns();
let num_trace_segments = air.trace_info().num_segments();
let main_trace_width = air.trace_info().main_trace_width();
let aux_trace_width = air.trace_info().aux_segment_width();
let lde_domain_size = air.lde_domain_size();
let fri_options = air.options().to_fri_options();
let partition_options = air.options().partition_options();
let (trace_commitments, constraint_commitment, fri_commitments) = commitments
.parse::<H>(num_trace_segments, fri_options.num_fri_layers(lde_domain_size))
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let trace_queries =
TraceQueries::<E, H, V>::new(trace_queries, air, num_unique_queries as usize)?;
let constraint_queries = ConstraintQueries::<E, H, V>::new(
constraint_queries,
air,
num_unique_queries as usize,
)?;
let fri_num_partitions = fri_proof.num_partitions();
let fri_remainder = fri_proof
.parse_remainder()
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let (fri_layer_queries, fri_layer_proofs) = fri_proof
.parse_layers::<E, H, V>(lde_domain_size, fri_options.folding_factor())
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let (ood_trace_frame, ood_constraint_evaluations) = ood_frame
.parse(main_trace_width, aux_trace_width, constraint_frame_width)
.map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?;
let partition_size_main = partition_options
.partition_size::<E::BaseField>(air.context().trace_info().main_trace_width());
let partition_size_aux =
partition_options.partition_size::<E>(air.context().trace_info().aux_segment_width());
let partition_size_constraint = partition_options
.partition_size::<E>(air.context().num_constraint_composition_columns());
Ok(VerifierChannel {
trace_commitments,
trace_queries: Some(trace_queries),
constraint_commitment,
constraint_queries: Some(constraint_queries),
partition_size_main,
partition_size_aux,
partition_size_constraint,
fri_commitments: Some(fri_commitments),
fri_layer_proofs,
fri_layer_queries,
fri_remainder: Some(fri_remainder),
fri_num_partitions,
ood_trace_frame: Some(ood_trace_frame),
ood_constraint_evaluations: Some(ood_constraint_evaluations),
pow_nonce,
})
}
pub fn read_trace_commitments(&self) -> &[H::Digest] {
&self.trace_commitments
}
pub fn read_constraint_commitment(&self) -> H::Digest {
self.constraint_commitment
}
pub fn read_ood_trace_frame(&mut self) -> TraceOodFrame<E> {
self.ood_trace_frame.take().expect("already read")
}
pub fn read_ood_constraint_frame(&mut self) -> QuotientOodFrame<E> {
self.ood_constraint_evaluations.take().expect("already read")
}
pub fn read_pow_nonce(&self) -> u64 {
self.pow_nonce
}
#[allow(clippy::type_complexity)]
pub fn read_queried_trace_states(
&mut self,
positions: &[usize],
) -> Result<(Table<E::BaseField>, Option<Table<E>>), VerifierError> {
let queries = self.trace_queries.take().expect("already read");
let items: Vec<H::Digest> = queries
.main_states
.rows()
.map(|row| hash_row::<H, E::BaseField>(row, self.partition_size_main))
.collect();
<V as VectorCommitment<H>>::verify_many(
self.trace_commitments[0],
positions,
&items,
&queries.query_proofs[0],
)
.map_err(|_| VerifierError::TraceQueryDoesNotMatchCommitment)?;
if let Some(ref aux_states) = queries.aux_states {
let items: Vec<H::Digest> = aux_states
.rows()
.map(|row| hash_row::<H, E>(row, self.partition_size_aux))
.collect();
<V as VectorCommitment<H>>::verify_many(
self.trace_commitments[1],
positions,
&items,
&queries.query_proofs[1],
)
.map_err(|_| VerifierError::TraceQueryDoesNotMatchCommitment)?;
}
Ok((queries.main_states, queries.aux_states))
}
pub fn read_constraint_evaluations(
&mut self,
positions: &[usize],
) -> Result<Table<E>, VerifierError> {
let queries = self.constraint_queries.take().expect("already read");
let items: Vec<H::Digest> = queries
.evaluations
.rows()
.map(|row| hash_row::<H, E>(row, self.partition_size_constraint))
.collect();
<V as VectorCommitment<H>>::verify_many(
self.constraint_commitment,
positions,
&items,
&queries.query_proofs,
)
.map_err(|_| VerifierError::ConstraintQueryDoesNotMatchCommitment)?;
Ok(queries.evaluations)
}
}
impl<E, H, V> FriVerifierChannel<E> for VerifierChannel<E, H, V>
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
type Hasher = H;
type VectorCommitment = V;
fn read_fri_num_partitions(&self) -> usize {
self.fri_num_partitions
}
fn read_fri_layer_commitments(&mut self) -> Vec<H::Digest> {
self.fri_commitments.take().expect("already read")
}
fn take_next_fri_layer_proof(&mut self) -> V::MultiProof {
self.fri_layer_proofs.remove(0)
}
fn take_next_fri_layer_queries(&mut self) -> Vec<E> {
self.fri_layer_queries.remove(0)
}
fn take_fri_remainder(&mut self) -> Vec<E> {
self.fri_remainder.take().expect("already read")
}
}
struct TraceQueries<
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
> {
query_proofs: Vec<V::MultiProof>,
main_states: Table<E::BaseField>,
aux_states: Option<Table<E>>,
_h: PhantomData<H>,
}
impl<E, H, V> TraceQueries<E, H, V>
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
pub fn new<A: Air<BaseField = E::BaseField>>(
mut queries: Vec<Queries>,
air: &A,
num_queries: usize,
) -> Result<Self, VerifierError> {
assert_eq!(
queries.len(),
air.trace_info().num_segments(),
"expected {} trace segment queries, but received {}",
air.trace_info().num_segments(),
queries.len()
);
let main_segment_width = air.trace_info().main_trace_width();
let main_segment_queries = queries.remove(0);
let (main_segment_query_proofs, main_segment_states) = main_segment_queries
.parse::<E::BaseField, H, V>(air.lde_domain_size(), num_queries, main_segment_width)
.map_err(|err| {
VerifierError::ProofDeserializationError(format!(
"main trace segment query deserialization failed: {err}"
))
})?;
let mut query_proofs = vec![main_segment_query_proofs];
let aux_trace_states = if air.trace_info().is_multi_segment() {
let mut aux_trace_states = Vec::new();
let segment_queries = queries.remove(0);
let segment_width = air.trace_info().get_aux_segment_width();
let (segment_query_proof, segment_trace_states) = segment_queries
.parse::<E, H, V>(air.lde_domain_size(), num_queries, segment_width)
.map_err(|err| {
VerifierError::ProofDeserializationError(format!(
"auxiliary trace segment query deserialization failed: {err}"
))
})?;
query_proofs.push(segment_query_proof);
aux_trace_states.push(segment_trace_states);
Some(Table::merge(aux_trace_states))
} else {
None
};
Ok(Self {
query_proofs,
main_states: main_segment_states,
aux_states: aux_trace_states,
_h: PhantomData,
})
}
}
struct ConstraintQueries<
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
> {
query_proofs: V::MultiProof,
evaluations: Table<E>,
_h: PhantomData<H>,
}
impl<E, H, V> ConstraintQueries<E, H, V>
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
pub fn new<A: Air<BaseField = E::BaseField>>(
queries: Queries,
air: &A,
num_queries: usize,
) -> Result<Self, VerifierError> {
let constraint_frame_width = air.context().num_constraint_composition_columns();
let (query_proofs, evaluations) = queries
.parse::<E, H, V>(air.lde_domain_size(), num_queries, constraint_frame_width)
.map_err(|err| {
VerifierError::ProofDeserializationError(format!(
"constraint evaluation query deserialization failed: {err}"
))
})?;
Ok(Self {
query_proofs,
evaluations,
_h: PhantomData,
})
}
}
fn hash_row<H, E>(row: &[E], partition_size: usize) -> H::Digest
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
{
if partition_size == row.len() {
H::hash_elements(row)
} else {
let num_partitions = row.len().div_ceil(partition_size);
let mut buffer = vec![H::Digest::default(); num_partitions];
row.chunks(partition_size)
.zip(buffer.iter_mut())
.for_each(|(chunk, buf)| *buf = H::hash_elements(chunk));
H::merge_many(&buffer)
}
}