Skip to main content

p3_air/symbolic/
mod.rs

1//! Symbolic expression types for AIR constraint representation.
2
3mod builder;
4mod expression;
5pub(crate) mod expression_ext;
6mod variable;
7
8use alloc::sync::Arc;
9use core::iter::{Product, Sum};
10use core::ops;
11
12pub use builder::*;
13pub use expression::{BaseLeaf, SymbolicExpression};
14pub use expression_ext::{ExtLeaf, SymbolicExpressionExt};
15use p3_field::{Dup, ExtensionField, Field, PrimeCharacteristicRing};
16pub use variable::{BaseEntry, ExtEntry, SymbolicVariable, SymbolicVariableExt};
17
18/// Properties that leaf nodes must provide for the generic expression tree.
19///
20/// Both [`BaseLeaf`](expression::BaseLeaf) (base-field) and
21/// [`ExtLeaf`](expression_ext::ExtLeaf) (extension-field) implement this trait,
22/// enabling [`SymbolicExpr`] to handle constant folding, degree tracking, and
23/// arithmetic generically.
24pub trait SymLeaf: Clone + core::fmt::Debug {
25    /// The base field type used for constant folding.
26    type F: Field;
27
28    const ZERO: Self;
29    const ONE: Self;
30    const TWO: Self;
31    const NEG_ONE: Self;
32
33    /// Returns the degree multiple of this leaf.
34    fn degree_multiple(&self) -> usize;
35
36    /// Try to view this leaf as a base-field constant.
37    fn as_const(&self) -> Option<&Self::F>;
38
39    /// Create a leaf from a base-field constant.
40    fn from_const(c: Self::F) -> Self;
41}
42
43/// A symbolic expression tree, generic over its leaf type `A`.
44///
45/// This enum captures the shared tree structure — Add/Sub/Neg/Mul nodes with
46/// `Arc`-wrapped children and cached degree multiples — used by both base-field
47/// and extension-field symbolic expressions.
48///
49/// Concrete types are provided via type aliases:
50/// - [`SymbolicExpression<F>`] = `SymbolicExpr<BaseLeaf<F>>` (base-field constraints)
51/// - [`SymbolicExpressionExt<F, EF>`] = `SymbolicExpr<ExtLeaf<F, EF>>` (extension-field constraints)
52#[derive(Clone, Debug)]
53pub enum SymbolicExpr<A> {
54    /// A leaf node (variable, constant, selector, or lifted sub-expression).
55    Leaf(A),
56
57    /// Addition of two sub-expressions.
58    Add {
59        x: Arc<Self>,
60        y: Arc<Self>,
61        degree_multiple: usize,
62    },
63
64    /// Subtraction of two sub-expressions.
65    Sub {
66        x: Arc<Self>,
67        y: Arc<Self>,
68        degree_multiple: usize,
69    },
70
71    /// Negation of a sub-expression.
72    Neg {
73        x: Arc<Self>,
74        degree_multiple: usize,
75    },
76
77    /// Multiplication of two sub-expressions.
78    Mul {
79        x: Arc<Self>,
80        y: Arc<Self>,
81        degree_multiple: usize,
82    },
83}
84
85impl<A: SymLeaf> SymbolicExpr<A> {
86    /// Returns the degree multiple of this expression.
87    pub fn degree_multiple(&self) -> usize {
88        match self {
89            Self::Leaf(a) => a.degree_multiple(),
90            Self::Add {
91                degree_multiple, ..
92            }
93            | Self::Sub {
94                degree_multiple, ..
95            }
96            | Self::Neg {
97                degree_multiple, ..
98            }
99            | Self::Mul {
100                degree_multiple, ..
101            } => *degree_multiple,
102        }
103    }
104
105    /// Try to view this expression as a base-field constant.
106    fn as_const(&self) -> Option<&A::F> {
107        match self {
108            Self::Leaf(a) => a.as_const(),
109            _ => None,
110        }
111    }
112
113    /// Addition with constant folding and zero-identity elimination.
114    fn sym_add(self, rhs: Self) -> Self {
115        if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
116            return Self::Leaf(A::from_const(a + b));
117        }
118        if self.as_const().is_some_and(|c| c.is_zero()) {
119            return rhs;
120        }
121        if rhs.as_const().is_some_and(|c| c.is_zero()) {
122            return self;
123        }
124        let dm = self.degree_multiple().max(rhs.degree_multiple());
125        Self::Add {
126            x: Arc::new(self),
127            y: Arc::new(rhs),
128            degree_multiple: dm,
129        }
130    }
131
132    /// Subtraction with constant folding and zero-identity elimination.
133    fn sym_sub(self, rhs: Self) -> Self {
134        if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
135            return Self::Leaf(A::from_const(a - b));
136        }
137        if self.as_const().is_some_and(|c| c.is_zero()) {
138            return rhs.sym_neg();
139        }
140        if rhs.as_const().is_some_and(|c| c.is_zero()) {
141            return self;
142        }
143        let dm = self.degree_multiple().max(rhs.degree_multiple());
144        Self::Sub {
145            x: Arc::new(self),
146            y: Arc::new(rhs),
147            degree_multiple: dm,
148        }
149    }
150
151    /// Negation with constant folding.
152    fn sym_neg(self) -> Self {
153        if let Some(&c) = self.as_const() {
154            return Self::Leaf(A::from_const(-c));
155        }
156        let dm = self.degree_multiple();
157        Self::Neg {
158            x: Arc::new(self),
159            degree_multiple: dm,
160        }
161    }
162
163    /// Multiplication with constant folding, zero-annihilation, and one-identity.
164    fn sym_mul(self, rhs: Self) -> Self {
165        if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
166            return Self::Leaf(A::from_const(a * b));
167        }
168        if self.as_const().is_some_and(|c| c.is_zero())
169            || rhs.as_const().is_some_and(|c| c.is_zero())
170        {
171            return Self::Leaf(A::from_const(A::F::ZERO));
172        }
173        if self.as_const().is_some_and(|c| c.is_one()) {
174            return rhs;
175        }
176        if rhs.as_const().is_some_and(|c| c.is_one()) {
177            return self;
178        }
179        let dm = self.degree_multiple() + rhs.degree_multiple();
180        Self::Mul {
181            x: Arc::new(self),
182            y: Arc::new(rhs),
183            degree_multiple: dm,
184        }
185    }
186}
187
188impl<A: SymLeaf> PrimeCharacteristicRing for SymbolicExpr<A> {
189    type PrimeSubfield = <A::F as PrimeCharacteristicRing>::PrimeSubfield;
190
191    const ZERO: Self = Self::Leaf(A::ZERO);
192    const ONE: Self = Self::Leaf(A::ONE);
193    const TWO: Self = Self::Leaf(A::TWO);
194    const NEG_ONE: Self = Self::Leaf(A::NEG_ONE);
195
196    #[inline]
197    fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
198        Self::Leaf(A::from_const(A::F::from_prime_subfield(f)))
199    }
200}
201
202impl<A: SymLeaf> Dup for SymbolicExpr<A> {
203    #[inline(always)]
204    fn dup(&self) -> Self {
205        self.clone()
206    }
207}
208
209impl<A: SymLeaf> Default for SymbolicExpr<A> {
210    fn default() -> Self {
211        Self::ZERO
212    }
213}
214
215impl<A: SymLeaf, T: Into<Self>> ops::Add<T> for SymbolicExpr<A> {
216    type Output = Self;
217    fn add(self, rhs: T) -> Self {
218        self.sym_add(rhs.into())
219    }
220}
221
222impl<A: SymLeaf, T: Into<Self>> ops::Sub<T> for SymbolicExpr<A> {
223    type Output = Self;
224    fn sub(self, rhs: T) -> Self {
225        self.sym_sub(rhs.into())
226    }
227}
228
229impl<A: SymLeaf> ops::Neg for SymbolicExpr<A> {
230    type Output = Self;
231    fn neg(self) -> Self {
232        self.sym_neg()
233    }
234}
235
236impl<A: SymLeaf, T: Into<Self>> ops::Mul<T> for SymbolicExpr<A> {
237    type Output = Self;
238    fn mul(self, rhs: T) -> Self {
239        self.sym_mul(rhs.into())
240    }
241}
242
243impl<A: SymLeaf, T: Into<Self>> ops::AddAssign<T> for SymbolicExpr<A> {
244    fn add_assign(&mut self, rhs: T) {
245        *self = self.clone() + rhs.into();
246    }
247}
248
249impl<A: SymLeaf, T: Into<Self>> ops::SubAssign<T> for SymbolicExpr<A> {
250    fn sub_assign(&mut self, rhs: T) {
251        *self = self.clone() - rhs.into();
252    }
253}
254
255impl<A: SymLeaf, T: Into<Self>> ops::MulAssign<T> for SymbolicExpr<A> {
256    fn mul_assign(&mut self, rhs: T) {
257        *self = self.clone() * rhs.into();
258    }
259}
260
261impl<A: SymLeaf, T: Into<Self>> Sum<T> for SymbolicExpr<A> {
262    fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
263        iter.map(Into::into)
264            .reduce(|a, b| a + b)
265            .unwrap_or(Self::ZERO)
266    }
267}
268
269impl<A: SymLeaf, T: Into<Self>> Product<T> for SymbolicExpr<A> {
270    fn product<I: Iterator<Item = T>>(iter: I) -> Self {
271        iter.map(Into::into)
272            .reduce(|a, b| a * b)
273            .unwrap_or(Self::ONE)
274    }
275}
276
277impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Add<T> for SymbolicVariable<F> {
278    type Output = SymbolicExpression<F>;
279    fn add(self, rhs: T) -> Self::Output {
280        Self::Output::from(self) + rhs.into()
281    }
282}
283
284impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Sub<T> for SymbolicVariable<F> {
285    type Output = SymbolicExpression<F>;
286    fn sub(self, rhs: T) -> Self::Output {
287        Self::Output::from(self) - rhs.into()
288    }
289}
290
291impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Mul<T> for SymbolicVariable<F> {
292    type Output = SymbolicExpression<F>;
293    fn mul(self, rhs: T) -> Self::Output {
294        Self::Output::from(self) * rhs.into()
295    }
296}
297
298impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Add<T>
299    for SymbolicVariableExt<F, EF>
300{
301    type Output = SymbolicExpressionExt<F, EF>;
302    fn add(self, rhs: T) -> Self::Output {
303        Self::Output::from(self) + rhs.into()
304    }
305}
306
307impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Sub<T>
308    for SymbolicVariableExt<F, EF>
309{
310    type Output = SymbolicExpressionExt<F, EF>;
311    fn sub(self, rhs: T) -> Self::Output {
312        Self::Output::from(self) - rhs.into()
313    }
314}
315
316impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Mul<T>
317    for SymbolicVariableExt<F, EF>
318{
319    type Output = SymbolicExpressionExt<F, EF>;
320    fn mul(self, rhs: T) -> Self::Output {
321        Self::Output::from(self) * rhs.into()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use p3_baby_bear::BabyBear;
328    use p3_field::extension::BinomialExtensionField;
329
330    use super::*;
331    use crate::symbolic::expression::BaseLeaf;
332    use crate::symbolic::expression_ext::ExtLeaf;
333    use crate::symbolic::variable::{BaseEntry, ExtEntry};
334
335    type F = BabyBear;
336    type EF = BinomialExtensionField<BabyBear, 4>;
337
338    #[test]
339    fn symbolic_variable_add_produces_add_node() {
340        // Adding a variable and a non-zero constant creates an addition node.
341        let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
342        let expr = SymbolicExpression::from(F::new(5));
343        let result = var + expr;
344        match result {
345            SymbolicExpr::Add {
346                x,
347                y,
348                degree_multiple,
349            } => {
350                assert_eq!(degree_multiple, 1);
351                assert!(matches!(
352                    x.as_ref(),
353                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
354                        if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
355                ));
356                assert!(matches!(
357                    y.as_ref(),
358                    SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if *c == F::new(5)
359                ));
360            }
361            _ => panic!("Expected an Add node"),
362        }
363    }
364
365    #[test]
366    fn symbolic_variable_sub_produces_sub_node() {
367        // Subtracting two variables creates a subtraction node.
368        let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
369        let other = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
370            BaseEntry::Main { offset: 0 },
371            1,
372        )));
373        let result = var - other;
374        match result {
375            SymbolicExpr::Sub {
376                x,
377                y,
378                degree_multiple,
379            } => {
380                assert_eq!(degree_multiple, 1);
381                assert!(matches!(
382                    x.as_ref(),
383                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
384                        if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
385                ));
386                assert!(matches!(
387                    y.as_ref(),
388                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
389                        if v.index == 1 && v.entry == BaseEntry::Main { offset: 0 }
390                ));
391            }
392            _ => panic!("Expected a Sub node"),
393        }
394    }
395
396    #[test]
397    fn symbolic_variable_mul_produces_mul_node() {
398        // Multiplying two variables creates a multiplication node with summed degree.
399        let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
400        let other = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
401            BaseEntry::Main { offset: 0 },
402            1,
403        )));
404        let result = var * other;
405        match result {
406            SymbolicExpr::Mul {
407                x,
408                y,
409                degree_multiple,
410            } => {
411                assert_eq!(degree_multiple, 2);
412                assert!(matches!(
413                    x.as_ref(),
414                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
415                        if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
416                ));
417                assert!(matches!(
418                    y.as_ref(),
419                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
420                        if v.index == 1 && v.entry == BaseEntry::Main { offset: 0 }
421                ));
422            }
423            _ => panic!("Expected a Mul node"),
424        }
425    }
426
427    #[test]
428    fn symbolic_variable_ext_add_produces_add_node() {
429        // Adding an extension variable and a non-zero constant creates an addition node.
430        let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
431        let expr = SymbolicExpressionExt::<F, EF>::from(F::new(3));
432        let result = var + expr;
433        match result {
434            SymbolicExpr::Add {
435                x,
436                y,
437                degree_multiple,
438            } => {
439                assert_eq!(degree_multiple, 1);
440                assert!(matches!(
441                    x.as_ref(),
442                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
443                        if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
444                ));
445                assert!(matches!(
446                    y.as_ref(),
447                    SymbolicExpr::Leaf(ExtLeaf::Base(SymbolicExpr::Leaf(BaseLeaf::Constant(c))))
448                        if *c == F::new(3)
449                ));
450            }
451            _ => panic!("Expected an Add node"),
452        }
453    }
454
455    #[test]
456    fn symbolic_variable_ext_sub_produces_sub_node() {
457        // Subtracting two extension variables creates a subtraction node.
458        let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
459        let other = SymbolicExpressionExt::<F, EF>::from(SymbolicVariableExt::<F, EF>::new(
460            ExtEntry::Permutation { offset: 0 },
461            1,
462        ));
463        let result = var - other;
464        match result {
465            SymbolicExpr::Sub {
466                x,
467                y,
468                degree_multiple,
469            } => {
470                assert_eq!(degree_multiple, 1);
471                assert!(matches!(
472                    x.as_ref(),
473                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
474                        if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
475                ));
476                assert!(matches!(
477                    y.as_ref(),
478                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
479                        if v.index == 1 && v.entry == ExtEntry::Permutation { offset: 0 }
480                ));
481            }
482            _ => panic!("Expected a Sub node"),
483        }
484    }
485
486    #[test]
487    fn symbolic_variable_ext_mul_produces_mul_node() {
488        // Multiplying two extension variables creates a multiplication node with summed degree.
489        let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
490        let other = SymbolicExpressionExt::<F, EF>::from(SymbolicVariableExt::<F, EF>::new(
491            ExtEntry::Permutation { offset: 0 },
492            1,
493        ));
494        let result = var * other;
495        match result {
496            SymbolicExpr::Mul {
497                x,
498                y,
499                degree_multiple,
500            } => {
501                assert_eq!(degree_multiple, 2);
502                assert!(matches!(
503                    x.as_ref(),
504                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
505                        if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
506                ));
507                assert!(matches!(
508                    y.as_ref(),
509                    SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
510                        if v.index == 1 && v.entry == ExtEntry::Permutation { offset: 0 }
511                ));
512            }
513            _ => panic!("Expected a Mul node"),
514        }
515    }
516}