Skip to main content

proof_engine/symbolic/
integrate.rs

1//! Symbolic integration — common patterns: power rule, trig, exponential.
2
3use super::expr::Expr;
4
5/// Attempt symbolic integration. Returns None if the integral can't be computed.
6pub fn integrate(expr: &Expr, var: &str) -> Option<Expr> {
7    if !expr.contains_var(var) {
8        // ∫c dx = c*x
9        return Some(expr.clone().mul(Expr::var(var)));
10    }
11
12    match expr {
13        // ∫x dx = x²/2
14        Expr::Var(name) if name == var => {
15            Some(Expr::Div(
16                Box::new(Expr::Pow(Box::new(Expr::var(var)), Box::new(Expr::c(2.0)))),
17                Box::new(Expr::c(2.0)),
18            ))
19        }
20
21        // ∫x^n dx = x^(n+1)/(n+1) for constant n ≠ -1
22        Expr::Pow(base, exp) if matches!(**base, Expr::Var(ref n) if n == var) && !exp.contains_var(var) => {
23            if let Expr::Const(n) = **exp {
24                if (n + 1.0).abs() < 1e-10 {
25                    // ∫x^(-1) dx = ln|x|
26                    Some(Expr::Ln(Box::new(Expr::Abs(Box::new(Expr::var(var))))))
27                } else {
28                    let n1 = n + 1.0;
29                    Some(Expr::Div(
30                        Box::new(Expr::Pow(Box::new(Expr::var(var)), Box::new(Expr::c(n1)))),
31                        Box::new(Expr::c(n1)),
32                    ))
33                }
34            } else { None }
35        }
36
37        // ∫sin(x) dx = -cos(x)
38        Expr::Sin(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
39            Some(Expr::Neg(Box::new(Expr::Cos(a.clone()))))
40        }
41
42        // ∫cos(x) dx = sin(x)
43        Expr::Cos(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
44            Some(Expr::Sin(a.clone()))
45        }
46
47        // ∫e^x dx = e^x
48        Expr::Exp(a) if matches!(**a, Expr::Var(ref n) if n == var) => {
49            Some(Expr::Exp(a.clone()))
50        }
51
52        // ∫(a + b) dx = ∫a dx + ∫b dx
53        Expr::Add(a, b) => {
54            let ia = integrate(a, var)?;
55            let ib = integrate(b, var)?;
56            Some(Expr::Add(Box::new(ia), Box::new(ib)))
57        }
58
59        // ∫(a - b) dx = ∫a dx - ∫b dx
60        Expr::Sub(a, b) => {
61            let ia = integrate(a, var)?;
62            let ib = integrate(b, var)?;
63            Some(Expr::Sub(Box::new(ia), Box::new(ib)))
64        }
65
66        // ∫c*f dx = c * ∫f dx (constant factor)
67        Expr::Mul(a, b) if !a.contains_var(var) => {
68            let ib = integrate(b, var)?;
69            Some(Expr::Mul(a.clone(), Box::new(ib)))
70        }
71        Expr::Mul(a, b) if !b.contains_var(var) => {
72            let ia = integrate(a, var)?;
73            Some(Expr::Mul(Box::new(ia), b.clone()))
74        }
75
76        // ∫-f dx = -∫f dx
77        Expr::Neg(a) => {
78            let ia = integrate(a, var)?;
79            Some(Expr::Neg(Box::new(ia)))
80        }
81
82        // ∫1/x dx = ln|x|
83        Expr::Div(a, b) if matches!(**a, Expr::Const(v) if (v - 1.0).abs() < 1e-10) &&
84            matches!(**b, Expr::Var(ref n) if n == var) => {
85            Some(Expr::Ln(Box::new(Expr::Abs(Box::new(Expr::var(var))))))
86        }
87
88        _ => None,
89    }
90}
91
92/// Numerical definite integration using Simpson's rule.
93pub fn numerical_integrate(
94    expr: &Expr, var: &str, a: f64, b: f64, n: usize,
95) -> f64 {
96    let n = if n % 2 == 0 { n } else { n + 1 };
97    let h = (b - a) / n as f64;
98    let mut sum = 0.0;
99    let mut vars = std::collections::HashMap::new();
100
101    let mut f = |x: f64| -> f64 {
102        vars.insert(var.to_string(), x);
103        expr.eval(&vars)
104    };
105
106    sum += f(a) + f(b);
107    for i in 1..n {
108        let x = a + i as f64 * h;
109        sum += if i % 2 == 0 { 2.0 * f(x) } else { 4.0 * f(x) };
110    }
111
112    sum * h / 3.0
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use std::collections::HashMap;
119
120    #[test]
121    fn integrate_x() {
122        let result = integrate(&Expr::var("x"), "x").unwrap();
123        let mut vars = HashMap::new();
124        vars.insert("x".to_string(), 3.0);
125        assert!((result.eval(&vars) - 4.5).abs() < 0.01); // x²/2 at x=3 = 4.5
126    }
127
128    #[test]
129    fn integrate_x_squared() {
130        let expr = Expr::var("x").pow(Expr::c(2.0));
131        let result = integrate(&expr, "x").unwrap();
132        let mut vars = HashMap::new();
133        vars.insert("x".to_string(), 3.0);
134        assert!((result.eval(&vars) - 9.0).abs() < 0.01); // x³/3 at x=3 = 9
135    }
136
137    #[test]
138    fn integrate_sin() {
139        let result = integrate(&Expr::var("x").sin(), "x").unwrap();
140        // -cos(0) = -1
141        let mut vars = HashMap::new();
142        vars.insert("x".to_string(), 0.0);
143        assert!((result.eval(&vars) - (-1.0)).abs() < 0.01);
144    }
145
146    #[test]
147    fn numerical_integral_x_squared() {
148        let expr = Expr::var("x").pow(Expr::c(2.0));
149        let result = numerical_integrate(&expr, "x", 0.0, 3.0, 100);
150        assert!((result - 9.0).abs() < 0.01); // ∫₀³ x² dx = 9
151    }
152
153    #[test]
154    fn integrate_constant() {
155        let result = integrate(&Expr::c(5.0), "x").unwrap();
156        let mut vars = HashMap::new();
157        vars.insert("x".to_string(), 3.0);
158        assert!((result.eval(&vars) - 15.0).abs() < 0.01); // 5x at x=3
159    }
160}