1use crate::column::{Column, ColumnRef};
10use crate::error::Error;
11use field_cat::Field;
12
13#[derive(Debug, Clone)]
30pub enum AirExpr<F: Field> {
31 Constant(F),
33 Ref(ColumnRef),
35 Neg(Box<AirExpr<F>>),
37 Sum(Box<AirExpr<F>>, Box<AirExpr<F>>),
39 Product(Box<AirExpr<F>>, Box<AirExpr<F>>),
41}
42
43impl<F: Field> AirExpr<F> {
44 #[must_use]
46 pub fn constant(c: F) -> Self {
47 Self::Constant(c)
48 }
49
50 #[must_use]
52 pub fn current(col: Column) -> Self {
53 Self::Ref(ColumnRef::Current(col))
54 }
55
56 #[must_use]
58 pub fn next(col: Column) -> Self {
59 Self::Ref(ColumnRef::Next(col))
60 }
61
62 pub fn evaluate<A: Fn(ColumnRef) -> Result<F, Error>>(
72 &self,
73 assignment: &A,
74 ) -> Result<F, Error> {
75 match self {
76 Self::Constant(c) => Ok(c.clone()),
77 Self::Ref(cr) => assignment(*cr),
78 Self::Neg(inner) => inner.evaluate(assignment).map(|v| -v),
79 Self::Sum(left, right) => {
80 let l = left.evaluate(assignment)?;
81 let r = right.evaluate(assignment)?;
82 Ok(l + r)
83 }
84 Self::Product(left, right) => {
85 let l = left.evaluate(assignment)?;
86 let r = right.evaluate(assignment)?;
87 Ok(l * r)
88 }
89 }
90 }
91}
92
93impl<F: Field> std::ops::Add for AirExpr<F> {
94 type Output = Self;
95 fn add(self, rhs: Self) -> Self {
96 Self::Sum(Box::new(self), Box::new(rhs))
97 }
98}
99
100impl<F: Field> std::ops::Sub for AirExpr<F> {
101 type Output = Self;
102 fn sub(self, rhs: Self) -> Self {
103 self + (-rhs)
104 }
105}
106
107impl<F: Field> std::ops::Mul for AirExpr<F> {
108 type Output = Self;
109 fn mul(self, rhs: Self) -> Self {
110 Self::Product(Box::new(self), Box::new(rhs))
111 }
112}
113
114impl<F: Field> std::ops::Neg for AirExpr<F> {
115 type Output = Self;
116 fn neg(self) -> Self {
117 Self::Neg(Box::new(self))
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use field_cat::F101;
125
126 fn two_col_assignment(
127 curr: Vec<F101>,
128 next: Vec<F101>,
129 ) -> impl Fn(ColumnRef) -> Result<F101, Error> {
130 move |cr| match cr {
131 ColumnRef::Current(c) => curr
132 .get(c.index())
133 .cloned()
134 .ok_or(Error::ColumnOutOfBounds {
135 index: c.index(),
136 column_count: curr.len(),
137 }),
138 ColumnRef::Next(c) => next
139 .get(c.index())
140 .cloned()
141 .ok_or(Error::ColumnOutOfBounds {
142 index: c.index(),
143 column_count: next.len(),
144 }),
145 }
146 }
147
148 #[test]
149 fn constant_evaluates_to_itself() -> Result<(), Error> {
150 let e = AirExpr::constant(F101::new(42));
151 let assign = two_col_assignment(vec![], vec![]);
152 assert_eq!(e.evaluate(&assign)?, F101::new(42));
153 Ok(())
154 }
155
156 #[test]
157 fn current_ref_evaluates() -> Result<(), Error> {
158 let e = AirExpr::<F101>::current(Column::new(1));
159 let assign = two_col_assignment(
160 vec![F101::new(10), F101::new(20)],
161 vec![F101::new(30), F101::new(40)],
162 );
163 assert_eq!(e.evaluate(&assign)?, F101::new(20));
164 Ok(())
165 }
166
167 #[test]
168 fn next_ref_evaluates() -> Result<(), Error> {
169 let e = AirExpr::<F101>::next(Column::new(0));
170 let assign = two_col_assignment(
171 vec![F101::new(10), F101::new(20)],
172 vec![F101::new(30), F101::new(40)],
173 );
174 assert_eq!(e.evaluate(&assign)?, F101::new(30));
175 Ok(())
176 }
177
178 #[test]
179 fn arithmetic_works() -> Result<(), Error> {
180 let current_a = AirExpr::<F101>::current(Column::new(0));
183 let current_b = AirExpr::<F101>::current(Column::new(1));
184 let next_b = AirExpr::<F101>::next(Column::new(1));
185 let expr = next_b - current_a - current_b;
186
187 let assign = two_col_assignment(
188 vec![F101::new(3), F101::new(5)],
189 vec![F101::new(5), F101::new(8)],
190 );
191 assert_eq!(expr.evaluate(&assign)?, F101::zero());
192 Ok(())
193 }
194
195 #[test]
196 fn product_evaluates() -> Result<(), Error> {
197 let expr = AirExpr::<F101>::current(Column::new(0)) * AirExpr::current(Column::new(1));
199 let assign = two_col_assignment(vec![F101::new(3), F101::new(5)], vec![]);
200 assert_eq!(expr.evaluate(&assign)?, F101::new(15));
201 Ok(())
202 }
203
204 #[test]
205 fn out_of_bounds_column_fails() {
206 let e = AirExpr::<F101>::current(Column::new(5));
207 let assign = two_col_assignment(vec![F101::new(1)], vec![]);
208 assert!(e.evaluate(&assign).is_err());
209 }
210}