1mod builder;
4mod expression;
5pub(crate) mod expression_ext;
6mod variable;
7
8use alloc::sync::Arc;
9use core::iter::{Product, Sum};
10use core::ops;
11
12pub use builder::*;
13pub use expression::{BaseLeaf, SymbolicExpression};
14pub use expression_ext::{ExtLeaf, SymbolicExpressionExt};
15use p3_field::{Dup, ExtensionField, Field, PrimeCharacteristicRing};
16pub use variable::{BaseEntry, ExtEntry, SymbolicVariable, SymbolicVariableExt};
17
18pub trait SymLeaf: Clone + core::fmt::Debug {
25 type F: Field;
27
28 const ZERO: Self;
29 const ONE: Self;
30 const TWO: Self;
31 const NEG_ONE: Self;
32
33 fn degree_multiple(&self) -> usize;
35
36 fn as_const(&self) -> Option<&Self::F>;
38
39 fn from_const(c: Self::F) -> Self;
41}
42
43#[derive(Clone, Debug)]
53pub enum SymbolicExpr<A> {
54 Leaf(A),
56
57 Add {
59 x: Arc<Self>,
60 y: Arc<Self>,
61 degree_multiple: usize,
62 },
63
64 Sub {
66 x: Arc<Self>,
67 y: Arc<Self>,
68 degree_multiple: usize,
69 },
70
71 Neg {
73 x: Arc<Self>,
74 degree_multiple: usize,
75 },
76
77 Mul {
79 x: Arc<Self>,
80 y: Arc<Self>,
81 degree_multiple: usize,
82 },
83}
84
85impl<A: SymLeaf> SymbolicExpr<A> {
86 pub fn degree_multiple(&self) -> usize {
88 match self {
89 Self::Leaf(a) => a.degree_multiple(),
90 Self::Add {
91 degree_multiple, ..
92 }
93 | Self::Sub {
94 degree_multiple, ..
95 }
96 | Self::Neg {
97 degree_multiple, ..
98 }
99 | Self::Mul {
100 degree_multiple, ..
101 } => *degree_multiple,
102 }
103 }
104
105 fn as_const(&self) -> Option<&A::F> {
107 match self {
108 Self::Leaf(a) => a.as_const(),
109 _ => None,
110 }
111 }
112
113 fn sym_add(self, rhs: Self) -> Self {
115 if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
116 return Self::Leaf(A::from_const(a + b));
117 }
118 if self.as_const().is_some_and(|c| c.is_zero()) {
119 return rhs;
120 }
121 if rhs.as_const().is_some_and(|c| c.is_zero()) {
122 return self;
123 }
124 let dm = self.degree_multiple().max(rhs.degree_multiple());
125 Self::Add {
126 x: Arc::new(self),
127 y: Arc::new(rhs),
128 degree_multiple: dm,
129 }
130 }
131
132 fn sym_sub(self, rhs: Self) -> Self {
134 if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
135 return Self::Leaf(A::from_const(a - b));
136 }
137 if self.as_const().is_some_and(|c| c.is_zero()) {
138 return rhs.sym_neg();
139 }
140 if rhs.as_const().is_some_and(|c| c.is_zero()) {
141 return self;
142 }
143 let dm = self.degree_multiple().max(rhs.degree_multiple());
144 Self::Sub {
145 x: Arc::new(self),
146 y: Arc::new(rhs),
147 degree_multiple: dm,
148 }
149 }
150
151 fn sym_neg(self) -> Self {
153 if let Some(&c) = self.as_const() {
154 return Self::Leaf(A::from_const(-c));
155 }
156 let dm = self.degree_multiple();
157 Self::Neg {
158 x: Arc::new(self),
159 degree_multiple: dm,
160 }
161 }
162
163 fn sym_mul(self, rhs: Self) -> Self {
165 if let (Some(&a), Some(&b)) = (self.as_const(), rhs.as_const()) {
166 return Self::Leaf(A::from_const(a * b));
167 }
168 if self.as_const().is_some_and(|c| c.is_zero())
169 || rhs.as_const().is_some_and(|c| c.is_zero())
170 {
171 return Self::Leaf(A::from_const(A::F::ZERO));
172 }
173 if self.as_const().is_some_and(|c| c.is_one()) {
174 return rhs;
175 }
176 if rhs.as_const().is_some_and(|c| c.is_one()) {
177 return self;
178 }
179 let dm = self.degree_multiple() + rhs.degree_multiple();
180 Self::Mul {
181 x: Arc::new(self),
182 y: Arc::new(rhs),
183 degree_multiple: dm,
184 }
185 }
186}
187
188impl<A: SymLeaf> PrimeCharacteristicRing for SymbolicExpr<A> {
189 type PrimeSubfield = <A::F as PrimeCharacteristicRing>::PrimeSubfield;
190
191 const ZERO: Self = Self::Leaf(A::ZERO);
192 const ONE: Self = Self::Leaf(A::ONE);
193 const TWO: Self = Self::Leaf(A::TWO);
194 const NEG_ONE: Self = Self::Leaf(A::NEG_ONE);
195
196 #[inline]
197 fn from_prime_subfield(f: Self::PrimeSubfield) -> Self {
198 Self::Leaf(A::from_const(A::F::from_prime_subfield(f)))
199 }
200}
201
202impl<A: SymLeaf> Dup for SymbolicExpr<A> {
203 #[inline(always)]
204 fn dup(&self) -> Self {
205 self.clone()
206 }
207}
208
209impl<A: SymLeaf> Default for SymbolicExpr<A> {
210 fn default() -> Self {
211 Self::ZERO
212 }
213}
214
215impl<A: SymLeaf, T: Into<Self>> ops::Add<T> for SymbolicExpr<A> {
216 type Output = Self;
217 fn add(self, rhs: T) -> Self {
218 self.sym_add(rhs.into())
219 }
220}
221
222impl<A: SymLeaf, T: Into<Self>> ops::Sub<T> for SymbolicExpr<A> {
223 type Output = Self;
224 fn sub(self, rhs: T) -> Self {
225 self.sym_sub(rhs.into())
226 }
227}
228
229impl<A: SymLeaf> ops::Neg for SymbolicExpr<A> {
230 type Output = Self;
231 fn neg(self) -> Self {
232 self.sym_neg()
233 }
234}
235
236impl<A: SymLeaf, T: Into<Self>> ops::Mul<T> for SymbolicExpr<A> {
237 type Output = Self;
238 fn mul(self, rhs: T) -> Self {
239 self.sym_mul(rhs.into())
240 }
241}
242
243impl<A: SymLeaf, T: Into<Self>> ops::AddAssign<T> for SymbolicExpr<A> {
244 fn add_assign(&mut self, rhs: T) {
245 *self = self.clone() + rhs.into();
246 }
247}
248
249impl<A: SymLeaf, T: Into<Self>> ops::SubAssign<T> for SymbolicExpr<A> {
250 fn sub_assign(&mut self, rhs: T) {
251 *self = self.clone() - rhs.into();
252 }
253}
254
255impl<A: SymLeaf, T: Into<Self>> ops::MulAssign<T> for SymbolicExpr<A> {
256 fn mul_assign(&mut self, rhs: T) {
257 *self = self.clone() * rhs.into();
258 }
259}
260
261impl<A: SymLeaf, T: Into<Self>> Sum<T> for SymbolicExpr<A> {
262 fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
263 iter.map(Into::into)
264 .reduce(|a, b| a + b)
265 .unwrap_or(Self::ZERO)
266 }
267}
268
269impl<A: SymLeaf, T: Into<Self>> Product<T> for SymbolicExpr<A> {
270 fn product<I: Iterator<Item = T>>(iter: I) -> Self {
271 iter.map(Into::into)
272 .reduce(|a, b| a * b)
273 .unwrap_or(Self::ONE)
274 }
275}
276
277impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Add<T> for SymbolicVariable<F> {
278 type Output = SymbolicExpression<F>;
279 fn add(self, rhs: T) -> Self::Output {
280 Self::Output::from(self) + rhs.into()
281 }
282}
283
284impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Sub<T> for SymbolicVariable<F> {
285 type Output = SymbolicExpression<F>;
286 fn sub(self, rhs: T) -> Self::Output {
287 Self::Output::from(self) - rhs.into()
288 }
289}
290
291impl<F: Field, T: Into<SymbolicExpression<F>>> ops::Mul<T> for SymbolicVariable<F> {
292 type Output = SymbolicExpression<F>;
293 fn mul(self, rhs: T) -> Self::Output {
294 Self::Output::from(self) * rhs.into()
295 }
296}
297
298impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Add<T>
299 for SymbolicVariableExt<F, EF>
300{
301 type Output = SymbolicExpressionExt<F, EF>;
302 fn add(self, rhs: T) -> Self::Output {
303 Self::Output::from(self) + rhs.into()
304 }
305}
306
307impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Sub<T>
308 for SymbolicVariableExt<F, EF>
309{
310 type Output = SymbolicExpressionExt<F, EF>;
311 fn sub(self, rhs: T) -> Self::Output {
312 Self::Output::from(self) - rhs.into()
313 }
314}
315
316impl<F: Field, EF: ExtensionField<F>, T: Into<SymbolicExpressionExt<F, EF>>> ops::Mul<T>
317 for SymbolicVariableExt<F, EF>
318{
319 type Output = SymbolicExpressionExt<F, EF>;
320 fn mul(self, rhs: T) -> Self::Output {
321 Self::Output::from(self) * rhs.into()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use p3_baby_bear::BabyBear;
328 use p3_field::extension::BinomialExtensionField;
329
330 use super::*;
331 use crate::symbolic::expression::BaseLeaf;
332 use crate::symbolic::expression_ext::ExtLeaf;
333 use crate::symbolic::variable::{BaseEntry, ExtEntry};
334
335 type F = BabyBear;
336 type EF = BinomialExtensionField<BabyBear, 4>;
337
338 #[test]
339 fn symbolic_variable_add_produces_add_node() {
340 let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
342 let expr = SymbolicExpression::from(F::new(5));
343 let result = var + expr;
344 match result {
345 SymbolicExpr::Add {
346 x,
347 y,
348 degree_multiple,
349 } => {
350 assert_eq!(degree_multiple, 1);
351 assert!(matches!(
352 x.as_ref(),
353 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
354 if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
355 ));
356 assert!(matches!(
357 y.as_ref(),
358 SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if *c == F::new(5)
359 ));
360 }
361 _ => panic!("Expected an Add node"),
362 }
363 }
364
365 #[test]
366 fn symbolic_variable_sub_produces_sub_node() {
367 let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
369 let other = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
370 BaseEntry::Main { offset: 0 },
371 1,
372 )));
373 let result = var - other;
374 match result {
375 SymbolicExpr::Sub {
376 x,
377 y,
378 degree_multiple,
379 } => {
380 assert_eq!(degree_multiple, 1);
381 assert!(matches!(
382 x.as_ref(),
383 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
384 if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
385 ));
386 assert!(matches!(
387 y.as_ref(),
388 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
389 if v.index == 1 && v.entry == BaseEntry::Main { offset: 0 }
390 ));
391 }
392 _ => panic!("Expected a Sub node"),
393 }
394 }
395
396 #[test]
397 fn symbolic_variable_mul_produces_mul_node() {
398 let var = SymbolicVariable::<F>::new(BaseEntry::Main { offset: 0 }, 0);
400 let other = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
401 BaseEntry::Main { offset: 0 },
402 1,
403 )));
404 let result = var * other;
405 match result {
406 SymbolicExpr::Mul {
407 x,
408 y,
409 degree_multiple,
410 } => {
411 assert_eq!(degree_multiple, 2);
412 assert!(matches!(
413 x.as_ref(),
414 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
415 if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
416 ));
417 assert!(matches!(
418 y.as_ref(),
419 SymbolicExpr::Leaf(BaseLeaf::Variable(v))
420 if v.index == 1 && v.entry == BaseEntry::Main { offset: 0 }
421 ));
422 }
423 _ => panic!("Expected a Mul node"),
424 }
425 }
426
427 #[test]
428 fn symbolic_variable_ext_add_produces_add_node() {
429 let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
431 let expr = SymbolicExpressionExt::<F, EF>::from(F::new(3));
432 let result = var + expr;
433 match result {
434 SymbolicExpr::Add {
435 x,
436 y,
437 degree_multiple,
438 } => {
439 assert_eq!(degree_multiple, 1);
440 assert!(matches!(
441 x.as_ref(),
442 SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
443 if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
444 ));
445 assert!(matches!(
446 y.as_ref(),
447 SymbolicExpr::Leaf(ExtLeaf::Base(SymbolicExpr::Leaf(BaseLeaf::Constant(c))))
448 if *c == F::new(3)
449 ));
450 }
451 _ => panic!("Expected an Add node"),
452 }
453 }
454
455 #[test]
456 fn symbolic_variable_ext_sub_produces_sub_node() {
457 let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
459 let other = SymbolicExpressionExt::<F, EF>::from(SymbolicVariableExt::<F, EF>::new(
460 ExtEntry::Permutation { offset: 0 },
461 1,
462 ));
463 let result = var - other;
464 match result {
465 SymbolicExpr::Sub {
466 x,
467 y,
468 degree_multiple,
469 } => {
470 assert_eq!(degree_multiple, 1);
471 assert!(matches!(
472 x.as_ref(),
473 SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
474 if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
475 ));
476 assert!(matches!(
477 y.as_ref(),
478 SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
479 if v.index == 1 && v.entry == ExtEntry::Permutation { offset: 0 }
480 ));
481 }
482 _ => panic!("Expected a Sub node"),
483 }
484 }
485
486 #[test]
487 fn symbolic_variable_ext_mul_produces_mul_node() {
488 let var = SymbolicVariableExt::<F, EF>::new(ExtEntry::Permutation { offset: 0 }, 0);
490 let other = SymbolicExpressionExt::<F, EF>::from(SymbolicVariableExt::<F, EF>::new(
491 ExtEntry::Permutation { offset: 0 },
492 1,
493 ));
494 let result = var * other;
495 match result {
496 SymbolicExpr::Mul {
497 x,
498 y,
499 degree_multiple,
500 } => {
501 assert_eq!(degree_multiple, 2);
502 assert!(matches!(
503 x.as_ref(),
504 SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
505 if v.index == 0 && v.entry == ExtEntry::Permutation { offset: 0 }
506 ));
507 assert!(matches!(
508 y.as_ref(),
509 SymbolicExpr::Leaf(ExtLeaf::ExtVariable(v))
510 if v.index == 1 && v.entry == ExtEntry::Permutation { offset: 0 }
511 ));
512 }
513 _ => panic!("Expected a Mul node"),
514 }
515 }
516}