Skip to main content

machine_cat/
air_expr.rs

1//! Symbolic constraint expressions with row-relative addressing.
2//!
3//! [`AirExpr<F>`] mirrors plonkish-cat's [`Expression`](plonkish_cat::Expression)
4//! but uses [`ColumnRef`] (current row / next row) instead of absolute
5//! [`Wire`](plonkish_cat::Wire) indices.  Constraints built from
6//! `AirExpr` must evaluate to zero at every consecutive row pair
7//! in the execution trace.
8
9use crate::column::{Column, ColumnRef};
10use crate::error::Error;
11use field_cat::Field;
12
13/// A symbolic polynomial expression over row-relative column references.
14///
15/// Used to define AIR transition constraints: expressions that must
16/// equal zero for every consecutive row pair `(row[i], row[i+1])`.
17///
18/// # Examples
19///
20/// ```
21/// use field_cat::F101;
22/// use machine_cat::{AirExpr, Column};
23///
24/// // Constraint: next_a - current_b = 0
25/// let current_b = AirExpr::<F101>::current(Column::new(1));
26/// let next_a = AirExpr::<F101>::next(Column::new(0));
27/// let constraint = next_a - current_b;
28/// ```
29#[derive(Debug, Clone)]
30pub enum AirExpr<F: Field> {
31    /// A field constant.
32    Constant(F),
33    /// A row-relative column reference.
34    Ref(ColumnRef),
35    /// Negation.
36    Neg(Box<AirExpr<F>>),
37    /// Sum of two expressions.
38    Sum(Box<AirExpr<F>>, Box<AirExpr<F>>),
39    /// Product of two expressions.
40    Product(Box<AirExpr<F>>, Box<AirExpr<F>>),
41}
42
43impl<F: Field> AirExpr<F> {
44    /// A constant expression.
45    #[must_use]
46    pub fn constant(c: F) -> Self {
47        Self::Constant(c)
48    }
49
50    /// Reference a column in the current row.
51    #[must_use]
52    pub fn current(col: Column) -> Self {
53        Self::Ref(ColumnRef::Current(col))
54    }
55
56    /// Reference a column in the next row.
57    #[must_use]
58    pub fn next(col: Column) -> Self {
59        Self::Ref(ColumnRef::Next(col))
60    }
61
62    /// Evaluate this expression given a row-pair assignment.
63    ///
64    /// The assignment maps each [`ColumnRef`] to a field value
65    /// for a specific `(row[i], row[i+1])` pair.
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if the assignment fails for any
70    /// referenced column (e.g., column out of bounds).
71    pub fn evaluate<A: Fn(ColumnRef) -> Result<F, Error>>(
72        &self,
73        assignment: &A,
74    ) -> Result<F, Error> {
75        match self {
76            Self::Constant(c) => Ok(c.clone()),
77            Self::Ref(cr) => assignment(*cr),
78            Self::Neg(inner) => inner.evaluate(assignment).map(|v| -v),
79            Self::Sum(left, right) => {
80                let l = left.evaluate(assignment)?;
81                let r = right.evaluate(assignment)?;
82                Ok(l + r)
83            }
84            Self::Product(left, right) => {
85                let l = left.evaluate(assignment)?;
86                let r = right.evaluate(assignment)?;
87                Ok(l * r)
88            }
89        }
90    }
91}
92
93impl<F: Field> std::ops::Add for AirExpr<F> {
94    type Output = Self;
95    fn add(self, rhs: Self) -> Self {
96        Self::Sum(Box::new(self), Box::new(rhs))
97    }
98}
99
100impl<F: Field> std::ops::Sub for AirExpr<F> {
101    type Output = Self;
102    fn sub(self, rhs: Self) -> Self {
103        self + (-rhs)
104    }
105}
106
107impl<F: Field> std::ops::Mul for AirExpr<F> {
108    type Output = Self;
109    fn mul(self, rhs: Self) -> Self {
110        Self::Product(Box::new(self), Box::new(rhs))
111    }
112}
113
114impl<F: Field> std::ops::Neg for AirExpr<F> {
115    type Output = Self;
116    fn neg(self) -> Self {
117        Self::Neg(Box::new(self))
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use field_cat::F101;
125
126    fn two_col_assignment(
127        curr: Vec<F101>,
128        next: Vec<F101>,
129    ) -> impl Fn(ColumnRef) -> Result<F101, Error> {
130        move |cr| match cr {
131            ColumnRef::Current(c) => curr
132                .get(c.index())
133                .cloned()
134                .ok_or(Error::ColumnOutOfBounds {
135                    index: c.index(),
136                    column_count: curr.len(),
137                }),
138            ColumnRef::Next(c) => next
139                .get(c.index())
140                .cloned()
141                .ok_or(Error::ColumnOutOfBounds {
142                    index: c.index(),
143                    column_count: next.len(),
144                }),
145        }
146    }
147
148    #[test]
149    fn constant_evaluates_to_itself() -> Result<(), Error> {
150        let e = AirExpr::constant(F101::new(42));
151        let assign = two_col_assignment(vec![], vec![]);
152        assert_eq!(e.evaluate(&assign)?, F101::new(42));
153        Ok(())
154    }
155
156    #[test]
157    fn current_ref_evaluates() -> Result<(), Error> {
158        let e = AirExpr::<F101>::current(Column::new(1));
159        let assign = two_col_assignment(
160            vec![F101::new(10), F101::new(20)],
161            vec![F101::new(30), F101::new(40)],
162        );
163        assert_eq!(e.evaluate(&assign)?, F101::new(20));
164        Ok(())
165    }
166
167    #[test]
168    fn next_ref_evaluates() -> Result<(), Error> {
169        let e = AirExpr::<F101>::next(Column::new(0));
170        let assign = two_col_assignment(
171            vec![F101::new(10), F101::new(20)],
172            vec![F101::new(30), F101::new(40)],
173        );
174        assert_eq!(e.evaluate(&assign)?, F101::new(30));
175        Ok(())
176    }
177
178    #[test]
179    fn arithmetic_works() -> Result<(), Error> {
180        // next_b - current_a - current_b = 0
181        // With curr = [3, 5], next = [5, 8]: 8 - 3 - 5 = 0
182        let current_a = AirExpr::<F101>::current(Column::new(0));
183        let current_b = AirExpr::<F101>::current(Column::new(1));
184        let next_b = AirExpr::<F101>::next(Column::new(1));
185        let expr = next_b - current_a - current_b;
186
187        let assign = two_col_assignment(
188            vec![F101::new(3), F101::new(5)],
189            vec![F101::new(5), F101::new(8)],
190        );
191        assert_eq!(expr.evaluate(&assign)?, F101::zero());
192        Ok(())
193    }
194
195    #[test]
196    fn product_evaluates() -> Result<(), Error> {
197        // current_a * current_b = 3 * 5 = 15
198        let expr = AirExpr::<F101>::current(Column::new(0)) * AirExpr::current(Column::new(1));
199        let assign = two_col_assignment(vec![F101::new(3), F101::new(5)], vec![]);
200        assert_eq!(expr.evaluate(&assign)?, F101::new(15));
201        Ok(())
202    }
203
204    #[test]
205    fn out_of_bounds_column_fails() {
206        let e = AirExpr::<F101>::current(Column::new(5));
207        let assign = two_col_assignment(vec![F101::new(1)], vec![]);
208        assert!(e.evaluate(&assign).is_err());
209    }
210}