Skip to main content

proof_engine/symbolic/
solve.rs

1//! Equation solving — linear, quadratic, systems of equations.
2
3use super::expr::Expr;
4
5/// Solutions to an equation.
6#[derive(Debug, Clone)]
7pub enum Solutions {
8    None,
9    Single(f64),
10    Two(f64, f64),
11    Many(Vec<f64>),
12    Infinite,
13}
14
15/// Solve a linear equation ax + b = 0.
16pub fn solve_linear(a: f64, b: f64) -> Solutions {
17    if a.abs() < 1e-15 {
18        if b.abs() < 1e-15 { Solutions::Infinite } else { Solutions::None }
19    } else {
20        Solutions::Single(-b / a)
21    }
22}
23
24/// Solve a quadratic equation ax² + bx + c = 0.
25pub fn solve_quadratic(a: f64, b: f64, c: f64) -> Solutions {
26    if a.abs() < 1e-15 { return solve_linear(b, c); }
27    let disc = b * b - 4.0 * a * c;
28    if disc < -1e-15 {
29        Solutions::None
30    } else if disc.abs() < 1e-15 {
31        Solutions::Single(-b / (2.0 * a))
32    } else {
33        let sqrt_disc = disc.sqrt();
34        Solutions::Two(
35            (-b + sqrt_disc) / (2.0 * a),
36            (-b - sqrt_disc) / (2.0 * a),
37        )
38    }
39}
40
41/// Solve a cubic equation ax³ + bx² + cx + d = 0 using Cardano's method.
42pub fn solve_cubic(a: f64, b: f64, c: f64, d: f64) -> Solutions {
43    if a.abs() < 1e-15 { return solve_quadratic(b, c, d); }
44
45    // Normalize: x³ + px + q = 0 via substitution x = t - b/(3a)
46    let p = (3.0 * a * c - b * b) / (3.0 * a * a);
47    let q = (2.0 * b * b * b - 9.0 * a * b * c + 27.0 * a * a * d) / (27.0 * a * a * a);
48    let disc = q * q / 4.0 + p * p * p / 27.0;
49    let offset = -b / (3.0 * a);
50
51    if disc > 1e-15 {
52        let u = (-q / 2.0 + disc.sqrt()).cbrt();
53        let v = (-q / 2.0 - disc.sqrt()).cbrt();
54        Solutions::Single(u + v + offset)
55    } else if disc.abs() < 1e-15 {
56        if q.abs() < 1e-15 {
57            Solutions::Single(offset)
58        } else {
59            let u = (-q / 2.0).cbrt();
60            Solutions::Two(2.0 * u + offset, -u + offset)
61        }
62    } else {
63        // Three real roots (casus irreducibilis)
64        let r = (-p * p * p / 27.0).sqrt();
65        let phi = (-q / (2.0 * r)).acos();
66        let cube_r = r.cbrt();
67        Solutions::Many(vec![
68            2.0 * cube_r * (phi / 3.0).cos() + offset,
69            2.0 * cube_r * ((phi + std::f64::consts::TAU) / 3.0).cos() + offset,
70            2.0 * cube_r * ((phi + 2.0 * std::f64::consts::TAU) / 3.0).cos() + offset,
71        ])
72    }
73}
74
75/// Solve a system of 2 linear equations via Cramer's rule:
76/// a1*x + b1*y = c1
77/// a2*x + b2*y = c2
78pub fn solve_system_2x2(a1: f64, b1: f64, c1: f64, a2: f64, b2: f64, c2: f64) -> Option<(f64, f64)> {
79    let det = a1 * b2 - a2 * b1;
80    if det.abs() < 1e-15 { return None; }
81    let x = (c1 * b2 - c2 * b1) / det;
82    let y = (a1 * c2 - a2 * c1) / det;
83    Some((x, y))
84}
85
86/// Newton-Raphson root finding for f(x) = 0.
87pub fn newton_raphson(
88    f: &dyn Fn(f64) -> f64,
89    df: &dyn Fn(f64) -> f64,
90    x0: f64,
91    tol: f64,
92    max_iter: u32,
93) -> Option<f64> {
94    let mut x = x0;
95    for _ in 0..max_iter {
96        let fx = f(x);
97        if fx.abs() < tol { return Some(x); }
98        let dfx = df(x);
99        if dfx.abs() < 1e-15 { return None; }
100        x -= fx / dfx;
101    }
102    if f(x).abs() < tol * 100.0 { Some(x) } else { None }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn linear_solution() {
111        if let Solutions::Single(x) = solve_linear(2.0, -4.0) {
112            assert!((x - 2.0).abs() < 1e-10);
113        } else { panic!("Expected single solution"); }
114    }
115
116    #[test]
117    fn quadratic_two_roots() {
118        // x² - 5x + 6 = 0 → x = 2, 3
119        if let Solutions::Two(a, b) = solve_quadratic(1.0, -5.0, 6.0) {
120            assert!((a - 3.0).abs() < 1e-10 || (a - 2.0).abs() < 1e-10);
121            assert!((b - 3.0).abs() < 1e-10 || (b - 2.0).abs() < 1e-10);
122        } else { panic!("Expected two solutions"); }
123    }
124
125    #[test]
126    fn quadratic_no_real_roots() {
127        assert!(matches!(solve_quadratic(1.0, 0.0, 1.0), Solutions::None));
128    }
129
130    #[test]
131    fn system_2x2() {
132        // x + y = 3, x - y = 1 → x=2, y=1
133        let (x, y) = solve_system_2x2(1.0, 1.0, 3.0, 1.0, -1.0, 1.0).unwrap();
134        assert!((x - 2.0).abs() < 1e-10);
135        assert!((y - 1.0).abs() < 1e-10);
136    }
137
138    #[test]
139    fn newton_raphson_sqrt2() {
140        // f(x) = x² - 2 = 0 → x = √2
141        let root = newton_raphson(&|x| x * x - 2.0, &|x| 2.0 * x, 1.0, 1e-10, 100).unwrap();
142        assert!((root - std::f64::consts::SQRT_2).abs() < 1e-8);
143    }
144
145    #[test]
146    fn cubic_one_real_root() {
147        // x³ + x + 1 = 0 has one real root ≈ -0.6824
148        if let Solutions::Single(x) = solve_cubic(1.0, 0.0, 1.0, 1.0) {
149            assert!((x + 0.6824).abs() < 0.01);
150        }
151    }
152}