1use p3_field::{Algebra, ExtensionField, Field, InjectiveMonomial};
2
3use crate::symbolic::variable::{BaseEntry, SymbolicVariable};
4use crate::symbolic::{SymLeaf, SymbolicExpr};
5use crate::{AirBuilder, WindowAccess};
6
7#[derive(Clone, Debug)]
12pub enum BaseLeaf<F> {
13 Variable(SymbolicVariable<F>),
15
16 IsFirstRow,
18
19 IsLastRow,
21
22 IsTransition,
24
25 Constant(F),
27}
28
29pub type SymbolicExpression<F> = SymbolicExpr<BaseLeaf<F>>;
34
35impl<F: Field> SymLeaf for BaseLeaf<F> {
36 type F = F;
37
38 const ZERO: Self = Self::Constant(F::ZERO);
39 const ONE: Self = Self::Constant(F::ONE);
40 const TWO: Self = Self::Constant(F::TWO);
41 const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
42
43 fn degree_multiple(&self) -> usize {
44 match self {
45 Self::Variable(v) => v.degree_multiple(),
46 Self::IsFirstRow | Self::IsLastRow => 1,
47 Self::IsTransition | Self::Constant(_) => 0,
48 }
49 }
50
51 fn as_const(&self) -> Option<&F> {
52 match self {
53 Self::Constant(c) => Some(c),
54 _ => None,
55 }
56 }
57
58 fn from_const(c: F) -> Self {
59 Self::Constant(c)
60 }
61}
62
63impl<F: Field, EF: ExtensionField<F>> From<SymbolicVariable<F>> for SymbolicExpression<EF> {
64 fn from(var: SymbolicVariable<F>) -> Self {
65 Self::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
66 var.entry, var.index,
67 )))
68 }
69}
70
71impl<F: Field, EF: ExtensionField<F>> From<F> for SymbolicExpression<EF> {
72 fn from(f: F) -> Self {
73 Self::Leaf(BaseLeaf::Constant(f.into()))
74 }
75}
76
77impl<F: Field> SymbolicExpression<F> {
78 pub fn resolve<AB>(&self, builder: &AB) -> AB::Expr
101 where
102 AB: AirBuilder<F = F>,
103 {
104 match self {
105 Self::Leaf(leaf) => match leaf {
106 BaseLeaf::Variable(v) => match v.entry {
107 BaseEntry::Main { offset } => {
110 let main = builder.main();
111 match offset {
112 0 => main
113 .current(v.index)
114 .expect("main column index out of bounds")
115 .into(),
116 1 => main
117 .next(v.index)
118 .expect("main column index out of bounds")
119 .into(),
120 _ => panic!("expressions cannot span more than two rows"),
121 }
122 }
123 BaseEntry::Preprocessed { offset } => {
125 let prep = builder.preprocessed();
126 match offset {
127 0 => prep
128 .current(v.index)
129 .expect("preprocessed column index out of bounds")
130 .into(),
131 1 => prep
132 .next(v.index)
133 .expect("preprocessed column index out of bounds")
134 .into(),
135 _ => panic!("expressions cannot span more than two rows"),
136 }
137 }
138 BaseEntry::Public => builder.public_values()[v.index].into(),
140 BaseEntry::Periodic => builder.periodic_values()[v.index].into(),
143 },
144 BaseLeaf::IsFirstRow => builder.is_first_row(),
146 BaseLeaf::IsLastRow => builder.is_last_row(),
147 BaseLeaf::IsTransition => builder.is_transition_window(2),
148 BaseLeaf::Constant(c) => AB::Expr::from(*c),
150 },
151 Self::Add { x, y, .. } => x.resolve(builder) + y.resolve(builder),
153 Self::Sub { x, y, .. } => x.resolve(builder) - y.resolve(builder),
154 Self::Neg { x, .. } => -x.resolve(builder),
155 Self::Mul { x, y, .. } => x.resolve(builder) * y.resolve(builder),
156 }
157 }
158}
159
160impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
161
162impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
163
164impl<F: Field + InjectiveMonomial<N>, const N: u64> InjectiveMonomial<N> for SymbolicExpression<F> {}
167
168#[cfg(test)]
169mod tests {
170 use alloc::sync::Arc;
171 use alloc::vec;
172 use alloc::vec::Vec;
173
174 use p3_baby_bear::BabyBear;
175 use p3_field::PrimeCharacteristicRing;
176 use p3_matrix::dense::RowMajorMatrix;
177
178 use super::*;
179 use crate::symbolic::BaseEntry;
180
181 #[test]
182 fn test_symbolic_expression_degree_multiple() {
183 let constant_expr =
184 SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
185 assert_eq!(
186 constant_expr.degree_multiple(),
187 0,
188 "Constant should have degree 0"
189 );
190
191 let variable_expr = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
192 BaseEntry::Main { offset: 0 },
193 1,
194 )));
195 assert_eq!(
196 variable_expr.degree_multiple(),
197 1,
198 "Main variable should have degree 1"
199 );
200
201 let preprocessed_var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
202 BaseEntry::Preprocessed { offset: 0 },
203 2,
204 )));
205 assert_eq!(
206 preprocessed_var.degree_multiple(),
207 1,
208 "Preprocessed variable should have degree 1"
209 );
210
211 let public_var = SymbolicExpression::Leaf(BaseLeaf::Variable(
212 SymbolicVariable::<BabyBear>::new(BaseEntry::Public, 4),
213 ));
214 assert_eq!(
215 public_var.degree_multiple(),
216 0,
217 "Public variable should have degree 0"
218 );
219
220 let is_first_row = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsFirstRow);
221 assert_eq!(
222 is_first_row.degree_multiple(),
223 1,
224 "IsFirstRow should have degree 1"
225 );
226
227 let is_last_row = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsLastRow);
228 assert_eq!(
229 is_last_row.degree_multiple(),
230 1,
231 "IsLastRow should have degree 1"
232 );
233
234 let is_transition = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsTransition);
235 assert_eq!(
236 is_transition.degree_multiple(),
237 0,
238 "IsTransition should have degree 0"
239 );
240
241 let add_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Add {
242 x: Arc::new(variable_expr.clone()),
243 y: Arc::new(preprocessed_var.clone()),
244 degree_multiple: 1,
245 };
246 assert_eq!(
247 add_expr.degree_multiple(),
248 1,
249 "Addition should take max degree of inputs"
250 );
251
252 let sub_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Sub {
253 x: Arc::new(variable_expr.clone()),
254 y: Arc::new(preprocessed_var.clone()),
255 degree_multiple: 1,
256 };
257 assert_eq!(
258 sub_expr.degree_multiple(),
259 1,
260 "Subtraction should take max degree of inputs"
261 );
262
263 let neg_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Neg {
264 x: Arc::new(variable_expr.clone()),
265 degree_multiple: 1,
266 };
267 assert_eq!(
268 neg_expr.degree_multiple(),
269 1,
270 "Negation should keep the degree"
271 );
272
273 let mul_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Mul {
274 x: Arc::new(variable_expr),
275 y: Arc::new(preprocessed_var),
276 degree_multiple: 2,
277 };
278 assert_eq!(
279 mul_expr.degree_multiple(),
280 2,
281 "Multiplication should sum degrees"
282 );
283 }
284
285 #[test]
286 fn test_addition_of_constants() {
287 let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
288 let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
289 let result = a + b;
290 match result {
291 SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(7)),
292 _ => panic!("Addition of constants did not simplify correctly"),
293 }
294 }
295
296 #[test]
297 fn test_subtraction_of_constants() {
298 let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(10)));
299 let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
300 let result = a - b;
301 match result {
302 SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(6)),
303 _ => panic!("Subtraction of constants did not simplify correctly"),
304 }
305 }
306
307 #[test]
308 fn test_negation() {
309 let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(7)));
310 let result = -a;
311 match result {
312 SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => {
313 assert_eq!(val, BabyBear::NEG_ONE * BabyBear::new(7));
314 }
315 _ => panic!("Negation did not work correctly"),
316 }
317 }
318
319 #[test]
320 fn test_multiplication_of_constants() {
321 let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
322 let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
323 let result = a * b;
324 match result {
325 SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(15)),
326 _ => panic!("Multiplication of constants did not simplify correctly"),
327 }
328 }
329
330 #[test]
331 fn test_degree_multiple_for_addition() {
332 let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
333 BaseEntry::Main { offset: 0 },
334 1,
335 )));
336 let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
337 BaseEntry::Main { offset: 0 },
338 2,
339 )));
340 let result = a + b;
341 match result {
342 SymbolicExpr::Add {
343 degree_multiple,
344 x,
345 y,
346 } => {
347 assert_eq!(degree_multiple, 1);
348 assert!(
349 matches!(&*x, SymbolicExpr::Leaf(BaseLeaf::Variable(v)) if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 }))
350 );
351 assert!(
352 matches!(&*y, SymbolicExpr::Leaf(BaseLeaf::Variable(v)) if v.index == 2 && matches!(v.entry, BaseEntry::Main { offset: 0 }))
353 );
354 }
355 _ => panic!("Addition did not create an Add expression"),
356 }
357 }
358
359 #[test]
360 fn test_degree_multiple_for_multiplication() {
361 let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
362 BaseEntry::Main { offset: 0 },
363 1,
364 )));
365 let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
366 BaseEntry::Main { offset: 0 },
367 2,
368 )));
369 let result = a * b;
370
371 match result {
372 SymbolicExpr::Mul {
373 degree_multiple,
374 x,
375 y,
376 } => {
377 assert_eq!(degree_multiple, 2, "Multiplication should sum degrees");
378
379 assert!(
380 matches!(&*x, SymbolicExpr::Leaf(BaseLeaf::Variable(v))
381 if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 })
382 ),
383 "Left operand should match `a`"
384 );
385
386 assert!(
387 matches!(&*y, SymbolicExpr::Leaf(BaseLeaf::Variable(v))
388 if v.index == 2 && matches!(v.entry, BaseEntry::Main { offset: 0 })
389 ),
390 "Right operand should match `b`"
391 );
392 }
393 _ => panic!("Multiplication did not create a `Mul` expression"),
394 }
395 }
396
397 #[test]
398 fn test_sum_operator() {
399 let expressions = vec![
400 SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(2))),
401 SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3))),
402 SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5))),
403 ];
404 let result: SymbolicExpression<BabyBear> = expressions.into_iter().sum();
405 match result {
406 SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(10)),
407 _ => panic!("Sum did not produce correct result"),
408 }
409 }
410
411 #[test]
412 fn test_product_operator() {
413 let expressions = vec![
414 SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(2))),
415 SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3))),
416 SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4))),
417 ];
418 let result: SymbolicExpression<BabyBear> = expressions.into_iter().product();
419 match result {
420 SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(24)),
421 _ => panic!("Product did not produce correct result"),
422 }
423 }
424
425 #[test]
426 fn test_default_is_zero() {
427 let expr: SymbolicExpression<BabyBear> = Default::default();
429
430 assert!(matches!(
432 expr,
433 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
434 ));
435 }
436
437 #[test]
438 fn test_ring_constants() {
439 assert!(matches!(
441 SymbolicExpression::<BabyBear>::ZERO,
442 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
443 ));
444 assert!(matches!(
446 SymbolicExpression::<BabyBear>::ONE,
447 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ONE
448 ));
449 assert!(matches!(
451 SymbolicExpression::<BabyBear>::TWO,
452 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::TWO
453 ));
454 assert!(matches!(
456 SymbolicExpression::<BabyBear>::NEG_ONE,
457 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::NEG_ONE
458 ));
459 }
460
461 #[test]
462 fn test_from_symbolic_variable() {
463 let var = SymbolicVariable::<BabyBear>::new(BaseEntry::Main { offset: 0 }, 3);
465 let expr: SymbolicExpression<BabyBear> = var.into();
467 match expr {
469 SymbolicExpr::Leaf(BaseLeaf::Variable(v)) => {
470 assert!(matches!(v.entry, BaseEntry::Main { offset: 0 }));
471 assert_eq!(v.index, 3);
472 }
473 _ => panic!("Expected Variable variant"),
474 }
475 }
476
477 #[test]
478 fn test_from_field_element() {
479 let field_val = BabyBear::new(42);
481 let expr: SymbolicExpression<BabyBear> = field_val.into();
482 assert!(matches!(
484 expr,
485 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == field_val
486 ));
487 }
488
489 #[test]
490 fn test_from_prime_subfield() {
491 let prime_subfield_val = <BabyBear as PrimeCharacteristicRing>::PrimeSubfield::new(7);
493 let expr = SymbolicExpression::<BabyBear>::from_prime_subfield(prime_subfield_val);
494 assert!(matches!(
496 expr,
497 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(7)
498 ));
499 }
500
501 #[test]
502 fn test_assign_operators() {
503 let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
505 expr += SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
506 assert!(matches!(
507 expr,
508 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(8)
509 ));
510
511 let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(10)));
513 expr -= SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
514 assert!(matches!(
515 expr,
516 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(6)
517 ));
518
519 let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(6)));
521 expr *= SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(7)));
522 assert!(matches!(
523 expr,
524 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(42)
525 ));
526 }
527
528 #[test]
529 fn test_subtraction_creates_sub_node() {
530 let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
532 BaseEntry::Main { offset: 0 },
533 0,
534 )));
535 let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
536 BaseEntry::Main { offset: 0 },
537 1,
538 )));
539
540 let result = a - b;
542
543 match result {
545 SymbolicExpr::Sub {
546 x,
547 y,
548 degree_multiple,
549 } => {
550 assert_eq!(degree_multiple, 1);
552
553 assert!(matches!(
555 x.as_ref(),
556 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
557 if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
558 ));
559
560 assert!(matches!(
562 y.as_ref(),
563 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
564 if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 })
565 ));
566 }
567 _ => panic!("Expected Sub variant"),
568 }
569 }
570
571 #[test]
572 fn test_negation_creates_neg_node() {
573 let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
575 BaseEntry::Main { offset: 0 },
576 0,
577 )));
578
579 let result = -var;
581
582 match result {
584 SymbolicExpr::Neg { x, degree_multiple } => {
585 assert_eq!(degree_multiple, 1);
587
588 assert!(matches!(
590 x.as_ref(),
591 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
592 if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
593 ));
594 }
595 _ => panic!("Expected Neg variant"),
596 }
597 }
598
599 #[test]
600 fn test_empty_sum_returns_zero() {
601 let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
603 let result: SymbolicExpression<BabyBear> = empty.into_iter().sum();
604 assert!(matches!(
605 result,
606 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
607 ));
608 }
609
610 #[test]
611 fn test_empty_product_returns_one() {
612 let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
614 let result: SymbolicExpression<BabyBear> = empty.into_iter().product();
615 assert!(matches!(
616 result,
617 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ONE
618 ));
619 }
620
621 #[test]
622 fn test_mixed_degree_addition() {
623 let constant = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
625
626 let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
628 BaseEntry::Main { offset: 0 },
629 0,
630 )));
631
632 let result = constant + var;
634
635 match result {
636 SymbolicExpr::Add {
637 x,
638 y,
639 degree_multiple,
640 } => {
641 assert_eq!(degree_multiple, 1);
643
644 assert!(matches!(
646 x.as_ref(),
647 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if *c == BabyBear::new(5)
648 ));
649
650 assert!(matches!(
652 y.as_ref(),
653 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
654 if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
655 ));
656 }
657 _ => panic!("Expected Add variant"),
658 }
659 }
660
661 #[test]
662 fn test_chained_multiplication_degree() {
663 let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
665 BaseEntry::Main { offset: 0 },
666 0,
667 )));
668 let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
669 BaseEntry::Main { offset: 0 },
670 1,
671 )));
672 let c = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
673 BaseEntry::Main { offset: 0 },
674 2,
675 )));
676
677 let ab = a * b;
679 assert_eq!(ab.degree_multiple(), 2);
680
681 let abc = ab * c;
683 assert_eq!(abc.degree_multiple(), 3);
684 }
685
686 #[test]
687 fn test_add_zero_identity_folding() {
688 let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
689 BaseEntry::Main { offset: 0 },
690 0,
691 )));
692 let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
693
694 let result = var.clone() + zero.clone();
696 assert!(
697 matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
698 "x + 0 should fold to x"
699 );
700
701 let result = zero + var;
703 assert!(
704 matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
705 "0 + x should fold to x"
706 );
707 }
708
709 #[test]
710 fn test_sub_zero_identity_folding() {
711 let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
712 BaseEntry::Main { offset: 0 },
713 0,
714 )));
715 let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
716
717 let result = var.clone() - zero.clone();
719 assert!(
720 matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
721 "x - 0 should fold to x"
722 );
723
724 let result = zero - var;
726 match result {
727 SymbolicExpr::Neg { x, degree_multiple } => {
728 assert_eq!(degree_multiple, 1);
729 assert!(matches!(
730 x.as_ref(),
731 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
732 if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
733 ));
734 }
735 _ => panic!("0 - x should fold to Neg(x)"),
736 }
737 }
738
739 #[test]
740 fn test_mul_zero_identity_folding() {
741 let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
742 BaseEntry::Main { offset: 0 },
743 0,
744 )));
745 let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
746
747 let result = var.clone() * zero.clone();
749 assert!(
750 matches!(result, SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO),
751 "x * 0 should fold to 0"
752 );
753
754 let result = zero * var;
756 assert!(
757 matches!(result, SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO),
758 "0 * x should fold to 0"
759 );
760 }
761
762 #[test]
763 fn test_mul_one_identity_folding() {
764 let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
765 BaseEntry::Main { offset: 0 },
766 0,
767 )));
768 let one = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ONE));
769
770 let result = var.clone() * one.clone();
772 assert!(
773 matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
774 "x * 1 should fold to x"
775 );
776
777 let result = one * var;
779 assert!(
780 matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
781 "1 * x should fold to x"
782 );
783 }
784
785 #[test]
786 fn test_identity_folding_preserves_degree() {
787 let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
788 BaseEntry::Main { offset: 0 },
789 0,
790 )));
791 let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
792 let one = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ONE));
793
794 let result = var.clone() + zero.clone();
796 assert_eq!(result.degree_multiple(), 1);
797
798 let result = var.clone() - zero.clone();
800 assert_eq!(result.degree_multiple(), 1);
801
802 let result = zero.clone() - var.clone();
804 assert_eq!(result.degree_multiple(), 1);
805
806 let result = var.clone() * one;
808 assert_eq!(result.degree_multiple(), 1);
809
810 let result = var * zero;
812 assert_eq!(result.degree_multiple(), 0);
813 }
814
815 struct ResolveTestBuilder {
823 main: RowMajorMatrix<BabyBear>,
824 public_values: Vec<BabyBear>,
825 periodic_row: Vec<BabyBear>,
826 is_first: BabyBear,
827 is_last: BabyBear,
828 is_transition: BabyBear,
829 }
830
831 impl AirBuilder for ResolveTestBuilder {
832 type F = BabyBear;
833 type Expr = BabyBear;
834 type Var = BabyBear;
835 type PreprocessedWindow = RowMajorMatrix<BabyBear>;
836 type MainWindow = RowMajorMatrix<BabyBear>;
837 type PublicVar = BabyBear;
838 type PeriodicVar = BabyBear;
839
840 fn main(&self) -> Self::MainWindow {
841 self.main.clone()
842 }
843
844 fn preprocessed(&self) -> &Self::PreprocessedWindow {
845 unimplemented!("no preprocessed columns in test builder")
846 }
847
848 fn is_first_row(&self) -> Self::Expr {
849 self.is_first
850 }
851
852 fn is_last_row(&self) -> Self::Expr {
853 self.is_last
854 }
855
856 fn is_transition(&self) -> Self::Expr {
857 self.is_transition
858 }
859
860 fn assert_zero<I: Into<Self::Expr>>(&mut self, _: I) {}
861
862 fn public_values(&self) -> &[Self::PublicVar] {
863 &self.public_values
864 }
865
866 fn periodic_values(&self) -> &[Self::PeriodicVar] {
867 &self.periodic_row
868 }
869 }
870
871 fn test_builder() -> ResolveTestBuilder {
879 ResolveTestBuilder {
880 main: RowMajorMatrix::new(
881 vec![
882 BabyBear::new(10),
883 BabyBear::new(20), BabyBear::new(30),
885 BabyBear::new(40), ],
887 2, ),
889 public_values: vec![BabyBear::new(99)],
890 periodic_row: vec![BabyBear::new(7), BabyBear::new(13)],
893 is_first: BabyBear::ONE,
894 is_last: BabyBear::ZERO,
895 is_transition: BabyBear::ONE,
896 }
897 }
898
899 #[test]
900 fn resolve_main_current_row() {
901 let b = test_builder();
902 let expr =
904 SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 0));
905 assert_eq!(expr.resolve(&b), BabyBear::new(10));
906 }
907
908 #[test]
909 fn resolve_main_next_row() {
910 let b = test_builder();
911 let expr =
913 SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 1 }, 1));
914 assert_eq!(expr.resolve(&b), BabyBear::new(40));
915 }
916
917 #[test]
918 fn resolve_public_value() {
919 let b = test_builder();
920 let expr = SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Public, 0));
922 assert_eq!(expr.resolve(&b), BabyBear::new(99));
923 }
924
925 #[test]
926 fn resolve_constant() {
927 let b = test_builder();
928 let expr = SymbolicExpression::<BabyBear>::from(BabyBear::new(42));
929 assert_eq!(expr.resolve(&b), BabyBear::new(42));
930 }
931
932 #[test]
933 fn resolve_selectors() {
934 let b = test_builder();
935
936 let first = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsFirstRow);
937 assert_eq!(first.resolve(&b), BabyBear::ONE, "is_first_row = 1");
938
939 let last = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsLastRow);
940 assert_eq!(last.resolve(&b), BabyBear::ZERO, "is_last_row = 0");
941
942 let trans = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsTransition);
943 assert_eq!(trans.resolve(&b), BabyBear::ONE, "is_transition = 1");
944 }
945
946 #[test]
947 fn resolve_arithmetic() {
948 let b = test_builder();
949
950 let col0 =
952 SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 0));
953 let col1 =
954 SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 1));
955
956 let add = col0.clone() + col1.clone();
958 assert_eq!(add.resolve(&b), BabyBear::new(30));
959
960 let sub = col0.clone() - col1.clone();
962 assert_eq!(sub.resolve(&b), BabyBear::new(10) - BabyBear::new(20));
963
964 let mul = col0.clone() * col1;
966 assert_eq!(mul.resolve(&b), BabyBear::new(200));
967
968 let neg = -col0;
970 assert_eq!(neg.resolve(&b), -BabyBear::new(10));
971 }
972
973 #[test]
974 fn resolve_periodic_columns() {
975 let b = test_builder();
984
985 let p0 =
987 SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 0));
988 assert_eq!(p0.resolve(&b), BabyBear::new(7));
989
990 let p1 =
992 SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 1));
993 assert_eq!(p1.resolve(&b), BabyBear::new(13));
994 }
995
996 #[test]
997 fn resolve_periodic_combines_with_arithmetic() {
998 let b = test_builder();
1010
1011 let col0 =
1012 SymbolicExpression::from(SymbolicVariable::new(BaseEntry::Main { offset: 0 }, 0));
1013 let p0 =
1014 SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 0));
1015 let p1 =
1016 SymbolicExpression::from(SymbolicVariable::<BabyBear>::new(BaseEntry::Periodic, 1));
1017
1018 let expr = col0 * p0 + p1;
1019 assert_eq!(expr.resolve(&b), BabyBear::new(83));
1020 }
1021}