Skip to main content

machine_cat/
bridge.rs

1//! Trace-to-sumcheck bridge: proving AIR constraint satisfaction.
2//!
3//! Converts an AIR's transition constraints over an execution trace
4//! into a sumcheck claim and delegates to proof-cat for the proof.
5//!
6//! [`air_prove`] produces an [`AirProof`]; [`air_verify`] checks it.
7//! The current protocol opens all trace values (not zero-knowledge).
8
9use proof_cat::commit::merkle::{MerkleProof, MerkleRoot, MerkleTree};
10use proof_cat::poly::MultilinearPoly;
11use proof_cat::sumcheck::{SumcheckClaim, SumcheckProof, sumcheck_prove, sumcheck_verify};
12use proof_cat::transcript::Transcript;
13use proof_cat::FieldBytes;
14
15use plonkish_cat::Field;
16
17use crate::air::Air;
18use crate::air_expr::AirExpr;
19use crate::column::Column;
20use crate::error::Error;
21use crate::trace::Trace;
22
23/// Domain separation label for the AIR proof transcript.
24const TRANSCRIPT_LABEL: &[u8] = b"machine-cat-v0.1";
25
26// ── Proof types ──────────────────────────────────────────────
27
28/// An opened column: all values in a column with their Merkle proofs.
29#[derive(Debug, Clone)]
30pub struct ColumnOpening<F: Field> {
31    column_index: usize,
32    values: Vec<F>,
33    merkle_proofs: Vec<MerkleProof>,
34}
35
36impl<F: Field> ColumnOpening<F> {
37    /// The column index.
38    #[must_use]
39    pub fn column_index(&self) -> usize {
40        self.column_index
41    }
42
43    /// The opened values (one per row).
44    #[must_use]
45    pub fn values(&self) -> &[F] {
46        &self.values
47    }
48
49    /// The Merkle proofs (one per value).
50    #[must_use]
51    pub fn merkle_proofs(&self) -> &[MerkleProof] {
52        &self.merkle_proofs
53    }
54}
55
56/// A proof that an AIR's constraints hold over a given trace.
57///
58/// # Examples
59///
60/// ```
61/// use plonkish_cat::F101;
62/// use machine_cat::{Air, FibonacciAir, FibonacciInput, StepCount};
63/// use machine_cat::bridge::{air_prove, air_verify};
64///
65/// let air = FibonacciAir;
66/// let input = FibonacciInput::new(
67///     F101::new(1), F101::new(1), StepCount::new(7),
68/// );
69/// let trace = air.generate_trace(&input)?;
70///
71/// let proof = air_prove(&air, &trace)?;
72/// assert!(air_verify(&air, &proof)?);
73/// # Ok::<(), machine_cat::Error>(())
74/// ```
75#[derive(Debug, Clone)]
76pub struct AirProof<F: Field> {
77    trace_commitment: MerkleRoot,
78    sumcheck: SumcheckProof<F>,
79    column_openings: Vec<ColumnOpening<F>>,
80    row_count: usize,
81}
82
83impl<F: Field> AirProof<F> {
84    /// The Merkle root committing to the trace.
85    #[must_use]
86    pub fn trace_commitment(&self) -> &MerkleRoot {
87        &self.trace_commitment
88    }
89
90    /// The sumcheck proof.
91    #[must_use]
92    pub fn sumcheck(&self) -> &SumcheckProof<F> {
93        &self.sumcheck
94    }
95
96    /// The opened column values with Merkle proofs.
97    #[must_use]
98    pub fn column_openings(&self) -> &[ColumnOpening<F>] {
99        &self.column_openings
100    }
101
102    /// The trace row count.
103    #[must_use]
104    pub fn row_count(&self) -> usize {
105        self.row_count
106    }
107}
108
109// ── Prove ────────────────────────────────────────────────────
110
111/// Prove that a trace satisfies an AIR's constraints.
112///
113/// # Protocol
114///
115/// 1. Validate the trace dimensions and constraint satisfaction.
116/// 2. Commit the trace (all values, row-major) to a Merkle tree.
117/// 3. Squeeze random challenges for batching constraints.
118/// 4. Compute the random-linear-combination of constraint evaluations
119///    at each row pair, producing a vector of length `N-1`.
120/// 5. Pad to a power of two, build a multilinear polynomial, and
121///    run sumcheck (claim: sum = 0).
122/// 6. Open all trace column values with Merkle proofs.
123///
124/// # Errors
125///
126/// Returns an error if the trace does not satisfy the constraints,
127/// or if any internal step fails.
128pub fn air_prove<F: FieldBytes, A: Air<F>>(
129    air: &A,
130    trace: &Trace<F>,
131) -> Result<AirProof<F>, Error> {
132    // 1. Validate dimensions.
133    validate_trace(air, trace)?;
134
135    let constraints = air.constraints();
136    if constraints.is_empty() {
137        Err(Error::NoConstraints)
138    } else {
139        // 2. Validate constraint satisfaction at every row pair.
140        validate_constraints(&constraints, trace)?;
141
142        // 3. Commit trace (row-major flat data).
143        let tree = MerkleTree::from_field_values(trace.data());
144
145        // 4. Initialize transcript.
146        let transcript = Transcript::new(TRANSCRIPT_LABEL)
147            .absorb_bytes(tree.root().as_bytes())
148            .absorb_bytes(&air.column_count().count().to_le_bytes())
149            .absorb_bytes(&constraints.len().to_le_bytes());
150
151        // 5. Squeeze combination challenges (one per constraint).
152        let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
153
154        // 6. Compute combined constraint evaluations at each row pair.
155        let combined_evals = compute_combined_evals(&constraints, &alphas, trace)?;
156
157        // 7. Pad to power of two and build MLE.
158        let padded = pad_to_power_of_two(combined_evals);
159        let poly = MultilinearPoly::from_evals(padded)?;
160
161        // 8. Run sumcheck.
162        let (sumcheck, _, _) = sumcheck_prove(
163            &SumcheckClaim::new(poly, F::zero()),
164            transcript,
165        )?;
166
167        // 9. Open all columns.
168        let column_openings = open_all_columns(air, trace, &tree)?;
169
170        Ok(AirProof {
171            trace_commitment: tree.root(),
172            sumcheck,
173            column_openings,
174            row_count: trace.row_count().count(),
175        })
176    }
177}
178
179// ── Verify ───────────────────────────────────────────────────
180
181/// Verify an AIR proof.
182///
183/// Replays the transcript, checks the sumcheck proof, verifies
184/// all Merkle openings, and confirms the final evaluation matches.
185///
186/// # Errors
187///
188/// Returns an error if any verification step fails structurally.
189pub fn air_verify<F: FieldBytes, A: Air<F>>(
190    air: &A,
191    proof: &AirProof<F>,
192) -> Result<bool, Error> {
193    let constraints = air.constraints();
194    if constraints.is_empty() {
195        Err(Error::NoConstraints)
196    } else {
197        // 1. Replay transcript.
198        let transcript = Transcript::new(TRANSCRIPT_LABEL)
199            .absorb_bytes(proof.trace_commitment.as_bytes())
200            .absorb_bytes(&air.column_count().count().to_le_bytes())
201            .absorb_bytes(&constraints.len().to_le_bytes());
202
203        let (alphas, transcript) = squeeze_challenges(constraints.len(), transcript)?;
204
205        // 2. Compute padded length and num_vars for sumcheck.
206        let num_row_pairs = proof.row_count.saturating_sub(1);
207        let padded_len = pad_to_power_of_two_len(num_row_pairs);
208        let num_vars = usize::try_from(padded_len.trailing_zeros())
209            .map_err(|_| Error::TraceNotPowerOfTwo { row_count: padded_len })?;
210
211        // 3. Run sumcheck verifier.
212        let (final_eval, challenges, _) = sumcheck_verify(
213            proof.sumcheck(),
214            &F::zero(),
215            proof_cat::NumVars::new(num_vars),
216            transcript,
217        )?;
218
219        // 4. Verify Merkle openings.
220        if verify_merkle_openings(proof) {
221            // 5. Reconstruct trace from openings and re-evaluate.
222            let trace = reconstruct_trace(air, proof)?;
223            let combined_evals = compute_combined_evals(&constraints, &alphas, &trace)?;
224            let padded = pad_to_power_of_two(combined_evals);
225            let poly = MultilinearPoly::from_evals(padded)?;
226
227            // 6. Check MLE evaluation at challenges.
228            let expected = poly.evaluate(&challenges)?;
229            Ok(expected == final_eval)
230        } else {
231            Err(Error::ProofCat(proof_cat::Error::MerkleVerificationFailed))
232        }
233    }
234}
235
236// ── Helpers ──────────────────────────────────────────────────
237
238/// Validate that the trace matches the AIR's column count and has enough rows.
239fn validate_trace<F: Field, A: Air<F>>(air: &A, trace: &Trace<F>) -> Result<(), Error> {
240    if trace.column_count() != air.column_count() {
241        Err(Error::ColumnCountMismatch {
242            expected: air.column_count().count(),
243            actual: trace.column_count().count(),
244        })
245    } else if trace.row_count().count() < 2 {
246        Err(Error::InsufficientRows {
247            row_count: trace.row_count().count(),
248        })
249    } else {
250        Ok(())
251    }
252}
253
254/// Validate that all constraints hold at every row pair.
255fn validate_constraints<F: Field>(
256    constraints: &[AirExpr<F>],
257    trace: &Trace<F>,
258) -> Result<(), Error> {
259    (0..trace.row_count().count() - 1).try_for_each(|row| {
260        let assign = trace.row_pair_assignment(row)?;
261        constraints.iter().try_for_each(|c| {
262            let val = c.evaluate(&assign)?;
263            if val == F::zero() {
264                Ok(())
265            } else {
266                Err(Error::UnsatisfiedAirConstraint { row })
267            }
268        })
269    })
270}
271
272/// Squeeze `count` challenges from the transcript.
273fn squeeze_challenges<F: FieldBytes>(
274    count: usize,
275    transcript: Transcript,
276) -> Result<(Vec<F>, Transcript), Error> {
277    (0..count).try_fold((Vec::with_capacity(count), transcript), |(alphas, t), _| {
278        let (challenge, t): (F, Transcript) = t.squeeze_challenge()?;
279        Ok((
280            alphas.into_iter().chain(core::iter::once(challenge)).collect(),
281            t,
282        ))
283    })
284}
285
286/// Compute the random-linear-combination of constraint evaluations.
287///
288/// For each row pair `(i, i+1)`, computes `sum_j alpha_j * P_j(row_i, row_{i+1})`.
289fn compute_combined_evals<F: Field>(
290    constraints: &[AirExpr<F>],
291    alphas: &[F],
292    trace: &Trace<F>,
293) -> Result<Vec<F>, Error> {
294    (0..trace.row_count().count() - 1)
295        .map(|row| {
296            let assign = trace.row_pair_assignment(row)?;
297            constraints
298                .iter()
299                .zip(alphas.iter())
300                .try_fold(F::zero(), |acc, (c, alpha)| {
301                    let val = c.evaluate(&assign)?;
302                    Ok(acc + alpha.clone() * val)
303                })
304        })
305        .collect()
306}
307
308/// Open all columns from the Merkle tree.
309fn open_all_columns<F: FieldBytes, A: Air<F>>(
310    air: &A,
311    trace: &Trace<F>,
312    tree: &MerkleTree,
313) -> Result<Vec<ColumnOpening<F>>, Error> {
314    let cols = air.column_count().count();
315    let rows = trace.row_count().count();
316    (0..cols)
317        .map(|col_idx| {
318            let values = trace.column_values(Column::new(col_idx))?;
319            let merkle_proofs: Result<Vec<MerkleProof>, Error> = (0..rows)
320                .map(|row| {
321                    let flat_idx = row * cols + col_idx;
322                    tree.open(flat_idx).map_err(Error::from)
323                })
324                .collect();
325            Ok(ColumnOpening {
326                column_index: col_idx,
327                values,
328                merkle_proofs: merkle_proofs?,
329            })
330        })
331        .collect()
332}
333
334/// Verify all Merkle openings in the proof.
335fn verify_merkle_openings<F: FieldBytes>(proof: &AirProof<F>) -> bool {
336    let cols = proof.column_openings.len();
337    proof.column_openings.iter().all(|opening| {
338        opening
339            .values
340            .iter()
341            .enumerate()
342            .all(|(row, value)| {
343                let flat_idx = row * cols + opening.column_index;
344                MerkleTree::verify_opening(
345                    &proof.trace_commitment,
346                    flat_idx,
347                    value,
348                    &opening.merkle_proofs[row],
349                )
350            })
351    })
352}
353
354/// Reconstruct a Trace from the opened column values.
355fn reconstruct_trace<F: Field, A: Air<F>>(
356    air: &A,
357    proof: &AirProof<F>,
358) -> Result<Trace<F>, Error> {
359    let cols = air.column_count().count();
360    let rows = proof.row_count;
361    let row_vecs: Vec<Vec<F>> = (0..rows)
362        .map(|r| {
363            (0..cols)
364                .map(|c| {
365                    proof.column_openings
366                        .get(c)
367                        .and_then(|opening| opening.values.get(r).cloned())
368                        .ok_or(Error::ColumnOutOfBounds {
369                            index: c,
370                            column_count: cols,
371                        })
372                })
373                .collect::<Result<Vec<F>, Error>>()
374        })
375        .collect::<Result<Vec<Vec<F>>, Error>>()?;
376    Trace::from_rows(air.column_count(), &row_vecs)
377}
378
379/// Pad a vector to the next power of two with `F::zero()`.
380fn pad_to_power_of_two<F: Field>(v: Vec<F>) -> Vec<F> {
381    let target = pad_to_power_of_two_len(v.len());
382    let padding = target - v.len();
383    v.into_iter()
384        .chain((0..padding).map(|_| F::zero()))
385        .collect()
386}
387
388/// Next power of two >= n (minimum 1).
389fn pad_to_power_of_two_len(n: usize) -> usize {
390    if n <= 1 { 1 } else { n.next_power_of_two() }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::fibonacci::{FibonacciAir, FibonacciInput, StepCount};
397    use plonkish_cat::F101;
398
399    #[test]
400    fn fibonacci_prove_verify_roundtrip() -> Result<(), Error> {
401        let air = FibonacciAir;
402        let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(7));
403        let trace = air.generate_trace(&input)?;
404
405        let proof = air_prove(&air, &trace)?;
406        assert!(air_verify(&air, &proof)?);
407        Ok(())
408    }
409
410    #[test]
411    fn fibonacci_small_trace() -> Result<(), Error> {
412        let air = FibonacciAir;
413        // Minimum: 2 rows (1 step).
414        let input = FibonacciInput::new(F101::new(1), F101::new(1), StepCount::new(1));
415        let trace = air.generate_trace(&input)?;
416
417        let proof = air_prove(&air, &trace)?;
418        assert!(air_verify(&air, &proof)?);
419        Ok(())
420    }
421
422    #[test]
423    fn fibonacci_different_initial_values() -> Result<(), Error> {
424        let air = FibonacciAir;
425        let input = FibonacciInput::new(F101::new(3), F101::new(5), StepCount::new(3));
426        let trace = air.generate_trace(&input)?;
427
428        let proof = air_prove(&air, &trace)?;
429        assert!(air_verify(&air, &proof)?);
430        Ok(())
431    }
432
433    #[test]
434    fn invalid_trace_column_count_rejected() {
435        let air = FibonacciAir;
436        // 3 columns instead of 2.
437        let trace = Trace::from_rows(
438            crate::column::ColumnCount::new(3),
439            &[
440                vec![F101::new(1), F101::new(1), F101::new(0)],
441                vec![F101::new(1), F101::new(2), F101::new(0)],
442            ],
443        );
444        match trace {
445            Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
446            Err(_) => {} // Also acceptable
447        }
448    }
449
450    #[test]
451    fn tampered_trace_rejected() {
452        let air = FibonacciAir;
453        // Valid first row, invalid second row.
454        let trace = Trace::from_rows(
455            crate::column::ColumnCount::new(2),
456            &[
457                vec![F101::new(1), F101::new(1)],
458                vec![F101::new(1), F101::new(99)], // Should be 2, not 99
459            ],
460        );
461        match trace {
462            Ok(t) => assert!(air_prove::<F101, _>(&air, &t).is_err()),
463            Err(_) => {}
464        }
465    }
466}