1use field_cat::{Field, FieldBytes};
10use proof_cat_core::commit::merkle::{MerkleProof, MerkleRoot, MerkleTree};
11use proof_cat_core::poly::MultilinearPoly;
12use proof_cat_core::sumcheck::{SumcheckClaim, SumcheckProof, sumcheck_prove, sumcheck_verify};
13use proof_cat_core::transcript::Transcript;
14
15use crate::air::Air;
16use crate::air_expr::AirExpr;
17use crate::column::Column;
18use crate::error::Error;
19use crate::trace::Trace;
20
21const TRANSCRIPT_LABEL: &[u8] = b"machine-cat-v0.1";
23
24#[derive(Debug, Clone)]
28pub struct ColumnOpening<F: Field> {
29 column_index: usize,
30 values: Vec<F>,
31 merkle_proofs: Vec<MerkleProof>,
32}
33
34impl<F: Field> ColumnOpening<F> {
35 #[must_use]
37 pub fn column_index(&self) -> usize {
38 self.column_index
39 }
40
41 #[must_use]
43 pub fn values(&self) -> &[F] {
44 &self.values
45 }
46
47 #[must_use]
49 pub fn merkle_proofs(&self) -> &[MerkleProof] {
50 &self.merkle_proofs
51 }
52}
53
54#[derive(Debug, Clone)]
74pub struct AirProof<F: Field> {
75 trace_commitment: MerkleRoot,
76 sumcheck: SumcheckProof<F>,
77 column_openings: Vec<ColumnOpening<F>>,
78 row_count: usize,
79}
80
81impl<F: Field> AirProof<F> {
82 #[must_use]
84 pub fn trace_commitment(&self) -> &MerkleRoot {
85 &self.trace_commitment
86 }
87
88 #[must_use]
90 pub fn sumcheck(&self) -> &SumcheckProof<F> {
91 &self.sumcheck
92 }
93
94 #[must_use]
96 pub fn column_openings(&self) -> &[ColumnOpening<F>] {
97 &self.column_openings
98 }
99
100 #[must_use]
102 pub fn row_count(&self) -> usize {
103 self.row_count
104 }
105}
106
107pub fn air_prove<F: FieldBytes, A: Air<F>>(
127 air: &A,
128 trace: &Trace<F>,
129) -> Result<AirProof<F>, Error> {
130 validate_trace(air, trace)?;
132
133 let constraints = air.constraints();
134 if constraints.is_empty() {
135 Err(Error::NoConstraints)
136 } else {
137 validate_constraints(&constraints, trace)?;
139
140 let tree = MerkleTree::from_field_values(trace.data());
142
143 let transcript = Transcript::new(TRANSCRIPT_LABEL)
145 .absorb_bytes(tree.root().as_bytes())
146 .absorb_bytes(&air.column_count().count().to_le_bytes())
147 .absorb_bytes(&constraints.len().to_le_bytes());
148
149 let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
151
152 let combined_evals = compute_combined_evals(&constraints, &alphas, trace)?;
154
155 let padded = pad_to_power_of_two(combined_evals);
157 let poly = MultilinearPoly::from_evals(padded)?;
158
159 let (sumcheck, _, _) = sumcheck_prove(&SumcheckClaim::new(poly, F::zero()), transcript)?;
161
162 let column_openings = open_all_columns(air, trace, &tree)?;
164
165 Ok(AirProof {
166 trace_commitment: tree.root(),
167 sumcheck,
168 column_openings,
169 row_count: trace.row_count().count(),
170 })
171 }
172}
173
174pub fn air_verify<F: FieldBytes, A: Air<F>>(air: &A, proof: &AirProof<F>) -> Result<bool, Error> {
185 let constraints = air.constraints();
186 if constraints.is_empty() {
187 Err(Error::NoConstraints)
188 } else {
189 let transcript = Transcript::new(TRANSCRIPT_LABEL)
191 .absorb_bytes(proof.trace_commitment.as_bytes())
192 .absorb_bytes(&air.column_count().count().to_le_bytes())
193 .absorb_bytes(&constraints.len().to_le_bytes());
194
195 let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
196
197 let num_row_pairs = proof.row_count.saturating_sub(1);
199 let padded_len = pad_to_power_of_two_len(num_row_pairs);
200 let num_vars = usize::try_from(padded_len.trailing_zeros()).map_err(|_| {
201 Error::TraceNotPowerOfTwo {
202 row_count: padded_len,
203 }
204 })?;
205
206 let (final_eval, challenges, _) = sumcheck_verify(
208 proof.sumcheck(),
209 &F::zero(),
210 proof_cat_core::NumVars::new(num_vars),
211 transcript,
212 )?;
213
214 if verify_merkle_openings(proof) {
216 let trace = reconstruct_trace(air, proof)?;
218 let combined_evals = compute_combined_evals(&constraints, &alphas, &trace)?;
219 let padded = pad_to_power_of_two(combined_evals);
220 let poly = MultilinearPoly::from_evals(padded)?;
221
222 let expected = poly.evaluate(&challenges)?;
224 Ok(expected == final_eval)
225 } else {
226 Err(Error::ProofCatCore(proof_cat_core::Error::MerkleVerificationFailed))
227 }
228 }
229}
230
231fn validate_trace<F: Field, A: Air<F>>(air: &A, trace: &Trace<F>) -> Result<(), Error> {
235 if trace.column_count() != air.column_count() {
236 Err(Error::ColumnCountMismatch {
237 expected: air.column_count().count(),
238 actual: trace.column_count().count(),
239 })
240 } else if trace.row_count().count() < 2 {
241 Err(Error::InsufficientRows {
242 row_count: trace.row_count().count(),
243 })
244 } else {
245 Ok(())
246 }
247}
248
249fn validate_constraints<F: Field>(
251 constraints: &[AirExpr<F>],
252 trace: &Trace<F>,
253) -> Result<(), Error> {
254 (0..trace.row_count().count() - 1).try_for_each(|row| {
255 let assign = trace.row_pair_assignment(row)?;
256 constraints.iter().try_for_each(|c| {
257 let val = c.evaluate(&assign)?;
258 if val == F::zero() {
259 Ok(())
260 } else {
261 Err(Error::UnsatisfiedAirConstraint { row })
262 }
263 })
264 })
265}
266
267fn squeeze_challenges<F: FieldBytes>(
269 count: usize,
270 transcript: Transcript,
271) -> Result<(Vec<F>, Transcript), Error> {
272 (0..count).try_fold((Vec::with_capacity(count), transcript), |(alphas, t), _| {
273 let (challenge, t): (F, Transcript) = t.squeeze_challenge()?;
274 Ok((
275 alphas
276 .into_iter()
277 .chain(core::iter::once(challenge))
278 .collect(),
279 t,
280 ))
281 })
282}
283
284fn compute_combined_evals<F: Field>(
288 constraints: &[AirExpr<F>],
289 alphas: &[F],
290 trace: &Trace<F>,
291) -> Result<Vec<F>, Error> {
292 (0..trace.row_count().count() - 1)
293 .map(|row| {
294 let assign = trace.row_pair_assignment(row)?;
295 constraints
296 .iter()
297 .zip(alphas.iter())
298 .try_fold(F::zero(), |acc, (c, alpha)| {
299 let val = c.evaluate(&assign)?;
300 Ok(acc + alpha.clone() * val)
301 })
302 })
303 .collect()
304}
305
306fn open_all_columns<F: FieldBytes, A: Air<F>>(
308 air: &A,
309 trace: &Trace<F>,
310 tree: &MerkleTree,
311) -> Result<Vec<ColumnOpening<F>>, Error> {
312 let cols = air.column_count().count();
313 let rows = trace.row_count().count();
314 (0..cols)
315 .map(|col_idx| {
316 let values = trace.column_values(Column::new(col_idx))?;
317 let merkle_proofs: Result<Vec<MerkleProof>, Error> = (0..rows)
318 .map(|row| {
319 let flat_idx = row * cols + col_idx;
320 tree.open(flat_idx).map_err(Error::from)
321 })
322 .collect();
323 Ok(ColumnOpening {
324 column_index: col_idx,
325 values,
326 merkle_proofs: merkle_proofs?,
327 })
328 })
329 .collect()
330}
331
332fn verify_merkle_openings<F: FieldBytes>(proof: &AirProof<F>) -> bool {
334 let cols = proof.column_openings.len();
335 proof.column_openings.iter().all(|opening| {
336 opening.values.iter().enumerate().all(|(row, value)| {
337 let flat_idx = row * cols + opening.column_index;
338 MerkleTree::verify_opening(
339 &proof.trace_commitment,
340 flat_idx,
341 value,
342 &opening.merkle_proofs[row],
343 )
344 })
345 })
346}
347
348fn reconstruct_trace<F: Field, A: Air<F>>(air: &A, proof: &AirProof<F>) -> Result<Trace<F>, Error> {
350 let cols = air.column_count().count();
351 let rows = proof.row_count;
352 let row_vecs: Vec<Vec<F>> = (0..rows)
353 .map(|r| {
354 (0..cols)
355 .map(|c| {
356 proof
357 .column_openings
358 .get(c)
359 .and_then(|opening| opening.values.get(r).cloned())
360 .ok_or(Error::ColumnOutOfBounds {
361 index: c,
362 column_count: cols,
363 })
364 })
365 .collect::<Result<Vec<F>, Error>>()
366 })
367 .collect::<Result<Vec<Vec<F>>, Error>>()?;
368 Trace::from_rows(air.column_count(), &row_vecs)
369}
370
371fn pad_to_power_of_two<F: Field>(v: Vec<F>) -> Vec<F> {
373 let target = pad_to_power_of_two_len(v.len());
374 let padding = target - v.len();
375 v.into_iter()
376 .chain((0..padding).map(|_| F::zero()))
377 .collect()
378}
379
380fn pad_to_power_of_two_len(n: usize) -> usize {
382 if n <= 1 { 1 } else { n.next_power_of_two() }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use crate::fibonacci::{FibonacciAir, FibonacciInput, StepCount};
389 use field_cat::F101;
390
391 #[test]
392 fn fibonacci_prove_verify_roundtrip() -> Result<(), Error> {
393 let air = FibonacciAir;
394 let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(7));
395 let trace = air.generate_trace(&input)?;
396
397 let proof = air_prove(&air, &trace)?;
398 assert!(air_verify(&air, &proof)?);
399 Ok(())
400 }
401
402 #[test]
403 fn fibonacci_small_trace() -> Result<(), Error> {
404 let air = FibonacciAir;
405 let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(1));
407 let trace = air.generate_trace(&input)?;
408
409 let proof = air_prove(&air, &trace)?;
410 assert!(air_verify(&air, &proof)?);
411 Ok(())
412 }
413
414 #[test]
415 fn fibonacci_different_initial_values() -> Result<(), Error> {
416 let air = FibonacciAir;
417 let input = FibonacciInput::new(F101::new(3), F101::new(5), StepCount::new(3));
418 let trace = air.generate_trace(&input)?;
419
420 let proof = air_prove(&air, &trace)?;
421 assert!(air_verify(&air, &proof)?);
422 Ok(())
423 }
424
425 #[test]
426 fn invalid_trace_column_count_rejected() {
427 let air = FibonacciAir;
428 let trace = Trace::from_rows(
430 crate::column::ColumnCount::new(3),
431 &[
432 vec![F101::new(1), F101::new(1), F101::new(0)],
433 vec![F101::new(1), F101::new(2), F101::new(0)],
434 ],
435 );
436 match trace {
437 Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
438 Err(_) => {} }
440 }
441
442 #[test]
443 fn tampered_trace_rejected() {
444 let air = FibonacciAir;
445 let trace = Trace::from_rows(
447 crate::column::ColumnCount::new(2),
448 &[
449 vec![F101::new(1), F101::new(1)],
450 vec![F101::new(1), F101::new(99)], ],
452 );
453 match trace {
454 Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
455 Err(_) => {}
456 }
457 }
458}