use super::expr::Expr;
pub fn integrate(expr: &Expr, var: &str) -> Option<Expr> {
if !expr.contains_var(var) {
return Some(expr.clone().mul(Expr::var(var)));
}
match expr {
Expr::Var(name) if name == var => {
Some(Expr::Div(
Box::new(Expr::Pow(Box::new(Expr::var(var)), Box::new(Expr::c(2.0)))),
Box::new(Expr::c(2.0)),
))
}
Expr::Pow(base, exp) if matches!(**base, Expr::Var(ref n) if n == var) && !exp.contains_var(var) => {
if let Expr::Const(n) = **exp {
if (n + 1.0).abs() < 1e-10 {
Some(Expr::Ln(Box::new(Expr::Abs(Box::new(Expr::var(var))))))
} else {
let n1 = n + 1.0;
Some(Expr::Div(
Box::new(Expr::Pow(Box::new(Expr::var(var)), Box::new(Expr::c(n1)))),
Box::new(Expr::c(n1)),
))
}
} else { None }
}
Expr::Sin(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
Some(Expr::Neg(Box::new(Expr::Cos(a.clone()))))
}
Expr::Cos(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
Some(Expr::Sin(a.clone()))
}
Expr::Exp(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
Some(Expr::Exp(a.clone()))
}
Expr::Add(a, b) => {
let ia = integrate(a, var)?;
let ib = integrate(b, var)?;
Some(Expr::Add(Box::new(ia), Box::new(ib)))
}
Expr::Sub(a, b) => {
let ia = integrate(a, var)?;
let ib = integrate(b, var)?;
Some(Expr::Sub(Box::new(ia), Box::new(ib)))
}
Expr::Mul(a, b) if !a.contains_var(var) => {
let ib = integrate(b, var)?;
Some(Expr::Mul(a.clone(), Box::new(ib)))
}
Expr::Mul(a, b) if !b.contains_var(var) => {
let ia = integrate(a, var)?;
Some(Expr::Mul(Box::new(ia), b.clone()))
}
Expr::Neg(a) => {
let ia = integrate(a, var)?;
Some(Expr::Neg(Box::new(ia)))
}
Expr::Div(a, b) if matches!(**a, Expr::Const(v) if (v - 1.0).abs() < 1e-10) &&
matches!(**b, Expr::Var(ref n) if n == var) => {
Some(Expr::Ln(Box::new(Expr::Abs(Box::new(Expr::var(var))))))
}
_ => None,
}
}
pub fn numerical_integrate(
expr: &Expr, var: &str, a: f64, b: f64, n: usize,
) -> f64 {
let n = if n % 2 == 0 { n } else { n + 1 };
let h = (b - a) / n as f64;
let mut sum = 0.0;
let mut vars = std::collections::HashMap::new();
let mut f = |x: f64| -> f64 {
vars.insert(var.to_string(), x);
expr.eval(&vars)
};
sum += f(a) + f(b);
for i in 1..n {
let x = a + i as f64 * h;
sum += if i % 2 == 0 { 2.0 * f(x) } else { 4.0 * f(x) };
}
sum * h / 3.0
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn integrate_x() {
let result = integrate(&Expr::var("x"), "x").unwrap();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
assert!((result.eval(&vars) - 4.5).abs() < 0.01); }
#[test]
fn integrate_x_squared() {
let expr = Expr::var("x").pow(Expr::c(2.0));
let result = integrate(&expr, "x").unwrap();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
assert!((result.eval(&vars) - 9.0).abs() < 0.01); }
#[test]
fn integrate_sin() {
let result = integrate(&Expr::var("x").sin(), "x").unwrap();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 0.0);
assert!((result.eval(&vars) - (-1.0)).abs() < 0.01);
}
#[test]
fn numerical_integral_x_squared() {
let expr = Expr::var("x").pow(Expr::c(2.0));
let result = numerical_integrate(&expr, "x", 0.0, 3.0, 100);
assert!((result - 9.0).abs() < 0.01); }
#[test]
fn integrate_constant() {
let result = integrate(&Expr::c(5.0), "x").unwrap();
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
assert!((result.eval(&vars) - 15.0).abs() < 0.01); }
}