mathhook_core/algebra/root_finding/
bisection.rs1use super::{RootFinder, RootFindingConfig, RootResult};
31use crate::error::MathError;
32use crate::expr;
33
34pub struct BisectionMethod {
39 pub a: f64,
41 pub b: f64,
43}
44
45impl BisectionMethod {
46 pub fn new(a: f64, b: f64) -> Self {
61 Self { a, b }
62 }
63
64 fn validate_bracket<F>(&self, f: &F) -> Result<(), MathError>
66 where
67 F: Fn(f64) -> f64,
68 {
69 let fa = f(self.a);
70 let fb = f(self.b);
71
72 if fa.is_nan() || fb.is_nan() {
73 return Err(MathError::DomainError {
74 operation: "bisection".to_owned(),
75 value: expr!(x),
76 reason: "Function evaluates to NaN at bracket endpoints".to_owned(),
77 });
78 }
79
80 if fa * fb > 0.0 {
81 return Err(MathError::ConvergenceFailed {
82 reason: format!(
83 "Function values at bracket endpoints must have opposite signs: f({}) = {}, f({}) = {}",
84 self.a, fa, self.b, fb
85 ),
86 });
87 }
88
89 Ok(())
90 }
91}
92
93impl RootFinder for BisectionMethod {
94 fn find_root<F>(&self, f: F, config: &RootFindingConfig) -> Result<RootResult, MathError>
95 where
96 F: Fn(f64) -> f64,
97 {
98 self.validate_bracket(&f)?;
99
100 let mut a = self.a;
101 let mut b = self.b;
102 let mut fa = f(a);
103
104 for iteration in 0..config.max_iterations {
105 let c = (a + b) / 2.0;
106 let fc = f(c);
107
108 if fc.abs() < config.tolerance || (b - a).abs() / 2.0 < config.tolerance {
110 return Ok(RootResult {
111 root: c,
112 iterations: iteration + 1,
113 function_value: fc,
114 converged: true,
115 });
116 }
117
118 if fa * fc < 0.0 {
120 b = c;
121 } else {
122 a = c;
123 fa = fc;
124 }
125 }
126
127 let final_c = (a + b) / 2.0;
129 let final_fc = f(final_c);
130
131 Ok(RootResult {
132 root: final_c,
133 iterations: config.max_iterations,
134 function_value: final_fc,
135 converged: false,
136 })
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_bisection_simple_linear() {
146 let method = BisectionMethod::new(-1.0, 2.0);
147 let config = RootFindingConfig::default();
148
149 let result = method.find_root(|x| x - 1.0, &config).unwrap();
150
151 assert!(result.function_value.abs() < config.tolerance);
153 assert!((result.root - 1.0).abs() < 1e-9);
155 assert!(result.converged);
156 }
157
158 #[test]
159 fn test_bisection_quadratic() {
160 let method = BisectionMethod::new(0.0, 3.0);
161 let config = RootFindingConfig {
162 tolerance: 1e-10,
163 ..Default::default()
164 };
165
166 let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
167
168 let residual = (result.root * result.root - 2.0).abs();
170 assert!(
171 residual < 1e-9,
172 "Solution doesn't satisfy x² = 2: residual = {}",
173 residual
174 );
175
176 assert!((result.root - 2.0_f64.sqrt()).abs() < 1e-9);
178 assert!(result.converged);
179 }
180
181 #[test]
182 fn test_bisection_transcendental() {
183 let method = BisectionMethod::new(0.0, 2.0);
184 let config = RootFindingConfig {
185 tolerance: 1e-10,
186 ..Default::default()
187 };
188
189 let result = method.find_root(|x| x.cos() - x, &config).unwrap();
190
191 let residual = (result.root.cos() - result.root).abs();
193 assert!(
194 residual < 1e-9,
195 "Solution doesn't satisfy cos(x) = x: residual = {}",
196 residual
197 );
198
199 assert!(result.root > 0.73_f64 && result.root < 0.75_f64);
202 assert!(result.converged);
203 }
204
205 #[test]
206 fn test_bisection_invalid_bracket() {
207 let method = BisectionMethod::new(0.0, 1.0);
208 let config = RootFindingConfig::default();
209
210 let result = method.find_root(|x| x * x + 1.0, &config);
212 assert!(result.is_err());
213 }
214
215 #[test]
216 fn test_bisection_exact_root() {
217 let method = BisectionMethod::new(-1.0, 1.0);
218 let config = RootFindingConfig {
219 tolerance: 1e-15,
220 ..Default::default()
221 };
222
223 let result = method.find_root(|x| x, &config).unwrap();
224
225 assert!(result.root.abs() < 1e-14);
227 assert!(result.function_value.abs() < 1e-14);
228 assert!(result.converged);
229 }
230
231 #[test]
232 fn test_bisection_cubic() {
233 let method = BisectionMethod::new(0.0, 1.0);
234 let config = RootFindingConfig::default();
235
236 let result = method
237 .find_root(|x| x * x * x + x * x - 1.0, &config)
238 .unwrap();
239
240 let residual = (result.root.powi(3) + result.root.powi(2) - 1.0).abs();
242 assert!(
243 residual < 1e-9,
244 "Solution doesn't satisfy x³ + x² = 1: residual = {}",
245 residual
246 );
247
248 assert!(result.root > 0.75_f64 && result.root < 0.76_f64);
250 assert!(result.converged);
251 }
252
253 #[test]
254 fn test_bisection_sine() {
255 let method = BisectionMethod::new(3.0, 4.0);
256 let config = RootFindingConfig::default();
257
258 let result = method.find_root(|x| x.sin(), &config).unwrap();
259
260 let residual = result.root.sin().abs();
262 assert!(
263 residual < 1e-9,
264 "Solution doesn't satisfy sin(x) = 0: residual = {}",
265 residual
266 );
267
268 assert!((result.root - std::f64::consts::PI).abs() < 1e-9);
270 assert!(result.converged);
271 }
272
273 #[test]
274 fn test_bisection_exponential() {
275 let method = BisectionMethod::new(-1.0, 1.0);
276 let config = RootFindingConfig::default();
277
278 let result = method.find_root(|x| x.exp() - 2.0, &config).unwrap();
279
280 let residual = (result.root.exp() - 2.0).abs();
282 assert!(
283 residual < 1e-9,
284 "Solution doesn't satisfy e^x = 2: residual = {}",
285 residual
286 );
287
288 assert!((result.root - 2.0_f64.ln()).abs() < 1e-9);
290 assert!(result.converged);
291 }
292
293 #[test]
294 fn test_bisection_multiple_roots_finds_one() {
295 let method = BisectionMethod::new(-2.0, 2.0);
296 let config = RootFindingConfig::default();
297
298 let result = method
300 .find_root(|x| x * (x - 1.0) * (x + 1.0), &config)
301 .unwrap();
302
303 assert!(result.converged);
304
305 let residual = result.function_value.abs();
307 assert!(residual < 1e-9, "Not a valid root: f(x) = {}", residual);
308
309 let is_root = (result.root.abs() < 1e-9)
311 || ((result.root - 1.0).abs() < 1e-9)
312 || ((result.root + 1.0).abs() < 1e-9);
313 assert!(is_root, "Root {} is not one of -1, 0, or 1", result.root);
314 }
315
316 #[test]
317 fn test_bisection_convergence_rate() {
318 let method = BisectionMethod::new(0.0, 2.0);
319 let config = RootFindingConfig {
320 tolerance: 1e-12,
321 ..Default::default()
322 };
323
324 let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
325
326 assert!(result.iterations > 0);
329 assert!(
330 result.iterations < 50,
331 "Too many iterations: {}",
332 result.iterations
333 );
334 assert!(result.converged);
335 }
336
337 #[test]
338 fn test_bisection_near_discontinuity() {
339 let method = BisectionMethod::new(-1.0, 1.0);
340 let config = RootFindingConfig {
341 tolerance: 1e-8,
342 ..Default::default()
343 };
344
345 let result = method
347 .find_root(|x| if x < 0.0 { -1.0 } else { 1.0 }, &config)
348 .unwrap();
349
350 assert!(result.root.abs() < 1e-7);
352 }
353
354 #[test]
355 fn test_bisection_polynomial_with_close_roots() {
356 let method = BisectionMethod::new(0.5, 1.5);
357 let config = RootFindingConfig::default();
358
359 let result = method
361 .find_root(|x| (x - 1.0) * (x - 2.0), &config)
362 .unwrap();
363
364 assert!(result.converged);
365
366 let residual = result.function_value.abs();
368 assert!(residual < 1e-9, "Not a valid root: f(x) = {}", residual);
369
370 assert!((result.root - 1.0).abs() < 1e-9);
372 }
373
374 #[test]
375 fn test_bisection_oscillatory_function() {
376 let method = BisectionMethod::new(0.1, 0.5);
377 let config = RootFindingConfig::default();
378
379 let result = method.find_root(|x| (10.0 * x).sin(), &config).unwrap();
380
381 assert!(result.converged);
382
383 let residual = (10.0 * result.root).sin().abs();
385 assert!(
386 residual < 1e-9,
387 "Solution doesn't satisfy sin(10x) = 0: residual = {}",
388 residual
389 );
390
391 assert!((result.root - std::f64::consts::PI / 10.0).abs() < 1e-9);
393 }
394
395 #[test]
396 fn test_bisection_tolerance_control() {
397 let method = BisectionMethod::new(0.0, 2.0);
398
399 let config_loose = RootFindingConfig {
400 tolerance: 1e-4,
401 ..Default::default()
402 };
403 let result_loose = method.find_root(|x| x * x - 2.0, &config_loose).unwrap();
404
405 let config_tight = RootFindingConfig {
406 tolerance: 1e-12,
407 ..Default::default()
408 };
409 let result_tight = method.find_root(|x| x * x - 2.0, &config_tight).unwrap();
410
411 assert!(result_loose.iterations < result_tight.iterations);
413
414 assert!(result_tight.function_value.abs() < result_loose.function_value.abs());
416 }
417
418 #[test]
419 fn test_bisection_negative_interval() {
420 let method = BisectionMethod::new(-3.0, -1.0);
421 let config = RootFindingConfig::default();
422
423 let result = method.find_root(|x| x + 2.0, &config).unwrap();
424
425 let residual = (result.root + 2.0).abs();
427 assert!(
428 residual < 1e-9,
429 "Solution doesn't satisfy x = -2: residual = {}",
430 residual
431 );
432
433 assert!(result.converged);
434 }
435
436 #[test]
437 fn test_bisection_max_iterations_reached() {
438 let method = BisectionMethod::new(0.0, 2.0);
439 let config = RootFindingConfig {
440 tolerance: 1e-15,
441 max_iterations: 10, ..Default::default()
443 };
444
445 let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
446
447 assert!(
449 !result.converged,
450 "Should not have converged with only 10 iterations"
451 );
452 assert_eq!(result.iterations, 10);
453
454 assert!(result.root > 1.0 && result.root < 2.0);
456 assert!(result.function_value.abs() < 1.0); }
458
459 #[test]
460 fn test_bisection_function_value_convergence() {
461 let method = BisectionMethod::new(0.0, 2.0);
462 let config = RootFindingConfig {
463 tolerance: 1e-10,
464 ..Default::default()
465 };
466
467 let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
468
469 assert!(result.converged);
471 assert!(result.function_value.abs() < 1e-9);
472 }
473
474 #[test]
475 fn test_bisection_bracket_width_convergence() {
476 let method = BisectionMethod::new(1.0, 2.0);
477 let config = RootFindingConfig {
478 tolerance: 1e-6,
479 ..Default::default()
480 };
481
482 let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
484
485 assert!(result.converged);
486 let sqrt2 = 2.0_f64.sqrt();
488 assert!((result.root - sqrt2).abs() < config.tolerance);
489 }
490}