1use alloc::rc::Rc;
2use core::fmt::Debug;
3use core::iter::{Product, Sum};
4use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
5
6use p3_field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing};
7
8use crate::symbolic_variable::SymbolicVariable;
9
10#[derive(Clone, Debug)]
12pub enum SymbolicExpression<F> {
13 Variable(SymbolicVariable<F>),
14 IsFirstRow,
15 IsLastRow,
16 IsTransition,
17 Constant(F),
18 Add {
19 x: Rc<Self>,
20 y: Rc<Self>,
21 degree_multiple: usize,
22 },
23 Sub {
24 x: Rc<Self>,
25 y: Rc<Self>,
26 degree_multiple: usize,
27 },
28 Neg {
29 x: Rc<Self>,
30 degree_multiple: usize,
31 },
32 Mul {
33 x: Rc<Self>,
34 y: Rc<Self>,
35 degree_multiple: usize,
36 },
37}
38
39impl<F> SymbolicExpression<F> {
40 pub const fn degree_multiple(&self) -> usize {
42 match self {
43 Self::Variable(v) => v.degree_multiple(),
44 Self::IsFirstRow | Self::IsLastRow => 1,
45 Self::IsTransition | Self::Constant(_) => 0,
46 Self::Add {
47 degree_multiple, ..
48 }
49 | Self::Sub {
50 degree_multiple, ..
51 }
52 | Self::Neg {
53 degree_multiple, ..
54 }
55 | Self::Mul {
56 degree_multiple, ..
57 } => *degree_multiple,
58 }
59 }
60}
61
62impl<F: Field> Default for SymbolicExpression<F> {
63 fn default() -> Self {
64 Self::Constant(F::ZERO)
65 }
66}
67
68impl<F: Field> From<F> for SymbolicExpression<F> {
69 fn from(value: F) -> Self {
70 Self::Constant(value)
71 }
72}
73
74impl<F: Field> PrimeCharacteristicRing for SymbolicExpression<F> {
75 type PrimeSubfield = F::PrimeSubfield;
76
77 const ZERO: Self = Self::Constant(F::ZERO);
78 const ONE: Self = Self::Constant(F::ONE);
79 const TWO: Self = Self::Constant(F::TWO);
80 const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
81
82 #[inline]
83 fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
84 F::from_prime_subfield(f).into()
85 }
86}
87
88impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
89
90impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
91
92impl<F: Field + InjectiveMonomial<N>, const N: u64> InjectiveMonomial<N> for SymbolicExpression<F> {}
95
96impl<F: Field, T> Add<T> for SymbolicExpression<F>
97where
98 T: Into<Self>,
99{
100 type Output = Self;
101
102 fn add(self, rhs: T) -> Self {
103 match (self, rhs.into()) {
104 (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs + rhs),
105 (lhs, rhs) => Self::Add {
106 degree_multiple: lhs.degree_multiple().max(rhs.degree_multiple()),
107 x: Rc::new(lhs),
108 y: Rc::new(rhs),
109 },
110 }
111 }
112}
113
114impl<F: Field, T> AddAssign<T> for SymbolicExpression<F>
115where
116 T: Into<Self>,
117{
118 fn add_assign(&mut self, rhs: T) {
119 *self = self.clone() + rhs.into();
120 }
121}
122
123impl<F: Field, T> Sum<T> for SymbolicExpression<F>
124where
125 T: Into<Self>,
126{
127 fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
128 iter.map(Into::into)
129 .reduce(|x, y| x + y)
130 .unwrap_or(Self::ZERO)
131 }
132}
133
134impl<F: Field, T: Into<Self>> Sub<T> for SymbolicExpression<F> {
135 type Output = Self;
136
137 fn sub(self, rhs: T) -> Self {
138 match (self, rhs.into()) {
139 (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs - rhs),
140 (lhs, rhs) => Self::Sub {
141 degree_multiple: lhs.degree_multiple().max(rhs.degree_multiple()),
142 x: Rc::new(lhs),
143 y: Rc::new(rhs),
144 },
145 }
146 }
147}
148
149impl<F: Field, T> SubAssign<T> for SymbolicExpression<F>
150where
151 T: Into<Self>,
152{
153 fn sub_assign(&mut self, rhs: T) {
154 *self = self.clone() - rhs.into();
155 }
156}
157
158impl<F: Field> Neg for SymbolicExpression<F> {
159 type Output = Self;
160
161 fn neg(self) -> Self {
162 match self {
163 Self::Constant(c) => Self::Constant(-c),
164 expr => Self::Neg {
165 degree_multiple: expr.degree_multiple(),
166 x: Rc::new(expr),
167 },
168 }
169 }
170}
171
172impl<F: Field, T: Into<Self>> Mul<T> for SymbolicExpression<F> {
173 type Output = Self;
174
175 fn mul(self, rhs: T) -> Self {
176 match (self, rhs.into()) {
177 (Self::Constant(lhs), Self::Constant(rhs)) => Self::Constant(lhs * rhs),
178 (lhs, rhs) => Self::Mul {
179 degree_multiple: lhs.degree_multiple() + rhs.degree_multiple(),
180 x: Rc::new(lhs),
181 y: Rc::new(rhs),
182 },
183 }
184 }
185}
186
187impl<F: Field, T> MulAssign<T> for SymbolicExpression<F>
188where
189 T: Into<Self>,
190{
191 fn mul_assign(&mut self, rhs: T) {
192 *self = self.clone() * rhs.into();
193 }
194}
195
196impl<F: Field, T: Into<Self>> Product<T> for SymbolicExpression<F> {
197 fn product<I: Iterator<Item = T>>(iter: I) -> Self {
198 iter.map(Into::into)
199 .reduce(|x, y| x * y)
200 .unwrap_or(Self::ONE)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use alloc::vec;
207
208 use p3_baby_bear::BabyBear;
209
210 use super::*;
211 use crate::Entry;
212
213 #[test]
214 fn test_symbolic_expression_degree_multiple() {
215 let constant_expr = SymbolicExpression::<BabyBear>::Constant(BabyBear::new(5));
216 assert_eq!(
217 constant_expr.degree_multiple(),
218 0,
219 "Constant should have degree 0"
220 );
221
222 let variable_expr =
223 SymbolicExpression::Variable(SymbolicVariable::new(Entry::Main { offset: 0 }, 1));
224 assert_eq!(
225 variable_expr.degree_multiple(),
226 1,
227 "Main variable should have degree 1"
228 );
229
230 let preprocessed_var = SymbolicExpression::Variable(SymbolicVariable::new(
231 Entry::Preprocessed { offset: 0 },
232 2,
233 ));
234 assert_eq!(
235 preprocessed_var.degree_multiple(),
236 1,
237 "Preprocessed variable should have degree 1"
238 );
239
240 let permutation_var = SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(
241 Entry::Permutation { offset: 0 },
242 3,
243 ));
244 assert_eq!(
245 permutation_var.degree_multiple(),
246 1,
247 "Permutation variable should have degree 1"
248 );
249
250 let public_var =
251 SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(Entry::Public, 4));
252 assert_eq!(
253 public_var.degree_multiple(),
254 0,
255 "Public variable should have degree 0"
256 );
257
258 let challenge_var =
259 SymbolicExpression::Variable(SymbolicVariable::<BabyBear>::new(Entry::Challenge, 5));
260 assert_eq!(
261 challenge_var.degree_multiple(),
262 0,
263 "Challenge variable should have degree 0"
264 );
265
266 let is_first_row = SymbolicExpression::<BabyBear>::IsFirstRow;
267 assert_eq!(
268 is_first_row.degree_multiple(),
269 1,
270 "IsFirstRow should have degree 1"
271 );
272
273 let is_last_row = SymbolicExpression::<BabyBear>::IsLastRow;
274 assert_eq!(
275 is_last_row.degree_multiple(),
276 1,
277 "IsLastRow should have degree 1"
278 );
279
280 let is_transition = SymbolicExpression::<BabyBear>::IsTransition;
281 assert_eq!(
282 is_transition.degree_multiple(),
283 0,
284 "IsTransition should have degree 0"
285 );
286
287 let add_expr = SymbolicExpression::<BabyBear>::Add {
288 x: Rc::new(variable_expr.clone()),
289 y: Rc::new(preprocessed_var.clone()),
290 degree_multiple: 1,
291 };
292 assert_eq!(
293 add_expr.degree_multiple(),
294 1,
295 "Addition should take max degree of inputs"
296 );
297
298 let sub_expr = SymbolicExpression::<BabyBear>::Sub {
299 x: Rc::new(variable_expr.clone()),
300 y: Rc::new(preprocessed_var.clone()),
301 degree_multiple: 1,
302 };
303 assert_eq!(
304 sub_expr.degree_multiple(),
305 1,
306 "Subtraction should take max degree of inputs"
307 );
308
309 let neg_expr = SymbolicExpression::<BabyBear>::Neg {
310 x: Rc::new(variable_expr.clone()),
311 degree_multiple: 1,
312 };
313 assert_eq!(
314 neg_expr.degree_multiple(),
315 1,
316 "Negation should keep the degree"
317 );
318
319 let mul_expr = SymbolicExpression::<BabyBear>::Mul {
320 x: Rc::new(variable_expr.clone()),
321 y: Rc::new(preprocessed_var.clone()),
322 degree_multiple: 2,
323 };
324 assert_eq!(
325 mul_expr.degree_multiple(),
326 2,
327 "Multiplication should sum degrees"
328 );
329 }
330
331 #[test]
332 fn test_addition_of_constants() {
333 let a = SymbolicExpression::Constant(BabyBear::new(3));
334 let b = SymbolicExpression::Constant(BabyBear::new(4));
335 let result = a + b;
336 match result {
337 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(7)),
338 _ => panic!("Addition of constants did not simplify correctly"),
339 }
340 }
341
342 #[test]
343 fn test_subtraction_of_constants() {
344 let a = SymbolicExpression::Constant(BabyBear::new(10));
345 let b = SymbolicExpression::Constant(BabyBear::new(4));
346 let result = a - b;
347 match result {
348 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(6)),
349 _ => panic!("Subtraction of constants did not simplify correctly"),
350 }
351 }
352
353 #[test]
354 fn test_negation() {
355 let a = SymbolicExpression::Constant(BabyBear::new(7));
356 let result = -a;
357 match result {
358 SymbolicExpression::Constant(val) => {
359 assert_eq!(val, BabyBear::NEG_ONE * BabyBear::new(7))
360 }
361 _ => panic!("Negation did not work correctly"),
362 }
363 }
364
365 #[test]
366 fn test_multiplication_of_constants() {
367 let a = SymbolicExpression::Constant(BabyBear::new(3));
368 let b = SymbolicExpression::Constant(BabyBear::new(5));
369 let result = a * b;
370 match result {
371 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(15)),
372 _ => panic!("Multiplication of constants did not simplify correctly"),
373 }
374 }
375
376 #[test]
377 fn test_degree_multiple_for_addition() {
378 let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
379 Entry::Main { offset: 0 },
380 1,
381 ));
382 let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
383 Entry::Main { offset: 0 },
384 2,
385 ));
386 let result = a.clone() + b.clone();
387 match result {
388 SymbolicExpression::Add {
389 degree_multiple,
390 x,
391 y,
392 } => {
393 assert_eq!(degree_multiple, 1);
394 assert!(
395 matches!(*x, SymbolicExpression::Variable(ref v) if v.index == 1 && matches!(v.entry, Entry::Main { offset: 0 }))
396 );
397 assert!(
398 matches!(*y, SymbolicExpression::Variable(ref v) if v.index == 2 && matches!(v.entry, Entry::Main { offset: 0 }))
399 );
400 }
401 _ => panic!("Addition did not create an Add expression"),
402 }
403 }
404
405 #[test]
406 fn test_degree_multiple_for_multiplication() {
407 let a = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
408 Entry::Main { offset: 0 },
409 1,
410 ));
411 let b = SymbolicExpression::Variable::<BabyBear>(SymbolicVariable::new(
412 Entry::Main { offset: 0 },
413 2,
414 ));
415 let result = a.clone() * b.clone();
416
417 match result {
418 SymbolicExpression::Mul {
419 degree_multiple,
420 x,
421 y,
422 } => {
423 assert_eq!(degree_multiple, 2, "Multiplication should sum degrees");
424
425 assert!(
426 matches!(*x, SymbolicExpression::Variable(ref v)
427 if v.index == 1 && matches!(v.entry, Entry::Main { offset: 0 })
428 ),
429 "Left operand should match `a`"
430 );
431
432 assert!(
433 matches!(*y, SymbolicExpression::Variable(ref v)
434 if v.index == 2 && matches!(v.entry, Entry::Main { offset: 0 })
435 ),
436 "Right operand should match `b`"
437 );
438 }
439 _ => panic!("Multiplication did not create a `Mul` expression"),
440 }
441 }
442
443 #[test]
444 fn test_sum_operator() {
445 let expressions = vec![
446 SymbolicExpression::Constant(BabyBear::new(2)),
447 SymbolicExpression::Constant(BabyBear::new(3)),
448 SymbolicExpression::Constant(BabyBear::new(5)),
449 ];
450 let result: SymbolicExpression<BabyBear> = expressions.into_iter().sum();
451 match result {
452 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(10)),
453 _ => panic!("Sum did not produce correct result"),
454 }
455 }
456
457 #[test]
458 fn test_product_operator() {
459 let expressions = vec![
460 SymbolicExpression::Constant(BabyBear::new(2)),
461 SymbolicExpression::Constant(BabyBear::new(3)),
462 SymbolicExpression::Constant(BabyBear::new(4)),
463 ];
464 let result: SymbolicExpression<BabyBear> = expressions.into_iter().product();
465 match result {
466 SymbolicExpression::Constant(val) => assert_eq!(val, BabyBear::new(24)),
467 _ => panic!("Product did not produce correct result"),
468 }
469 }
470}