Skip to main content

machine_cat/
bridge.rs

1//! Trace-to-sumcheck bridge: proving AIR constraint satisfaction.
2//!
3//! Converts an AIR's transition constraints over an execution trace
4//! into a sumcheck claim and delegates to proof-cat for the proof.
5//!
6//! [`air_prove`] produces an [`AirProof`]; [`air_verify`] checks it.
7//! The current protocol opens all trace values (not zero-knowledge).
8
9use field_cat::{Field, FieldBytes};
10use proof_cat_core::commit::merkle::{MerkleProof, MerkleRoot, MerkleTree};
11use proof_cat_core::poly::MultilinearPoly;
12use proof_cat_core::sumcheck::{SumcheckClaim, SumcheckProof, sumcheck_prove, sumcheck_verify};
13use proof_cat_core::transcript::Transcript;
14
15use crate::air::Air;
16use crate::air_expr::AirExpr;
17use crate::column::Column;
18use crate::error::Error;
19use crate::trace::Trace;
20
21/// Domain separation label for the AIR proof transcript.
22const TRANSCRIPT_LABEL: &[u8] = b"machine-cat-v0.1";
23
24// ── Proof types ──────────────────────────────────────────────
25
26/// An opened column: all values in a column with their Merkle proofs.
27#[derive(Debug, Clone)]
28pub struct ColumnOpening<F: Field> {
29    column_index: usize,
30    values: Vec<F>,
31    merkle_proofs: Vec<MerkleProof>,
32}
33
34impl<F: Field> ColumnOpening<F> {
35    /// The column index.
36    #[must_use]
37    pub fn column_index(&self) -> usize {
38        self.column_index
39    }
40
41    /// The opened values (one per row).
42    #[must_use]
43    pub fn values(&self) -> &[F] {
44        &self.values
45    }
46
47    /// The Merkle proofs (one per value).
48    #[must_use]
49    pub fn merkle_proofs(&self) -> &[MerkleProof] {
50        &self.merkle_proofs
51    }
52}
53
54/// A proof that an AIR's constraints hold over a given trace.
55///
56/// # Examples
57///
58/// ```
59/// use field_cat::F101;
60/// use machine_cat::{Air, FibonacciAir, FibonacciInput, StepCount};
61/// use machine_cat::bridge::{air_prove, air_verify};
62///
63/// let air = FibonacciAir;
64/// let input = FibonacciInput::new(
65///     F101::new(1), F101::new(1), StepCount::new(7),
66/// );
67/// let trace = air.generate_trace(&input)?;
68///
69/// let proof = air_prove(&air, &trace)?;
70/// assert!(air_verify(&air, &proof)?);
71/// # Ok::<(), machine_cat::Error>(())
72/// ```
73#[derive(Debug, Clone)]
74pub struct AirProof<F: Field> {
75    trace_commitment: MerkleRoot,
76    sumcheck: SumcheckProof<F>,
77    column_openings: Vec<ColumnOpening<F>>,
78    row_count: usize,
79}
80
81impl<F: Field> AirProof<F> {
82    /// The Merkle root committing to the trace.
83    #[must_use]
84    pub fn trace_commitment(&self) -> &MerkleRoot {
85        &self.trace_commitment
86    }
87
88    /// The sumcheck proof.
89    #[must_use]
90    pub fn sumcheck(&self) -> &SumcheckProof<F> {
91        &self.sumcheck
92    }
93
94    /// The opened column values with Merkle proofs.
95    #[must_use]
96    pub fn column_openings(&self) -> &[ColumnOpening<F>] {
97        &self.column_openings
98    }
99
100    /// The trace row count.
101    #[must_use]
102    pub fn row_count(&self) -> usize {
103        self.row_count
104    }
105}
106
107// ── Prove ────────────────────────────────────────────────────
108
109/// Prove that a trace satisfies an AIR's constraints.
110///
111/// # Protocol
112///
113/// 1. Validate the trace dimensions and constraint satisfaction.
114/// 2. Commit the trace (all values, row-major) to a Merkle tree.
115/// 3. Squeeze random challenges for batching constraints.
116/// 4. Compute the random-linear-combination of constraint evaluations
117///    at each row pair, producing a vector of length `N-1`.
118/// 5. Pad to a power of two, build a multilinear polynomial, and
119///    run sumcheck (claim: sum = 0).
120/// 6. Open all trace column values with Merkle proofs.
121///
122/// # Errors
123///
124/// Returns an error if the trace does not satisfy the constraints,
125/// or if any internal step fails.
126pub fn air_prove<F: FieldBytes, A: Air<F>>(
127    air: &A,
128    trace: &Trace<F>,
129) -> Result<AirProof<F>, Error> {
130    // 1. Validate dimensions.
131    validate_trace(air, trace)?;
132
133    let constraints = air.constraints();
134    if constraints.is_empty() {
135        Err(Error::NoConstraints)
136    } else {
137        // 2. Validate constraint satisfaction at every row pair.
138        validate_constraints(&constraints, trace)?;
139
140        // 3. Commit trace (row-major flat data).
141        let tree = MerkleTree::from_field_values(trace.data());
142
143        // 4. Initialize transcript.
144        let transcript = Transcript::new(TRANSCRIPT_LABEL)
145            .absorb_bytes(tree.root().as_bytes())
146            .absorb_bytes(&air.column_count().count().to_le_bytes())
147            .absorb_bytes(&constraints.len().to_le_bytes());
148
149        // 5. Squeeze combination challenges (one per constraint).
150        let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
151
152        // 6. Compute combined constraint evaluations at each row pair.
153        let combined_evals = compute_combined_evals(&constraints, &alphas, trace)?;
154
155        // 7. Pad to power of two and build MLE.
156        let padded = pad_to_power_of_two(combined_evals);
157        let poly = MultilinearPoly::from_evals(padded)?;
158
159        // 8. Run sumcheck.
160        let (sumcheck, _, _) = sumcheck_prove(&SumcheckClaim::new(poly, F::zero()), transcript)?;
161
162        // 9. Open all columns.
163        let column_openings = open_all_columns(air, trace, &tree)?;
164
165        Ok(AirProof {
166            trace_commitment: tree.root(),
167            sumcheck,
168            column_openings,
169            row_count: trace.row_count().count(),
170        })
171    }
172}
173
174// ── Verify ───────────────────────────────────────────────────
175
176/// Verify an AIR proof.
177///
178/// Replays the transcript, checks the sumcheck proof, verifies
179/// all Merkle openings, and confirms the final evaluation matches.
180///
181/// # Errors
182///
183/// Returns an error if any verification step fails structurally.
184pub fn air_verify<F: FieldBytes, A: Air<F>>(air: &A, proof: &AirProof<F>) -> Result<bool, Error> {
185    let constraints = air.constraints();
186    if constraints.is_empty() {
187        Err(Error::NoConstraints)
188    } else {
189        // 1. Replay transcript.
190        let transcript = Transcript::new(TRANSCRIPT_LABEL)
191            .absorb_bytes(proof.trace_commitment.as_bytes())
192            .absorb_bytes(&air.column_count().count().to_le_bytes())
193            .absorb_bytes(&constraints.len().to_le_bytes());
194
195        let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
196
197        // 2. Compute padded length and num_vars for sumcheck.
198        let num_row_pairs = proof.row_count.saturating_sub(1);
199        let padded_len = pad_to_power_of_two_len(num_row_pairs);
200        let num_vars = usize::try_from(padded_len.trailing_zeros()).map_err(|_| {
201            Error::TraceNotPowerOfTwo {
202                row_count: padded_len,
203            }
204        })?;
205
206        // 3. Run sumcheck verifier.
207        let (final_eval, challenges, _) = sumcheck_verify(
208            proof.sumcheck(),
209            &F::zero(),
210            proof_cat_core::NumVars::new(num_vars),
211            transcript,
212        )?;
213
214        // 4. Verify Merkle openings.
215        if verify_merkle_openings(proof) {
216            // 5. Reconstruct trace from openings and re-evaluate.
217            let trace = reconstruct_trace(air, proof)?;
218            let combined_evals = compute_combined_evals(&constraints, &alphas, &trace)?;
219            let padded = pad_to_power_of_two(combined_evals);
220            let poly = MultilinearPoly::from_evals(padded)?;
221
222            // 6. Check MLE evaluation at challenges.
223            let expected = poly.evaluate(&challenges)?;
224            Ok(expected == final_eval)
225        } else {
226            Err(Error::ProofCatCore(proof_cat_core::Error::MerkleVerificationFailed))
227        }
228    }
229}
230
231// ── Helpers ──────────────────────────────────────────────────
232
233/// Validate that the trace matches the AIR's column count and has enough rows.
234fn validate_trace<F: Field, A: Air<F>>(air: &A, trace: &Trace<F>) -> Result<(), Error> {
235    if trace.column_count() != air.column_count() {
236        Err(Error::ColumnCountMismatch {
237            expected: air.column_count().count(),
238            actual: trace.column_count().count(),
239        })
240    } else if trace.row_count().count() < 2 {
241        Err(Error::InsufficientRows {
242            row_count: trace.row_count().count(),
243        })
244    } else {
245        Ok(())
246    }
247}
248
249/// Validate that all constraints hold at every row pair.
250fn validate_constraints<F: Field>(
251    constraints: &[AirExpr<F>],
252    trace: &Trace<F>,
253) -> Result<(), Error> {
254    (0..trace.row_count().count() - 1).try_for_each(|row| {
255        let assign = trace.row_pair_assignment(row)?;
256        constraints.iter().try_for_each(|c| {
257            let val = c.evaluate(&assign)?;
258            if val == F::zero() {
259                Ok(())
260            } else {
261                Err(Error::UnsatisfiedAirConstraint { row })
262            }
263        })
264    })
265}
266
267/// Squeeze `count` challenges from the transcript.
268fn squeeze_challenges<F: FieldBytes>(
269    count: usize,
270    transcript: Transcript,
271) -> Result<(Vec<F>, Transcript), Error> {
272    (0..count).try_fold((Vec::with_capacity(count), transcript), |(alphas, t), _| {
273        let (challenge, t): (F, Transcript) = t.squeeze_challenge()?;
274        Ok((
275            alphas
276                .into_iter()
277                .chain(core::iter::once(challenge))
278                .collect(),
279            t,
280        ))
281    })
282}
283
284/// Compute the random-linear-combination of constraint evaluations.
285///
286/// For each row pair `(i, i+1)`, computes `sum_j alpha_j * P_j(row_i, row_{i+1})`.
287fn compute_combined_evals<F: Field>(
288    constraints: &[AirExpr<F>],
289    alphas: &[F],
290    trace: &Trace<F>,
291) -> Result<Vec<F>, Error> {
292    (0..trace.row_count().count() - 1)
293        .map(|row| {
294            let assign = trace.row_pair_assignment(row)?;
295            constraints
296                .iter()
297                .zip(alphas.iter())
298                .try_fold(F::zero(), |acc, (c, alpha)| {
299                    let val = c.evaluate(&assign)?;
300                    Ok(acc + alpha.clone() * val)
301                })
302        })
303        .collect()
304}
305
306/// Open all columns from the Merkle tree.
307fn open_all_columns<F: FieldBytes, A: Air<F>>(
308    air: &A,
309    trace: &Trace<F>,
310    tree: &MerkleTree,
311) -> Result<Vec<ColumnOpening<F>>, Error> {
312    let cols = air.column_count().count();
313    let rows = trace.row_count().count();
314    (0..cols)
315        .map(|col_idx| {
316            let values = trace.column_values(Column::new(col_idx))?;
317            let merkle_proofs: Result<Vec<MerkleProof>, Error> = (0..rows)
318                .map(|row| {
319                    let flat_idx = row * cols + col_idx;
320                    tree.open(flat_idx).map_err(Error::from)
321                })
322                .collect();
323            Ok(ColumnOpening {
324                column_index: col_idx,
325                values,
326                merkle_proofs: merkle_proofs?,
327            })
328        })
329        .collect()
330}
331
332/// Verify all Merkle openings in the proof.
333fn verify_merkle_openings<F: FieldBytes>(proof: &AirProof<F>) -> bool {
334    let cols = proof.column_openings.len();
335    proof.column_openings.iter().all(|opening| {
336        opening.values.iter().enumerate().all(|(row, value)| {
337            let flat_idx = row * cols + opening.column_index;
338            MerkleTree::verify_opening(
339                &proof.trace_commitment,
340                flat_idx,
341                value,
342                &opening.merkle_proofs[row],
343            )
344        })
345    })
346}
347
348/// Reconstruct a Trace from the opened column values.
349fn reconstruct_trace<F: Field, A: Air<F>>(air: &A, proof: &AirProof<F>) -> Result<Trace<F>, Error> {
350    let cols = air.column_count().count();
351    let rows = proof.row_count;
352    let row_vecs: Vec<Vec<F>> = (0..rows)
353        .map(|r| {
354            (0..cols)
355                .map(|c| {
356                    proof
357                        .column_openings
358                        .get(c)
359                        .and_then(|opening| opening.values.get(r).cloned())
360                        .ok_or(Error::ColumnOutOfBounds {
361                            index: c,
362                            column_count: cols,
363                        })
364                })
365                .collect::<Result<Vec<F>, Error>>()
366        })
367        .collect::<Result<Vec<Vec<F>>, Error>>()?;
368    Trace::from_rows(air.column_count(), &row_vecs)
369}
370
371/// Pad a vector to the next power of two with `F::zero()`.
372fn pad_to_power_of_two<F: Field>(v: Vec<F>) -> Vec<F> {
373    let target = pad_to_power_of_two_len(v.len());
374    let padding = target - v.len();
375    v.into_iter()
376        .chain((0..padding).map(|_| F::zero()))
377        .collect()
378}
379
380/// Next power of two >= n (minimum 1).
381fn pad_to_power_of_two_len(n: usize) -> usize {
382    if n <= 1 { 1 } else { n.next_power_of_two() }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use crate::fibonacci::{FibonacciAir, FibonacciInput, StepCount};
389    use field_cat::F101;
390
391    #[test]
392    fn fibonacci_prove_verify_roundtrip() -> Result<(), Error> {
393        let air = FibonacciAir;
394        let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(7));
395        let trace = air.generate_trace(&input)?;
396
397        let proof = air_prove(&air, &trace)?;
398        assert!(air_verify(&air, &proof)?);
399        Ok(())
400    }
401
402    #[test]
403    fn fibonacci_small_trace() -> Result<(), Error> {
404        let air = FibonacciAir;
405        // Minimum: 2 rows (1 step).
406        let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(1));
407        let trace = air.generate_trace(&input)?;
408
409        let proof = air_prove(&air, &trace)?;
410        assert!(air_verify(&air, &proof)?);
411        Ok(())
412    }
413
414    #[test]
415    fn fibonacci_different_initial_values() -> Result<(), Error> {
416        let air = FibonacciAir;
417        let input = FibonacciInput::new(F101::new(3), F101::new(5), StepCount::new(3));
418        let trace = air.generate_trace(&input)?;
419
420        let proof = air_prove(&air, &trace)?;
421        assert!(air_verify(&air, &proof)?);
422        Ok(())
423    }
424
425    #[test]
426    fn invalid_trace_column_count_rejected() {
427        let air = FibonacciAir;
428        // 3 columns instead of 2.
429        let trace = Trace::from_rows(
430            crate::column::ColumnCount::new(3),
431            &[
432                vec![F101::new(1), F101::new(1), F101::new(0)],
433                vec![F101::new(1), F101::new(2), F101::new(0)],
434            ],
435        );
436        match trace {
437            Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
438            Err(_) => {} // Also acceptable
439        }
440    }
441
442    #[test]
443    fn tampered_trace_rejected() {
444        let air = FibonacciAir;
445        // Valid first row, invalid second row.
446        let trace = Trace::from_rows(
447            crate::column::ColumnCount::new(2),
448            &[
449                vec![F101::new(1), F101::new(1)],
450                vec![F101::new(1), F101::new(99)], // Should be 2, not 99
451            ],
452        );
453        match trace {
454            Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
455            Err(_) => {}
456        }
457    }
458}