1use alloc::sync::Arc;
2use core::fmt::Debug;
3use core::iter::{Product, Sum};
4use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
5
6use p3_field::extension::BinomialExtensionField;
7use p3_field::{Algebra, ExtensionField, Field, InjectiveMonomial, PrimeCharacteristicRing};
8
9use crate::symbolic_variable::SymbolicVariable;
10
11impl<F, const D: usize> From<SymbolicExpression<F>>
12 for SymbolicExpression<BinomialExtensionField<F, D>>
13where
14 F: Field,
15 BinomialExtensionField<F, D>: ExtensionField<F>,
16{
17 fn from(expr: SymbolicExpression<F>) -> Self {
24 match expr {
25 SymbolicExpression::Variable(v) => {
26 Self::Variable(SymbolicVariable::new(v.entry, v.index))
27 }
28 SymbolicExpression::IsFirstRow => Self::IsFirstRow,
29 SymbolicExpression::IsLastRow => Self::IsLastRow,
30 SymbolicExpression::IsTransition => Self::IsTransition,
31 SymbolicExpression::Constant(c) => {
32 Self::Constant(BinomialExtensionField::<F, D>::from(c))
34 }
35 SymbolicExpression::Add {
36 x,
37 y,
38 degree_multiple,
39 } => Self::Add {
40 x: Arc::new(Self::from((*x).clone())),
41 y: Arc::new(Self::from((*y).clone())),
42 degree_multiple,
43 },
44 SymbolicExpression::Sub {
45 x,
46 y,
47 degree_multiple,
48 } => Self::Sub {
49 x: Arc::new(Self::from((*x).clone())),
50 y: Arc::new(Self::from((*y).clone())),
51 degree_multiple,
52 },
53 SymbolicExpression::Neg { x, degree_multiple } => Self::Neg {
54 x: Arc::new(Self::from((*x).clone())),
55 degree_multiple,
56 },
57 SymbolicExpression::Mul {
58 x,
59 y,
60 degree_multiple,
61 } => Self::Mul {
62 x: Arc::new(Self::from((*x).clone())),
63 y: Arc::new(Self::from((*y).clone())),
64 degree_multiple,
65 },
66 }
67 }
68}
69
70#[derive(Clone, Debug)]
78pub enum SymbolicExpression<F> {
79 Variable(SymbolicVariable<F>),
83
84 IsFirstRow,
90
91 IsLastRow,
97
98 IsTransition,
104
105 Constant(F),
107
108 Add {
110 x: Arc<Self>,
112 y: Arc<Self>,
114 degree_multiple: usize,
116 },
117
118 Sub {
120 x: Arc<Self>,
122 y: Arc<Self>,
124 degree_multiple: usize,
126 },
127
128 Neg {
130 x: Arc<Self>,
132 degree_multiple: usize,
134 },
135
136 Mul {
138 x: Arc<Self>,
140 y: Arc<Self>,
142 degree_multiple: usize,
144 },
145}
146
147impl<F> SymbolicExpression<F> {
148 pub const fn degree_multiple(&self) -> usize {
172 match self {
173 Self::Variable(v) => v.degree_multiple(),
174 Self::IsFirstRow | Self::IsLastRow => 1,
175 Self::IsTransition | Self::Constant(_) => 0,
176 Self::Add {
177 degree_multiple, ..
178 }
179 | Self::Sub {
180 degree_multiple, ..
181 }
182 | Self::Neg {
183 degree_multiple, ..
184 }
185 | Self::Mul {
186 degree_multiple, ..
187 } => *degree_multiple,
188 }
189 }
190}
191
192impl<F: Field> Default for SymbolicExpression<F> {
193 fn default() -> Self {
194 Self::Constant(F::ZERO)
195 }
196}
197
198impl<F: Field, EF: ExtensionField<F>> From<SymbolicVariable<F>> for SymbolicExpression<EF> {
199 fn from(var: SymbolicVariable<F>) -> Self {
200 Self::Variable(SymbolicVariable::new(var.entry, var.index))
201 }
202}
203
204impl<F: Field, EF: ExtensionField<F>> From<F> for SymbolicExpression<EF> {
205 fn from(var: F) -> Self {
206 Self::Constant(var.into())
207 }
208}
209
210impl<F: Field> PrimeCharacteristicRing for SymbolicExpression<F> {
211 type PrimeSubfield = F::PrimeSubfield;
212
213 const ZERO: Self = Self::Constant(F::ZERO);
214 const ONE: Self = Self::Constant(F::ONE);
215 const TWO: Self = Self::Constant(F::TWO);
216 const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
217
218 #[inline]
219 fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
220 F::from_prime_subfield(f).into()
221 }
222}
223
224impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
225
226impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
227
228impl<F: Field + InjectiveMonomial<N>, const N: u64> InjectiveMonomial<N> for SymbolicExpression<F> {}
231
232impl<F: Field, T> Add<T> for SymbolicExpression<F>
233where
234 T: Into<Self>,
235{
236 type Output = Self;
237
238 fn add(self, rhs: T) -> Self {
239 match (self, rhs.into()) {
240 (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs + rhs),
241 (lhs, rhs) => Self::Add {
242 degree_multiple: lhs.degree_multiple().max(rhs.degree_multiple()),
243 x: Arc::new(lhs),
244 y: Arc::new(rhs),
245 },
246 }
247 }
248}
249
250impl<F: Field, T> AddAssign<T> for SymbolicExpression<F>
251where
252 T: Into<Self>,
253{
254 fn add_assign(&mut self, rhs: T) {
255 *self = self.clone() + rhs.into();
256 }
257}
258
259impl<F: Field, T> Sum<T> for SymbolicExpression<F>
260where
261 T: Into<Self>,
262{
263 fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
264 iter.map(Into::into)
265 .reduce(|x, y| x + y)
266 .unwrap_or(Self::ZERO)
267 }
268}
269
270impl<F: Field, T: Into<Self>> Sub<T> for SymbolicExpression<F> {
271 type Output = Self;
272
273 fn sub(self, rhs: T) -> Self {
274 match (self, rhs.into()) {
275 (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs - rhs),
276 (lhs, rhs) => Self::Sub {
277 degree_multiple: lhs.degree_multiple().max(rhs.degree_multiple()),
278 x: Arc::new(lhs),
279 y: Arc::new(rhs),
280 },
281 }
282 }
283}
284
285impl<F: Field, T> SubAssign<T> for SymbolicExpression<F>
286where
287 T: Into<Self>,
288{
289 fn sub_assign(&mut self, rhs: T) {
290 *self = self.clone() - rhs.into();
291 }
292}
293
294impl<F: Field> Neg for SymbolicExpression<F> {
295 type Output = Self;
296
297 fn neg(self) -> Self {
298 match self {
299 Self::Constant(c) => Self::Constant(-c),
300 expr => Self::Neg {
301 degree_multiple: expr.degree_multiple(),
302 x: Arc::new(expr),
303 },
304 }
305 }
306}
307
308impl<F: Field, T: Into<Self>> Mul<T> for SymbolicExpression<F> {
309 type Output = Self;
310
311 fn mul(self, rhs: T) -> Self {
312 match (self, rhs.into()) {
313 (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs * rhs),
314 (lhs, rhs) => Self::Mul {
315 degree_multiple: lhs.degree_multiple() + rhs.degree_multiple(),
316 x: Arc::new(lhs),
317 y: Arc::new(rhs),
318 },
319 }
320 }
321}
322
323impl<F: Field, T> MulAssign<T> for SymbolicExpression<F>
324where
325 T: Into<Self>,
326{
327 fn mul_assign(&mut self, rhs: T) {
328 *self = self.clone() * rhs.into();
329 }
330}
331
332impl<F: Field, T: Into<Self>> Product<T> for SymbolicExpression<F> {
333 fn product<I: Iterator<Item = T>>(iter: I) -> Self {
334 iter.map(Into::into)
335 .reduce(|x, y| x * y)
336 .unwrap_or(Self::ONE)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use alloc::vec;
343 use alloc::vec::Vec;
344
345 use p3_baby_bear::BabyBear;
346
347 use super::*;
348 use crate::Entry;
349
350 #[test]
351 fn test_symbolic_expression_degree_multiple() {
352 let constant_expr = SymbolicExpression::<BabyBear>::Constant(BabyBear::new(5));
353 assert_eq!(
354 constant_expr.degree_multiple(),
355 0,
356 "Constant should have degree 0"
357 );
358
359 let variable_expr =
360 SymbolicExpression::Variable(SymbolicVariable::new(Entry::Main { offset: 0 }, 1));
361 assert_eq!(
362 variable_expr.degree_multiple(),
363 1,
364 "Main variable should have degree 1"
365 );
366
367 let preprocessed_var = SymbolicExpression::Variable(SymbolicVariable::new(
368 Entry::Preprocessed { offset: 0 },
369 2,
370 ));
371 assert_eq!(
372 preprocessed_var.degree_multiple(),
373 1,
374 "Preprocessed variable should have degree 1"
375 );
376
377 let permutation_var = SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(
378 Entry::Permutation { offset: 0 },
379 3,
380 ));
381 assert_eq!(
382 permutation_var.degree_multiple(),
383 1,
384 "Permutation variable should have degree 1"
385 );
386
387 let public_var =
388 SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(Entry::Public, 4));
389 assert_eq!(
390 public_var.degree_multiple(),
391 0,
392 "Public variable should have degree 0"
393 );
394
395 let challenge_var =
396 SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(Entry::Challenge, 5));
397 assert_eq!(
398 challenge_var.degree_multiple(),
399 0,
400 "Challenge variable should have degree 0"
401 );
402
403 let is_first_row = SymbolicExpression::<BabyBear>::IsFirstRow;
404 assert_eq!(
405 is_first_row.degree_multiple(),
406 1,
407 "IsFirstRow should have degree 1"
408 );
409
410 let is_last_row = SymbolicExpression::<BabyBear>::IsLastRow;
411 assert_eq!(
412 is_last_row.degree_multiple(),
413 1,
414 "IsLastRow should have degree 1"
415 );
416
417 let is_transition = SymbolicExpression::<BabyBear>::IsTransition;
418 assert_eq!(
419 is_transition.degree_multiple(),
420 0,
421 "IsTransition should have degree 0"
422 );
423
424 let add_expr = SymbolicExpression::<BabyBear>::Add {
425 x: Arc::new(variable_expr.clone()),
426 y: Arc::new(preprocessed_var.clone()),
427 degree_multiple: 1,
428 };
429 assert_eq!(
430 add_expr.degree_multiple(),
431 1,
432 "Addition should take max degree of inputs"
433 );
434
435 let sub_expr = SymbolicExpression::<BabyBear>::Sub {
436 x: Arc::new(variable_expr.clone()),
437 y: Arc::new(preprocessed_var.clone()),
438 degree_multiple: 1,
439 };
440 assert_eq!(
441 sub_expr.degree_multiple(),
442 1,
443 "Subtraction should take max degree of inputs"
444 );
445
446 let neg_expr = SymbolicExpression::<BabyBear>::Neg {
447 x: Arc::new(variable_expr.clone()),
448 degree_multiple: 1,
449 };
450 assert_eq!(
451 neg_expr.degree_multiple(),
452 1,
453 "Negation should keep the degree"
454 );
455
456 let mul_expr = SymbolicExpression::<BabyBear>::Mul {
457 x: Arc::new(variable_expr),
458 y: Arc::new(preprocessed_var),
459 degree_multiple: 2,
460 };
461 assert_eq!(
462 mul_expr.degree_multiple(),
463 2,
464 "Multiplication should sum degrees"
465 );
466 }
467
468 #[test]
469 fn test_addition_of_constants() {
470 let a = SymbolicExpression::Constant(BabyBear::new(3));
471 let b = SymbolicExpression::Constant(BabyBear::new(4));
472 let result = a + b;
473 match result {
474 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(7)),
475 _ => panic!("Addition of constants did not simplify correctly"),
476 }
477 }
478
479 #[test]
480 fn test_subtraction_of_constants() {
481 let a = SymbolicExpression::Constant(BabyBear::new(10));
482 let b = SymbolicExpression::Constant(BabyBear::new(4));
483 let result = a - b;
484 match result {
485 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(6)),
486 _ => panic!("Subtraction of constants did not simplify correctly"),
487 }
488 }
489
490 #[test]
491 fn test_negation() {
492 let a = SymbolicExpression::Constant(BabyBear::new(7));
493 let result = -a;
494 match result {
495 SymbolicExpression::Constant(val) => {
496 assert_eq!(val, BabyBear::NEG_ONE * BabyBear::new(7));
497 }
498 _ => panic!("Negation did not work correctly"),
499 }
500 }
501
502 #[test]
503 fn test_multiplication_of_constants() {
504 let a = SymbolicExpression::Constant(BabyBear::new(3));
505 let b = SymbolicExpression::Constant(BabyBear::new(5));
506 let result = a * b;
507 match result {
508 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(15)),
509 _ => panic!("Multiplication of constants did not simplify correctly"),
510 }
511 }
512
513 #[test]
514 fn test_degree_multiple_for_addition() {
515 let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
516 Entry::Main { offset: 0 },
517 1,
518 ));
519 let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
520 Entry::Main { offset: 0 },
521 2,
522 ));
523 let result = a + b;
524 match result {
525 SymbolicExpression::Add {
526 degree_multiple,
527 x,
528 y,
529 } => {
530 assert_eq!(degree_multiple, 1);
531 assert!(
532 matches!(*x, SymbolicExpression::Variable(ref v) if v.index == 1 && matches!(v.entry, Entry::Main { offset: 0 }))
533 );
534 assert!(
535 matches!(*y, SymbolicExpression::Variable(ref v) if v.index == 2 && matches!(v.entry, Entry::Main { offset: 0 }))
536 );
537 }
538 _ => panic!("Addition did not create an Add expression"),
539 }
540 }
541
542 #[test]
543 fn test_degree_multiple_for_multiplication() {
544 let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
545 Entry::Main { offset: 0 },
546 1,
547 ));
548 let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
549 Entry::Main { offset: 0 },
550 2,
551 ));
552 let result = a * b;
553
554 match result {
555 SymbolicExpression::Mul {
556 degree_multiple,
557 x,
558 y,
559 } => {
560 assert_eq!(degree_multiple, 2, "Multiplication should sum degrees");
561
562 assert!(
563 matches!(*x, SymbolicExpression::Variable(ref v)
564 if v.index == 1 && matches!(v.entry, Entry::Main { offset: 0 })
565 ),
566 "Left operand should match `a`"
567 );
568
569 assert!(
570 matches!(*y, SymbolicExpression::Variable(ref v)
571 if v.index == 2 && matches!(v.entry, Entry::Main { offset: 0 })
572 ),
573 "Right operand should match `b`"
574 );
575 }
576 _ => panic!("Multiplication did not create a `Mul` expression"),
577 }
578 }
579
580 #[test]
581 fn test_sum_operator() {
582 let expressions = vec![
583 SymbolicExpression::Constant(BabyBear::new(2)),
584 SymbolicExpression::Constant(BabyBear::new(3)),
585 SymbolicExpression::Constant(BabyBear::new(5)),
586 ];
587 let result: SymbolicExpression<BabyBear> = expressions.into_iter().sum();
588 match result {
589 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(10)),
590 _ => panic!("Sum did not produce correct result"),
591 }
592 }
593
594 #[test]
595 fn test_product_operator() {
596 let expressions = vec![
597 SymbolicExpression::Constant(BabyBear::new(2)),
598 SymbolicExpression::Constant(BabyBear::new(3)),
599 SymbolicExpression::Constant(BabyBear::new(4)),
600 ];
601 let result: SymbolicExpression<BabyBear> = expressions.into_iter().product();
602 match result {
603 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(24)),
604 _ => panic!("Product did not produce correct result"),
605 }
606 }
607
608 #[test]
609 fn test_default_is_zero() {
610 let expr: SymbolicExpression<BabyBear> = Default::default();
612
613 assert!(matches!(
615 expr,
616 SymbolicExpression::Constant(c) if c == BabyBear::ZERO
617 ));
618 }
619
620 #[test]
621 fn test_ring_constants() {
622 assert!(matches!(
624 SymbolicExpression::<BabyBear>::ZERO,
625 SymbolicExpression::Constant(c) if c == BabyBear::ZERO
626 ));
627
628 assert!(matches!(
630 SymbolicExpression::<BabyBear>::ONE,
631 SymbolicExpression::Constant(c) if c == BabyBear::ONE
632 ));
633
634 assert!(matches!(
636 SymbolicExpression::<BabyBear>::TWO,
637 SymbolicExpression::Constant(c) if c == BabyBear::TWO
638 ));
639
640 assert!(matches!(
642 SymbolicExpression::<BabyBear>::NEG_ONE,
643 SymbolicExpression::Constant(c) if c == BabyBear::NEG_ONE
644 ));
645 }
646
647 #[test]
648 fn test_from_symbolic_variable() {
649 let var = SymbolicVariable::<BabyBear>::new(Entry::Main { offset: 0 }, 3);
651
652 let expr: SymbolicExpression<BabyBear> = var.into();
654
655 match expr {
657 SymbolicExpression::Variable(v) => {
658 assert!(matches!(v.entry, Entry::Main { offset: 0 }));
659 assert_eq!(v.index, 3);
660 }
661 _ => panic!("Expected Variable variant"),
662 }
663 }
664
665 #[test]
666 fn test_from_field_element() {
667 let field_val = BabyBear::new(42);
669 let expr: SymbolicExpression<BabyBear> = field_val.into();
670
671 assert!(matches!(
673 expr,
674 SymbolicExpression::Constant(c) if c == field_val
675 ));
676 }
677
678 #[test]
679 fn test_from_prime_subfield() {
680 let prime_subfield_val = <BabyBear as PrimeCharacteristicRing>::PrimeSubfield::new(7);
682 let expr = SymbolicExpression::<BabyBear>::from_prime_subfield(prime_subfield_val);
683
684 assert!(matches!(
686 expr,
687 SymbolicExpression::Constant(c) if c == BabyBear::new(7)
688 ));
689 }
690
691 #[test]
692 fn test_assign_operators() {
693 let mut expr = SymbolicExpression::Constant(BabyBear::new(5));
695 expr += SymbolicExpression::Constant(BabyBear::new(3));
696 assert!(matches!(
697 expr,
698 SymbolicExpression::Constant(c) if c == BabyBear::new(8)
699 ));
700
701 let mut expr = SymbolicExpression::Constant(BabyBear::new(10));
703 expr -= SymbolicExpression::Constant(BabyBear::new(4));
704 assert!(matches!(
705 expr,
706 SymbolicExpression::Constant(c) if c == BabyBear::new(6)
707 ));
708
709 let mut expr = SymbolicExpression::Constant(BabyBear::new(6));
711 expr *= SymbolicExpression::Constant(BabyBear::new(7));
712 assert!(matches!(
713 expr,
714 SymbolicExpression::Constant(c) if c == BabyBear::new(42)
715 ));
716 }
717
718 #[test]
719 fn test_subtraction_creates_sub_node() {
720 let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
722 Entry::Main { offset: 0 },
723 0,
724 ));
725 let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
726 Entry::Main { offset: 0 },
727 1,
728 ));
729
730 let result = a - b;
732
733 match result {
735 SymbolicExpression::Sub {
736 x,
737 y,
738 degree_multiple,
739 } => {
740 assert_eq!(degree_multiple, 1);
742
743 assert!(matches!(
745 x.as_ref(),
746 SymbolicExpression::Variable(v)
747 if v.index == 0 && matches!(v.entry, Entry::Main { offset: 0 })
748 ));
749
750 assert!(matches!(
752 y.as_ref(),
753 SymbolicExpression::Variable(v)
754 if v.index == 1 && matches!(v.entry, Entry::Main { offset: 0 })
755 ));
756 }
757 _ => panic!("Expected Sub variant"),
758 }
759 }
760
761 #[test]
762 fn test_negation_creates_neg_node() {
763 let var = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
765 Entry::Main { offset: 0 },
766 0,
767 ));
768
769 let result = -var;
771
772 match result {
774 SymbolicExpression::Neg { x, degree_multiple } => {
775 assert_eq!(degree_multiple, 1);
777
778 assert!(matches!(
780 x.as_ref(),
781 SymbolicExpression::Variable(v)
782 if v.index == 0 && matches!(v.entry, Entry::Main { offset: 0 })
783 ));
784 }
785 _ => panic!("Expected Neg variant"),
786 }
787 }
788
789 #[test]
790 fn test_empty_sum_returns_zero() {
791 let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
793 let result: SymbolicExpression<BabyBear> = empty.into_iter().sum();
794
795 assert!(matches!(
796 result,
797 SymbolicExpression::Constant(c) if c == BabyBear::ZERO
798 ));
799 }
800
801 #[test]
802 fn test_empty_product_returns_one() {
803 let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
805 let result: SymbolicExpression<BabyBear> = empty.into_iter().product();
806
807 assert!(matches!(
808 result,
809 SymbolicExpression::Constant(c) if c == BabyBear::ONE
810 ));
811 }
812
813 #[test]
814 fn test_mixed_degree_addition() {
815 let constant = SymbolicExpression::Constant(BabyBear::new(5));
817
818 let var = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
820 Entry::Main { offset: 0 },
821 0,
822 ));
823
824 let result = constant + var;
826
827 match result {
828 SymbolicExpression::Add {
829 x,
830 y,
831 degree_multiple,
832 } => {
833 assert_eq!(degree_multiple, 1);
835
836 assert!(matches!(
838 x.as_ref(),
839 SymbolicExpression::Constant(c) if *c == BabyBear::new(5)
840 ));
841
842 assert!(matches!(
844 y.as_ref(),
845 SymbolicExpression::Variable(v)
846 if v.index == 0 && matches!(v.entry, Entry::Main { offset: 0 })
847 ));
848 }
849 _ => panic!("Expected Add variant"),
850 }
851 }
852
853 #[test]
854 fn test_chained_multiplication_degree() {
855 let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
857 Entry::Main { offset: 0 },
858 0,
859 ));
860 let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
861 Entry::Main { offset: 0 },
862 1,
863 ));
864 let c = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
865 Entry::Main { offset: 0 },
866 2,
867 ));
868
869 let ab = a * b;
871 assert_eq!(ab.degree_multiple(), 2);
872
873 let abc = ab * c;
875 assert_eq!(abc.degree_multiple(), 3);
876 }
877}