use crate::symbolic::expr::Expr;
pub fn simplify(expr: &Expr) -> Expr {
match expr {
Expr::Constant(_) | Expr::Variable(_) => expr.clone(),
Expr::Add(f, g) => {
let f_simp = simplify(f);
let g_simp = simplify(g);
if let (Expr::Constant(a), Expr::Constant(b)) = (&f_simp, &g_simp) {
return Expr::Constant(a + b);
}
if matches!(g_simp, Expr::Constant(c) if c == 0.0) {
return f_simp;
}
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return g_simp;
}
Expr::Add(Box::new(f_simp), Box::new(g_simp))
}
Expr::Sub(f, g) => {
let f_simp = simplify(f);
let g_simp = simplify(g);
if let (Expr::Constant(a), Expr::Constant(b)) = (&f_simp, &g_simp) {
return Expr::Constant(a - b);
}
if matches!(g_simp, Expr::Constant(c) if c == 0.0) {
return f_simp;
}
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Neg(Box::new(g_simp));
}
if f_simp == g_simp {
return Expr::Constant(0.0);
}
Expr::Sub(Box::new(f_simp), Box::new(g_simp))
}
Expr::Mul(f, g) => {
let f_simp = simplify(f);
let g_simp = simplify(g);
if let (Expr::Constant(a), Expr::Constant(b)) = (&f_simp, &g_simp) {
return Expr::Constant(a * b);
}
if matches!(f_simp, Expr::Constant(c) if c == 0.0)
|| matches!(g_simp, Expr::Constant(c) if c == 0.0)
{
return Expr::Constant(0.0);
}
if matches!(g_simp, Expr::Constant(c) if c == 1.0) {
return f_simp;
}
if matches!(f_simp, Expr::Constant(c) if c == 1.0) {
return g_simp;
}
if matches!(g_simp, Expr::Constant(c) if c == -1.0) {
return Expr::Neg(Box::new(f_simp));
}
if matches!(f_simp, Expr::Constant(c) if c == -1.0) {
return Expr::Neg(Box::new(g_simp));
}
Expr::Mul(Box::new(f_simp), Box::new(g_simp))
}
Expr::Div(f, g) => {
let f_simp = simplify(f);
let g_simp = simplify(g);
if let (Expr::Constant(a), Expr::Constant(b)) = (&f_simp, &g_simp) {
if *b != 0.0 {
return Expr::Constant(a / b);
}
}
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(0.0);
}
if matches!(g_simp, Expr::Constant(c) if c == 1.0) {
return f_simp;
}
if f_simp == g_simp {
return Expr::Constant(1.0);
}
if matches!(g_simp, Expr::Constant(c) if c == -1.0) {
return Expr::Neg(Box::new(f_simp));
}
Expr::Div(Box::new(f_simp), Box::new(g_simp))
}
Expr::Pow(f, g) => {
let f_simp = simplify(f);
let g_simp = simplify(g);
if let (Expr::Constant(a), Expr::Constant(b)) = (&f_simp, &g_simp) {
if *a >= 0.0 || b.fract() == 0.0 {
return Expr::Constant(a.powf(*b));
}
}
if matches!(g_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(1.0);
}
if matches!(g_simp, Expr::Constant(c) if c == 1.0) {
return f_simp;
}
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
if let Expr::Constant(exp) = g_simp {
if exp > 0.0 {
return Expr::Constant(0.0);
}
}
}
if matches!(f_simp, Expr::Constant(c) if c == 1.0) {
return Expr::Constant(1.0);
}
Expr::Pow(Box::new(f_simp), Box::new(g_simp))
}
Expr::Neg(f) => {
let f_simp = simplify(f);
if let Expr::Neg(inner) = f_simp {
return *inner;
}
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(0.0);
}
if let Expr::Constant(c) = f_simp {
return Expr::Constant(-c);
}
Expr::Neg(Box::new(f_simp))
}
Expr::Sin(f) => {
let f_simp = simplify(f);
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(0.0);
}
if let Expr::Constant(c) = f_simp {
return Expr::Constant(c.sin());
}
Expr::Sin(Box::new(f_simp))
}
Expr::Cos(f) => {
let f_simp = simplify(f);
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(1.0);
}
if let Expr::Constant(c) = f_simp {
return Expr::Constant(c.cos());
}
Expr::Cos(Box::new(f_simp))
}
Expr::Tan(f) => {
let f_simp = simplify(f);
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(0.0);
}
if let Expr::Constant(c) = f_simp {
return Expr::Constant(c.tan());
}
Expr::Tan(Box::new(f_simp))
}
Expr::Exp(f) => {
let f_simp = simplify(f);
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(1.0);
}
if let Expr::Constant(c) = f_simp {
return Expr::Constant(c.exp());
}
if let Expr::Ln(inner) = &f_simp {
return (**inner).clone();
}
Expr::Exp(Box::new(f_simp))
}
Expr::Ln(f) => {
let f_simp = simplify(f);
if matches!(f_simp, Expr::Constant(c) if c == 1.0) {
return Expr::Constant(0.0);
}
if let Expr::Constant(c) = f_simp {
if c > 0.0 {
return Expr::Constant(c.ln());
}
}
if let Expr::Exp(inner) = &f_simp {
return (**inner).clone();
}
Expr::Ln(Box::new(f_simp))
}
Expr::Sqrt(f) => {
let f_simp = simplify(f);
if matches!(f_simp, Expr::Constant(c) if c == 0.0) {
return Expr::Constant(0.0);
}
if matches!(f_simp, Expr::Constant(c) if c == 1.0) {
return Expr::Constant(1.0);
}
if let Expr::Constant(c) = f_simp {
if c >= 0.0 {
return Expr::Constant(c.sqrt());
}
}
Expr::Sqrt(Box::new(f_simp))
}
}
}
pub fn expand(expr: &Expr) -> Expr {
match expr {
Expr::Constant(_) | Expr::Variable(_) => expr.clone(),
Expr::Add(f, g) => Expr::Add(Box::new(expand(f)), Box::new(expand(g))),
Expr::Sub(f, g) => Expr::Sub(Box::new(expand(f)), Box::new(expand(g))),
Expr::Mul(f, g) => {
let f_exp = expand(f);
let g_exp = expand(g);
if let Expr::Add(f1, f2) = f_exp {
return expand(&Expr::Add(
Box::new(Expr::Mul(f1, Box::new(g_exp.clone()))),
Box::new(Expr::Mul(f2, Box::new(g_exp))),
));
}
if let Expr::Add(g1, g2) = g_exp {
return expand(&Expr::Add(
Box::new(Expr::Mul(Box::new(f_exp.clone()), g1)),
Box::new(Expr::Mul(Box::new(f_exp), g2)),
));
}
if let Expr::Sub(f1, f2) = f_exp {
return expand(&Expr::Sub(
Box::new(Expr::Mul(f1, Box::new(g_exp.clone()))),
Box::new(Expr::Mul(f2, Box::new(g_exp))),
));
}
if let Expr::Sub(g1, g2) = g_exp {
return expand(&Expr::Sub(
Box::new(Expr::Mul(Box::new(f_exp.clone()), g1)),
Box::new(Expr::Mul(Box::new(f_exp), g2)),
));
}
Expr::Mul(Box::new(f_exp), Box::new(g_exp))
}
Expr::Pow(f, g) => {
let f_exp = expand(f);
if let Expr::Constant(n) = **g {
if n > 0.0 && n <= 5.0 && n.fract() == 0.0 {
let n_int = n as i32;
let mut result = f_exp.clone();
for _ in 1..n_int {
result = expand(&Expr::Mul(Box::new(result), Box::new(f_exp.clone())));
}
return result;
}
}
Expr::Pow(Box::new(f_exp), Box::new(expand(g)))
}
Expr::Div(f, g) => Expr::Div(Box::new(expand(f)), Box::new(expand(g))),
Expr::Neg(f) => Expr::Neg(Box::new(expand(f))),
Expr::Sin(f) => Expr::Sin(Box::new(expand(f))),
Expr::Cos(f) => Expr::Cos(Box::new(expand(f))),
Expr::Tan(f) => Expr::Tan(Box::new(expand(f))),
Expr::Exp(f) => Expr::Exp(Box::new(expand(f))),
Expr::Ln(f) => Expr::Ln(Box::new(expand(f))),
Expr::Sqrt(f) => Expr::Sqrt(Box::new(expand(f))),
}
}
pub fn factor(expr: &Expr) -> Expr {
simplify(expr)
}
pub fn structurally_equal(a: &Expr, b: &Expr) -> bool {
a == b
}
pub fn collect_terms(expr: &Expr) -> Expr {
simplify(expr)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simplify_addition() {
let x = Expr::var("x");
let expr = x.clone() + 0.0;
let simplified = simplify(&expr);
assert_eq!(simplified, x);
let expr = 0.0 + x.clone();
let simplified = simplify(&expr);
assert_eq!(simplified, x);
let expr = Expr::constant(2.0) + Expr::constant(3.0);
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(5.0));
}
#[test]
fn test_simplify_multiplication() {
let x = Expr::var("x");
let expr = x.clone() * 0.0;
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(0.0));
let expr = x.clone() * 1.0;
let simplified = simplify(&expr);
assert_eq!(simplified, x);
let expr = Expr::constant(2.0) * Expr::constant(3.0);
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(6.0));
}
#[test]
fn test_simplify_power() {
let x = Expr::var("x");
let expr = x.clone().pow(0.0);
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(1.0));
let expr = x.clone().pow(1.0);
let simplified = simplify(&expr);
assert_eq!(simplified, x);
let expr = Expr::constant(2.0).pow(3.0);
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(8.0));
}
#[test]
fn test_simplify_negation() {
let x = Expr::var("x");
let expr = -(-x.clone());
let simplified = simplify(&expr);
assert_eq!(simplified, x);
let expr = -Expr::constant(0.0);
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(0.0));
}
#[test]
fn test_simplify_exp_ln() {
let x = Expr::var("x");
let expr = x.clone().ln().exp();
let simplified = simplify(&expr);
assert_eq!(simplified, x);
let expr = x.clone().exp().ln();
let simplified = simplify(&expr);
assert_eq!(simplified, x);
}
#[test]
fn test_expand_distributive() {
let x = Expr::var("x");
let expr = (x.clone() + 1.0) * 2.0;
let expanded = expand(&expr);
let simplified = simplify(&expanded);
use std::collections::HashMap;
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
let original_val = expr.eval(&vars).expect("eval failed");
let expanded_val = simplified.eval(&vars).expect("eval failed");
assert_eq!(original_val, expanded_val);
}
#[test]
fn test_simplify_division() {
let x = Expr::var("x");
let expr = x.clone() / 1.0;
let simplified = simplify(&expr);
assert_eq!(simplified, x);
let expr = 0.0 / x.clone();
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(0.0));
let expr = Expr::constant(6.0) / Expr::constant(2.0);
let simplified = simplify(&expr);
assert_eq!(simplified, Expr::constant(3.0));
}
}