use crate::expr::Expr;
#[must_use]
pub fn simplify(expr: &Expr) -> Expr {
match expr {
Expr::Const(_) | Expr::Var(_) => expr.clone(),
Expr::Add(a, b) => {
let a = simplify(a);
let b = simplify(b);
simplify_add(a, b)
}
Expr::Mul(a, b) => {
let a = simplify(a);
let b = simplify(b);
simplify_mul(a, b)
}
Expr::Pow(base, exp) => {
let base = simplify(base);
let exp = simplify(exp);
simplify_pow(base, exp)
}
Expr::Neg(inner) => {
let inner = simplify(inner);
simplify_neg(inner)
}
Expr::Fn(func, arg) => {
let arg = simplify(arg);
if let Some(v) = arg.as_const() {
let result = match func {
crate::expr::MathFn::Sin => v.sin(),
crate::expr::MathFn::Cos => v.cos(),
crate::expr::MathFn::Tan => v.tan(),
crate::expr::MathFn::Exp => v.exp(),
crate::expr::MathFn::Ln => v.ln(),
crate::expr::MathFn::Sqrt => v.sqrt(),
crate::expr::MathFn::Abs => v.abs(),
};
return Expr::Const(result);
}
Expr::Fn(*func, Box::new(arg))
}
}
}
fn simplify_add(a: Expr, b: Expr) -> Expr {
if let (Some(av), Some(bv)) = (a.as_const(), b.as_const()) {
return Expr::Const(av + bv);
}
if b.is_zero() {
return a;
}
if a.is_zero() {
return b;
}
Expr::Add(Box::new(a), Box::new(b))
}
fn simplify_mul(a: Expr, b: Expr) -> Expr {
if let (Some(av), Some(bv)) = (a.as_const(), b.as_const()) {
return Expr::Const(av * bv);
}
if a.is_zero() || b.is_zero() {
return Expr::Const(0.0);
}
if b.is_one() {
return a;
}
if a.is_one() {
return b;
}
Expr::Mul(Box::new(a), Box::new(b))
}
fn simplify_pow(base: Expr, exp: Expr) -> Expr {
if let (Some(bv), Some(ev)) = (base.as_const(), exp.as_const()) {
return Expr::Const(bv.powf(ev));
}
if exp.is_zero() {
return Expr::Const(1.0);
}
if exp.is_one() {
return base;
}
Expr::Pow(Box::new(base), Box::new(exp))
}
fn simplify_neg(inner: Expr) -> Expr {
if let Expr::Neg(x) = inner {
return *x;
}
if let Some(v) = inner.as_const() {
return Expr::Const(-v);
}
Expr::Neg(Box::new(inner))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, one, var, zero};
#[test]
fn x_plus_zero() {
let e = Expr::Add(Box::new(var("x")), Box::new(zero()));
let s = simplify(&e);
assert_eq!(s, var("x"));
}
#[test]
fn zero_plus_x() {
let e = Expr::Add(Box::new(zero()), Box::new(var("x")));
let s = simplify(&e);
assert_eq!(s, var("x"));
}
#[test]
fn x_times_one() {
let e = Expr::Mul(Box::new(var("x")), Box::new(one()));
let s = simplify(&e);
assert_eq!(s, var("x"));
}
#[test]
fn one_times_x() {
let e = Expr::Mul(Box::new(one()), Box::new(var("x")));
let s = simplify(&e);
assert_eq!(s, var("x"));
}
#[test]
fn x_times_zero() {
let e = Expr::Mul(Box::new(var("x")), Box::new(zero()));
let s = simplify(&e);
assert_eq!(s, zero());
}
#[test]
fn x_pow_zero() {
let e = Expr::Pow(Box::new(var("x")), Box::new(zero()));
let s = simplify(&e);
assert_eq!(s, one());
}
#[test]
fn x_pow_one() {
let e = Expr::Pow(Box::new(var("x")), Box::new(one()));
let s = simplify(&e);
assert_eq!(s, var("x"));
}
#[test]
fn neg_neg_x() {
let e = Expr::Neg(Box::new(Expr::Neg(Box::new(var("x")))));
let s = simplify(&e);
assert_eq!(s, var("x"));
}
#[test]
fn neg_const() {
let e = Expr::Neg(Box::new(constant(5.0)));
let s = simplify(&e);
assert_eq!(s, constant(-5.0));
}
#[test]
fn constant_folding() {
let e = constant(2.0) + constant(3.0);
let s = simplify(&e);
assert_eq!(s, constant(5.0));
let e = constant(4.0) * constant(5.0);
let s = simplify(&e);
assert_eq!(s, constant(20.0));
}
}