Skip to main content

p3_air/symbolic/
mod.rs

1//! Symbolic expression types for AIR constraint representation.
2
3mod builder;
4mod expression;
5pub(crate) mod expression_ext;
6mod variable;
7
8use alloc::sync::Arc;
9use core::iter::{Product, Sum};
10use core::ops;
11
12pub use builder::*;
13pub use expression::{BaseLeaf, SymbolicExpression};
14pub use expression_ext::{ExtLeaf, SymbolicExpressionExt};
15use p3_field::{ExtensionField, Field, PrimeCharacteristicRing};
16pub use variable::{BaseEntry, ExtEntry, SymbolicVariable, SymbolicVariableExt};
17
18/// Properties that leaf nodes must provide for the generic expression tree.
19///
20/// Both [`BaseLeaf`](expression::BaseLeaf) (base-field) and
21/// [`ExtLeaf`](expression_ext::ExtLeaf) (extension-field) implement this trait,
22/// enabling [`SymbolicExpr`] to handle constant folding, degree tracking, and
23/// arithmetic generically.
24pub trait SymLeaf: Clone + core::fmt::Debug {
25    /// The base field type used for constant folding.
26    type F: Field;
27
28    const ZERO: Self;
29    const ONE: Self;
30    const TWO: Self;
31    const NEG_ONE: Self;
32
33    /// Returns the degree multiple of this leaf.
34    fn degree_multiple(&self) -> usize;
35
36    /// Try to view this leaf as a base-field constant.
37    fn as_const(&self) -> Option<&Self::F>;
38
39    /// Create a leaf from a base-field constant.
40    fn from_const(c: Self::F) -> Self;
41}
42
43/// A symbolic expression tree, generic over its leaf type `A`.
44///
45/// This enum captures the shared tree structure — Add/Sub/Neg/Mul nodes with
46/// `Arc`-wrapped children and cached degree multiples — used by both base-field
47/// and extension-field symbolic expressions.
48///
49/// Concrete types are provided via type aliases:
50/// - [`SymbolicExpression<F>`] = `SymbolicExpr<BaseLeaf<F>>` (base-field constraints)
51/// - [`SymbolicExpressionExt<F, EF>`] = `SymbolicExpr<ExtLeaf<F, EF>>` (extension-field constraints)
52#[derive(Clone, Debug)]
53pub enum SymbolicExpr<A> {
54    /// A leaf node (variable, constant, selector, or lifted sub-expression).
55    Leaf(A),
56
57    /// Addition of two sub-expressions.
58    Add {
59        x: Arc<Self>,
60        y: Arc<Self>,
61        degree_multiple: usize,
62    },
63
64    /// Subtraction of two sub-expressions.
65    Sub {
66        x: Arc<Self>,
67        y: Arc<Self>,
68        degree_multiple: usize,
69    },
70
71    /// Negation of a sub-expression.
72    Neg {
73        x: Arc<Self>,
74        degree_multiple: usize,
75    },
76
77    /// Multiplication of two sub-expressions.
78    Mul {
79        x: Arc<Self>,
80        y: Arc<Self>,
81        degree_multiple: usize,
82    },
83}
84
85impl<A: SymLeaf> SymbolicExpr<A> {
86    /// Returns the degree multiple of this expression.
87    pub fn degree_multiple(&self) -> usize {
88        match self {
89            Self::Leaf(a) => a.degree_multiple(),
90            Self::Add {
91                degree_multiple, ..
92            }
93            | Self::Sub {
94                degree_multiple, ..
95            }
96            | Self::Neg {
97                degree_multiple, ..
98            }
99            | Self::Mul {
100                degree_multiple, ..
101            } => *degree_multiple,
102        }
103    }
104
105    /// Try to view this expression as a base-field constant.
106    fn as_const(&self) -> Option<&A::F> {
107        match self {
108            Self::Leaf(a) => a.as_const(),
109            _ => None,
110        }
111    }
112
113    /// Addition with constant folding and zero-identity elimination.
114    fn sym_add(self, rhs: Self) -> Self {
115        if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
116            return Self::Leaf(A::from_const(a + b));
117        }
118        if self.as_const().is_some_and(|c| c.is_zero()) {
119            return rhs;
120        }
121        if rhs.as_const().is_some_and(|c| c.is_zero()) {
122            return self;
123        }
124        let dm = self.degree_multiple().max(rhs.degree_multiple());
125        Self::Add {
126            x: Arc::new(self),
127            y: Arc::new(rhs),
128            degree_multiple: dm,
129        }
130    }
131
132    /// Subtraction with constant folding and zero-identity elimination.
133    fn sym_sub(self, rhs: Self) -> Self {
134        if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
135            return Self::Leaf(A::from_const(a - b));
136        }
137        if self.as_const().is_some_and(|c| c.is_zero()) {
138            return rhs.sym_neg();
139        }
140        if rhs.as_const().is_some_and(|c| c.is_zero()) {
141            return self;
142        }
143        let dm = self.degree_multiple().max(rhs.degree_multiple());
144        Self::Sub {
145            x: Arc::new(self),
146            y: Arc::new(rhs),
147            degree_multiple: dm,
148        }
149    }
150
151    /// Negation with constant folding.
152    fn sym_neg(self) -> Self {
153        if let Some(&c) = self.as_const() {
154            return Self::Leaf(A::from_const(-c));
155        }
156        let dm = self.degree_multiple();
157        Self::Neg {
158            x: Arc::new(self),
159            degree_multiple: dm,
160        }
161    }
162
163    /// Multiplication with constant folding, zero-annihilation, and one-identity.
164    fn sym_mul(self, rhs: Self) -> Self {
165        if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
166            return Self::Leaf(A::from_const(a * b));
167        }
168        if self.as_const().is_some_and(|c| c.is_zero())
169            || rhs.as_const().is_some_and(|c| c.is_zero())
170        {
171            return Self::Leaf(A::from_const(A::F::ZERO));
172        }
173        if self.as_const().is_some_and(|c| c.is_one()) {
174            return rhs;
175        }
176        if rhs.as_const().is_some_and(|c| c.is_one()) {
177            return self;
178        }
179        let dm = self.degree_multiple() + rhs.degree_multiple();
180        Self::Mul {
181            x: Arc::new(self),
182            y: Arc::new(rhs),
183            degree_multiple: dm,
184        }
185    }
186}
187
188impl<A: SymLeaf> PrimeCharacteristicRing for SymbolicExpr<A> {
189    type PrimeSubfield = <A::F as PrimeCharacteristicRing>::PrimeSubfield;
190
191    const ZERO: Self = Self::Leaf(A::ZERO);
192    const ONE: Self = Self::Leaf(A::ONE);
193    const TWO: Self = Self::Leaf(A::TWO);
194    const NEG_ONE: Self = Self::Leaf(A::NEG_ONE);
195
196    #[inline]
197    fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
198        Self::Leaf(A::from_const(A::F::from_prime_subfield(f)))
199    }
200}
201
202impl<A: SymLeaf> Default for SymbolicExpr<A> {
203    fn default() -> Self {
204        Self::ZERO
205    }
206}
207
208impl<A: SymLeaf, T: Into<Self>> ops::Add<T> for SymbolicExpr<A> {
209    type Output = Self;
210    fn add(self, rhs: T) -> Self {
211        self.sym_add(rhs.into())
212    }
213}
214
215impl<A: SymLeaf, T: Into<Self>> ops::Sub<T> for SymbolicExpr<A> {
216    type Output = Self;
217    fn sub(self, rhs: T) -> Self {
218        self.sym_sub(rhs.into())
219    }
220}
221
222impl<A: SymLeaf> ops::Neg for SymbolicExpr<A> {
223    type Output = Self;
224    fn neg(self) -> Self {
225        self.sym_neg()
226    }
227}
228
229impl<A: SymLeaf, T: Into<Self>> ops::Mul<T> for SymbolicExpr<A> {
230    type Output = Self;
231    fn mul(self, rhs: T) -> Self {
232        self.sym_mul(rhs.into())
233    }
234}
235
236impl<A: SymLeaf, T: Into<Self>> ops::AddAssign<T> for SymbolicExpr<A> {
237    fn add_assign(&mut self, rhs: T) {
238        *self = self.clone() + rhs.into();
239    }
240}
241
242impl<A: SymLeaf, T: Into<Self>> ops::SubAssign<T> for SymbolicExpr<A> {
243    fn sub_assign(&mut self, rhs: T) {
244        *self = self.clone() - rhs.into();
245    }
246}
247
248impl<A: SymLeaf, T: Into<Self>> ops::MulAssign<T> for SymbolicExpr<A> {
249    fn mul_assign(&mut self, rhs: T) {
250        *self = self.clone() * rhs.into();
251    }
252}
253
254impl<A: SymLeaf, T: Into<Self>> Sum<T> for SymbolicExpr<A> {
255    fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
256        iter.map(Into::into)
257            .reduce(|a, b| a + b)
258            .unwrap_or(Self::ZERO)
259    }
260}
261
262impl<A: SymLeaf, T: Into<Self>> Product<T> for SymbolicExpr<A> {
263    fn product<I: Iterator<Item = T>>(iter: I) -> Self {
264        iter.map(Into::into)
265            .reduce(|a, b| a * b)
266            .unwrap_or(Self::ONE)
267    }
268}
269
270impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Add<T> for SymbolicVariable<F> {
271    type Output = SymbolicExpression<F>;
272    fn add(self, rhs: T) -> Self::Output {
273        Self::Output::from(self) + rhs.into()
274    }
275}
276
277impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Sub<T> for SymbolicVariable<F> {
278    type Output = SymbolicExpression<F>;
279    fn sub(self, rhs: T) -> Self::Output {
280        Self::Output::from(self) - rhs.into()
281    }
282}
283
284impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Mul<T> for SymbolicVariable<F> {
285    type Output = SymbolicExpression<F>;
286    fn mul(self, rhs: T) -> Self::Output {
287        Self::Output::from(self) * rhs.into()
288    }
289}
290
291impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Add<T>
292    for SymbolicVariableExt<F, EF>
293{
294    type Output = SymbolicExpressionExt<F, EF>;
295    fn add(self, rhs: T) -> Self::Output {
296        Self::Output::from(self) + rhs.into()
297    }
298}
299
300impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Sub<T>
301    for SymbolicVariableExt<F, EF>
302{
303    type Output = SymbolicExpressionExt<F, EF>;
304    fn sub(self, rhs: T) -> Self::Output {
305        Self::Output::from(self) - rhs.into()
306    }
307}
308
309impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Mul<T>
310    for SymbolicVariableExt<F, EF>
311{
312    type Output = SymbolicExpressionExt<F, EF>;
313    fn mul(self, rhs: T) -> Self::Output {
314        Self::Output::from(self) * rhs.into()
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use p3_baby_bear::BabyBear;
321    use p3_field::extension::BinomialExtensionField;
322
323    use super::*;
324    use crate::symbolic::expression::BaseLeaf;
325    use crate::symbolic::expression_ext::ExtLeaf;
326    use crate::symbolic::variable::{BaseEntry, ExtEntry};
327
328    type F = BabyBear;
329    type EF = BinomialExtensionField<BabyBear, 4>;
330
331    #[test]
332    fn symbolic_variable_add_produces_add_node() {
333        // Adding a variable and a non-zero constant creates an addition node.
334        let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
335        let expr = SymbolicExpression::from(F::new(5));
336        let result = var + expr;
337        match result {
338            SymbolicExpr::Add {
339                x,
340                y,
341                degree_multiple,
342            } => {
343                assert_eq!(degree_multiple, 1);
344                assert!(matches!(
345                    x.as_ref(),
346                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
347                        if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
348                ));
349                assert!(matches!(
350                    y.as_ref(),
351                    SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if *c == F::new(5)
352                ));
353            }
354            _ => panic!("Expected an Add node"),
355        }
356    }
357
358    #[test]
359    fn symbolic_variable_sub_produces_sub_node() {
360        // Subtracting two variables creates a subtraction node.
361        let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
362        let other = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
363            BaseEntry::Main { offset: 0 },
364            1,
365        )));
366        let result = var - other;
367        match result {
368            SymbolicExpr::Sub {
369                x,
370                y,
371                degree_multiple,
372            } => {
373                assert_eq!(degree_multiple, 1);
374                assert!(matches!(
375                    x.as_ref(),
376                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
377                        if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
378                ));
379                assert!(matches!(
380                    y.as_ref(),
381                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
382                        if v.index == 1 && v.entry == BaseEntry::Main { offset: 0 }
383                ));
384            }
385            _ => panic!("Expected a Sub node"),
386        }
387    }
388
389    #[test]
390    fn symbolic_variable_mul_produces_mul_node() {
391        // Multiplying two variables creates a multiplication node with summed degree.
392        let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
393        let other = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
394            BaseEntry::Main { offset: 0 },
395            1,
396        )));
397        let result = var * other;
398        match result {
399            SymbolicExpr::Mul {
400                x,
401                y,
402                degree_multiple,
403            } => {
404                assert_eq!(degree_multiple, 2);
405                assert!(matches!(
406                    x.as_ref(),
407                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
408                        if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
409                ));
410                assert!(matches!(
411                    y.as_ref(),
412                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
413                        if v.index == 1 && v.entry == BaseEntry::Main { offset: 0 }
414                ));
415            }
416            _ => panic!("Expected a Mul node"),
417        }
418    }
419
420    #[test]
421    fn symbolic_variable_ext_add_produces_add_node() {
422        // Adding an extension variable and a non-zero constant creates an addition node.
423        let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
424        let expr = SymbolicExpressionExt::<F, EF>::from(F::new(3));
425        let result = var + expr;
426        match result {
427            SymbolicExpr::Add {
428                x,
429                y,
430                degree_multiple,
431            } => {
432                assert_eq!(degree_multiple, 1);
433                assert!(matches!(
434                    x.as_ref(),
435                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
436                        if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
437                ));
438                assert!(matches!(
439                    y.as_ref(),
440                    SymbolicExpr::Leaf(ExtLeaf::Base(SymbolicExpr::Leaf(BaseLeaf::Constant(c))))
441                        if *c == F::new(3)
442                ));
443            }
444            _ => panic!("Expected an Add node"),
445        }
446    }
447
448    #[test]
449    fn symbolic_variable_ext_sub_produces_sub_node() {
450        // Subtracting two extension variables creates a subtraction node.
451        let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
452        let other = SymbolicExpressionExt::<F, EF>::from(SymbolicVariableExt::<F, EF>::new(
453            ExtEntry::Permutation { offset: 0 },
454            1,
455        ));
456        let result = var - other;
457        match result {
458            SymbolicExpr::Sub {
459                x,
460                y,
461                degree_multiple,
462            } => {
463                assert_eq!(degree_multiple, 1);
464                assert!(matches!(
465                    x.as_ref(),
466                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
467                        if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
468                ));
469                assert!(matches!(
470                    y.as_ref(),
471                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
472                        if v.index == 1 && v.entry == ExtEntry::Permutation { offset: 0 }
473                ));
474            }
475            _ => panic!("Expected a Sub node"),
476        }
477    }
478
479    #[test]
480    fn symbolic_variable_ext_mul_produces_mul_node() {
481        // Multiplying two extension variables creates a multiplication node with summed degree.
482        let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
483        let other = SymbolicExpressionExt::<F, EF>::from(SymbolicVariableExt::<F, EF>::new(
484            ExtEntry::Permutation { offset: 0 },
485            1,
486        ));
487        let result = var * other;
488        match result {
489            SymbolicExpr::Mul {
490                x,
491                y,
492                degree_multiple,
493            } => {
494                assert_eq!(degree_multiple, 2);
495                assert!(matches!(
496                    x.as_ref(),
497                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
498                        if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
499                ));
500                assert!(matches!(
501                    y.as_ref(),
502                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
503                        if v.index == 1 && v.entry == ExtEntry::Permutation { offset: 0 }
504                ));
505            }
506            _ => panic!("Expected a Mul node"),
507        }
508    }
509}