use crate::calculus::derivatives::Derivative;
use crate::core::{Expression, Symbol};
use crate::simplify::Simplify;
pub struct BasicDerivatives;
impl BasicDerivatives {
pub fn handle_calculus(
expr: &Expression,
data: &crate::core::expression::CalculusData,
variable: Symbol,
) -> Expression {
match data {
crate::core::expression::CalculusData::Derivative {
variable: var,
order,
..
} => {
if *var == variable {
Expression::derivative(expr.clone(), variable, order + 1)
} else {
Expression::integer(0) }
}
_ => Expression::derivative(expr.clone(), variable, 1),
}
}
pub fn handle_symbol(sym: &Symbol, variable: &Symbol) -> Expression {
if sym == variable {
Expression::integer(1) } else {
Expression::integer(0) }
}
pub fn handle_sum(terms: &[Expression], variable: &Symbol) -> Expression {
let mut derivative_terms = Vec::with_capacity(terms.len());
for term in terms {
derivative_terms.push(term.derivative(variable.clone()));
}
Expression::add(derivative_terms).simplify() }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol;
use crate::MathConstant;
#[test]
fn test_basic_constant_derivatives() {
let x = symbol!(x);
let y = symbol!(y);
assert_eq!(
Expression::integer(2).derivative(x.clone()),
Expression::integer(0) );
assert_eq!(
Expression::symbol(x.clone()).derivative(x.clone()),
Expression::integer(1) );
assert_eq!(
Expression::symbol(x.clone()).derivative(y.clone()),
Expression::integer(0) );
assert_eq!(
Expression::integer(-1).derivative(x.clone()),
Expression::integer(0) );
assert_eq!(
Expression::constant(MathConstant::Pi).derivative(x.clone()),
Expression::integer(0) );
}
#[test]
fn test_sum_linearity() {
let x = symbol!(x);
let sum = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(5)]);
let result = sum.derivative(x.clone()).simplify();
assert_eq!(result, Expression::integer(1));
let linear_combo = Expression::add(vec![
Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]),
]);
let linear_result = linear_combo.derivative(x.clone()).simplify();
assert_eq!(linear_result, Expression::integer(5));
}
#[test]
fn test_multiple_variables() {
let x = symbol!(x);
let y = symbol!(y);
let expr = Expression::add(vec![
Expression::mul(vec![
Expression::symbol(x.clone()),
Expression::symbol(y.clone()),
]),
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
Expression::symbol(y.clone()),
]);
let dx = expr.derivative(x.clone()).simplify();
let expected_dx = Expression::add(vec![
Expression::symbol(y.clone()),
Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
])
.simplify();
let dy = expr.derivative(y.clone()).simplify();
let expected_dy =
Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]).simplify();
assert_eq!(dx, expected_dx);
assert_eq!(dy, expected_dy);
}
#[test]
fn test_special_constants() {
let x = symbol!(x);
let pi_derivative = Expression::constant(MathConstant::Pi).derivative(x.clone());
assert_eq!(pi_derivative, Expression::integer(0));
let e_derivative = Expression::constant(MathConstant::E).derivative(x.clone());
assert_eq!(e_derivative, Expression::integer(0));
let i_derivative = Expression::constant(MathConstant::I).derivative(x.clone());
assert_eq!(i_derivative, Expression::integer(0));
}
#[test]
fn test_nested_sums() {
let x = symbol!(x);
let nested = Expression::add(vec![
Expression::symbol(x.clone()),
Expression::add(vec![
Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
Expression::integer(3),
]),
]);
let result = nested.derivative(x.clone()).simplify();
assert_eq!(result, Expression::integer(3));
}
#[test]
fn test_zero_and_negative_constants() {
let x = symbol!(x);
assert_eq!(
Expression::integer(0).derivative(x.clone()),
Expression::integer(0)
);
assert_eq!(
Expression::integer(-42).derivative(x.clone()),
Expression::integer(0)
);
assert_eq!(
Expression::float(std::f64::consts::PI).derivative(x.clone()),
Expression::integer(0)
);
}
}