use super::helpers::is_one;
use crate::core::{Expression, Symbol};
#[derive(Debug, Clone, PartialEq)]
pub enum DifferentialExtension {
Rational,
Exponential {
argument: Box<Expression>,
derivative: Box<Expression>,
},
Logarithmic {
argument: Box<Expression>,
derivative: Box<Expression>,
},
}
pub fn build_extension_tower(expr: &Expression, var: Symbol) -> Option<Vec<DifferentialExtension>> {
let mut extensions = vec![DifferentialExtension::Rational];
if let Some(exp_ext) = detect_exponential_extension(expr, var.clone()) {
extensions.push(exp_ext);
}
if let Some(log_ext) = detect_logarithmic_extension(expr, var) {
extensions.push(log_ext);
}
Some(extensions)
}
fn detect_exponential_extension(expr: &Expression, var: Symbol) -> Option<DifferentialExtension> {
match expr {
Expression::Function { name, args } if name.as_ref() == "exp" && args.len() == 1 => {
let arg = &args[0];
if arg.contains_variable(&var) {
Some(DifferentialExtension::Exponential {
argument: Box::new(arg.clone()),
derivative: Box::new(compute_exponential_derivative(arg, var)),
})
} else {
None
}
}
Expression::Mul(factors) => {
for factor in &**factors {
if let Some(ext) = detect_exponential_extension(factor, var.clone()) {
return Some(ext);
}
}
None
}
_ => None,
}
}
fn detect_logarithmic_extension(expr: &Expression, var: Symbol) -> Option<DifferentialExtension> {
use super::helpers::extract_division;
match expr {
Expression::Function { name, args }
if (name.as_ref() == "ln" || name.as_ref() == "log") && args.len() == 1 =>
{
let arg = &args[0];
if arg.contains_variable(&var) {
Some(DifferentialExtension::Logarithmic {
argument: Box::new(arg.clone()),
derivative: Box::new(compute_logarithmic_derivative(arg, var)),
})
} else {
None
}
}
Expression::Mul(_) => {
if let Some((num, den)) = extract_division(expr) {
if is_one(&num) && den.contains_variable(&var) {
return Some(DifferentialExtension::Logarithmic {
argument: Box::new(den.clone()),
derivative: Box::new(Expression::div(Expression::integer(1), den)),
});
}
}
None
}
Expression::Pow(_, _) => {
if let Some((num, den)) = extract_division(expr) {
if is_one(&num) && den.contains_variable(&var) {
return Some(DifferentialExtension::Logarithmic {
argument: Box::new(den.clone()),
derivative: Box::new(Expression::div(Expression::integer(1), den)),
});
}
}
None
}
_ => None,
}
}
fn compute_exponential_derivative(arg: &Expression, var: Symbol) -> Expression {
let arg_derivative = derivative_of(arg, var);
Expression::mul(vec![
arg_derivative,
Expression::function("exp", vec![arg.clone()]),
])
}
fn compute_logarithmic_derivative(arg: &Expression, var: Symbol) -> Expression {
let arg_derivative = derivative_of(arg, var);
Expression::div(arg_derivative, arg.clone())
}
fn derivative_of(expr: &Expression, var: Symbol) -> Expression {
match expr {
Expression::Symbol(s) if *s == var => Expression::integer(1),
Expression::Number(_) | Expression::Constant(_) => Expression::integer(0),
Expression::Symbol(_) => Expression::integer(0),
Expression::Mul(factors) => {
if factors.len() == 2 {
let f = &factors[0];
let g = &factors[1];
let f_prime = derivative_of(f, var.clone());
let g_prime = derivative_of(g, var);
Expression::add(vec![
Expression::mul(vec![f_prime, g.clone()]),
Expression::mul(vec![f.clone(), g_prime]),
])
} else {
Expression::integer(0)
}
}
Expression::Add(terms) => {
Expression::add(
terms
.iter()
.map(|t| derivative_of(t, var.clone()))
.collect(),
)
}
Expression::Pow(base, exp) => {
if !exp.contains_variable(&var) {
let base_derivative = derivative_of(base, var);
Expression::mul(vec![
(**exp).clone(),
Expression::pow(
(**base).clone(),
Expression::add(vec![(**exp).clone(), Expression::integer(-1)]),
),
base_derivative,
])
} else {
Expression::integer(0)
}
}
_ => Expression::integer(0),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol;
#[test]
fn test_detect_exponential_simple() {
let x = symbol!(x);
let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
let ext = detect_exponential_extension(&expr, x);
assert!(ext.is_some());
assert!(matches!(
ext.unwrap(),
DifferentialExtension::Exponential { .. }
));
}
#[test]
fn test_detect_logarithmic_simple() {
let x = symbol!(x);
let expr = Expression::function("ln", vec![Expression::symbol(x.clone())]);
let ext = detect_logarithmic_extension(&expr, x);
assert!(ext.is_some());
assert!(matches!(
ext.unwrap(),
DifferentialExtension::Logarithmic { .. }
));
}
#[test]
fn test_detect_logarithmic_derivative() {
let x = symbol!(x);
let expr = Expression::div(Expression::integer(1), Expression::symbol(x.clone()));
let ext = detect_logarithmic_extension(&expr, x);
assert!(ext.is_some());
assert!(matches!(
ext.unwrap(),
DifferentialExtension::Logarithmic { .. }
));
}
#[test]
fn test_build_tower_exponential() {
let x = symbol!(x);
let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
let tower = build_extension_tower(&expr, x);
assert!(tower.is_some());
let extensions = tower.unwrap();
assert!(extensions.len() >= 2);
}
}