use crate::expr::{Expr, MathFn, cos, exp, ln, one, sin, zero};
use crate::simplify::simplify;
#[must_use]
pub fn diff(expr: &Expr, var: &str) -> Expr {
let raw = diff_raw(expr, var);
simplify(&raw)
}
#[must_use]
pub fn diff_n(expr: &Expr, var: &str, n: usize) -> Expr {
let mut result = expr.clone();
for _ in 0..n {
result = diff(&result, var);
}
result
}
fn diff_raw(expr: &Expr, var: &str) -> Expr {
match expr {
Expr::Const(_) => zero(),
Expr::Var(name) => {
if name == var {
one()
} else {
zero()
}
}
Expr::Add(a, b) => {
let da = diff_raw(a, var);
let db = diff_raw(b, var);
Expr::Add(Box::new(da), Box::new(db))
}
Expr::Mul(a, b) => {
let da = diff_raw(a, var);
let db = diff_raw(b, var);
Expr::Add(
Box::new(Expr::Mul(Box::new(da), b.clone())),
Box::new(Expr::Mul(a.clone(), Box::new(db))),
)
}
Expr::Pow(base, exp) => {
let base_has_var = base.free_variables_contains(var);
let exp_has_var = exp.free_variables_contains(var);
if !base_has_var && !exp_has_var {
zero()
} else if base_has_var && !exp_has_var {
let df = diff_raw(base, var);
let c_minus_1 = Expr::Add(Box::new(*exp.clone()), Box::new(Expr::Const(-1.0)));
Expr::Mul(
Box::new(Expr::Mul(
exp.clone(),
Box::new(Expr::Pow(base.clone(), Box::new(c_minus_1))),
)),
Box::new(df),
)
} else if !base_has_var && exp_has_var {
let dg = diff_raw(exp, var);
Expr::Mul(
Box::new(Expr::Mul(
Box::new(expr.clone()),
Box::new(ln(*base.clone())),
)),
Box::new(dg),
)
} else {
let df = diff_raw(base, var);
let dg = diff_raw(exp, var);
let ln_f = ln(*base.clone());
let f_prime_over_f = Expr::Mul(
Box::new(df),
Box::new(Expr::Pow(base.clone(), Box::new(Expr::Const(-1.0)))),
);
let inner = Expr::Add(
Box::new(Expr::Mul(Box::new(dg), Box::new(ln_f))),
Box::new(Expr::Mul(exp.clone(), Box::new(f_prime_over_f))),
);
Expr::Mul(Box::new(expr.clone()), Box::new(inner))
}
}
Expr::Neg(inner) => Expr::Neg(Box::new(diff_raw(inner, var))),
Expr::Fn(func, arg) => {
let darg = diff_raw(arg, var);
let outer_deriv = match func {
MathFn::Sin => cos(*arg.clone()),
MathFn::Cos => Expr::Neg(Box::new(sin(*arg.clone()))),
MathFn::Tan => {
Expr::Pow(Box::new(cos(*arg.clone())), Box::new(Expr::Const(-2.0)))
}
MathFn::Exp => exp(*arg.clone()),
MathFn::Ln => {
Expr::Pow(arg.clone(), Box::new(Expr::Const(-1.0)))
}
MathFn::Sqrt => {
Expr::Mul(
Box::new(Expr::Const(0.5)),
Box::new(Expr::Pow(
Box::new(Expr::Fn(MathFn::Sqrt, arg.clone())),
Box::new(Expr::Const(-1.0)),
)),
)
}
MathFn::Abs => {
Expr::Mul(
arg.clone(),
Box::new(Expr::Pow(
Box::new(Expr::Fn(MathFn::Abs, arg.clone())),
Box::new(Expr::Const(-1.0)),
)),
)
}
};
Expr::Mul(Box::new(outer_deriv), Box::new(darg))
}
}
}
trait ContainsVar {
fn free_variables_contains(&self, var: &str) -> bool;
}
impl ContainsVar for Expr {
fn free_variables_contains(&self, var: &str) -> bool {
match self {
Expr::Const(_) => false,
Expr::Var(name) => name == var,
Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Pow(a, b) => {
a.free_variables_contains(var) || b.free_variables_contains(var)
}
Expr::Neg(inner) | Expr::Fn(_, inner) => inner.free_variables_contains(var),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, var};
use std::collections::HashMap;
fn eval(e: &Expr, x_val: f64) -> f64 {
let mut vars = HashMap::new();
vars.insert("x".into(), x_val);
e.eval(&vars).unwrap()
}
#[test]
fn diff_constant() {
let e = constant(5.0);
assert_eq!(diff(&e, "x"), zero());
}
#[test]
fn diff_variable() {
assert_eq!(diff(&var("x"), "x"), one());
assert_eq!(diff(&var("y"), "x"), zero());
}
#[test]
fn diff_power_rule() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(3.0)));
let d = diff(&e, "x");
assert!((eval(&d, 2.0) - 12.0).abs() < 1e-10);
}
#[test]
fn diff_product_rule() {
let e = var("x") * var("x");
let d = diff(&e, "x");
assert!((eval(&d, 3.0) - 6.0).abs() < 1e-10);
}
#[test]
fn diff_sin() {
let e = sin(var("x"));
let d = diff(&e, "x");
assert!((eval(&d, 0.0) - 1.0).abs() < 1e-10);
}
#[test]
fn diff_cos() {
let e = cos(var("x"));
let d = diff(&e, "x");
assert!((eval(&d, std::f64::consts::FRAC_PI_2) - (-1.0)).abs() < 1e-10);
}
#[test]
fn diff_exp_func() {
let e = exp(var("x"));
let d = diff(&e, "x");
assert!((eval(&d, 1.0) - std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn diff_ln_func() {
let e = ln(var("x"));
let d = diff(&e, "x");
assert!((eval(&d, 2.0) - 0.5).abs() < 1e-10);
}
#[test]
fn diff_chain_rule() {
let x_sq = Expr::Pow(Box::new(var("x")), Box::new(constant(2.0)));
let e = sin(x_sq);
let d = diff(&e, "x");
let expected = 1.0_f64.cos() * 2.0;
assert!((eval(&d, 1.0) - expected).abs() < 1e-10);
}
#[test]
fn diff_n_second_derivative() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(3.0)));
let d2 = diff_n(&e, "x", 2);
assert!((eval(&d2, 2.0) - 12.0).abs() < 1e-10);
}
}