1use proof_cat::commit::merkle::{MerkleProof, MerkleRoot, MerkleTree};
10use proof_cat::poly::MultilinearPoly;
11use proof_cat::sumcheck::{SumcheckClaim, SumcheckProof, sumcheck_prove, sumcheck_verify};
12use proof_cat::transcript::Transcript;
13use proof_cat::FieldBytes;
14
15use plonkish_cat::Field;
16
17use crate::air::Air;
18use crate::air_expr::AirExpr;
19use crate::column::Column;
20use crate::error::Error;
21use crate::trace::Trace;
22
23const TRANSCRIPT_LABEL: &[u8] = b"machine-cat-v0.1";
25
26#[derive(Debug, Clone)]
30pub struct ColumnOpening<F: Field> {
31 column_index: usize,
32 values: Vec<F>,
33 merkle_proofs: Vec<MerkleProof>,
34}
35
36impl<F: Field> ColumnOpening<F> {
37 #[must_use]
39 pub fn column_index(&self) -> usize {
40 self.column_index
41 }
42
43 #[must_use]
45 pub fn values(&self) -> &[F] {
46 &self.values
47 }
48
49 #[must_use]
51 pub fn merkle_proofs(&self) -> &[MerkleProof] {
52 &self.merkle_proofs
53 }
54}
55
56#[derive(Debug, Clone)]
76pub struct AirProof<F: Field> {
77 trace_commitment: MerkleRoot,
78 sumcheck: SumcheckProof<F>,
79 column_openings: Vec<ColumnOpening<F>>,
80 row_count: usize,
81}
82
83impl<F: Field> AirProof<F> {
84 #[must_use]
86 pub fn trace_commitment(&self) -> &MerkleRoot {
87 &self.trace_commitment
88 }
89
90 #[must_use]
92 pub fn sumcheck(&self) -> &SumcheckProof<F> {
93 &self.sumcheck
94 }
95
96 #[must_use]
98 pub fn column_openings(&self) -> &[ColumnOpening<F>] {
99 &self.column_openings
100 }
101
102 #[must_use]
104 pub fn row_count(&self) -> usize {
105 self.row_count
106 }
107}
108
109pub fn air_prove<F: FieldBytes, A: Air<F>>(
129 air: &A,
130 trace: &Trace<F>,
131) -> Result<AirProof<F>, Error> {
132 validate_trace(air, trace)?;
134
135 let constraints = air.constraints();
136 if constraints.is_empty() {
137 Err(Error::NoConstraints)
138 } else {
139 validate_constraints(&constraints, trace)?;
141
142 let tree = MerkleTree::from_field_values(trace.data());
144
145 let transcript = Transcript::new(TRANSCRIPT_LABEL)
147 .absorb_bytes(tree.root().as_bytes())
148 .absorb_bytes(&air.column_count().count().to_le_bytes())
149 .absorb_bytes(&constraints.len().to_le_bytes());
150
151 let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
153
154 let combined_evals = compute_combined_evals(&constraints, &alphas, trace)?;
156
157 let padded = pad_to_power_of_two(combined_evals);
159 let poly = MultilinearPoly::from_evals(padded)?;
160
161 let (sumcheck, _, _) = sumcheck_prove(
163 &SumcheckClaim::new(poly, F::zero()),
164 transcript,
165 )?;
166
167 let column_openings = open_all_columns(air, trace, &tree)?;
169
170 Ok(AirProof {
171 trace_commitment: tree.root(),
172 sumcheck,
173 column_openings,
174 row_count: trace.row_count().count(),
175 })
176 }
177}
178
179pub fn air_verify<F: FieldBytes, A: Air<F>>(
190 air: &A,
191 proof: &AirProof<F>,
192) -> Result<bool, Error> {
193 let constraints = air.constraints();
194 if constraints.is_empty() {
195 Err(Error::NoConstraints)
196 } else {
197 let transcript = Transcript::new(TRANSCRIPT_LABEL)
199 .absorb_bytes(proof.trace_commitment.as_bytes())
200 .absorb_bytes(&air.column_count().count().to_le_bytes())
201 .absorb_bytes(&constraints.len().to_le_bytes());
202
203 let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
204
205 let num_row_pairs = proof.row_count.saturating_sub(1);
207 let padded_len = pad_to_power_of_two_len(num_row_pairs);
208 let num_vars = usize::try_from(padded_len.trailing_zeros())
209 .map_err(|_| Error::TraceNotPowerOfTwo { row_count: padded_len })?;
210
211 let (final_eval, challenges, _) = sumcheck_verify(
213 proof.sumcheck(),
214 &F::zero(),
215 proof_cat::NumVars::new(num_vars),
216 transcript,
217 )?;
218
219 if verify_merkle_openings(proof) {
221 let trace = reconstruct_trace(air, proof)?;
223 let combined_evals = compute_combined_evals(&constraints, &alphas, &trace)?;
224 let padded = pad_to_power_of_two(combined_evals);
225 let poly = MultilinearPoly::from_evals(padded)?;
226
227 let expected = poly.evaluate(&challenges)?;
229 Ok(expected == final_eval)
230 } else {
231 Err(Error::ProofCat(proof_cat::Error::MerkleVerificationFailed))
232 }
233 }
234}
235
236fn validate_trace<F: Field, A: Air<F>>(air: &A, trace: &Trace<F>) -> Result<(), Error> {
240 if trace.column_count() != air.column_count() {
241 Err(Error::ColumnCountMismatch {
242 expected: air.column_count().count(),
243 actual: trace.column_count().count(),
244 })
245 } else if trace.row_count().count() < 2 {
246 Err(Error::InsufficientRows {
247 row_count: trace.row_count().count(),
248 })
249 } else {
250 Ok(())
251 }
252}
253
254fn validate_constraints<F: Field>(
256 constraints: &[AirExpr<F>],
257 trace: &Trace<F>,
258) -> Result<(), Error> {
259 (0..trace.row_count().count() - 1).try_for_each(|row| {
260 let assign = trace.row_pair_assignment(row)?;
261 constraints.iter().try_for_each(|c| {
262 let val = c.evaluate(&assign)?;
263 if val == F::zero() {
264 Ok(())
265 } else {
266 Err(Error::UnsatisfiedAirConstraint { row })
267 }
268 })
269 })
270}
271
272fn squeeze_challenges<F: FieldBytes>(
274 count: usize,
275 transcript: Transcript,
276) -> Result<(Vec<F>, Transcript), Error> {
277 (0..count).try_fold((Vec::with_capacity(count), transcript), |(alphas, t), _| {
278 let (challenge, t): (F, Transcript) = t.squeeze_challenge()?;
279 Ok((
280 alphas.into_iter().chain(core::iter::once(challenge)).collect(),
281 t,
282 ))
283 })
284}
285
286fn compute_combined_evals<F: Field>(
290 constraints: &[AirExpr<F>],
291 alphas: &[F],
292 trace: &Trace<F>,
293) -> Result<Vec<F>, Error> {
294 (0..trace.row_count().count() - 1)
295 .map(|row| {
296 let assign = trace.row_pair_assignment(row)?;
297 constraints
298 .iter()
299 .zip(alphas.iter())
300 .try_fold(F::zero(), |acc, (c, alpha)| {
301 let val = c.evaluate(&assign)?;
302 Ok(acc + alpha.clone() * val)
303 })
304 })
305 .collect()
306}
307
308fn open_all_columns<F: FieldBytes, A: Air<F>>(
310 air: &A,
311 trace: &Trace<F>,
312 tree: &MerkleTree,
313) -> Result<Vec<ColumnOpening<F>>, Error> {
314 let cols = air.column_count().count();
315 let rows = trace.row_count().count();
316 (0..cols)
317 .map(|col_idx| {
318 let values = trace.column_values(Column::new(col_idx))?;
319 let merkle_proofs: Result<Vec<MerkleProof>, Error> = (0..rows)
320 .map(|row| {
321 let flat_idx = row * cols + col_idx;
322 tree.open(flat_idx).map_err(Error::from)
323 })
324 .collect();
325 Ok(ColumnOpening {
326 column_index: col_idx,
327 values,
328 merkle_proofs: merkle_proofs?,
329 })
330 })
331 .collect()
332}
333
334fn verify_merkle_openings<F: FieldBytes>(proof: &AirProof<F>) -> bool {
336 let cols = proof.column_openings.len();
337 proof.column_openings.iter().all(|opening| {
338 opening
339 .values
340 .iter()
341 .enumerate()
342 .all(|(row, value)| {
343 let flat_idx = row * cols + opening.column_index;
344 MerkleTree::verify_opening(
345 &proof.trace_commitment,
346 flat_idx,
347 value,
348 &opening.merkle_proofs[row],
349 )
350 })
351 })
352}
353
354fn reconstruct_trace<F: Field, A: Air<F>>(
356 air: &A,
357 proof: &AirProof<F>,
358) -> Result<Trace<F>, Error> {
359 let cols = air.column_count().count();
360 let rows = proof.row_count;
361 let row_vecs: Vec<Vec<F>> = (0..rows)
362 .map(|r| {
363 (0..cols)
364 .map(|c| {
365 proof.column_openings
366 .get(c)
367 .and_then(|opening| opening.values.get(r).cloned())
368 .ok_or(Error::ColumnOutOfBounds {
369 index: c,
370 column_count: cols,
371 })
372 })
373 .collect::<Result<Vec<F>, Error>>()
374 })
375 .collect::<Result<Vec<Vec<F>>, Error>>()?;
376 Trace::from_rows(air.column_count(), &row_vecs)
377}
378
379fn pad_to_power_of_two<F: Field>(v: Vec<F>) -> Vec<F> {
381 let target = pad_to_power_of_two_len(v.len());
382 let padding = target - v.len();
383 v.into_iter()
384 .chain((0..padding).map(|_| F::zero()))
385 .collect()
386}
387
388fn pad_to_power_of_two_len(n: usize) -> usize {
390 if n <= 1 { 1 } else { n.next_power_of_two() }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::fibonacci::{FibonacciAir, FibonacciInput, StepCount};
397 use plonkish_cat::F101;
398
399 #[test]
400 fn fibonacci_prove_verify_roundtrip() -> Result<(), Error> {
401 let air = FibonacciAir;
402 let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(7));
403 let trace = air.generate_trace(&input)?;
404
405 let proof = air_prove(&air, &trace)?;
406 assert!(air_verify(&air, &proof)?);
407 Ok(())
408 }
409
410 #[test]
411 fn fibonacci_small_trace() -> Result<(), Error> {
412 let air = FibonacciAir;
413 let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(1));
415 let trace = air.generate_trace(&input)?;
416
417 let proof = air_prove(&air, &trace)?;
418 assert!(air_verify(&air, &proof)?);
419 Ok(())
420 }
421
422 #[test]
423 fn fibonacci_different_initial_values() -> Result<(), Error> {
424 let air = FibonacciAir;
425 let input = FibonacciInput::new(F101::new(3), F101::new(5), StepCount::new(3));
426 let trace = air.generate_trace(&input)?;
427
428 let proof = air_prove(&air, &trace)?;
429 assert!(air_verify(&air, &proof)?);
430 Ok(())
431 }
432
433 #[test]
434 fn invalid_trace_column_count_rejected() {
435 let air = FibonacciAir;
436 let trace = Trace::from_rows(
438 crate::column::ColumnCount::new(3),
439 &[
440 vec![F101::new(1), F101::new(1), F101::new(0)],
441 vec![F101::new(1), F101::new(2), F101::new(0)],
442 ],
443 );
444 match trace {
445 Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
446 Err(_) => {} }
448 }
449
450 #[test]
451 fn tampered_trace_rejected() {
452 let air = FibonacciAir;
453 let trace = Trace::from_rows(
455 crate::column::ColumnCount::new(2),
456 &[
457 vec![F101::new(1), F101::new(1)],
458 vec![F101::new(1), F101::new(99)], ],
460 );
461 match trace {
462 Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
463 Err(_) => {}
464 }
465 }
466}