proof_engine/symbolic/
taylor.rs1use super::expr::Expr;
4use super::differentiate::diff;
5use super::simplify::simplify;
6
7pub fn taylor_expand(expr: &Expr, var: &str, center: f64, order: u32) -> Expr {
10 let mut result = Expr::zero();
11 let mut current = expr.clone();
12 let mut factorial = 1u64;
13
14 for k in 0..=order {
15 if k > 0 { factorial *= k as u64; }
16
17 let value_expr = current.substitute(var, &Expr::c(center));
19 let value_simplified = simplify(&value_expr);
20
21 let coeff = if let Expr::Const(v) = value_simplified {
23 v / factorial as f64
24 } else {
25 let mut vars = std::collections::HashMap::new();
27 vars.insert(var.to_string(), center);
28 value_simplified.eval(&vars) / factorial as f64
29 };
30
31 if coeff.abs() > 1e-15 {
32 let term = if k == 0 {
34 Expr::c(coeff)
35 } else {
36 let x_minus_a = if center.abs() < 1e-15 {
37 Expr::var(var)
38 } else {
39 Expr::var(var).sub(Expr::c(center))
40 };
41 let power = if k == 1 {
42 x_minus_a
43 } else {
44 x_minus_a.pow(Expr::c(k as f64))
45 };
46 Expr::c(coeff).mul(power)
47 };
48 result = result.add(term);
49 }
50
51 current = diff(¤t, var);
53 }
54
55 simplify(&result)
56}
57
58pub fn taylor_eval(
60 f: &dyn Fn(f64) -> f64,
61 center: f64,
62 x: f64,
63 order: u32,
64 h: f64,
65) -> f64 {
66 let mut result = 0.0;
67 let mut factorial = 1.0;
68 let dx = x - center;
69 let mut dx_power = 1.0;
70
71 for k in 0..=order {
72 if k > 0 {
73 factorial *= k as f64;
74 dx_power *= dx;
75 }
76
77 let deriv = numerical_derivative(f, center, k, h);
79 result += deriv / factorial * dx_power;
80 }
81
82 result
83}
84
85fn numerical_derivative(f: &dyn Fn(f64) -> f64, x: f64, order: u32, h: f64) -> f64 {
86 if order == 0 { return f(x); }
87 if order == 1 { return (f(x + h) - f(x - h)) / (2.0 * h); }
88 let h2 = h * 1.5;
90 (numerical_derivative(f, x + h2, order - 1, h) - numerical_derivative(f, x - h2, order - 1, h)) / (2.0 * h2)
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use std::collections::HashMap;
97
98 #[test]
99 fn taylor_exp_at_zero() {
100 let expr = Expr::var("x").exp();
102 let taylor = taylor_expand(&expr, "x", 0.0, 4);
103
104 let mut vars = HashMap::new();
105 vars.insert("x".to_string(), 0.5);
106 let approx = taylor.eval(&vars);
107 let exact = 0.5_f64.exp();
108 assert!((approx - exact).abs() < 0.01, "approx={approx}, exact={exact}");
109 }
110
111 #[test]
112 fn taylor_sin_at_zero() {
113 let expr = Expr::var("x").sin();
115 let taylor = taylor_expand(&expr, "x", 0.0, 5);
116
117 let mut vars = HashMap::new();
118 vars.insert("x".to_string(), 0.5);
119 let approx = taylor.eval(&vars);
120 let exact = 0.5_f64.sin();
121 assert!((approx - exact).abs() < 0.001, "approx={approx}, exact={exact}");
122 }
123
124 #[test]
125 fn taylor_eval_numerical() {
126 let result = taylor_eval(&|x| x.exp(), 0.0, 0.5, 6, 0.001);
127 let exact = 0.5_f64.exp();
128 assert!((result - exact).abs() < 0.001);
129 }
130}