proof_engine/symbolic/
solve.rs1use super::expr::Expr;
4
5#[derive(Debug, Clone)]
7pub enum Solutions {
8 None,
9 Single(f64),
10 Two(f64, f64),
11 Many(Vec<f64>),
12 Infinite,
13}
14
15pub fn solve_linear(a: f64, b: f64) -> Solutions {
17 if a.abs() < 1e-15 {
18 if b.abs() < 1e-15 { Solutions::Infinite } else { Solutions::None }
19 } else {
20 Solutions::Single(-b / a)
21 }
22}
23
24pub fn solve_quadratic(a: f64, b: f64, c: f64) -> Solutions {
26 if a.abs() < 1e-15 { return solve_linear(b, c); }
27 let disc = b * b - 4.0 * a * c;
28 if disc < -1e-15 {
29 Solutions::None
30 } else if disc.abs() < 1e-15 {
31 Solutions::Single(-b / (2.0 * a))
32 } else {
33 let sqrt_disc = disc.sqrt();
34 Solutions::Two(
35 (-b + sqrt_disc) / (2.0 * a),
36 (-b - sqrt_disc) / (2.0 * a),
37 )
38 }
39}
40
41pub fn solve_cubic(a: f64, b: f64, c: f64, d: f64) -> Solutions {
43 if a.abs() < 1e-15 { return solve_quadratic(b, c, d); }
44
45 let p = (3.0 * a * c - b * b) / (3.0 * a * a);
47 let q = (2.0 * b * b * b - 9.0 * a * b * c + 27.0 * a * a * d) / (27.0 * a * a * a);
48 let disc = q * q / 4.0 + p * p * p / 27.0;
49 let offset = -b / (3.0 * a);
50
51 if disc > 1e-15 {
52 let u = (-q / 2.0 + disc.sqrt()).cbrt();
53 let v = (-q / 2.0 - disc.sqrt()).cbrt();
54 Solutions::Single(u + v + offset)
55 } else if disc.abs() < 1e-15 {
56 if q.abs() < 1e-15 {
57 Solutions::Single(offset)
58 } else {
59 let u = (-q / 2.0).cbrt();
60 Solutions::Two(2.0 * u + offset, -u + offset)
61 }
62 } else {
63 let r = (-p * p * p / 27.0).sqrt();
65 let phi = (-q / (2.0 * r)).acos();
66 let cube_r = r.cbrt();
67 Solutions::Many(vec![
68 2.0 * cube_r * (phi / 3.0).cos() + offset,
69 2.0 * cube_r * ((phi + std::f64::consts::TAU) / 3.0).cos() + offset,
70 2.0 * cube_r * ((phi + 2.0 * std::f64::consts::TAU) / 3.0).cos() + offset,
71 ])
72 }
73}
74
75pub fn solve_system_2x2(a1: f64, b1: f64, c1: f64, a2: f64, b2: f64, c2: f64) -> Option<(f64, f64)> {
79 let det = a1 * b2 - a2 * b1;
80 if det.abs() < 1e-15 { return None; }
81 let x = (c1 * b2 - c2 * b1) / det;
82 let y = (a1 * c2 - a2 * c1) / det;
83 Some((x, y))
84}
85
86pub fn newton_raphson(
88 f: &dyn Fn(f64) -> f64,
89 df: &dyn Fn(f64) -> f64,
90 x0: f64,
91 tol: f64,
92 max_iter: u32,
93) -> Option<f64> {
94 let mut x = x0;
95 for _ in 0..max_iter {
96 let fx = f(x);
97 if fx.abs() < tol { return Some(x); }
98 let dfx = df(x);
99 if dfx.abs() < 1e-15 { return None; }
100 x -= fx / dfx;
101 }
102 if f(x).abs() < tol * 100.0 { Some(x) } else { None }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn linear_solution() {
111 if let Solutions::Single(x) = solve_linear(2.0, -4.0) {
112 assert!((x - 2.0).abs() < 1e-10);
113 } else { panic!("Expected single solution"); }
114 }
115
116 #[test]
117 fn quadratic_two_roots() {
118 if let Solutions::Two(a, b) = solve_quadratic(1.0, -5.0, 6.0) {
120 assert!((a - 3.0).abs() < 1e-10 || (a - 2.0).abs() < 1e-10);
121 assert!((b - 3.0).abs() < 1e-10 || (b - 2.0).abs() < 1e-10);
122 } else { panic!("Expected two solutions"); }
123 }
124
125 #[test]
126 fn quadratic_no_real_roots() {
127 assert!(matches!(solve_quadratic(1.0, 0.0, 1.0), Solutions::None));
128 }
129
130 #[test]
131 fn system_2x2() {
132 let (x, y) = solve_system_2x2(1.0, 1.0, 3.0, 1.0, -1.0, 1.0).unwrap();
134 assert!((x - 2.0).abs() < 1e-10);
135 assert!((y - 1.0).abs() < 1e-10);
136 }
137
138 #[test]
139 fn newton_raphson_sqrt2() {
140 let root = newton_raphson(&|x| x * x - 2.0, &|x| 2.0 * x, 1.0, 1e-10, 100).unwrap();
142 assert!((root - std::f64::consts::SQRT_2).abs() < 1e-8);
143 }
144
145 #[test]
146 fn cubic_one_real_root() {
147 if let Solutions::Single(x) = solve_cubic(1.0, 0.0, 1.0, 1.0) {
149 assert!((x + 0.6824).abs() < 0.01);
150 }
151 }
152}