use p3_field::{Algebra, ExtensionField, Field, InjectiveMonomial};
use crate::symbolic::variable::SymbolicVariable;
use crate::symbolic::{SymLeaf, SymbolicExpr};
#[derive(Clone, Debug)]
pub enum BaseLeaf<F> {
Variable(SymbolicVariable<F>),
IsFirstRow,
IsLastRow,
IsTransition,
Constant(F),
}
pub type SymbolicExpression<F> = SymbolicExpr<BaseLeaf<F>>;
impl<F: Field> SymLeaf for BaseLeaf<F> {
type F = F;
const ZERO: Self = Self::Constant(F::ZERO);
const ONE: Self = Self::Constant(F::ONE);
const TWO: Self = Self::Constant(F::TWO);
const NEG_ONE: Self = Self::Constant(F::NEG_ONE);
fn degree_multiple(&self) -> usize {
match self {
Self::Variable(v) => v.degree_multiple(),
Self::IsFirstRow | Self::IsLastRow => 1,
Self::IsTransition | Self::Constant(_) => 0,
}
}
fn as_const(&self) -> Option<&F> {
match self {
Self::Constant(c) => Some(c),
_ => None,
}
}
fn from_const(c: F) -> Self {
Self::Constant(c)
}
}
impl<F: Field, EF: ExtensionField<F>> From<SymbolicVariable<F>> for SymbolicExpression<EF> {
fn from(var: SymbolicVariable<F>) -> Self {
Self::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
var.entry, var.index,
)))
}
}
impl<F: Field, EF: ExtensionField<F>> From<F> for SymbolicExpression<EF> {
fn from(f: F) -> Self {
Self::Leaf(BaseLeaf::Constant(f.into()))
}
}
impl<F: Field> Algebra<F> for SymbolicExpression<F> {}
impl<F: Field> Algebra<SymbolicVariable<F>> for SymbolicExpression<F> {}
impl<F: Field + InjectiveMonomial<N>, const N: u64> InjectiveMonomial<N> for SymbolicExpression<F> {}
#[cfg(test)]
mod tests {
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use super::*;
use crate::symbolic::BaseEntry;
#[test]
fn test_symbolic_expression_degree_multiple() {
let constant_expr =
SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
assert_eq!(
constant_expr.degree_multiple(),
0,
"Constant should have degree 0"
);
let variable_expr = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
BaseEntry::Main { offset: 0 },
1,
)));
assert_eq!(
variable_expr.degree_multiple(),
1,
"Main variable should have degree 1"
);
let preprocessed_var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::new(
BaseEntry::Preprocessed { offset: 0 },
2,
)));
assert_eq!(
preprocessed_var.degree_multiple(),
1,
"Preprocessed variable should have degree 1"
);
let public_var = SymbolicExpression::Leaf(BaseLeaf::Variable(
SymbolicVariable::<BabyBear>::new(BaseEntry::Public, 4),
));
assert_eq!(
public_var.degree_multiple(),
0,
"Public variable should have degree 0"
);
let is_first_row = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsFirstRow);
assert_eq!(
is_first_row.degree_multiple(),
1,
"IsFirstRow should have degree 1"
);
let is_last_row = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsLastRow);
assert_eq!(
is_last_row.degree_multiple(),
1,
"IsLastRow should have degree 1"
);
let is_transition = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::IsTransition);
assert_eq!(
is_transition.degree_multiple(),
0,
"IsTransition should have degree 0"
);
let add_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Add {
x: Arc::new(variable_expr.clone()),
y: Arc::new(preprocessed_var.clone()),
degree_multiple: 1,
};
assert_eq!(
add_expr.degree_multiple(),
1,
"Addition should take max degree of inputs"
);
let sub_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Sub {
x: Arc::new(variable_expr.clone()),
y: Arc::new(preprocessed_var.clone()),
degree_multiple: 1,
};
assert_eq!(
sub_expr.degree_multiple(),
1,
"Subtraction should take max degree of inputs"
);
let neg_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Neg {
x: Arc::new(variable_expr.clone()),
degree_multiple: 1,
};
assert_eq!(
neg_expr.degree_multiple(),
1,
"Negation should keep the degree"
);
let mul_expr = SymbolicExpr::<BaseLeaf<BabyBear>>::Mul {
x: Arc::new(variable_expr),
y: Arc::new(preprocessed_var),
degree_multiple: 2,
};
assert_eq!(
mul_expr.degree_multiple(),
2,
"Multiplication should sum degrees"
);
}
#[test]
fn test_addition_of_constants() {
let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
let result = a + b;
match result {
SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(7)),
_ => panic!("Addition of constants did not simplify correctly"),
}
}
#[test]
fn test_subtraction_of_constants() {
let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(10)));
let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
let result = a - b;
match result {
SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(6)),
_ => panic!("Subtraction of constants did not simplify correctly"),
}
}
#[test]
fn test_negation() {
let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(7)));
let result = -a;
match result {
SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => {
assert_eq!(val, BabyBear::NEG_ONE * BabyBear::new(7));
}
_ => panic!("Negation did not work correctly"),
}
}
#[test]
fn test_multiplication_of_constants() {
let a = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
let b = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
let result = a * b;
match result {
SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(15)),
_ => panic!("Multiplication of constants did not simplify correctly"),
}
}
#[test]
fn test_degree_multiple_for_addition() {
let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
1,
)));
let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
2,
)));
let result = a + b;
match result {
SymbolicExpr::Add {
degree_multiple,
x,
y,
} => {
assert_eq!(degree_multiple, 1);
assert!(
matches!(&*x, SymbolicExpr::Leaf(BaseLeaf::Variable(v)) if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 }))
);
assert!(
matches!(&*y, SymbolicExpr::Leaf(BaseLeaf::Variable(v)) if v.index == 2 && matches!(v.entry, BaseEntry::Main { offset: 0 }))
);
}
_ => panic!("Addition did not create an Add expression"),
}
}
#[test]
fn test_degree_multiple_for_multiplication() {
let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
1,
)));
let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
2,
)));
let result = a * b;
match result {
SymbolicExpr::Mul {
degree_multiple,
x,
y,
} => {
assert_eq!(degree_multiple, 2, "Multiplication should sum degrees");
assert!(
matches!(&*x, SymbolicExpr::Leaf(BaseLeaf::Variable(v))
if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 })
),
"Left operand should match `a`"
);
assert!(
matches!(&*y, SymbolicExpr::Leaf(BaseLeaf::Variable(v))
if v.index == 2 && matches!(v.entry, BaseEntry::Main { offset: 0 })
),
"Right operand should match `b`"
);
}
_ => panic!("Multiplication did not create a `Mul` expression"),
}
}
#[test]
fn test_sum_operator() {
let expressions = vec![
SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(2))),
SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3))),
SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5))),
];
let result: SymbolicExpression<BabyBear> = expressions.into_iter().sum();
match result {
SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(10)),
_ => panic!("Sum did not produce correct result"),
}
}
#[test]
fn test_product_operator() {
let expressions = vec![
SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(2))),
SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3))),
SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4))),
];
let result: SymbolicExpression<BabyBear> = expressions.into_iter().product();
match result {
SymbolicExpr::Leaf(BaseLeaf::Constant(val)) => assert_eq!(val, BabyBear::new(24)),
_ => panic!("Product did not produce correct result"),
}
}
#[test]
fn test_default_is_zero() {
let expr: SymbolicExpression<BabyBear> = Default::default();
assert!(matches!(
expr,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
));
}
#[test]
fn test_ring_constants() {
assert!(matches!(
SymbolicExpression::<BabyBear>::ZERO,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
));
assert!(matches!(
SymbolicExpression::<BabyBear>::ONE,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ONE
));
assert!(matches!(
SymbolicExpression::<BabyBear>::TWO,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::TWO
));
assert!(matches!(
SymbolicExpression::<BabyBear>::NEG_ONE,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::NEG_ONE
));
}
#[test]
fn test_from_symbolic_variable() {
let var = SymbolicVariable::<BabyBear>::new(BaseEntry::Main { offset: 0 }, 3);
let expr: SymbolicExpression<BabyBear> = var.into();
match expr {
SymbolicExpr::Leaf(BaseLeaf::Variable(v)) => {
assert!(matches!(v.entry, BaseEntry::Main { offset: 0 }));
assert_eq!(v.index, 3);
}
_ => panic!("Expected Variable variant"),
}
}
#[test]
fn test_from_field_element() {
let field_val = BabyBear::new(42);
let expr: SymbolicExpression<BabyBear> = field_val.into();
assert!(matches!(
expr,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == field_val
));
}
#[test]
fn test_from_prime_subfield() {
let prime_subfield_val = <BabyBear as PrimeCharacteristicRing>::PrimeSubfield::new(7);
let expr = SymbolicExpression::<BabyBear>::from_prime_subfield(prime_subfield_val);
assert!(matches!(
expr,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(7)
));
}
#[test]
fn test_assign_operators() {
let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
expr += SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(3)));
assert!(matches!(
expr,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(8)
));
let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(10)));
expr -= SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(4)));
assert!(matches!(
expr,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(6)
));
let mut expr = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(6)));
expr *= SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(7)));
assert!(matches!(
expr,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::new(42)
));
}
#[test]
fn test_subtraction_creates_sub_node() {
let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
1,
)));
let result = a - b;
match result {
SymbolicExpr::Sub {
x,
y,
degree_multiple,
} => {
assert_eq!(degree_multiple, 1);
assert!(matches!(
x.as_ref(),
SymbolicExpr::Leaf(BaseLeaf::Variable(v))
if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
));
assert!(matches!(
y.as_ref(),
SymbolicExpr::Leaf(BaseLeaf::Variable(v))
if v.index == 1 && matches!(v.entry, BaseEntry::Main { offset: 0 })
));
}
_ => panic!("Expected Sub variant"),
}
}
#[test]
fn test_negation_creates_neg_node() {
let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let result = -var;
match result {
SymbolicExpr::Neg { x, degree_multiple } => {
assert_eq!(degree_multiple, 1);
assert!(matches!(
x.as_ref(),
SymbolicExpr::Leaf(BaseLeaf::Variable(v))
if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
));
}
_ => panic!("Expected Neg variant"),
}
}
#[test]
fn test_empty_sum_returns_zero() {
let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
let result: SymbolicExpression<BabyBear> = empty.into_iter().sum();
assert!(matches!(
result,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO
));
}
#[test]
fn test_empty_product_returns_one() {
let empty: Vec<SymbolicExpression<BabyBear>> = vec![];
let result: SymbolicExpression<BabyBear> = empty.into_iter().product();
assert!(matches!(
result,
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ONE
));
}
#[test]
fn test_mixed_degree_addition() {
let constant = SymbolicExpression::Leaf(BaseLeaf::Constant(BabyBear::new(5)));
let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let result = constant + var;
match result {
SymbolicExpr::Add {
x,
y,
degree_multiple,
} => {
assert_eq!(degree_multiple, 1);
assert!(matches!(
x.as_ref(),
SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if *c == BabyBear::new(5)
));
assert!(matches!(
y.as_ref(),
SymbolicExpr::Leaf(BaseLeaf::Variable(v))
if v.index == 0 && matches!(v.entry, BaseEntry::Main { offset: 0 })
));
}
_ => panic!("Expected Add variant"),
}
}
#[test]
fn test_chained_multiplication_degree() {
let a = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let b = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
1,
)));
let c = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
2,
)));
let ab = a * b;
assert_eq!(ab.degree_multiple(), 2);
let abc = ab * c;
assert_eq!(abc.degree_multiple(), 3);
}
#[test]
fn test_add_zero_identity_folding() {
let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
let result = var.clone() + zero.clone();
assert!(
matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
"x + 0 should fold to x"
);
let result = zero + var;
assert!(
matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
"0 + x should fold to x"
);
}
#[test]
fn test_sub_zero_identity_folding() {
let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
let result = var.clone() - zero.clone();
assert!(
matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
"x - 0 should fold to x"
);
let result = zero - var;
match result {
SymbolicExpr::Neg { x, degree_multiple } => {
assert_eq!(degree_multiple, 1);
assert!(matches!(
x.as_ref(),
SymbolicExpr::Leaf(BaseLeaf::Variable(v))
if v.index == 0 && v.entry == BaseEntry::Main { offset: 0 }
));
}
_ => panic!("0 - x should fold to Neg(x)"),
}
}
#[test]
fn test_mul_zero_identity_folding() {
let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
let result = var.clone() * zero.clone();
assert!(
matches!(result, SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO),
"x * 0 should fold to 0"
);
let result = zero * var;
assert!(
matches!(result, SymbolicExpr::Leaf(BaseLeaf::Constant(c)) if c == BabyBear::ZERO),
"0 * x should fold to 0"
);
}
#[test]
fn test_mul_one_identity_folding() {
let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let one = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ONE));
let result = var.clone() * one.clone();
assert!(
matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
"x * 1 should fold to x"
);
let result = one * var;
assert!(
matches!(result, SymbolicExpr::Leaf(BaseLeaf::Variable(_))),
"1 * x should fold to x"
);
}
#[test]
fn test_identity_folding_preserves_degree() {
let var = SymbolicExpression::Leaf(BaseLeaf::Variable(SymbolicVariable::<BabyBear>::new(
BaseEntry::Main { offset: 0 },
0,
)));
let zero = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ZERO));
let one = SymbolicExpression::<BabyBear>::Leaf(BaseLeaf::Constant(BabyBear::ONE));
let result = var.clone() + zero.clone();
assert_eq!(result.degree_multiple(), 1);
let result = var.clone() - zero.clone();
assert_eq!(result.degree_multiple(), 1);
let result = zero.clone() - var.clone();
assert_eq!(result.degree_multiple(), 1);
let result = var.clone() * one;
assert_eq!(result.degree_multiple(), 1);
let result = var * zero;
assert_eq!(result.degree_multiple(), 0);
}
}