p3_uni_stark/
symbolic_expression.rs

1use alloc::rc::Rc;
2use core::fmt::Debug;
3use core::iter::{Product, Sum};
4use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
5
6use p3_field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing};
7
8use crate::symbolic_variable::SymbolicVariable;
9
10/// An expression over `SymbolicVariable`s.
11#[derive(Clone, Debug)]
12pub enum SymbolicExpression<F> {
13    Variable(SymbolicVariable<F>),
14    IsFirstRow,
15    IsLastRow,
16    IsTransition,
17    Constant(F),
18    Add {
19        x: Rc<Self>,
20        y: Rc<Self>,
21        degree_multiple: usize,
22    },
23    Sub {
24        x: Rc<Self>,
25        y: Rc<Self>,
26        degree_multiple: usize,
27    },
28    Neg {
29        x: Rc<Self>,
30        degree_multiple: usize,
31    },
32    Mul {
33        x: Rc<Self>,
34        y: Rc<Self>,
35        degree_multiple: usize,
36    },
37}
38
39impl<F> SymbolicExpression<F> {
40    /// Returns the multiple of `n` (the trace length) in this expression's degree.
41    pub const fn degree_multiple(&self) -> usize {
42        match self {
43            Self::Variable(v) => v.degree_multiple(),
44            Self::IsFirstRow | Self::IsLastRow => 1,
45            Self::IsTransition | Self::Constant(_) => 0,
46            Self::Add {
47                degree_multiple, ..
48            }
49            | Self::Sub {
50                degree_multiple, ..
51            }
52            | Self::Neg {
53                degree_multiple, ..
54            }
55            | Self::Mul {
56                degree_multiple, ..
57            } => *degree_multiple,
58        }
59    }
60}
61
62impl<F: Field> Default for SymbolicExpression<F> {
63    fn default() -> Self {
64        Self::Constant(F::ZERO)
65    }
66}
67
68impl<F: Field> From<F> for SymbolicExpression<F> {
69    fn from(value: F) -> Self {
70        Self::Constant(value)
71    }
72}
73
74impl<F: Field> PrimeCharacteristicRing for SymbolicExpression<F> {
75    type PrimeSubfield = F::PrimeSubfield;
76
77    const ZERO: Self = Self::Constant(F::ZERO);
78    const ONE: Self = Self::Constant(F::ONE);
79    const TWO: Self = Self::Constant(F::TWO);
80    const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
81
82    #[inline]
83    fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
84        F::from_prime_subfield(f).into()
85    }
86}
87
88impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
89
90impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
91
92// Note we cannot implement PermutationMonomial due to the degree_multiple part which makes
93// operations non invertible.
94impl<F: Field + InjectiveMonomial<N>, const N: u64> InjectiveMonomial<N> for SymbolicExpression<F> {}
95
96impl<F: Field, T> Add<T> for SymbolicExpression<F>
97where
98    T: Into<Self>,
99{
100    type Output = Self;
101
102    fn add(self, rhs: T) -> Self {
103        match (self, rhs.into()) {
104            (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs + rhs),
105            (lhs, rhs) => Self::Add {
106                degree_multiple: lhs.degree_multiple().max(rhs.degree_multiple()),
107                x: Rc::new(lhs),
108                y: Rc::new(rhs),
109            },
110        }
111    }
112}
113
114impl<F: Field, T> AddAssign<T> for SymbolicExpression<F>
115where
116    T: Into<Self>,
117{
118    fn add_assign(&mut self, rhs: T) {
119        *self = self.clone() + rhs.into();
120    }
121}
122
123impl<F: Field, T> Sum<T> for SymbolicExpression<F>
124where
125    T: Into<Self>,
126{
127    fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
128        iter.map(Into::into)
129            .reduce(|x, y| x + y)
130            .unwrap_or(Self::ZERO)
131    }
132}
133
134impl<F: Field, T: Into<Self>> Sub<T> for SymbolicExpression<F> {
135    type Output = Self;
136
137    fn sub(self, rhs: T) -> Self {
138        match (self, rhs.into()) {
139            (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs - rhs),
140            (lhs, rhs) => Self::Sub {
141                degree_multiple: lhs.degree_multiple().max(rhs.degree_multiple()),
142                x: Rc::new(lhs),
143                y: Rc::new(rhs),
144            },
145        }
146    }
147}
148
149impl<F: Field, T> SubAssign<T> for SymbolicExpression<F>
150where
151    T: Into<Self>,
152{
153    fn sub_assign(&mut self, rhs: T) {
154        *self = self.clone() - rhs.into();
155    }
156}
157
158impl<F: Field> Neg for SymbolicExpression<F> {
159    type Output = Self;
160
161    fn neg(self) -> Self {
162        match self {
163            Self::Constant(c) => Self::Constant(-c),
164            expr => Self::Neg {
165                degree_multiple: expr.degree_multiple(),
166                x: Rc::new(expr),
167            },
168        }
169    }
170}
171
172impl<F: Field, T: Into<Self>> Mul<T> for SymbolicExpression<F> {
173    type Output = Self;
174
175    fn mul(self, rhs: T) -> Self {
176        match (self, rhs.into()) {
177            (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs * rhs),
178            (lhs, rhs) => Self::Mul {
179                degree_multiple: lhs.degree_multiple() + rhs.degree_multiple(),
180                x: Rc::new(lhs),
181                y: Rc::new(rhs),
182            },
183        }
184    }
185}
186
187impl<F: Field, T> MulAssign<T> for SymbolicExpression<F>
188where
189    T: Into<Self>,
190{
191    fn mul_assign(&mut self, rhs: T) {
192        *self = self.clone() * rhs.into();
193    }
194}
195
196impl<F: Field, T: Into<Self>> Product<T> for SymbolicExpression<F> {
197    fn product<I: Iterator<Item = T>>(iter: I) -> Self {
198        iter.map(Into::into)
199            .reduce(|x, y| x * y)
200            .unwrap_or(Self::ONE)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use alloc::vec;
207
208    use p3_baby_bear::BabyBear;
209
210    use super::*;
211    use crate::Entry;
212
213    #[test]
214    fn test_symbolic_expression_degree_multiple() {
215        let constant_expr = SymbolicExpression::<BabyBear>::Constant(BabyBear::new(5));
216        assert_eq!(
217            constant_expr.degree_multiple(),
218            0,
219            "Constant should have degree 0"
220        );
221
222        let variable_expr =
223            SymbolicExpression::Variable(SymbolicVariable::new(Entry::Main { offset: 0 }, 1));
224        assert_eq!(
225            variable_expr.degree_multiple(),
226            1,
227            "Main variable should have degree 1"
228        );
229
230        let preprocessed_var = SymbolicExpression::Variable(SymbolicVariable::new(
231            Entry::Preprocessed { offset: 0 },
232            2,
233        ));
234        assert_eq!(
235            preprocessed_var.degree_multiple(),
236            1,
237            "Preprocessed variable should have degree 1"
238        );
239
240        let permutation_var = SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(
241            Entry::Permutation { offset: 0 },
242            3,
243        ));
244        assert_eq!(
245            permutation_var.degree_multiple(),
246            1,
247            "Permutation variable should have degree 1"
248        );
249
250        let public_var =
251            SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(Entry::Public, 4));
252        assert_eq!(
253            public_var.degree_multiple(),
254            0,
255            "Public variable should have degree 0"
256        );
257
258        let challenge_var =
259            SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(Entry::Challenge, 5));
260        assert_eq!(
261            challenge_var.degree_multiple(),
262            0,
263            "Challenge variable should have degree 0"
264        );
265
266        let is_first_row = SymbolicExpression::<BabyBear>::IsFirstRow;
267        assert_eq!(
268            is_first_row.degree_multiple(),
269            1,
270            "IsFirstRow should have degree 1"
271        );
272
273        let is_last_row = SymbolicExpression::<BabyBear>::IsLastRow;
274        assert_eq!(
275            is_last_row.degree_multiple(),
276            1,
277            "IsLastRow should have degree 1"
278        );
279
280        let is_transition = SymbolicExpression::<BabyBear>::IsTransition;
281        assert_eq!(
282            is_transition.degree_multiple(),
283            0,
284            "IsTransition should have degree 0"
285        );
286
287        let add_expr = SymbolicExpression::<BabyBear>::Add {
288            x: Rc::new(variable_expr.clone()),
289            y: Rc::new(preprocessed_var.clone()),
290            degree_multiple: 1,
291        };
292        assert_eq!(
293            add_expr.degree_multiple(),
294            1,
295            "Addition should take max degree of inputs"
296        );
297
298        let sub_expr = SymbolicExpression::<BabyBear>::Sub {
299            x: Rc::new(variable_expr.clone()),
300            y: Rc::new(preprocessed_var.clone()),
301            degree_multiple: 1,
302        };
303        assert_eq!(
304            sub_expr.degree_multiple(),
305            1,
306            "Subtraction should take max degree of inputs"
307        );
308
309        let neg_expr = SymbolicExpression::<BabyBear>::Neg {
310            x: Rc::new(variable_expr.clone()),
311            degree_multiple: 1,
312        };
313        assert_eq!(
314            neg_expr.degree_multiple(),
315            1,
316            "Negation should keep the degree"
317        );
318
319        let mul_expr = SymbolicExpression::<BabyBear>::Mul {
320            x: Rc::new(variable_expr.clone()),
321            y: Rc::new(preprocessed_var.clone()),
322            degree_multiple: 2,
323        };
324        assert_eq!(
325            mul_expr.degree_multiple(),
326            2,
327            "Multiplication should sum degrees"
328        );
329    }
330
331    #[test]
332    fn test_addition_of_constants() {
333        let a = SymbolicExpression::Constant(BabyBear::new(3));
334        let b = SymbolicExpression::Constant(BabyBear::new(4));
335        let result = a + b;
336        match result {
337            SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(7)),
338            _ => panic!("Addition of constants did not simplify correctly"),
339        }
340    }
341
342    #[test]
343    fn test_subtraction_of_constants() {
344        let a = SymbolicExpression::Constant(BabyBear::new(10));
345        let b = SymbolicExpression::Constant(BabyBear::new(4));
346        let result = a - b;
347        match result {
348            SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(6)),
349            _ => panic!("Subtraction of constants did not simplify correctly"),
350        }
351    }
352
353    #[test]
354    fn test_negation() {
355        let a = SymbolicExpression::Constant(BabyBear::new(7));
356        let result = -a;
357        match result {
358            SymbolicExpression::Constant(val) => {
359                assert_eq!(val, BabyBear::NEG_ONE * BabyBear::new(7))
360            }
361            _ => panic!("Negation did not work correctly"),
362        }
363    }
364
365    #[test]
366    fn test_multiplication_of_constants() {
367        let a = SymbolicExpression::Constant(BabyBear::new(3));
368        let b = SymbolicExpression::Constant(BabyBear::new(5));
369        let result = a * b;
370        match result {
371            SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(15)),
372            _ => panic!("Multiplication of constants did not simplify correctly"),
373        }
374    }
375
376    #[test]
377    fn test_degree_multiple_for_addition() {
378        let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
379            Entry::Main { offset: 0 },
380            1,
381        ));
382        let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
383            Entry::Main { offset: 0 },
384            2,
385        ));
386        let result = a.clone() + b.clone();
387        match result {
388            SymbolicExpression::Add {
389                degree_multiple,
390                x,
391                y,
392            } => {
393                assert_eq!(degree_multiple, 1);
394                assert!(
395                    matches!(*x, SymbolicExpression::Variable(ref v) if v.index == 1 && matches!(v.entry, Entry::Main { offset: 0 }))
396                );
397                assert!(
398                    matches!(*y, SymbolicExpression::Variable(ref v) if v.index == 2 && matches!(v.entry, Entry::Main { offset: 0 }))
399                );
400            }
401            _ => panic!("Addition did not create an Add expression"),
402        }
403    }
404
405    #[test]
406    fn test_degree_multiple_for_multiplication() {
407        let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
408            Entry::Main { offset: 0 },
409            1,
410        ));
411        let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
412            Entry::Main { offset: 0 },
413            2,
414        ));
415        let result = a.clone() * b.clone();
416
417        match result {
418            SymbolicExpression::Mul {
419                degree_multiple,
420                x,
421                y,
422            } => {
423                assert_eq!(degree_multiple, 2, "Multiplication should sum degrees");
424
425                assert!(
426                    matches!(*x, SymbolicExpression::Variable(ref v)
427                        if v.index == 1 && matches!(v.entry, Entry::Main { offset: 0 })
428                    ),
429                    "Left operand should match `a`"
430                );
431
432                assert!(
433                    matches!(*y, SymbolicExpression::Variable(ref v)
434                        if v.index == 2 && matches!(v.entry, Entry::Main { offset: 0 })
435                    ),
436                    "Right operand should match `b`"
437                );
438            }
439            _ => panic!("Multiplication did not create a `Mul` expression"),
440        }
441    }
442
443    #[test]
444    fn test_sum_operator() {
445        let expressions = vec![
446            SymbolicExpression::Constant(BabyBear::new(2)),
447            SymbolicExpression::Constant(BabyBear::new(3)),
448            SymbolicExpression::Constant(BabyBear::new(5)),
449        ];
450        let result: SymbolicExpression<BabyBear> = expressions.into_iter().sum();
451        match result {
452            SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(10)),
453            _ => panic!("Sum did not produce correct result"),
454        }
455    }
456
457    #[test]
458    fn test_product_operator() {
459        let expressions = vec![
460            SymbolicExpression::Constant(BabyBear::new(2)),
461            SymbolicExpression::Constant(BabyBear::new(3)),
462            SymbolicExpression::Constant(BabyBear::new(4)),
463        ];
464        let result: SymbolicExpression<BabyBear> = expressions.into_iter().product();
465        match result {
466            SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(24)),
467            _ => panic!("Product did not produce correct result"),
468        }
469    }
470}