use std::collections::HashMap;
use crate::diff::diff;
use crate::error::{Result, SymError};
use crate::expr::{Expr, MathFn, constant, cos, exp, ln, sin, var};
use crate::simplify::simplify;
pub fn integrate(expr: &Expr, var_name: &str) -> Result<Expr> {
let expr = simplify(expr);
let result = integrate_inner(&expr, var_name)?;
Ok(simplify(&result))
}
pub fn definite_integral(expr: &Expr, var_name: &str, a: f64, b: f64) -> Result<f64> {
let antideriv = integrate(expr, var_name)?;
let fa = antideriv.substitute(var_name, &constant(a));
let fb = antideriv.substitute(var_name, &constant(b));
let fa = simplify(&fa);
let fb = simplify(&fb);
let empty = HashMap::new();
let va = fa
.eval(&empty)
.map_err(|_| SymError::UnsupportedOperation {
reason: "could not evaluate antiderivative at lower bound",
})?;
let vb = fb
.eval(&empty)
.map_err(|_| SymError::UnsupportedOperation {
reason: "could not evaluate antiderivative at upper bound",
})?;
Ok(vb - va)
}
fn integrate_inner(expr: &Expr, v: &str) -> Result<Expr> {
if !contains_var(expr, v) {
return Ok(Expr::Mul(Box::new(expr.clone()), Box::new(var(v))));
}
match expr {
Expr::Var(name) if name == v => {
Ok(Expr::Mul(
Box::new(constant(0.5)),
Box::new(Expr::Pow(Box::new(var(v)), Box::new(constant(2.0)))),
))
}
Expr::Add(a, b) => {
let ia = integrate_inner(a, v)?;
let ib = integrate_inner(b, v)?;
Ok(Expr::Add(Box::new(ia), Box::new(ib)))
}
Expr::Neg(inner) => {
let ii = integrate_inner(inner, v)?;
Ok(Expr::Neg(Box::new(ii)))
}
Expr::Mul(a, b) => integrate_mul(a, b, v),
Expr::Pow(base, exp) => integrate_pow(base, exp, v),
Expr::Fn(func, arg) => integrate_fn(*func, arg, v),
_ => Err(SymError::UnsupportedOperation {
reason: "cannot integrate this expression",
}),
}
}
fn integrate_mul(a: &Expr, b: &Expr, v: &str) -> Result<Expr> {
let a_has = contains_var(a, v);
let b_has = contains_var(b, v);
if !a_has {
let ib = integrate_inner(b, v)?;
return Ok(Expr::Mul(Box::new(a.clone()), Box::new(ib)));
}
if !b_has {
let ia = integrate_inner(a, v)?;
return Ok(Expr::Mul(Box::new(b.clone()), Box::new(ia)));
}
if let Some(result) = try_by_parts(a, b, v) {
return Ok(result);
}
if let Some(result) = try_by_parts(b, a, v) {
return Ok(result);
}
Err(SymError::UnsupportedOperation {
reason: "cannot integrate this product",
})
}
#[allow(clippy::collapsible_if)]
fn integrate_pow(base: &Expr, exp: &Expr, v: &str) -> Result<Expr> {
let base_has = contains_var(base, v);
let exp_has = contains_var(exp, v);
if base_has && !exp_has {
if let Expr::Var(name) = base {
if name == v {
if let Some(n) = exp.as_const() {
if (n - (-1.0)).abs() < f64::EPSILON {
return Ok(ln(var(v)));
}
let n1 = n + 1.0;
return Ok(Expr::Mul(
Box::new(constant(1.0 / n1)),
Box::new(Expr::Pow(Box::new(var(v)), Box::new(constant(n1)))),
));
}
}
}
if let Some((a_coeff, _b_coeff)) = as_linear(base, v) {
if let Some(n) = exp.as_const() {
if (n - (-1.0)).abs() < f64::EPSILON {
return Ok(Expr::Mul(
Box::new(constant(1.0 / a_coeff)),
Box::new(ln(base.clone())),
));
}
let n1 = n + 1.0;
return Ok(Expr::Mul(
Box::new(constant(1.0 / (a_coeff * n1))),
Box::new(Expr::Pow(Box::new(base.clone()), Box::new(constant(n1)))),
));
}
}
}
if !base_has && exp_has {
if let Expr::Var(name) = exp {
if name == v {
if let Some(c) = base.as_const() {
if c > 0.0 && (c - 1.0).abs() > f64::EPSILON {
return Ok(Expr::Mul(
Box::new(constant(1.0 / c.ln())),
Box::new(Expr::Pow(Box::new(base.clone()), Box::new(var(v)))),
));
}
}
}
}
}
Err(SymError::UnsupportedOperation {
reason: "cannot integrate this power expression",
})
}
#[allow(clippy::collapsible_if)]
fn integrate_fn(func: MathFn, arg: &Expr, v: &str) -> Result<Expr> {
if let Expr::Var(name) = arg {
if name == v {
return match func {
MathFn::Sin => Ok(Expr::Neg(Box::new(cos(var(v))))),
MathFn::Cos => Ok(sin(var(v))),
MathFn::Exp => Ok(exp(var(v))),
MathFn::Ln => Ok(Expr::Add(
Box::new(Expr::Mul(Box::new(var(v)), Box::new(ln(var(v))))),
Box::new(Expr::Neg(Box::new(var(v)))),
)),
_ => Err(SymError::UnsupportedOperation {
reason: "cannot integrate this function",
}),
};
}
}
if let Some((a_coeff, _b_coeff)) = as_linear(arg, v) {
let inv_a = constant(1.0 / a_coeff);
let antideriv = match func {
MathFn::Sin => Ok(Expr::Neg(Box::new(cos(arg.clone())))),
MathFn::Cos => Ok(sin(arg.clone())),
MathFn::Exp => Ok(exp(arg.clone())),
_ => Err(SymError::UnsupportedOperation {
reason: "cannot integrate this function with linear argument",
}),
}?;
return Ok(Expr::Mul(Box::new(inv_a), Box::new(antideriv)));
}
Err(SymError::UnsupportedOperation {
reason: "cannot integrate this function expression",
})
}
fn try_by_parts(u_candidate: &Expr, dv_candidate: &Expr, v: &str) -> Option<Expr> {
if !is_polynomial_like(u_candidate, v) {
return None;
}
let big_v = integrate_inner(dv_candidate, v).ok()?;
let big_v = simplify(&big_v);
let du = diff(u_candidate, v);
let v_du = Expr::Mul(Box::new(big_v.clone()), Box::new(du));
let v_du = simplify(&v_du);
let int_v_du = integrate_inner(&v_du, v).ok()?;
Some(Expr::Add(
Box::new(Expr::Mul(Box::new(u_candidate.clone()), Box::new(big_v))),
Box::new(Expr::Neg(Box::new(int_v_du))),
))
}
#[allow(clippy::collapsible_if)]
fn is_polynomial_like(expr: &Expr, v: &str) -> bool {
match expr {
Expr::Const(_) | Expr::Var(_) => true,
Expr::Add(a, b) | Expr::Mul(a, b) => is_polynomial_like(a, v) && is_polynomial_like(b, v),
Expr::Neg(inner) => is_polynomial_like(inner, v),
Expr::Pow(base, exp) => {
if let Expr::Var(name) = base.as_ref() {
if name == v {
if let Some(n) = exp.as_const() {
return n >= 0.0 && (n - n.floor()).abs() < f64::EPSILON;
}
}
}
false
}
Expr::Fn(_, _) => false,
}
}
fn contains_var(expr: &Expr, v: &str) -> bool {
match expr {
Expr::Const(_) => false,
Expr::Var(name) => name == v,
Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Pow(a, b) => {
contains_var(a, v) || contains_var(b, v)
}
Expr::Neg(inner) | Expr::Fn(_, inner) => contains_var(inner, v),
}
}
fn as_linear(expr: &Expr, v: &str) -> Option<(f64, f64)> {
let d = diff(expr, v);
let d = simplify(&d);
let a = d.as_const()?;
let at_zero = simplify(&expr.substitute(v, &constant(0.0)));
let b = at_zero.eval(&HashMap::new()).ok()?;
let d2 = diff(&d, v);
let d2 = simplify(&d2);
if !d2.is_zero() {
return None;
}
Some((a, b))
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn eval_at(e: &Expr, x_val: f64) -> f64 {
let mut vars = HashMap::new();
vars.insert("x".into(), x_val);
e.eval(&vars).unwrap()
}
#[test]
fn integrate_constant() {
let e = constant(5.0);
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, 3.0) - 15.0).abs() < 1e-10);
}
#[test]
fn integrate_x() {
let result = integrate(&var("x"), "x").unwrap();
assert!((eval_at(&result, 4.0) - 8.0).abs() < 1e-10);
}
#[test]
fn integrate_x_squared() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(2.0)));
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, 3.0) - 9.0).abs() < 1e-10);
}
#[test]
fn integrate_x_inverse() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(-1.0)));
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, std::f64::consts::E) - 1.0).abs() < 1e-10);
}
#[test]
fn integrate_sin() {
let e = sin(var("x"));
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, 0.0) - (-1.0)).abs() < 1e-10);
}
#[test]
fn integrate_cos() {
let e = cos(var("x"));
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, std::f64::consts::FRAC_PI_2) - 1.0).abs() < 1e-10);
}
#[test]
fn integrate_exp() {
let e = exp(var("x"));
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, 1.0) - std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn integrate_ln() {
let e = ln(var("x"));
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, std::f64::consts::E)).abs() < 1e-10);
}
#[test]
fn integrate_constant_times_x() {
let e = constant(3.0) * var("x");
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, 2.0) - 6.0).abs() < 1e-10);
}
#[test]
fn integrate_sum() {
let e = var("x") + constant(1.0);
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, 2.0) - 4.0).abs() < 1e-10);
}
#[test]
fn integrate_polynomial() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(2.0)))
+ constant(2.0) * var("x")
+ constant(1.0);
let result = integrate(&e, "x").unwrap();
assert!((eval_at(&result, 3.0) - 21.0).abs() < 1e-10);
}
#[test]
fn integrate_linear_sub_sin() {
let e = sin(constant(2.0) * var("x"));
let result = integrate(&e, "x").unwrap();
let d = diff(&result, "x");
let d = simplify(&d);
assert!((eval_at(&d, 1.0) - (2.0_f64).sin()).abs() < 1e-10);
}
#[test]
fn integrate_linear_sub_exp() {
let e = exp(constant(3.0) * var("x"));
let result = integrate(&e, "x").unwrap();
let d = diff(&result, "x");
let d = simplify(&d);
assert!((eval_at(&d, 1.0) - (3.0_f64).exp()).abs() < 1e-8);
}
#[test]
fn definite_integral_x_squared() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(2.0)));
let result = definite_integral(&e, "x", 0.0, 2.0).unwrap();
assert!((result - 8.0 / 3.0).abs() < 1e-10);
}
#[test]
fn definite_integral_sin() {
let e = sin(var("x"));
let result = definite_integral(&e, "x", 0.0, std::f64::consts::PI).unwrap();
assert!((result - 2.0).abs() < 1e-10);
}
#[test]
fn integrate_by_parts_x_exp() {
let e = var("x") * exp(var("x"));
let result = integrate(&e, "x").unwrap();
let d = diff(&result, "x");
let d = simplify(&d);
assert!((eval_at(&d, 1.0) - std::f64::consts::E).abs() < 1e-8);
}
#[test]
fn integrate_unsupported_returns_error() {
let e = Expr::Fn(
MathFn::Tan,
Box::new(Expr::Pow(Box::new(var("x")), Box::new(constant(2.0)))),
);
assert!(integrate(&e, "x").is_err());
}
#[test]
fn integrate_wrt_other_var() {
let result = integrate(&var("x"), "y").unwrap();
let mut vars = HashMap::new();
vars.insert("x".into(), 3.0);
vars.insert("y".into(), 4.0);
assert!((result.eval(&vars).unwrap() - 12.0).abs() < 1e-10);
}
}