use super::Expr;
#[must_use]
pub fn symbolic_integrate(expr: &Expr, var: &str) -> Option<Expr> {
match expr {
Expr::Const(c) => Some(Expr::Mul(
Box::new(Expr::Const(*c)),
Box::new(Expr::Var(var.into())),
)),
Expr::Var(name) if name == var => Some(Expr::Mul(
Box::new(Expr::Const(0.5)),
Box::new(Expr::Pow(
Box::new(Expr::Var(var.into())),
Box::new(Expr::Const(2.0)),
)),
)),
Expr::Var(name) if name != var => Some(Expr::Mul(
Box::new(expr.clone()),
Box::new(Expr::Var(var.into())),
)),
Expr::Add(a, b) => {
let ia = symbolic_integrate(a, var)?;
let ib = symbolic_integrate(b, var)?;
Some(Expr::Add(Box::new(ia), Box::new(ib)))
}
Expr::Neg(a) => {
let ia = symbolic_integrate(a, var)?;
Some(Expr::Neg(Box::new(ia)))
}
Expr::Mul(a, b) => {
if !contains_var(a, var) {
let ib = symbolic_integrate(b, var)?;
Some(Expr::Mul(a.clone(), Box::new(ib)))
} else if !contains_var(b, var) {
let ia = symbolic_integrate(a, var)?;
Some(Expr::Mul(Box::new(ia), b.clone()))
} else {
None }
}
Expr::Pow(base, exp) => {
if is_var(base, var) && !contains_var(exp, var) {
if let Expr::Const(n) = exp.as_ref() {
if (*n + 1.0).abs() < 1e-15 {
Some(Expr::Ln(Box::new(Expr::Var(var.into()))))
} else {
let np1 = n + 1.0;
Some(Expr::Mul(
Box::new(Expr::Const(1.0 / np1)),
Box::new(Expr::Pow(
Box::new(Expr::Var(var.into())),
Box::new(Expr::Const(np1)),
)),
))
}
} else {
None
}
} else {
None
}
}
Expr::Sin(a) if is_var(a, var) => Some(Expr::Neg(Box::new(Expr::Cos(Box::new(
Expr::Var(var.into()),
))))),
Expr::Cos(a) if is_var(a, var) => Some(Expr::Sin(Box::new(Expr::Var(var.into())))),
Expr::Exp(a) if is_var(a, var) => Some(Expr::Exp(Box::new(Expr::Var(var.into())))),
Expr::Ln(_) => None,
_ => None,
}
}
fn is_var(expr: &Expr, var: &str) -> bool {
matches!(expr, Expr::Var(name) if name == var)
}
fn contains_var(expr: &Expr, var: &str) -> bool {
match expr {
Expr::Const(_) => false,
Expr::Var(name) => name == var,
Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Pow(a, b) => {
contains_var(a, var) || contains_var(b, var)
}
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Exp(a) | Expr::Ln(a) => {
contains_var(a, var)
}
#[allow(unreachable_patterns)]
_ => true, }
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn var(name: &str) -> Expr {
Expr::Var(name.into())
}
fn c(v: f64) -> Expr {
Expr::Const(v)
}
fn eval(e: &Expr, x: f64) -> f64 {
let mut vars = HashMap::new();
vars.insert("x".to_string(), x);
e.evaluate(&vars).unwrap()
}
#[test]
fn integrate_constant() {
let i = symbolic_integrate(&c(3.0), "x").unwrap();
let v = eval(&i.simplify(), 2.0);
assert!((v - 6.0).abs() < 1e-10);
}
#[test]
fn integrate_x() {
let i = symbolic_integrate(&var("x"), "x").unwrap();
let v = eval(&i.simplify(), 4.0);
assert!((v - 8.0).abs() < 1e-10); }
#[test]
fn integrate_x_squared() {
let e = Expr::Pow(Box::new(var("x")), Box::new(c(2.0)));
let i = symbolic_integrate(&e, "x").unwrap();
let v = eval(&i.simplify(), 3.0);
assert!((v - 9.0).abs() < 1e-10); }
#[test]
fn integrate_reciprocal() {
let e = Expr::Pow(Box::new(var("x")), Box::new(c(-1.0)));
let i = symbolic_integrate(&e, "x").unwrap();
let v = eval(&i, std::f64::consts::E);
assert!((v - 1.0).abs() < 1e-10);
}
#[test]
fn integrate_sin() {
let i = symbolic_integrate(&Expr::Sin(Box::new(var("x"))), "x").unwrap();
let v = eval(&i.simplify(), 0.0);
assert!((v - (-1.0)).abs() < 1e-10); }
#[test]
fn integrate_cos() {
let i = symbolic_integrate(&Expr::Cos(Box::new(var("x"))), "x").unwrap();
let v = eval(&i.simplify(), std::f64::consts::FRAC_PI_2);
assert!((v - 1.0).abs() < 1e-10); }
#[test]
fn integrate_exp() {
let i = symbolic_integrate(&Expr::Exp(Box::new(var("x"))), "x").unwrap();
let v = eval(&i, 0.0);
assert!((v - 1.0).abs() < 1e-10); }
#[test]
fn integrate_sum() {
let e = Expr::Add(Box::new(var("x")), Box::new(c(1.0)));
let i = symbolic_integrate(&e, "x").unwrap();
let v = eval(&i.simplify(), 2.0);
assert!((v - 4.0).abs() < 1e-10); }
#[test]
fn integrate_constant_multiple() {
let e = Expr::Mul(
Box::new(c(3.0)),
Box::new(Expr::Pow(Box::new(var("x")), Box::new(c(2.0)))),
);
let i = symbolic_integrate(&e, "x").unwrap();
let v = eval(&i.simplify(), 2.0);
assert!((v - 8.0).abs() < 1e-10); }
#[test]
fn integrate_negation() {
let e = Expr::Neg(Box::new(var("x")));
let i = symbolic_integrate(&e, "x").unwrap();
let v = eval(&i.simplify(), 4.0);
assert!((v - (-8.0)).abs() < 1e-10);
}
#[test]
fn integrate_unsupported_returns_none() {
let e = Expr::Mul(Box::new(var("x")), Box::new(var("x")));
assert!(symbolic_integrate(&e, "x").is_none());
}
#[test]
fn integrate_other_var_as_constant() {
let i = symbolic_integrate(&var("y"), "x").unwrap();
let mut vars = HashMap::new();
vars.insert("x".into(), 3.0);
vars.insert("y".into(), 5.0);
let v = i.simplify().evaluate(&vars).unwrap();
assert!((v - 15.0).abs() < 1e-10);
}
}