use super::ast::Expr;
pub fn diff(expr: &Expr, var: &str) -> Expr {
let raw = diff_raw(expr, var);
simplify(raw)
}
fn diff_raw(expr: &Expr, var: &str) -> Expr {
match expr {
Expr::Lit(_) => Expr::Lit(0.0),
Expr::Var(name) => {
if name == var {
Expr::Lit(1.0)
} else {
Expr::Lit(0.0)
}
}
Expr::Add(a, b) => {
let da = diff_raw(a, var);
let db = diff_raw(b, var);
Expr::add(da, db)
}
Expr::Sub(a, b) => {
let da = diff_raw(a, var);
let db = diff_raw(b, var);
Expr::sub(da, db)
}
Expr::Mul(a, b) => {
let da = diff_raw(a, var);
let db = diff_raw(b, var);
Expr::add(
Expr::mul(da, b.as_ref().clone()),
Expr::mul(a.as_ref().clone(), db),
)
}
Expr::Div(a, b) => {
let da = diff_raw(a, var);
if let Expr::Lit(v) = b.as_ref() {
Expr::div(da, Expr::Lit(*v))
} else {
let db = diff_raw(b, var);
Expr::div(
Expr::sub(
Expr::mul(da, b.as_ref().clone()),
Expr::mul(a.as_ref().clone(), db),
),
Expr::pow(b.as_ref().clone(), 2),
)
}
}
Expr::Pow(base, n) => {
if *n == 0 {
Expr::Lit(0.0)
} else {
let dbase = diff_raw(base, var);
Expr::mul(
Expr::mul(
Expr::Lit(*n as f64),
Expr::pow(base.as_ref().clone(), n - 1),
),
dbase,
)
}
}
}
}
pub fn simplify(expr: Expr) -> Expr {
match expr {
Expr::Lit(_) | Expr::Var(_) => expr,
Expr::Add(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
match (&a, &b) {
(Expr::Lit(v), _) if *v == 0.0 => b,
(_, Expr::Lit(v)) if *v == 0.0 => a,
(Expr::Lit(va), Expr::Lit(vb)) => Expr::Lit(va + vb),
_ => Expr::add(a, b),
}
}
Expr::Sub(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
match (&a, &b) {
(_, Expr::Lit(v)) if *v == 0.0 => a,
(Expr::Lit(v), _) if *v == 0.0 => Expr::mul(Expr::Lit(-1.0), b),
(Expr::Lit(va), Expr::Lit(vb)) => Expr::Lit(va - vb),
_ => Expr::sub(a, b),
}
}
Expr::Mul(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
match (&a, &b) {
(Expr::Lit(v), _) if *v == 0.0 => Expr::Lit(0.0),
(_, Expr::Lit(v)) if *v == 0.0 => Expr::Lit(0.0),
(Expr::Lit(v), _) if *v == 1.0 => b,
(_, Expr::Lit(v)) if *v == 1.0 => a,
(Expr::Lit(va), Expr::Lit(vb)) => Expr::Lit(va * vb),
_ => Expr::mul(a, b),
}
}
Expr::Div(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
match (&a, &b) {
(Expr::Lit(v), _) if *v == 0.0 => Expr::Lit(0.0),
(_, Expr::Lit(v)) if *v == 1.0 => a,
(Expr::Lit(va), Expr::Lit(vb)) if *vb != 0.0 => Expr::Lit(va / vb),
_ => Expr::div(a, b),
}
}
Expr::Pow(base, n) => {
let base = simplify(*base);
match n {
1 => base, 0 => Expr::Lit(1.0), _ => Expr::pow(base, n),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::eval::eval;
use crate::expr::parser::parse;
fn diff_and_eval(expr_str: &str, var: &str, bindings: &[(&str, f64)]) -> f64 {
let expr = parse(expr_str).unwrap();
let d = diff(&expr, var);
eval(&d, bindings).unwrap()
}
#[test]
fn test_diff_pos_in_trend() {
let expr = "pos + vel*dt + 0.5*acc*dt^2";
let v = diff_and_eval(
expr,
"pos",
&[("pos", 0.0), ("vel", 0.0), ("acc", 0.0), ("dt", 1.0)],
);
assert!((v - 1.0).abs() < 1e-15);
let v = diff_and_eval(
expr,
"pos",
&[("pos", 5.0), ("vel", 3.0), ("acc", 1.0), ("dt", 0.5)],
);
assert!((v - 1.0).abs() < 1e-15);
}
#[test]
fn test_diff_vel_in_trend() {
let expr = "pos + vel*dt + 0.5*acc*dt^2";
let v = diff_and_eval(
expr,
"vel",
&[("pos", 0.0), ("vel", 0.0), ("acc", 0.0), ("dt", 1.0)],
);
assert!((v - 1.0).abs() < 1e-15);
let v = diff_and_eval(
expr,
"vel",
&[("pos", 0.0), ("vel", 0.0), ("acc", 0.0), ("dt", 0.5)],
);
assert!((v - 0.5).abs() < 1e-15);
}
#[test]
fn test_diff_acc_in_trend() {
let expr = "pos + vel*dt + 0.5*acc*dt^2";
let v = diff_and_eval(
expr,
"acc",
&[("pos", 0.0), ("vel", 0.0), ("acc", 0.0), ("dt", 1.0)],
);
assert!((v - 0.5).abs() < 1e-15);
let v = diff_and_eval(
expr,
"acc",
&[("pos", 0.0), ("vel", 0.0), ("acc", 0.0), ("dt", 0.5)],
);
assert!((v - 0.125).abs() < 1e-15);
}
#[test]
fn test_diff_pos_simple() {
let expr = parse("pos").unwrap();
let d = diff(&expr, "pos");
assert!(matches!(d, Expr::Lit(v) if v == 1.0));
}
#[test]
fn test_diff_vel_acc_expr() {
let expr = "vel + acc*dt";
let v = diff_and_eval(expr, "vel", &[("vel", 0.0), ("acc", 0.0), ("dt", 1.0)]);
assert!((v - 1.0).abs() < 1e-15);
let v = diff_and_eval(expr, "acc", &[("vel", 0.0), ("acc", 0.0), ("dt", 1.0)]);
assert!((v - 1.0).abs() < 1e-15);
let v = diff_and_eval(expr, "acc", &[("vel", 0.0), ("acc", 0.0), ("dt", 0.5)]);
assert!((v - 0.5).abs() < 1e-15);
}
#[test]
fn test_simplify_zero_plus_x() {
let e = Expr::add(Expr::Lit(0.0), Expr::var("x"));
assert_eq!(simplify(e), Expr::var("x"));
}
#[test]
fn test_simplify_one_times_x() {
let e = Expr::mul(Expr::Lit(1.0), Expr::var("x"));
assert_eq!(simplify(e), Expr::var("x"));
}
#[test]
fn test_simplify_zero_times_x() {
let e = Expr::mul(Expr::Lit(0.0), Expr::var("x"));
assert_eq!(simplify(e), Expr::Lit(0.0));
}
#[test]
fn test_simplify_pow_one() {
let e = Expr::pow(Expr::var("x"), 1);
assert_eq!(simplify(e), Expr::var("x"));
}
#[test]
fn test_simplify_pow_zero() {
let e = Expr::pow(Expr::var("x"), 0);
assert_eq!(simplify(e), Expr::Lit(1.0));
}
}