use num_complex::Complex;
use std::rc::Rc;
use crate::astnode::AstNode;
use crate::core::Real;
use crate::functions::{
Arity,
FunctionKind,
};
use crate::err::ParseError;
use crate::lexer::Span;
use crate::operators::BinaryOperatorKind;
impl<T: Real> AstNode<T> {
pub(crate) fn differentiate(self, var: usize) -> Result<Self, ParseError> {
match self {
Self::Number { span, .. } => Ok(Self::zero(span)),
Self::Argument { index: i, span } => Ok(if i == var { Self::one(span) } else { Self::zero(span) }),
Self::UnaryOperator { kind, expr, span } => {
let expr = Self::unwrap_rc(expr);
Ok(Self::UnaryOperator {
kind,
expr: Rc::new(expr.differentiate(var)?),
span,
})
},
Self::BinaryOperator { kind, left, right, span } => {
let left = Self::unwrap_rc(left);
let right = Self::unwrap_rc(right);
Self::diff_binary(kind, left, right, var, span)
}
Self::FunctionCall { kind, args, span } => {
Self::diff_function(kind, args, var, span)
}
Self::UserFunctionCall { func, args, span } => {
if var >= func.arity() {
return Err(ParseError::OutOfRange { func: func.name().into(), idx: var, span });
}
if let Some(deriv) = func.derivative(var).cloned() {
Ok(Self::UserFunctionCall { func: deriv, args, span })
} else {
Err(ParseError::DerivativeUndefined { func: func.name().into(), idx: var, span })
}
}
Self::Derivative { expr, var: inner_var, order, span } => {
if inner_var == var {
Ok(Self::Derivative { expr, var, order: order + 1, span })
} else {
let expr = Self::unwrap_rc(expr);
Ok(Self::Derivative {
expr: Rc::new(expr.differentiate(var)?),
var: inner_var,
order,
span,
})
}
}
}
}
fn diff_binary(
kind: BinaryOperatorKind,
left: Self,
right: Self,
var: usize,
span: Span,
) -> Result<Self, ParseError> {
let dl = left.clone().differentiate(var)?;
let dr = right.clone().differentiate(var)?;
match kind {
BinaryOperatorKind::Add | BinaryOperatorKind::Sub => {
Ok(Self::BinaryOperator { kind, left: Rc::new(dl), right: Rc::new(dr), span })
}
BinaryOperatorKind::Mul => Ok(dl * right + left * dr),
BinaryOperatorKind::Div => {
Ok((dl * right.clone() - left * dr) / right.powi(2))
}
BinaryOperatorKind::Pow => Self::diff_pow(left, right, var),
}
}
fn diff_function(
kind: FunctionKind,
mut args: Vec<Rc<Self>>,
var: usize,
span: Span,
) -> Result<Self, ParseError> {
let x = Self::unwrap_rc(args.remove(0));
let dx = x.clone().differentiate(var)?;
match kind {
FunctionKind::Sin => Ok(x.cos() * dx),
FunctionKind::Cos => Ok(-x.sin() * dx),
FunctionKind::Tan => Ok(dx / x.cos().powi(2)),
FunctionKind::Asin => Ok(dx / (Self::one(span) - x.powi(2)).sqrt()),
FunctionKind::Acos => Ok(-dx / (Self::one(span) - x.powi(2)).sqrt()),
FunctionKind::Atan => Ok(dx / (Self::one(span) + x.powi(2))),
FunctionKind::Sinh => Ok(dx * x.cosh()),
FunctionKind::Cosh => Ok(dx * x.sinh()),
FunctionKind::Tanh => Ok(dx / x.cosh().powi(2)),
FunctionKind::Asinh => Ok(dx / (x.powi(2) + Self::one(span)).sqrt()),
FunctionKind::Acosh => Ok(dx / (x.powi(2) - Self::one(span)).sqrt()),
FunctionKind::Atanh => Ok(dx / (Self::one(span) - x.powi(2))),
FunctionKind::Exp => Ok(dx * x.exp()),
FunctionKind::Ln => Ok(dx / x),
FunctionKind::Log10 => Ok(dx * Self::Number { value: Complex::from(T::log10_e()), span } / x),
FunctionKind::Sqrt => Ok(dx * Self::Number { value: Complex::from(T::from_f64(0.5)), span } / x.sqrt()),
FunctionKind::Abs => Err(ParseError::InvalidFormula {
reason: "`abs(z)` is not differentiable in the complex domain".into(),
span,
}),
FunctionKind::Conj => Err(ParseError::InvalidFormula {
reason: "`conj(z)` is not differentiable in the complex domain".into(),
span,
}),
FunctionKind::Pow => {
let y = Self::unwrap_rc(args.remove(0));
Self::diff_pow(x, y, var)
},
FunctionKind::Powi => {
let n = Self::unwrap_rc(args.remove(0));
Self::diff_powi(x, n, var)
},
}
}
fn diff_pow(u: Self, v: Self, var: usize) -> Result<Self, ParseError> {
let du = u.clone().differentiate(var)?;
let dv = v.clone().differentiate(var)?;
let ln_u = Self::FunctionCall { kind: FunctionKind::Ln, args: vec![Rc::new(u.clone())], span: u.span() };
Ok(u.clone().pow(v.clone()) * (dv * ln_u + v * du / u))
}
fn diff_powi(u: Self, n: Self, var: usize) -> Result<Self, ParseError> {
let s = u.span();
let du = u.clone().differentiate(var)?;
Ok(Self::FunctionCall {
kind: FunctionKind::Powi,
args: vec![Rc::new(u), Rc::new(n.clone() - Self::one(s))],
span: s,
} * n * du)
}
}
#[cfg(test)]
mod differentiate_tests {
use super::*;
use num_complex::Complex;
#[test]
fn differentiate_number() {
let node = AstNode::Number { value: Complex::new(5.0, 0.0), span: Span::from(0..1) };
let diff = node.differentiate(0).unwrap();
assert_eq!(diff, AstNode::Number { value: Complex::ZERO, span: Span::from(0..1) });
}
#[test]
fn differentiate_argument() {
let node = AstNode::<f64>::Argument { index: 1, span: Span::from(0..1) };
let diff = node.clone().differentiate(1).unwrap();
assert_eq!(diff, AstNode::Number { value: Complex::ONE, span: Span::from(0..1) });
let diff_other = node.differentiate(0).unwrap();
assert_eq!(diff_other, AstNode::Number { value: Complex::ZERO, span: Span::from(0..1) });
}
#[test]
fn differentiate_unary_operator() {
let node = -AstNode::<f64>::Argument { index: 0, span: Span::from(2..3) };
let diff = node.differentiate(0).unwrap();
assert_eq!(diff, -AstNode::Number { value: Complex::ONE, span: Span::from(2..3) });
}
#[test]
fn differentiate_binary_add() {
let node = AstNode::Argument { index: 0, span: Span::from(0..1) } + AstNode::Number { value: Complex::new(2.0, 0.0), span: Span::from(4..5) };
let diff = node.differentiate(0).unwrap();
match diff {
AstNode::BinaryOperator { kind, .. } => {
assert_eq!(kind, BinaryOperatorKind::Add);
}
_ => panic!("Expected BinaryOperator Add"),
}
}
#[test]
fn differentiate_function_sin() {
let node = AstNode::<f64>::Argument { index: 0, span: Span::from(2..3) }.sin();
let diff = node.differentiate(0).unwrap();
match diff {
AstNode::BinaryOperator { kind, .. } => {
assert_eq!(kind, BinaryOperatorKind::Mul);
}
_ => panic!("Expected BinaryOperator Mul for cos(x) * 1"),
}
}
#[test]
fn differentiate_derivative_order() {
let node = AstNode::Derivative {
expr: Rc::new(AstNode::<f64>::Argument { index: 0, span: Span::from(2..3) }),
var: 0,
order: 1,
span: Span::from(0..4),
};
let diff = node.differentiate(0).unwrap();
assert_eq!(
diff,
AstNode::Derivative {
expr: Rc::new(AstNode::Argument { index: 0, span: Span::from(2..3) }),
var: 0,
order: 2,
span: Span::from(0..4),
}
);
}
#[test]
fn differentiate_mul_x2() {
let node = AstNode::<f64>::Argument { index: 0, span: Span::from(0..1) }.mul(AstNode::Argument { index: 0, span: Span::from(4..5) }).differentiate(0)
.unwrap().simplify();
match node {
AstNode::BinaryOperator { kind, .. } => {
matches!(kind, BinaryOperatorKind::Add | BinaryOperatorKind::Mul);
}
_ => panic!("Expected BinaryOperator after differentiation and simplification"),
}
}
#[test]
fn differentiate_builtin_functions() {
let span = Span::from(2..3);
let x = AstNode::<f64>::Argument { index: 0, span };
let y = AstNode::<f64>::Argument { index: 1, span };
let one = AstNode::one(span);
let three = AstNode::Number { value: Complex::new(3.0, 0.0), span };
let half = AstNode::Number { value: Complex::new(0.5, 0.0), span };
let log10_e = AstNode::Number { value: Complex::new(std::f64::consts::LOG10_E, 0.0), span };
let cases = vec![
(
FunctionKind::Sin,
AstNode::FunctionCall { kind: FunctionKind::Sin, args: vec![Rc::new(x.clone())], span },
x.clone().cos() * one.clone(),
),
(
FunctionKind::Cos,
AstNode::FunctionCall { kind: FunctionKind::Cos, args: vec![Rc::new(x.clone())], span },
(-x.clone().sin()) * one.clone(),
),
(
FunctionKind::Tan,
AstNode::FunctionCall { kind: FunctionKind::Tan, args: vec![Rc::new(x.clone())], span },
one.clone() / x.clone().cos().powi(2),
),
(
FunctionKind::Asin,
AstNode::FunctionCall { kind: FunctionKind::Asin, args: vec![Rc::new(x.clone())], span },
one.clone() / (one.clone() - x.clone().powi(2)).sqrt(),
),
(
FunctionKind::Acos,
AstNode::FunctionCall { kind: FunctionKind::Acos, args: vec![Rc::new(x.clone())], span },
(-one.clone()) / (one.clone() - x.clone().powi(2)).sqrt(),
),
(
FunctionKind::Atan,
AstNode::FunctionCall { kind: FunctionKind::Atan, args: vec![Rc::new(x.clone())], span },
one.clone() / (one.clone() + x.clone().powi(2)),
),
(
FunctionKind::Sinh,
AstNode::FunctionCall { kind: FunctionKind::Sinh, args: vec![Rc::new(x.clone())], span },
one.clone() * x.clone().cosh(),
),
(
FunctionKind::Cosh,
AstNode::FunctionCall { kind: FunctionKind::Cosh, args: vec![Rc::new(x.clone())], span },
one.clone() * x.clone().sinh(),
),
(
FunctionKind::Tanh,
AstNode::FunctionCall { kind: FunctionKind::Tanh, args: vec![Rc::new(x.clone())], span },
one.clone() / x.clone().cosh().powi(2),
),
(
FunctionKind::Asinh,
AstNode::FunctionCall { kind: FunctionKind::Asinh, args: vec![Rc::new(x.clone())], span },
one.clone() / (x.clone().powi(2) + one.clone()).sqrt(),
),
(
FunctionKind::Acosh,
AstNode::FunctionCall { kind: FunctionKind::Acosh, args: vec![Rc::new(x.clone())], span },
one.clone() / (x.clone().powi(2) - one.clone()).sqrt(),
),
(
FunctionKind::Atanh,
AstNode::FunctionCall { kind: FunctionKind::Atanh, args: vec![Rc::new(x.clone())], span },
one.clone() / (one.clone() - x.clone().powi(2)),
),
(
FunctionKind::Exp,
AstNode::FunctionCall { kind: FunctionKind::Exp, args: vec![Rc::new(x.clone())], span },
one.clone() * x.clone().exp(),
),
(
FunctionKind::Ln,
AstNode::FunctionCall { kind: FunctionKind::Ln, args: vec![Rc::new(x.clone())], span },
one.clone() / x.clone(),
),
(
FunctionKind::Log10,
AstNode::FunctionCall { kind: FunctionKind::Log10, args: vec![Rc::new(x.clone())], span },
one.clone() * log10_e.clone() / x.clone(),
),
(
FunctionKind::Sqrt,
AstNode::FunctionCall { kind: FunctionKind::Sqrt, args: vec![Rc::new(x.clone())], span },
one.clone() * half.clone() / x.clone().sqrt(),
),
];
for (kind, node, expected) in cases {
assert_eq!(node.differentiate(0).unwrap(), expected, "FunctionKind::{:?}", kind);
}
let pow_node = x.clone().pow(y.clone());
let pow_expected = x.clone().pow(y.clone()) * (one.clone() * AstNode::FunctionCall { kind: FunctionKind::Ln, args: vec![Rc::new(x.clone())], span } + y.clone() * AstNode::zero(span) / x.clone());
assert_eq!(pow_node.differentiate(1).unwrap(), pow_expected);
let powi_node = x.clone().powi(3);
let powi_expected = AstNode::FunctionCall {
kind: FunctionKind::Powi,
args: vec![Rc::new(x.clone()), Rc::new(three.clone() - one.clone())],
span,
} * three.clone() * one.clone();
assert_eq!(powi_node.differentiate(0).unwrap(), powi_expected);
assert!(AstNode::FunctionCall { kind: FunctionKind::Abs, args: vec![Rc::new(x.clone())], span }.differentiate(0).is_err());
assert!(AstNode::FunctionCall { kind: FunctionKind::Conj, args: vec![Rc::new(x.clone())], span }.differentiate(0).is_err());
}
}