use crate::diff::diff;
use crate::error::{Result, SymError};
use crate::expr::{Expr, constant, var};
use crate::simplify::simplify;
pub fn taylor(expr: &Expr, var_name: &str, center: f64, n: usize) -> Result<Expr> {
if n > 20 {
return Err(SymError::InvalidExpr {
reason: "Taylor expansion order must be <= 20",
});
}
let a = constant(center);
let x_minus_a = if center == 0.0 {
var(var_name)
} else {
Expr::Add(
Box::new(var(var_name)),
Box::new(Expr::Neg(Box::new(a.clone()))),
)
};
let empty = std::collections::HashMap::new();
let mut current = expr.clone();
let mut factorial = 1.0_f64;
let mut terms: Vec<Expr> = Vec::with_capacity(n + 1);
for k in 0..=n {
if k > 0 {
factorial *= k as f64;
}
let at_center = simplify(¤t.substitute(var_name, &a));
let coeff_val = at_center
.eval(&empty)
.map_err(|_| SymError::UnsupportedOperation {
reason: "could not evaluate derivative at expansion point",
})?;
if coeff_val.abs() > f64::EPSILON {
let coeff = constant(coeff_val / factorial);
let term = if k == 0 {
coeff
} else if k == 1 {
Expr::Mul(Box::new(coeff), Box::new(x_minus_a.clone()))
} else {
Expr::Mul(
Box::new(coeff),
Box::new(Expr::Pow(
Box::new(x_minus_a.clone()),
Box::new(constant(k as f64)),
)),
)
};
terms.push(term);
}
if k < n {
current = diff(¤t, var_name);
}
}
if terms.is_empty() {
return Ok(constant(0.0));
}
let result = terms
.into_iter()
.reduce(|acc, t| Expr::Add(Box::new(acc), Box::new(t)))
.unwrap_or_else(|| constant(0.0));
Ok(simplify(&result))
}
pub fn maclaurin(expr: &Expr, var_name: &str, n: usize) -> Result<Expr> {
taylor(expr, var_name, 0.0, n)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{cos, exp, sin, var};
use std::collections::HashMap;
fn eval_at(e: &Expr, x_val: f64) -> f64 {
let mut vars = HashMap::new();
vars.insert("x".into(), x_val);
e.eval(&vars).unwrap()
}
#[test]
fn maclaurin_exp_order_4() {
let e = exp(var("x"));
let t = maclaurin(&e, "x", 4).unwrap();
let approx = eval_at(&t, 0.5);
let exact = 0.5_f64.exp();
assert!(
(approx - exact).abs() < 0.01,
"approx={approx}, exact={exact}"
);
}
#[test]
fn maclaurin_sin_order_5() {
let e = sin(var("x"));
let t = maclaurin(&e, "x", 5).unwrap();
let approx = eval_at(&t, 0.3);
let exact = 0.3_f64.sin();
assert!(
(approx - exact).abs() < 1e-6,
"approx={approx}, exact={exact}"
);
}
#[test]
fn maclaurin_cos_order_4() {
let e = cos(var("x"));
let t = maclaurin(&e, "x", 4).unwrap();
let approx = eval_at(&t, 0.5);
let exact = 0.5_f64.cos();
assert!(
(approx - exact).abs() < 1e-4,
"approx={approx}, exact={exact}"
);
}
#[test]
fn taylor_around_nonzero() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(2.0)));
let t = taylor(&e, "x", 1.0, 2).unwrap();
assert!((eval_at(&t, 2.0) - 4.0).abs() < 1e-10);
assert!((eval_at(&t, 3.0) - 9.0).abs() < 1e-10);
}
#[test]
fn taylor_constant() {
let e = constant(7.0);
let t = taylor(&e, "x", 0.0, 3).unwrap();
assert!((eval_at(&t, 100.0) - 7.0).abs() < 1e-10);
}
#[test]
fn taylor_high_order_rejected() {
let e = var("x");
assert!(taylor(&e, "x", 0.0, 25).is_err());
}
#[test]
fn maclaurin_polynomial_exact() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(3.0)));
let t = maclaurin(&e, "x", 3).unwrap();
assert!((eval_at(&t, 2.0) - 8.0).abs() < 1e-10);
assert!((eval_at(&t, -1.0) - (-1.0)).abs() < 1e-10);
}
#[test]
fn maclaurin_1_over_1_minus_x() {
let e = Expr::Pow(
Box::new(constant(1.0) + Expr::Neg(Box::new(var("x")))),
Box::new(constant(-1.0)),
);
let t = maclaurin(&e, "x", 4).unwrap();
let approx = eval_at(&t, 0.5);
assert!((approx - 2.0).abs() < 0.1, "approx={approx}");
}
}