use super::SymExpr;
use crate::Float;
pub fn simplify_expr<T: Float>(expr: &SymExpr<T>) -> SymExpr<T> {
match expr {
SymExpr::Const(c) => SymExpr::Const(*c),
SymExpr::Var(name) => SymExpr::Var(name.clone()),
SymExpr::Add(a, b) => {
let a_simp = simplify_expr(a);
let b_simp = simplify_expr(b);
match (&a_simp, &b_simp) {
(SymExpr::Const(c), _) if c.abs() < T::epsilon() => b_simp,
(_, SymExpr::Const(c)) if c.abs() < T::epsilon() => a_simp,
(SymExpr::Const(c1), SymExpr::Const(c2)) => SymExpr::Const(*c1 + *c2),
_ => SymExpr::Add(Box::new(a_simp), Box::new(b_simp)),
}
}
SymExpr::Mul(a, b) => {
let a_simp = simplify_expr(a);
let b_simp = simplify_expr(b);
match (&a_simp, &b_simp) {
(SymExpr::Const(c), _) if c.abs() < T::epsilon() => SymExpr::Const(T::zero()),
(_, SymExpr::Const(c)) if c.abs() < T::epsilon() => SymExpr::Const(T::zero()),
(SymExpr::Const(c), _) if (*c - T::one()).abs() < T::epsilon() => b_simp,
(_, SymExpr::Const(c)) if (*c - T::one()).abs() < T::epsilon() => a_simp,
(SymExpr::Const(c1), SymExpr::Const(c2)) => SymExpr::Const(*c1 * *c2),
_ => SymExpr::Mul(Box::new(a_simp), Box::new(b_simp)),
}
}
SymExpr::Sub(a, b) => {
let a_simp = simplify_expr(a);
let b_simp = simplify_expr(b);
match (&a_simp, &b_simp) {
(_, SymExpr::Const(c)) if c.abs() < T::epsilon() => a_simp,
_ if a_simp == b_simp => SymExpr::Const(T::zero()),
(SymExpr::Const(c1), SymExpr::Const(c2)) => SymExpr::Const(*c1 - *c2),
_ => SymExpr::Sub(Box::new(a_simp), Box::new(b_simp)),
}
}
SymExpr::Div(a, b) => {
let a_simp = simplify_expr(a);
let b_simp = simplify_expr(b);
match (&a_simp, &b_simp) {
(SymExpr::Const(c), _) if c.abs() < T::epsilon() => SymExpr::Const(T::zero()),
(_, SymExpr::Const(c)) if (*c - T::one()).abs() < T::epsilon() => a_simp,
(SymExpr::Const(c1), SymExpr::Const(c2)) if c2.abs() > T::epsilon() => {
SymExpr::Const(*c1 / *c2)
}
_ => SymExpr::Div(Box::new(a_simp), Box::new(b_simp)),
}
}
SymExpr::Pow(a, b) => {
let a_simp = simplify_expr(a);
let b_simp = simplify_expr(b);
match (&a_simp, &b_simp) {
(_, SymExpr::Const(c)) if c.abs() < T::epsilon() => SymExpr::Const(T::one()),
(_, SymExpr::Const(c)) if (*c - T::one()).abs() < T::epsilon() => a_simp,
(SymExpr::Const(c1), SymExpr::Const(c2))
if c1.abs() < T::epsilon() && *c2 > T::zero() =>
{
SymExpr::Const(T::zero())
}
(SymExpr::Const(c1), SymExpr::Const(c2)) => SymExpr::Const(c1.powf(*c2)),
_ => SymExpr::Pow(Box::new(a_simp), Box::new(b_simp)),
}
}
SymExpr::Exp(a) => SymExpr::Exp(Box::new(simplify_expr(a))),
SymExpr::Log(a) => SymExpr::Log(Box::new(simplify_expr(a))),
SymExpr::Sin(a) => SymExpr::Sin(Box::new(simplify_expr(a))),
SymExpr::Cos(a) => SymExpr::Cos(Box::new(simplify_expr(a))),
SymExpr::Tanh(a) => SymExpr::Tanh(Box::new(simplify_expr(a))),
}
}