use crate::error::{NumRs2Error, Result};
use crate::symbolic::expr::Expr;
pub fn differentiate(expr: &Expr, var: &str) -> Result<Expr> {
match expr {
Expr::Constant(_) => Ok(Expr::Constant(0.0)),
Expr::Variable(v) => {
if v == var {
Ok(Expr::Constant(1.0))
} else {
Ok(Expr::Constant(0.0))
}
}
Expr::Add(f, g) => {
let df = differentiate(f, var)?;
let dg = differentiate(g, var)?;
Ok(Expr::Add(Box::new(df), Box::new(dg)))
}
Expr::Sub(f, g) => {
let df = differentiate(f, var)?;
let dg = differentiate(g, var)?;
Ok(Expr::Sub(Box::new(df), Box::new(dg)))
}
Expr::Mul(f, g) => {
let df = differentiate(f, var)?;
let dg = differentiate(g, var)?;
Ok(Expr::Add(
Box::new(Expr::Mul(Box::new(df), g.clone())),
Box::new(Expr::Mul(f.clone(), Box::new(dg))),
))
}
Expr::Div(f, g) => {
let df = differentiate(f, var)?;
let dg = differentiate(g, var)?;
let numerator = Expr::Sub(
Box::new(Expr::Mul(Box::new(df), g.clone())),
Box::new(Expr::Mul(f.clone(), Box::new(dg))),
);
let denominator = Expr::Mul(g.clone(), g.clone());
Ok(Expr::Div(Box::new(numerator), Box::new(denominator)))
}
Expr::Pow(f, g) => {
if !g.contains_var(var) {
let df = differentiate(f, var)?;
let g_minus_1 = Expr::Sub(g.clone(), Box::new(Expr::Constant(1.0)));
Ok(Expr::Mul(
Box::new(Expr::Mul(
g.clone(),
Box::new(Expr::Pow(f.clone(), Box::new(g_minus_1))),
)),
Box::new(df),
))
} else if !f.contains_var(var) {
let dg = differentiate(g, var)?;
Ok(Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Pow(f.clone(), g.clone())),
Box::new(Expr::Ln(f.clone())),
)),
Box::new(dg),
))
} else {
let df = differentiate(f, var)?;
let dg = differentiate(g, var)?;
let g_minus_1 = Expr::Sub(g.clone(), Box::new(Expr::Constant(1.0)));
let term1 = Expr::Mul(
Box::new(Expr::Mul(
g.clone(),
Box::new(Expr::Pow(f.clone(), Box::new(g_minus_1))),
)),
Box::new(df),
);
let term2 = Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Pow(f.clone(), g.clone())),
Box::new(Expr::Ln(f.clone())),
)),
Box::new(dg),
);
Ok(Expr::Add(Box::new(term1), Box::new(term2)))
}
}
Expr::Neg(f) => {
let df = differentiate(f, var)?;
Ok(Expr::Neg(Box::new(df)))
}
Expr::Sin(f) => {
let df = differentiate(f, var)?;
Ok(Expr::Mul(Box::new(Expr::Cos(f.clone())), Box::new(df)))
}
Expr::Cos(f) => {
let df = differentiate(f, var)?;
Ok(Expr::Mul(
Box::new(Expr::Neg(Box::new(Expr::Sin(f.clone())))),
Box::new(df),
))
}
Expr::Tan(f) => {
let df = differentiate(f, var)?;
let cos_f = Expr::Cos(f.clone());
let sec_squared = Expr::Div(
Box::new(Expr::Constant(1.0)),
Box::new(Expr::Mul(Box::new(cos_f.clone()), Box::new(cos_f))),
);
Ok(Expr::Mul(Box::new(sec_squared), Box::new(df)))
}
Expr::Exp(f) => {
let df = differentiate(f, var)?;
Ok(Expr::Mul(Box::new(Expr::Exp(f.clone())), Box::new(df)))
}
Expr::Ln(f) => {
let df = differentiate(f, var)?;
Ok(Expr::Div(Box::new(df), f.clone()))
}
Expr::Sqrt(f) => {
let df = differentiate(f, var)?;
let denominator = Expr::Mul(
Box::new(Expr::Constant(2.0)),
Box::new(Expr::Sqrt(f.clone())),
);
Ok(Expr::Div(Box::new(df), Box::new(denominator)))
}
}
}
pub fn gradient(expr: &Expr, vars: &[&str]) -> Result<Vec<Expr>> {
let mut grad = Vec::with_capacity(vars.len());
for var in vars {
grad.push(differentiate(expr, var)?);
}
Ok(grad)
}
pub fn jacobian(exprs: &[Expr], vars: &[&str]) -> Result<Vec<Vec<Expr>>> {
let mut jac = Vec::with_capacity(exprs.len());
for expr in exprs {
jac.push(gradient(expr, vars)?);
}
Ok(jac)
}
pub fn hessian(expr: &Expr, vars: &[&str]) -> Result<Vec<Vec<Expr>>> {
let mut hess = Vec::with_capacity(vars.len());
for &var1 in vars {
let mut row = Vec::with_capacity(vars.len());
for &var2 in vars {
let first_deriv = differentiate(expr, var1)?;
let second_deriv = differentiate(&first_deriv, var2)?;
row.push(second_deriv);
}
hess.push(row);
}
Ok(hess)
}
pub fn directional_derivative(expr: &Expr, vars: &[&str], direction: &[Expr]) -> Result<Expr> {
if vars.len() != direction.len() {
return Err(NumRs2Error::ValueError(
"Direction vector must have same length as variable list".to_string(),
));
}
let grad = gradient(expr, vars)?;
if grad.is_empty() {
return Ok(Expr::Constant(0.0));
}
let mut result = Expr::Mul(Box::new(grad[0].clone()), Box::new(direction[0].clone()));
for i in 1..grad.len() {
result = Expr::Add(
Box::new(result),
Box::new(Expr::Mul(
Box::new(grad[i].clone()),
Box::new(direction[i].clone()),
)),
);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_differentiate_constant() {
let c = Expr::constant(5.0);
let dc = differentiate(&c, "x").expect("differentiation failed");
assert!(matches!(dc, Expr::Constant(0.0)));
}
#[test]
fn test_differentiate_variable() {
let x = Expr::var("x");
let dx = differentiate(&x, "x").expect("differentiation failed");
assert!(matches!(dx, Expr::Constant(1.0)));
let dy = differentiate(&x, "y").expect("differentiation failed");
assert!(matches!(dy, Expr::Constant(0.0)));
}
#[test]
fn test_differentiate_sum() {
let x = Expr::var("x");
let expr = x.clone() + x.clone();
let derivative = differentiate(&expr, "x").expect("differentiation failed");
let mut vars = HashMap::new();
vars.insert("x".to_string(), 5.0);
let result = derivative.eval(&vars).expect("evaluation failed");
assert_eq!(result, 2.0);
}
#[test]
fn test_differentiate_product() {
let x = Expr::var("x");
let expr = x.clone() * x.clone(); let derivative = differentiate(&expr, "x").expect("differentiation failed");
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
let result = derivative.eval(&vars).expect("evaluation failed");
assert_eq!(result, 6.0);
}
#[test]
fn test_differentiate_power() {
let x = Expr::var("x");
let expr = x.clone().pow(3.0); let derivative = differentiate(&expr, "x").expect("differentiation failed");
let mut vars = HashMap::new();
vars.insert("x".to_string(), 2.0);
let result = derivative.eval(&vars).expect("evaluation failed");
assert_eq!(result, 12.0);
}
#[test]
fn test_differentiate_sin() {
let x = Expr::var("x");
let expr = x.clone().sin();
let derivative = differentiate(&expr, "x").expect("differentiation failed");
let mut vars = HashMap::new();
vars.insert("x".to_string(), 0.0);
let result = derivative.eval(&vars).expect("evaluation failed");
assert_eq!(result, 1.0);
}
#[test]
fn test_differentiate_exp() {
let x = Expr::var("x");
let expr = x.clone().exp();
let derivative = differentiate(&expr, "x").expect("differentiation failed");
let mut vars = HashMap::new();
vars.insert("x".to_string(), 0.0);
let result = derivative.eval(&vars).expect("evaluation failed");
assert_eq!(result, 1.0);
}
#[test]
fn test_gradient() {
let x = Expr::var("x");
let y = Expr::var("y");
let f = x.clone() * x.clone() + y.clone() * y.clone();
let grad = gradient(&f, &["x", "y"]).expect("gradient computation failed");
assert_eq!(grad.len(), 2);
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
vars.insert("y".to_string(), 4.0);
let dx = grad[0].eval(&vars).expect("evaluation failed");
let dy = grad[1].eval(&vars).expect("evaluation failed");
assert_eq!(dx, 6.0);
assert_eq!(dy, 8.0);
}
#[test]
fn test_chain_rule() {
let x = Expr::var("x");
let inner = x.clone() * 2.0;
let expr = inner.sin();
let derivative = differentiate(&expr, "x").expect("differentiation failed");
let mut vars = HashMap::new();
vars.insert("x".to_string(), 0.0);
let result = derivative.eval(&vars).expect("evaluation failed");
assert_eq!(result, 2.0);
}
#[test]
fn test_quotient_rule() {
let x = Expr::var("x");
let expr = x.clone() / (x.clone() + 1.0);
let derivative = differentiate(&expr, "x").expect("differentiation failed");
let mut vars = HashMap::new();
vars.insert("x".to_string(), 1.0);
let result = derivative.eval(&vars).expect("evaluation failed");
assert_eq!(result, 0.25);
}
}