proof_engine/symbolic/
integrate.rs1use super::expr::Expr;
4
5pub fn integrate(expr: &Expr, var: &str) -> Option<Expr> {
7 if !expr.contains_var(var) {
8 return Some(expr.clone().mul(Expr::var(var)));
10 }
11
12 match expr {
13 Expr::Var(name) if name == var => {
15 Some(Expr::Div(
16 Box::new(Expr::Pow(Box::new(Expr::var(var)), Box::new(Expr::c(2.0)))),
17 Box::new(Expr::c(2.0)),
18 ))
19 }
20
21 Expr::Pow(base, exp) if matches!(**base, Expr::Var(ref n) if n == var) && !exp.contains_var(var) => {
23 if let Expr::Const(n) = **exp {
24 if (n + 1.0).abs() < 1e-10 {
25 Some(Expr::Ln(Box::new(Expr::Abs(Box::new(Expr::var(var))))))
27 } else {
28 let n1 = n + 1.0;
29 Some(Expr::Div(
30 Box::new(Expr::Pow(Box::new(Expr::var(var)), Box::new(Expr::c(n1)))),
31 Box::new(Expr::c(n1)),
32 ))
33 }
34 } else { None }
35 }
36
37 Expr::Sin(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
39 Some(Expr::Neg(Box::new(Expr::Cos(a.clone()))))
40 }
41
42 Expr::Cos(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
44 Some(Expr::Sin(a.clone()))
45 }
46
47 Expr::Exp(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
49 Some(Expr::Exp(a.clone()))
50 }
51
52 Expr::Add(a, b) => {
54 let ia = integrate(a, var)?;
55 let ib = integrate(b, var)?;
56 Some(Expr::Add(Box::new(ia), Box::new(ib)))
57 }
58
59 Expr::Sub(a, b) => {
61 let ia = integrate(a, var)?;
62 let ib = integrate(b, var)?;
63 Some(Expr::Sub(Box::new(ia), Box::new(ib)))
64 }
65
66 Expr::Mul(a, b) if !a.contains_var(var) => {
68 let ib = integrate(b, var)?;
69 Some(Expr::Mul(a.clone(), Box::new(ib)))
70 }
71 Expr::Mul(a, b) if !b.contains_var(var) => {
72 let ia = integrate(a, var)?;
73 Some(Expr::Mul(Box::new(ia), b.clone()))
74 }
75
76 Expr::Neg(a) => {
78 let ia = integrate(a, var)?;
79 Some(Expr::Neg(Box::new(ia)))
80 }
81
82 Expr::Div(a, b) if matches!(**a, Expr::Const(v) if (v - 1.0).abs() < 1e-10) &&
84 matches!(**b, Expr::Var(ref n) if n == var) => {
85 Some(Expr::Ln(Box::new(Expr::Abs(Box::new(Expr::var(var))))))
86 }
87
88 _ => None,
89 }
90}
91
92pub fn numerical_integrate(
94 expr: &Expr, var: &str, a: f64, b: f64, n: usize,
95) -> f64 {
96 let n = if n % 2 == 0 { n } else { n + 1 };
97 let h = (b - a) / n as f64;
98 let mut sum = 0.0;
99 let mut vars = std::collections::HashMap::new();
100
101 let mut f = |x: f64| -> f64 {
102 vars.insert(var.to_string(), x);
103 expr.eval(&vars)
104 };
105
106 sum += f(a) + f(b);
107 for i in 1..n {
108 let x = a + i as f64 * h;
109 sum += if i % 2 == 0 { 2.0 * f(x) } else { 4.0 * f(x) };
110 }
111
112 sum * h / 3.0
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use std::collections::HashMap;
119
120 #[test]
121 fn integrate_x() {
122 let result = integrate(&Expr::var("x"), "x").unwrap();
123 let mut vars = HashMap::new();
124 vars.insert("x".to_string(), 3.0);
125 assert!((result.eval(&vars) - 4.5).abs() < 0.01); }
127
128 #[test]
129 fn integrate_x_squared() {
130 let expr = Expr::var("x").pow(Expr::c(2.0));
131 let result = integrate(&expr, "x").unwrap();
132 let mut vars = HashMap::new();
133 vars.insert("x".to_string(), 3.0);
134 assert!((result.eval(&vars) - 9.0).abs() < 0.01); }
136
137 #[test]
138 fn integrate_sin() {
139 let result = integrate(&Expr::var("x").sin(), "x").unwrap();
140 let mut vars = HashMap::new();
142 vars.insert("x".to_string(), 0.0);
143 assert!((result.eval(&vars) - (-1.0)).abs() < 0.01);
144 }
145
146 #[test]
147 fn numerical_integral_x_squared() {
148 let expr = Expr::var("x").pow(Expr::c(2.0));
149 let result = numerical_integrate(&expr, "x", 0.0, 3.0, 100);
150 assert!((result - 9.0).abs() < 0.01); }
152
153 #[test]
154 fn integrate_constant() {
155 let result = integrate(&Expr::c(5.0), "x").unwrap();
156 let mut vars = HashMap::new();
157 vars.insert("x".to_string(), 3.0);
158 assert!((result.eval(&vars) - 15.0).abs() < 0.01); }
160}