use field_cat::{Field, FieldBytes};
use proof_cat::commit::merkle::{MerkleProof, MerkleRoot, MerkleTree};
use proof_cat::poly::MultilinearPoly;
use proof_cat::sumcheck::{SumcheckClaim, SumcheckProof, sumcheck_prove, sumcheck_verify};
use proof_cat::transcript::Transcript;
use crate::air::Air;
use crate::air_expr::AirExpr;
use crate::column::Column;
use crate::error::Error;
use crate::trace::Trace;
const TRANSCRIPT_LABEL: &[u8] = b"machine-cat-v0.1";
#[derive(Debug, Clone)]
pub struct ColumnOpening<F: Field> {
column_index: usize,
values: Vec<F>,
merkle_proofs: Vec<MerkleProof>,
}
impl<F: Field> ColumnOpening<F> {
#[must_use]
pub fn column_index(&self) -> usize {
self.column_index
}
#[must_use]
pub fn values(&self) -> &[F] {
&self.values
}
#[must_use]
pub fn merkle_proofs(&self) -> &[MerkleProof] {
&self.merkle_proofs
}
}
#[derive(Debug, Clone)]
pub struct AirProof<F: Field> {
trace_commitment: MerkleRoot,
sumcheck: SumcheckProof<F>,
column_openings: Vec<ColumnOpening<F>>,
row_count: usize,
}
impl<F: Field> AirProof<F> {
#[must_use]
pub fn trace_commitment(&self) -> &MerkleRoot {
&self.trace_commitment
}
#[must_use]
pub fn sumcheck(&self) -> &SumcheckProof<F> {
&self.sumcheck
}
#[must_use]
pub fn column_openings(&self) -> &[ColumnOpening<F>] {
&self.column_openings
}
#[must_use]
pub fn row_count(&self) -> usize {
self.row_count
}
}
pub fn air_prove<F: FieldBytes, A: Air<F>>(
air: &A,
trace: &Trace<F>,
) -> Result<AirProof<F>, Error> {
validate_trace(air, trace)?;
let constraints = air.constraints();
if constraints.is_empty() {
Err(Error::NoConstraints)
} else {
validate_constraints(&constraints, trace)?;
let tree = MerkleTree::from_field_values(trace.data());
let transcript = Transcript::new(TRANSCRIPT_LABEL)
.absorb_bytes(tree.root().as_bytes())
.absorb_bytes(&air.column_count().count().to_le_bytes())
.absorb_bytes(&constraints.len().to_le_bytes());
let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
let combined_evals = compute_combined_evals(&constraints, &alphas, trace)?;
let padded = pad_to_power_of_two(combined_evals);
let poly = MultilinearPoly::from_evals(padded)?;
let (sumcheck, _, _) = sumcheck_prove(&SumcheckClaim::new(poly, F::zero()), transcript)?;
let column_openings = open_all_columns(air, trace, &tree)?;
Ok(AirProof {
trace_commitment: tree.root(),
sumcheck,
column_openings,
row_count: trace.row_count().count(),
})
}
}
pub fn air_verify<F: FieldBytes, A: Air<F>>(air: &A, proof: &AirProof<F>) -> Result<bool, Error> {
let constraints = air.constraints();
if constraints.is_empty() {
Err(Error::NoConstraints)
} else {
let transcript = Transcript::new(TRANSCRIPT_LABEL)
.absorb_bytes(proof.trace_commitment.as_bytes())
.absorb_bytes(&air.column_count().count().to_le_bytes())
.absorb_bytes(&constraints.len().to_le_bytes());
let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
let num_row_pairs = proof.row_count.saturating_sub(1);
let padded_len = pad_to_power_of_two_len(num_row_pairs);
let num_vars = usize::try_from(padded_len.trailing_zeros()).map_err(|_| {
Error::TraceNotPowerOfTwo {
row_count: padded_len,
}
})?;
let (final_eval, challenges, _) = sumcheck_verify(
proof.sumcheck(),
&F::zero(),
proof_cat::NumVars::new(num_vars),
transcript,
)?;
if verify_merkle_openings(proof) {
let trace = reconstruct_trace(air, proof)?;
let combined_evals = compute_combined_evals(&constraints, &alphas, &trace)?;
let padded = pad_to_power_of_two(combined_evals);
let poly = MultilinearPoly::from_evals(padded)?;
let expected = poly.evaluate(&challenges)?;
Ok(expected == final_eval)
} else {
Err(Error::ProofCat(proof_cat::Error::MerkleVerificationFailed))
}
}
}
fn validate_trace<F: Field, A: Air<F>>(air: &A, trace: &Trace<F>) -> Result<(), Error> {
if trace.column_count() != air.column_count() {
Err(Error::ColumnCountMismatch {
expected: air.column_count().count(),
actual: trace.column_count().count(),
})
} else if trace.row_count().count() < 2 {
Err(Error::InsufficientRows {
row_count: trace.row_count().count(),
})
} else {
Ok(())
}
}
fn validate_constraints<F: Field>(
constraints: &[AirExpr<F>],
trace: &Trace<F>,
) -> Result<(), Error> {
(0..trace.row_count().count() - 1).try_for_each(|row| {
let assign = trace.row_pair_assignment(row)?;
constraints.iter().try_for_each(|c| {
let val = c.evaluate(&assign)?;
if val == F::zero() {
Ok(())
} else {
Err(Error::UnsatisfiedAirConstraint { row })
}
})
})
}
fn squeeze_challenges<F: FieldBytes>(
count: usize,
transcript: Transcript,
) -> Result<(Vec<F>, Transcript), Error> {
(0..count).try_fold((Vec::with_capacity(count), transcript), |(alphas, t), _| {
let (challenge, t): (F, Transcript) = t.squeeze_challenge()?;
Ok((
alphas
.into_iter()
.chain(core::iter::once(challenge))
.collect(),
t,
))
})
}
fn compute_combined_evals<F: Field>(
constraints: &[AirExpr<F>],
alphas: &[F],
trace: &Trace<F>,
) -> Result<Vec<F>, Error> {
(0..trace.row_count().count() - 1)
.map(|row| {
let assign = trace.row_pair_assignment(row)?;
constraints
.iter()
.zip(alphas.iter())
.try_fold(F::zero(), |acc, (c, alpha)| {
let val = c.evaluate(&assign)?;
Ok(acc + alpha.clone() * val)
})
})
.collect()
}
fn open_all_columns<F: FieldBytes, A: Air<F>>(
air: &A,
trace: &Trace<F>,
tree: &MerkleTree,
) -> Result<Vec<ColumnOpening<F>>, Error> {
let cols = air.column_count().count();
let rows = trace.row_count().count();
(0..cols)
.map(|col_idx| {
let values = trace.column_values(Column::new(col_idx))?;
let merkle_proofs: Result<Vec<MerkleProof>, Error> = (0..rows)
.map(|row| {
let flat_idx = row * cols + col_idx;
tree.open(flat_idx).map_err(Error::from)
})
.collect();
Ok(ColumnOpening {
column_index: col_idx,
values,
merkle_proofs: merkle_proofs?,
})
})
.collect()
}
fn verify_merkle_openings<F: FieldBytes>(proof: &AirProof<F>) -> bool {
let cols = proof.column_openings.len();
proof.column_openings.iter().all(|opening| {
opening.values.iter().enumerate().all(|(row, value)| {
let flat_idx = row * cols + opening.column_index;
MerkleTree::verify_opening(
&proof.trace_commitment,
flat_idx,
value,
&opening.merkle_proofs[row],
)
})
})
}
fn reconstruct_trace<F: Field, A: Air<F>>(air: &A, proof: &AirProof<F>) -> Result<Trace<F>, Error> {
let cols = air.column_count().count();
let rows = proof.row_count;
let row_vecs: Vec<Vec<F>> = (0..rows)
.map(|r| {
(0..cols)
.map(|c| {
proof
.column_openings
.get(c)
.and_then(|opening| opening.values.get(r).cloned())
.ok_or(Error::ColumnOutOfBounds {
index: c,
column_count: cols,
})
})
.collect::<Result<Vec<F>, Error>>()
})
.collect::<Result<Vec<Vec<F>>, Error>>()?;
Trace::from_rows(air.column_count(), &row_vecs)
}
fn pad_to_power_of_two<F: Field>(v: Vec<F>) -> Vec<F> {
let target = pad_to_power_of_two_len(v.len());
let padding = target - v.len();
v.into_iter()
.chain((0..padding).map(|_| F::zero()))
.collect()
}
fn pad_to_power_of_two_len(n: usize) -> usize {
if n <= 1 { 1 } else { n.next_power_of_two() }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fibonacci::{FibonacciAir, FibonacciInput, StepCount};
use field_cat::F101;
#[test]
fn fibonacci_prove_verify_roundtrip() -> Result<(), Error> {
let air = FibonacciAir;
let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(7));
let trace = air.generate_trace(&input)?;
let proof = air_prove(&air, &trace)?;
assert!(air_verify(&air, &proof)?);
Ok(())
}
#[test]
fn fibonacci_small_trace() -> Result<(), Error> {
let air = FibonacciAir;
let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(1));
let trace = air.generate_trace(&input)?;
let proof = air_prove(&air, &trace)?;
assert!(air_verify(&air, &proof)?);
Ok(())
}
#[test]
fn fibonacci_different_initial_values() -> Result<(), Error> {
let air = FibonacciAir;
let input = FibonacciInput::new(F101::new(3), F101::new(5), StepCount::new(3));
let trace = air.generate_trace(&input)?;
let proof = air_prove(&air, &trace)?;
assert!(air_verify(&air, &proof)?);
Ok(())
}
#[test]
fn invalid_trace_column_count_rejected() {
let air = FibonacciAir;
let trace = Trace::from_rows(
crate::column::ColumnCount::new(3),
&[
vec![F101::new(1), F101::new(1), F101::new(0)],
vec![F101::new(1), F101::new(2), F101::new(0)],
],
);
match trace {
Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
Err(_) => {} }
}
#[test]
fn tampered_trace_rejected() {
let air = FibonacciAir;
let trace = Trace::from_rows(
crate::column::ColumnCount::new(2),
&[
vec![F101::new(1), F101::new(1)],
vec![F101::new(1), F101::new(99)], ],
);
match trace {
Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
Err(_) => {}
}
}
}