Skip to main content

p3_air/symbolic/
expression.rs

1use p3_field::{Algebra, ExtensionField, Field, InjectiveMonomial};
2
3use crate::symbolic::variable::{BaseEntry, SymbolicVariable};
4use crate::symbolic::{SymLeaf, SymbolicExpr};
5use crate::{AirBuilder, WindowAccess};
6
7/// Leaf nodes for base-field symbolic expressions.
8///
9/// These represent the atomic building blocks of AIR constraint expressions:
10/// trace column references, selectors, and field constants.
11#[derive(Clone, Debug)]
12pub enum BaseLeaf<F> {
13    /// A reference to a trace column or public input.
14    Variable(SymbolicVariable<F>),
15
16    /// Selector evaluating to a non-zero value only on the first row.
17    IsFirstRow,
18
19    /// Selector evaluating to a non-zero value only on the last row.
20    IsLastRow,
21
22    /// Selector evaluating to zero only on the last row.
23    IsTransition,
24
25    /// A constant field element.
26    Constant(F),
27}
28
29/// A symbolic expression tree for base-field AIR constraints.
30///
31/// This is a type alias for the generic [`SymbolicExpr`] parameterized with
32/// base-field [`BaseLeaf`] nodes.
33pub type SymbolicExpression<F> = SymbolicExpr<BaseLeaf<F>>;
34
35impl<F: Field> SymLeaf for BaseLeaf<F> {
36    type F = F;
37
38    const ZERO: Self = Self::Constant(F::ZERO);
39    const ONE: Self = Self::Constant(F::ONE);
40    const TWO: Self = Self::Constant(F::TWO);
41    const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
42
43    fn degree_multiple(&self) -> usize {
44        match self {
45            Self::Variable(v) => v.degree_multiple(),
46            Self::IsFirstRow | Self::IsLastRow => 1,
47            Self::IsTransition | Self::Constant(_) => 0,
48        }
49    }
50
51    fn as_const(&self) -> Option<&F> {
52        match self {
53            Self::Constant(c) => Some(c),
54            _ => None,
55        }
56    }
57
58    fn from_const(c: F) -> Self {
59        Self::Constant(c)
60    }
61}
62
63impl<F: Field, EF: ExtensionField<F>> From<SymbolicVariable<F>> for SymbolicExpression<EF> {
64    fn from(var: SymbolicVariable<F>) -> Self {
65        Self::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
66            var.entry, var.index,
67        )))
68    }
69}
70
71impl<F: Field, EF: ExtensionField<F>> From<F> for SymbolicExpression<EF> {
72    fn from(f: F) -> Self {
73        Self::Leaf(BaseLeaf::Constant(f.into()))
74    }
75}
76
77impl<F: Field> SymbolicExpression<F> {
78    /// Evaluate this symbolic expression against a concrete [`AirBuilder`].
79    ///
80    /// # Overview
81    ///
82    /// - Walk the expression tree top-down.
83    /// - Replace each leaf with the builder's concrete value.
84    /// - Recurse into arithmetic nodes; combine in the builder's algebra.
85    ///
86    /// # Algorithm
87    ///
88    /// ```text
89    ///     leaf   → builder lookup (main / preprocessed / public / periodic / selector / constant)
90    ///     x + y  → resolve(x) + resolve(y)
91    ///     x - y  → resolve(x) - resolve(y)
92    ///     x * y  → resolve(x) * resolve(y)
93    ///     -x     → -resolve(x)
94    /// ```
95    ///
96    /// # Panics
97    ///
98    /// - Row offset other than 0 or 1.
99    /// - Column index out of bounds.
100    pub fn resolve<AB>(&self, builder: &AB) -> AB::Expr
101    where
102        AB: AirBuilder<F = F>,
103    {
104        match self {
105            Self::Leaf(leaf) => match leaf {
106                BaseLeaf::Variable(v) => match v.entry {
107                    // Main trace: offset 0 = current row, offset 1 = next row.
108                    // Symbolic builders only emit two-row windows.
109                    BaseEntry::Main { offset } => {
110                        let main = builder.main();
111                        match offset {
112                            0 => main
113                                .current(v.index)
114                                .expect("main column index out of bounds")
115                                .into(),
116                            1 => main
117                                .next(v.index)
118                                .expect("main column index out of bounds")
119                                .into(),
120                            _ => panic!("expressions cannot span more than two rows"),
121                        }
122                    }
123                    // Preprocessed trace: same shape, commitment-free trace.
124                    BaseEntry::Preprocessed { offset } => {
125                        let prep = builder.preprocessed();
126                        match offset {
127                            0 => prep
128                                .current(v.index)
129                                .expect("preprocessed column index out of bounds")
130                                .into(),
131                            1 => prep
132                                .next(v.index)
133                                .expect("preprocessed column index out of bounds")
134                                .into(),
135                            _ => panic!("expressions cannot span more than two rows"),
136                        }
137                    }
138                    // Public input: direct slice lookup.
139                    BaseEntry::Public => builder.public_values()[v.index].into(),
140                    // Periodic column at the current row.
141                    // Empty default slice → out-of-bounds panic on stray emissions.
142                    BaseEntry::Periodic => builder.periodic_values()[v.index].into(),
143                },
144                // Boundary and transition selectors come straight from the builder.
145                BaseLeaf::IsFirstRow => builder.is_first_row(),
146                BaseLeaf::IsLastRow => builder.is_last_row(),
147                BaseLeaf::IsTransition => builder.is_transition_window(2),
148                // Lift the field constant into the builder's expression algebra.
149                BaseLeaf::Constant(c) => AB::Expr::from(*c),
150            },
151            // Arithmetic: recurse on operands, combine in the builder's algebra.
152            Self::Add { x, y, .. } => x.resolve(builder) + y.resolve(builder),
153            Self::Sub { x, y, .. } => x.resolve(builder) - y.resolve(builder),
154            Self::Neg { x, .. } => -x.resolve(builder),
155            Self::Mul { x, y, .. } => x.resolve(builder) * y.resolve(builder),
156        }
157    }
158}
159
160impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
161
162impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
163
164// Note we cannot implement PermutationMonomial due to the degree_multiple part which makes
165// operations non invertible.
166impl<F: Field + InjectiveMonomial<N>, const N: u64> InjectiveMonomial<N> for SymbolicExpression<F> {}
167
168#[cfg(test)]
169mod tests {
170    use alloc::sync::Arc;
171    use alloc::vec;
172    use alloc::vec::Vec;
173
174    use p3_baby_bear::BabyBear;
175    use p3_field::PrimeCharacteristicRing;
176    use p3_matrix::dense::RowMajorMatrix;
177
178    use super::*;
179    use crate::symbolic::BaseEntry;
180
181    #[test]
182    fn test_symbolic_expression_degree_multiple() {
183        let constant_expr =
184            SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
185        assert_eq!(
186            constant_expr.degree_multiple(),
187            0,
188            "Constant should have degree 0"
189        );
190
191        let variable_expr = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
192            BaseEntry::Main { offset: 0 },
193            1,
194        )));
195        assert_eq!(
196            variable_expr.degree_multiple(),
197            1,
198            "Main variable should have degree 1"
199        );
200
201        let preprocessed_var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
202            BaseEntry::Preprocessed { offset: 0 },
203            2,
204        )));
205        assert_eq!(
206            preprocessed_var.degree_multiple(),
207            1,
208            "Preprocessed variable should have degree 1"
209        );
210
211        let public_var = SymbolicExpression::Leaf(BaseLeaf::Variable(
212            SymbolicVariable::<BabyBear>::new(BaseEntry::Public, 4),
213        ));
214        assert_eq!(
215            public_var.degree_multiple(),
216            0,
217            "Public variable should have degree 0"
218        );
219
220        let is_first_row = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsFirstRow);
221        assert_eq!(
222            is_first_row.degree_multiple(),
223            1,
224            "IsFirstRow should have degree 1"
225        );
226
227        let is_last_row = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsLastRow);
228        assert_eq!(
229            is_last_row.degree_multiple(),
230            1,
231            "IsLastRow should have degree 1"
232        );
233
234        let is_transition = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsTransition);
235        assert_eq!(
236            is_transition.degree_multiple(),
237            0,
238            "IsTransition should have degree 0"
239        );
240
241        let add_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Add {
242            x: Arc::new(variable_expr.clone()),
243            y: Arc::new(preprocessed_var.clone()),
244            degree_multiple: 1,
245        };
246        assert_eq!(
247            add_expr.degree_multiple(),
248            1,
249            "Addition should take max degree of inputs"
250        );
251
252        let sub_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Sub {
253            x: Arc::new(variable_expr.clone()),
254            y: Arc::new(preprocessed_var.clone()),
255            degree_multiple: 1,
256        };
257        assert_eq!(
258            sub_expr.degree_multiple(),
259            1,
260            "Subtraction should take max degree of inputs"
261        );
262
263        let neg_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Neg {
264            x: Arc::new(variable_expr.clone()),
265            degree_multiple: 1,
266        };
267        assert_eq!(
268            neg_expr.degree_multiple(),
269            1,
270            "Negation should keep the degree"
271        );
272
273        let mul_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Mul {
274            x: Arc::new(variable_expr),
275            y: Arc::new(preprocessed_var),
276            degree_multiple: 2,
277        };
278        assert_eq!(
279            mul_expr.degree_multiple(),
280            2,
281            "Multiplication should sum degrees"
282        );
283    }
284
285    #[test]
286    fn test_addition_of_constants() {
287        let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
288        let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
289        let result = a + b;
290        match result {
291            SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(7)),
292            _ => panic!("Addition of constants did not simplify correctly"),
293        }
294    }
295
296    #[test]
297    fn test_subtraction_of_constants() {
298        let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(10)));
299        let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
300        let result = a - b;
301        match result {
302            SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(6)),
303            _ => panic!("Subtraction of constants did not simplify correctly"),
304        }
305    }
306
307    #[test]
308    fn test_negation() {
309        let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(7)));
310        let result = -a;
311        match result {
312            SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => {
313                assert_eq!(val, BabyBear::NEG_ONE * BabyBear::new(7));
314            }
315            _ => panic!("Negation did not work correctly"),
316        }
317    }
318
319    #[test]
320    fn test_multiplication_of_constants() {
321        let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
322        let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
323        let result = a * b;
324        match result {
325            SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(15)),
326            _ => panic!("Multiplication of constants did not simplify correctly"),
327        }
328    }
329
330    #[test]
331    fn test_degree_multiple_for_addition() {
332        let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
333            BaseEntry::Main { offset: 0 },
334            1,
335        )));
336        let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
337            BaseEntry::Main { offset: 0 },
338            2,
339        )));
340        let result = a + b;
341        match result {
342            SymbolicExpr::Add {
343                degree_multiple,
344                x,
345                y,
346            } => {
347                assert_eq!(degree_multiple, 1);
348                assert!(
349                    matches!(&*x, SymbolicExpr::Leaf(BaseLeaf::Variable(v)) if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 }))
350                );
351                assert!(
352                    matches!(&*y, SymbolicExpr::Leaf(BaseLeaf::Variable(v)) if v.index == 2 && matches!(v.entry, BaseEntry::Main { offset: 0 }))
353                );
354            }
355            _ => panic!("Addition did not create an Add expression"),
356        }
357    }
358
359    #[test]
360    fn test_degree_multiple_for_multiplication() {
361        let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
362            BaseEntry::Main { offset: 0 },
363            1,
364        )));
365        let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
366            BaseEntry::Main { offset: 0 },
367            2,
368        )));
369        let result = a * b;
370
371        match result {
372            SymbolicExpr::Mul {
373                degree_multiple,
374                x,
375                y,
376            } => {
377                assert_eq!(degree_multiple, 2, "Multiplication should sum degrees");
378
379                assert!(
380                    matches!(&*x, SymbolicExpr::Leaf(BaseLeaf::Variable(v))
381                        if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 })
382                    ),
383                    "Left operand should match `a`"
384                );
385
386                assert!(
387                    matches!(&*y, SymbolicExpr::Leaf(BaseLeaf::Variable(v))
388                        if v.index == 2 && matches!(v.entry, BaseEntry::Main { offset: 0 })
389                    ),
390                    "Right operand should match `b`"
391                );
392            }
393            _ => panic!("Multiplication did not create a `Mul` expression"),
394        }
395    }
396
397    #[test]
398    fn test_sum_operator() {
399        let expressions = vec![
400            SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(2))),
401            SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3))),
402            SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5))),
403        ];
404        let result: SymbolicExpression<BabyBear> = expressions.into_iter().sum();
405        match result {
406            SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(10)),
407            _ => panic!("Sum did not produce correct result"),
408        }
409    }
410
411    #[test]
412    fn test_product_operator() {
413        let expressions = vec![
414            SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(2))),
415            SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3))),
416            SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4))),
417        ];
418        let result: SymbolicExpression<BabyBear> = expressions.into_iter().product();
419        match result {
420            SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(24)),
421            _ => panic!("Product did not produce correct result"),
422        }
423    }
424
425    #[test]
426    fn test_default_is_zero() {
427        // Default should produce ZERO constant.
428        let expr: SymbolicExpression<BabyBear> = Default::default();
429
430        // Verify it matches the zero constant.
431        assert!(matches!(
432            expr,
433            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
434        ));
435    }
436
437    #[test]
438    fn test_ring_constants() {
439        // ZERO is a Constant variant wrapping the field's zero element.
440        assert!(matches!(
441            SymbolicExpression::<BabyBear>::ZERO,
442            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
443        ));
444        // ONE is a Constant variant wrapping the field's one element.
445        assert!(matches!(
446            SymbolicExpression::<BabyBear>::ONE,
447            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ONE
448        ));
449        // TWO is a Constant variant wrapping the field's two element.
450        assert!(matches!(
451            SymbolicExpression::<BabyBear>::TWO,
452            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::TWO
453        ));
454        // NEG_ONE is a Constant variant wrapping the field's -1 element.
455        assert!(matches!(
456            SymbolicExpression::<BabyBear>::NEG_ONE,
457            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::NEG_ONE
458        ));
459    }
460
461    #[test]
462    fn test_from_symbolic_variable() {
463        // Create a main trace variable at column index 3.
464        let var = SymbolicVariable::<BabyBear>::new(BaseEntry::Main { offset: 0 }, 3);
465        // Convert to expression.
466        let expr: SymbolicExpression<BabyBear> = var.into();
467        // Verify the variable is preserved with correct entry and index.
468        match expr {
469            SymbolicExpr::Leaf(BaseLeaf::Variable(v)) => {
470                assert!(matches!(v.entry, BaseEntry::Main { offset: 0 }));
471                assert_eq!(v.index, 3);
472            }
473            _ => panic!("Expected Variable variant"),
474        }
475    }
476
477    #[test]
478    fn test_from_field_element() {
479        // Convert a field element directly to expression.
480        let field_val = BabyBear::new(42);
481        let expr: SymbolicExpression<BabyBear> = field_val.into();
482        // Verify it becomes a Constant with the same value.
483        assert!(matches!(
484            expr,
485            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == field_val
486        ));
487    }
488
489    #[test]
490    fn test_from_prime_subfield() {
491        // Create expression from prime subfield element.
492        let prime_subfield_val = <BabyBear as PrimeCharacteristicRing>::PrimeSubfield::new(7);
493        let expr = SymbolicExpression::<BabyBear>::from_prime_subfield(prime_subfield_val);
494        // Verify it produces a constant with the converted value.
495        assert!(matches!(
496            expr,
497            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(7)
498        ));
499    }
500
501    #[test]
502    fn test_assign_operators() {
503        // Test AddAssign with constants (should simplify).
504        let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
505        expr += SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
506        assert!(matches!(
507            expr,
508            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(8)
509        ));
510
511        // Test SubAssign with constants (should simplify).
512        let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(10)));
513        expr -= SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
514        assert!(matches!(
515            expr,
516            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(6)
517        ));
518
519        // Test MulAssign with constants (should simplify).
520        let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(6)));
521        expr *= SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(7)));
522        assert!(matches!(
523            expr,
524            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(42)
525        ));
526    }
527
528    #[test]
529    fn test_subtraction_creates_sub_node() {
530        // Create two trace variables.
531        let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
532            BaseEntry::Main { offset: 0 },
533            0,
534        )));
535        let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
536            BaseEntry::Main { offset: 0 },
537            1,
538        )));
539
540        // Subtract them.
541        let result = a - b;
542
543        // Should create Sub node (not simplified).
544        match result {
545            SymbolicExpr::Sub {
546                x,
547                y,
548                degree_multiple,
549            } => {
550                // Both operands have degree 1, so max is 1.
551                assert_eq!(degree_multiple, 1);
552
553                // Verify left operand is main trace variable at index 0, offset 0.
554                assert!(matches!(
555                    x.as_ref(),
556                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
557                        if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
558                ));
559
560                // Verify right operand is main trace variable at index 1, offset 0.
561                assert!(matches!(
562                    y.as_ref(),
563                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
564                        if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 })
565                ));
566            }
567            _ => panic!("Expected Sub variant"),
568        }
569    }
570
571    #[test]
572    fn test_negation_creates_neg_node() {
573        // Create a trace variable.
574        let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
575            BaseEntry::Main { offset: 0 },
576            0,
577        )));
578
579        // Negate it.
580        let result = -var;
581
582        // Should create Neg node (not simplified).
583        match result {
584            SymbolicExpr::Neg { x, degree_multiple } => {
585                // Degree is preserved from operand.
586                assert_eq!(degree_multiple, 1);
587
588                // Verify operand is main trace variable at index 0, offset 0.
589                assert!(matches!(
590                    x.as_ref(),
591                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
592                        if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
593                ));
594            }
595            _ => panic!("Expected Neg variant"),
596        }
597    }
598
599    #[test]
600    fn test_empty_sum_returns_zero() {
601        // Sum of empty iterator should be additive identity.
602        let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
603        let result: SymbolicExpression<BabyBear> = empty.into_iter().sum();
604        assert!(matches!(
605            result,
606            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
607        ));
608    }
609
610    #[test]
611    fn test_empty_product_returns_one() {
612        // Product of empty iterator should be multiplicative identity.
613        let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
614        let result: SymbolicExpression<BabyBear> = empty.into_iter().product();
615        assert!(matches!(
616            result,
617            SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ONE
618        ));
619    }
620
621    #[test]
622    fn test_mixed_degree_addition() {
623        // Constant has degree 0.
624        let constant = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
625
626        // Variable has degree 1.
627        let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
628            BaseEntry::Main { offset: 0 },
629            0,
630        )));
631
632        // Add them: max(0, 1) = 1.
633        let result = constant + var;
634
635        match result {
636            SymbolicExpr::Add {
637                x,
638                y,
639                degree_multiple,
640            } => {
641                // Degree is max(0, 1) = 1.
642                assert_eq!(degree_multiple, 1);
643
644                // Verify left operand is the constant 5.
645                assert!(matches!(
646                    x.as_ref(),
647                    SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if *c == BabyBear::new(5)
648                ));
649
650                // Verify right operand is main trace variable at index 0, offset 0.
651                assert!(matches!(
652                    y.as_ref(),
653                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
654                        if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
655                ));
656            }
657            _ => panic!("Expected Add variant"),
658        }
659    }
660
661    #[test]
662    fn test_chained_multiplication_degree() {
663        // Create three variables, each with degree 1.
664        let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
665            BaseEntry::Main { offset: 0 },
666            0,
667        )));
668        let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
669            BaseEntry::Main { offset: 0 },
670            1,
671        )));
672        let c = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
673            BaseEntry::Main { offset: 0 },
674            2,
675        )));
676
677        // a * b has degree 1 + 1 = 2.
678        let ab = a * b;
679        assert_eq!(ab.degree_multiple(), 2);
680
681        // (a * b) * c has degree 2 + 1 = 3.
682        let abc = ab * c;
683        assert_eq!(abc.degree_multiple(), 3);
684    }
685
686    #[test]
687    fn test_add_zero_identity_folding() {
688        let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
689            BaseEntry::Main { offset: 0 },
690            0,
691        )));
692        let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
693
694        // x + 0 should return x, not create an Add node.
695        let result = var.clone() + zero.clone();
696        assert!(
697            matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
698            "x + 0 should fold to x"
699        );
700
701        // 0 + x should return x, not create an Add node.
702        let result = zero + var;
703        assert!(
704            matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
705            "0 + x should fold to x"
706        );
707    }
708
709    #[test]
710    fn test_sub_zero_identity_folding() {
711        let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
712            BaseEntry::Main { offset: 0 },
713            0,
714        )));
715        let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
716
717        // x - 0 should return x, not create a Sub node.
718        let result = var.clone() - zero.clone();
719        assert!(
720            matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
721            "x - 0 should fold to x"
722        );
723
724        // 0 - x should return -x, not create a Sub node.
725        let result = zero - var;
726        match result {
727            SymbolicExpr::Neg { x, degree_multiple } => {
728                assert_eq!(degree_multiple, 1);
729                assert!(matches!(
730                    x.as_ref(),
731                    SymbolicExpr::Leaf(BaseLeaf::Variable(v))
732                        if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
733                ));
734            }
735            _ => panic!("0 - x should fold to Neg(x)"),
736        }
737    }
738
739    #[test]
740    fn test_mul_zero_identity_folding() {
741        let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
742            BaseEntry::Main { offset: 0 },
743            0,
744        )));
745        let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
746
747        // x * 0 should return Constant(0), not create a Mul node.
748        let result = var.clone() * zero.clone();
749        assert!(
750            matches!(result, SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO),
751            "x * 0 should fold to 0"
752        );
753
754        // 0 * x should return Constant(0), not create a Mul node.
755        let result = zero * var;
756        assert!(
757            matches!(result, SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO),
758            "0 * x should fold to 0"
759        );
760    }
761
762    #[test]
763    fn test_mul_one_identity_folding() {
764        let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
765            BaseEntry::Main { offset: 0 },
766            0,
767        )));
768        let one = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ONE));
769
770        // x * 1 should return x, not create a Mul node.
771        let result = var.clone() * one.clone();
772        assert!(
773            matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
774            "x * 1 should fold to x"
775        );
776
777        // 1 * x should return x, not create a Mul node.
778        let result = one * var;
779        assert!(
780            matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
781            "1 * x should fold to x"
782        );
783    }
784
785    #[test]
786    fn test_identity_folding_preserves_degree() {
787        let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
788            BaseEntry::Main { offset: 0 },
789            0,
790        )));
791        let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
792        let one = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ONE));
793
794        // x + 0 should preserve degree of x.
795        let result = var.clone() + zero.clone();
796        assert_eq!(result.degree_multiple(), 1);
797
798        // x - 0 should preserve degree of x.
799        let result = var.clone() - zero.clone();
800        assert_eq!(result.degree_multiple(), 1);
801
802        // 0 - x should preserve degree of x.
803        let result = zero.clone() - var.clone();
804        assert_eq!(result.degree_multiple(), 1);
805
806        // x * 1 should preserve degree of x.
807        let result = var.clone() * one;
808        assert_eq!(result.degree_multiple(), 1);
809
810        // x * 0 should have degree 0 (constant).
811        let result = var * zero;
812        assert_eq!(result.degree_multiple(), 0);
813    }
814
815    /// Minimal builder used to drive symbolic-expression resolution.
816    ///
817    /// Carries:
818    /// - a 2-row main trace,
819    /// - a public-value slice,
820    /// - precomputed selector values for the current row,
821    /// - a periodic-column row evaluated at the current step.
822    struct ResolveTestBuilder {
823        main: RowMajorMatrix<BabyBear>,
824        public_values: Vec<BabyBear>,
825        periodic_row: Vec<BabyBear>,
826        is_first: BabyBear,
827        is_last: BabyBear,
828        is_transition: BabyBear,
829    }
830
831    impl AirBuilder for ResolveTestBuilder {
832        type F = BabyBear;
833        type Expr = BabyBear;
834        type Var = BabyBear;
835        type PreprocessedWindow = RowMajorMatrix<BabyBear>;
836        type MainWindow = RowMajorMatrix<BabyBear>;
837        type PublicVar = BabyBear;
838        type PeriodicVar = BabyBear;
839
840        fn main(&self) -> Self::MainWindow {
841            self.main.clone()
842        }
843
844        fn preprocessed(&self) -> &Self::PreprocessedWindow {
845            unimplemented!("no preprocessed columns in test builder")
846        }
847
848        fn is_first_row(&self) -> Self::Expr {
849            self.is_first
850        }
851
852        fn is_last_row(&self) -> Self::Expr {
853            self.is_last
854        }
855
856        fn is_transition(&self) -> Self::Expr {
857            self.is_transition
858        }
859
860        fn assert_zero<I: Into<Self::Expr>>(&mut self, _: I) {}
861
862        fn public_values(&self) -> &[Self::PublicVar] {
863            &self.public_values
864        }
865
866        fn periodic_values(&self) -> &[Self::PeriodicVar] {
867            &self.periodic_row
868        }
869    }
870
871    /// 2-row × 2-column trace, plus a 2-cell periodic row at the current step:
872    ///
873    /// ```text
874    ///     main row 0 (current): [10, 20]
875    ///     main row 1 (next):    [30, 40]
876    ///     periodic_row (curr):  [7, 13]
877    /// ```
878    fn test_builder() -> ResolveTestBuilder {
879        ResolveTestBuilder {
880            main: RowMajorMatrix::new(
881                vec![
882                    BabyBear::new(10),
883                    BabyBear::new(20), // current row
884                    BabyBear::new(30),
885                    BabyBear::new(40), // next row
886                ],
887                2, // width
888            ),
889            public_values: vec![BabyBear::new(99)],
890            // Two periodic columns at the current row.
891            // Distinct primes so any cross-stream mix-up is visible.
892            periodic_row: vec![BabyBear::new(7), BabyBear::new(13)],
893            is_first: BabyBear::ONE,
894            is_last: BabyBear::ZERO,
895            is_transition: BabyBear::ONE,
896        }
897    }
898
899    #[test]
900    fn resolve_main_current_row() {
901        let b = test_builder();
902        // Main column 0, offset 0 → current row value 10.
903        let expr =
904            SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 0));
905        assert_eq!(expr.resolve(&b), BabyBear::new(10));
906    }
907
908    #[test]
909    fn resolve_main_next_row() {
910        let b = test_builder();
911        // Main column 1, offset 1 → next row value 40.
912        let expr =
913            SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 1 }, 1));
914        assert_eq!(expr.resolve(&b), BabyBear::new(40));
915    }
916
917    #[test]
918    fn resolve_public_value() {
919        let b = test_builder();
920        // Public value at index 0 → 99.
921        let expr = SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Public, 0));
922        assert_eq!(expr.resolve(&b), BabyBear::new(99));
923    }
924
925    #[test]
926    fn resolve_constant() {
927        let b = test_builder();
928        let expr = SymbolicExpression::<BabyBear>::from(BabyBear::new(42));
929        assert_eq!(expr.resolve(&b), BabyBear::new(42));
930    }
931
932    #[test]
933    fn resolve_selectors() {
934        let b = test_builder();
935
936        let first = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsFirstRow);
937        assert_eq!(first.resolve(&b), BabyBear::ONE, "is_first_row = 1");
938
939        let last = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsLastRow);
940        assert_eq!(last.resolve(&b), BabyBear::ZERO, "is_last_row = 0");
941
942        let trans = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsTransition);
943        assert_eq!(trans.resolve(&b), BabyBear::ONE, "is_transition = 1");
944    }
945
946    #[test]
947    fn resolve_arithmetic() {
948        let b = test_builder();
949
950        // col0_curr = 10, col1_curr = 20.
951        let col0 =
952            SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 0));
953        let col1 =
954            SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 1));
955
956        // 10 + 20 = 30.
957        let add = col0.clone() + col1.clone();
958        assert_eq!(add.resolve(&b), BabyBear::new(30));
959
960        // 10 - 20 = -10 (mod p).
961        let sub = col0.clone() - col1.clone();
962        assert_eq!(sub.resolve(&b), BabyBear::new(10) - BabyBear::new(20));
963
964        // 10 * 20 = 200.
965        let mul = col0.clone() * col1;
966        assert_eq!(mul.resolve(&b), BabyBear::new(200));
967
968        // -10 (mod p).
969        let neg = -col0;
970        assert_eq!(neg.resolve(&b), -BabyBear::new(10));
971    }
972
973    #[test]
974    fn resolve_periodic_columns() {
975        // Invariant: a periodic leaf reads from the builder's
976        // periodic-value slice, in declared column order.
977        //
978        // Fixture:
979        //
980        //     periodic row (current step) : [7, 13]
981        //     index 0 →  7
982        //     index 1 → 13
983        let b = test_builder();
984
985        // Column 0 → 7.
986        let p0 =
987            SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 0));
988        assert_eq!(p0.resolve(&b), BabyBear::new(7));
989
990        // Column 1 → 13.
991        let p1 =
992            SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 1));
993        assert_eq!(p1.resolve(&b), BabyBear::new(13));
994    }
995
996    #[test]
997    fn resolve_periodic_combines_with_arithmetic() {
998        // Invariant: periodic leaves compose under the same algebra
999        // as main, public, and preprocessed leaves.
1000        //
1001        // Fixture:
1002        //
1003        //     periodic row : [7, 13]
1004        //     main row 0   : [10, 20]
1005        //
1006        //     expression   : main[0] * periodic[0] + periodic[1]
1007        //                  = 10 * 7 + 13
1008        //                  = 83
1009        let b = test_builder();
1010
1011        let col0 =
1012            SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 0));
1013        let p0 =
1014            SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 0));
1015        let p1 =
1016            SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 1));
1017
1018        let expr = col0 * p0 + p1;
1019        assert_eq!(expr.resolve(&b), BabyBear::new(83));
1020    }
1021}