mathhook_core/algebra/root_finding/
bisection.rs

1//! Bisection method for root finding
2//!
3//! Implements the bisection method which guarantees convergence
4//! for continuous functions with a sign change in the interval.
5//! Uses interval halving to iteratively narrow down the root location.
6//!
7//! # Algorithm
8//!
9//! Given f(a) and f(b) with opposite signs:
10//! 1. Compute midpoint: c = (a + b) / 2
11//! 2. If f(c) has same sign as f(a), replace a with c
12//! 3. Otherwise, replace b with c
13//! 4. Repeat until |b - a| < tolerance
14//!
15//! # Convergence
16//!
17//! - Guaranteed convergence if f is continuous and f(a)*f(b) < 0
18//! - Linear convergence rate: error halves each iteration
19//! - Requires O(log2((b-a)/tolerance)) iterations
20//!
21//! # Tolerance Semantics
22//!
23//! The algorithm stops when EITHER:
24//! - |f(c)| < tolerance (function value criterion)
25//! - |b - a| / 2 < tolerance (bracket width criterion)
26//!
27//! The bracket width criterion guarantees the root is within
28//! tolerance distance of the returned value.
29
30use super::{RootFinder, RootFindingConfig, RootResult};
31use crate::error::MathError;
32use crate::expr;
33
34/// Bisection method root finder
35///
36/// Guaranteed convergence method that requires an initial bracket [a, b]
37/// where f(a) and f(b) have opposite signs.
38pub struct BisectionMethod {
39    /// Lower bound of initial bracket
40    pub a: f64,
41    /// Upper bound of initial bracket
42    pub b: f64,
43}
44
45impl BisectionMethod {
46    /// Create a new bisection method with initial bracket
47    ///
48    /// # Arguments
49    ///
50    /// * `a` - Lower bound of bracket
51    /// * `b` - Upper bound of bracket
52    ///
53    /// # Examples
54    ///
55    /// ```rust
56    /// use mathhook_core::algebra::root_finding::BisectionMethod;
57    ///
58    /// let method = BisectionMethod::new(0.0, 2.0);
59    /// ```
60    pub fn new(a: f64, b: f64) -> Self {
61        Self { a, b }
62    }
63
64    /// Check if the bracket is valid (function values have opposite signs)
65    fn validate_bracket<F>(&self, f: &F) -> Result<(), MathError>
66    where
67        F: Fn(f64) -> f64,
68    {
69        let fa = f(self.a);
70        let fb = f(self.b);
71
72        if fa.is_nan() || fb.is_nan() {
73            return Err(MathError::DomainError {
74                operation: "bisection".to_owned(),
75                value: expr!(x),
76                reason: "Function evaluates to NaN at bracket endpoints".to_owned(),
77            });
78        }
79
80        if fa * fb > 0.0 {
81            return Err(MathError::ConvergenceFailed {
82                reason: format!(
83                    "Function values at bracket endpoints must have opposite signs: f({}) = {}, f({}) = {}",
84                    self.a, fa, self.b, fb
85                ),
86            });
87        }
88
89        Ok(())
90    }
91}
92
93impl RootFinder for BisectionMethod {
94    fn find_root<F>(&self, f: F, config: &RootFindingConfig) -> Result<RootResult, MathError>
95    where
96        F: Fn(f64) -> f64,
97    {
98        self.validate_bracket(&f)?;
99
100        let mut a = self.a;
101        let mut b = self.b;
102        let mut fa = f(a);
103
104        for iteration in 0..config.max_iterations {
105            let c = (a + b) / 2.0;
106            let fc = f(c);
107
108            // Check convergence: function value OR bracket width
109            if fc.abs() < config.tolerance || (b - a).abs() / 2.0 < config.tolerance {
110                return Ok(RootResult {
111                    root: c,
112                    iterations: iteration + 1,
113                    function_value: fc,
114                    converged: true,
115                });
116            }
117
118            // Update bracket based on sign of f(c)
119            if fa * fc < 0.0 {
120                b = c;
121            } else {
122                a = c;
123                fa = fc;
124            }
125        }
126
127        // Max iterations reached - return best approximation with converged=false
128        let final_c = (a + b) / 2.0;
129        let final_fc = f(final_c);
130
131        Ok(RootResult {
132            root: final_c,
133            iterations: config.max_iterations,
134            function_value: final_fc,
135            converged: false,
136        })
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_bisection_simple_linear() {
146        let method = BisectionMethod::new(-1.0, 2.0);
147        let config = RootFindingConfig::default();
148
149        let result = method.find_root(|x| x - 1.0, &config).unwrap();
150
151        // Primary: verify equation is satisfied
152        assert!(result.function_value.abs() < config.tolerance);
153        // Secondary: check expected value
154        assert!((result.root - 1.0).abs() < 1e-9);
155        assert!(result.converged);
156    }
157
158    #[test]
159    fn test_bisection_quadratic() {
160        let method = BisectionMethod::new(0.0, 3.0);
161        let config = RootFindingConfig {
162            tolerance: 1e-10,
163            ..Default::default()
164        };
165
166        let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
167
168        // Primary: verify x² = 2
169        let residual = (result.root * result.root - 2.0).abs();
170        assert!(
171            residual < 1e-9,
172            "Solution doesn't satisfy x² = 2: residual = {}",
173            residual
174        );
175
176        // Secondary: compare to sqrt(2)
177        assert!((result.root - 2.0_f64.sqrt()).abs() < 1e-9);
178        assert!(result.converged);
179    }
180
181    #[test]
182    fn test_bisection_transcendental() {
183        let method = BisectionMethod::new(0.0, 2.0);
184        let config = RootFindingConfig {
185            tolerance: 1e-10,
186            ..Default::default()
187        };
188
189        let result = method.find_root(|x| x.cos() - x, &config).unwrap();
190
191        // Primary: verify cos(x) = x
192        let residual = (result.root.cos() - result.root).abs();
193        assert!(
194            residual < 1e-9,
195            "Solution doesn't satisfy cos(x) = x: residual = {}",
196            residual
197        );
198
199        // Secondary: verify it's in expected range
200        // Reference: Dottie number ≈ 0.739085133215161
201        assert!(result.root > 0.73_f64 && result.root < 0.75_f64);
202        assert!(result.converged);
203    }
204
205    #[test]
206    fn test_bisection_invalid_bracket() {
207        let method = BisectionMethod::new(0.0, 1.0);
208        let config = RootFindingConfig::default();
209
210        // x² + 1 has no real roots
211        let result = method.find_root(|x| x * x + 1.0, &config);
212        assert!(result.is_err());
213    }
214
215    #[test]
216    fn test_bisection_exact_root() {
217        let method = BisectionMethod::new(-1.0, 1.0);
218        let config = RootFindingConfig {
219            tolerance: 1e-15,
220            ..Default::default()
221        };
222
223        let result = method.find_root(|x| x, &config).unwrap();
224
225        // Verify f(x) = x has root at 0
226        assert!(result.root.abs() < 1e-14);
227        assert!(result.function_value.abs() < 1e-14);
228        assert!(result.converged);
229    }
230
231    #[test]
232    fn test_bisection_cubic() {
233        let method = BisectionMethod::new(0.0, 1.0);
234        let config = RootFindingConfig::default();
235
236        let result = method
237            .find_root(|x| x * x * x + x * x - 1.0, &config)
238            .unwrap();
239
240        // Primary: verify x³ + x² - 1 = 0
241        let residual = (result.root.powi(3) + result.root.powi(2) - 1.0).abs();
242        assert!(
243            residual < 1e-9,
244            "Solution doesn't satisfy x³ + x² = 1: residual = {}",
245            residual
246        );
247
248        // Secondary: verify it's in expected range
249        assert!(result.root > 0.75_f64 && result.root < 0.76_f64);
250        assert!(result.converged);
251    }
252
253    #[test]
254    fn test_bisection_sine() {
255        let method = BisectionMethod::new(3.0, 4.0);
256        let config = RootFindingConfig::default();
257
258        let result = method.find_root(|x| x.sin(), &config).unwrap();
259
260        // Primary: verify sin(x) = 0
261        let residual = result.root.sin().abs();
262        assert!(
263            residual < 1e-9,
264            "Solution doesn't satisfy sin(x) = 0: residual = {}",
265            residual
266        );
267
268        // Secondary: compare to π
269        assert!((result.root - std::f64::consts::PI).abs() < 1e-9);
270        assert!(result.converged);
271    }
272
273    #[test]
274    fn test_bisection_exponential() {
275        let method = BisectionMethod::new(-1.0, 1.0);
276        let config = RootFindingConfig::default();
277
278        let result = method.find_root(|x| x.exp() - 2.0, &config).unwrap();
279
280        // Primary: verify e^x = 2
281        let residual = (result.root.exp() - 2.0).abs();
282        assert!(
283            residual < 1e-9,
284            "Solution doesn't satisfy e^x = 2: residual = {}",
285            residual
286        );
287
288        // Secondary: compare to ln(2)
289        assert!((result.root - 2.0_f64.ln()).abs() < 1e-9);
290        assert!(result.converged);
291    }
292
293    #[test]
294    fn test_bisection_multiple_roots_finds_one() {
295        let method = BisectionMethod::new(-2.0, 2.0);
296        let config = RootFindingConfig::default();
297
298        // f(x) = x(x-1)(x+1) has roots at -1, 0, 1
299        let result = method
300            .find_root(|x| x * (x - 1.0) * (x + 1.0), &config)
301            .unwrap();
302
303        assert!(result.converged);
304
305        // Primary: verify it's actually a root
306        let residual = result.function_value.abs();
307        assert!(residual < 1e-9, "Not a valid root: f(x) = {}", residual);
308
309        // Secondary: verify it's one of the three roots
310        let is_root = (result.root.abs() < 1e-9)
311            || ((result.root - 1.0).abs() < 1e-9)
312            || ((result.root + 1.0).abs() < 1e-9);
313        assert!(is_root, "Root {} is not one of -1, 0, or 1", result.root);
314    }
315
316    #[test]
317    fn test_bisection_convergence_rate() {
318        let method = BisectionMethod::new(0.0, 2.0);
319        let config = RootFindingConfig {
320            tolerance: 1e-12,
321            ..Default::default()
322        };
323
324        let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
325
326        // Bisection requires approximately log2((b-a)/tol) iterations
327        // For [0, 2] with tol=1e-12: log2(2/1e-12) ≈ 41 iterations
328        assert!(result.iterations > 0);
329        assert!(
330            result.iterations < 50,
331            "Too many iterations: {}",
332            result.iterations
333        );
334        assert!(result.converged);
335    }
336
337    #[test]
338    fn test_bisection_near_discontinuity() {
339        let method = BisectionMethod::new(-1.0, 1.0);
340        let config = RootFindingConfig {
341            tolerance: 1e-8,
342            ..Default::default()
343        };
344
345        // Step function: -1 for x < 0, +1 for x ≥ 0
346        let result = method
347            .find_root(|x| if x < 0.0 { -1.0 } else { 1.0 }, &config)
348            .unwrap();
349
350        // The "root" is at the discontinuity
351        assert!(result.root.abs() < 1e-7);
352    }
353
354    #[test]
355    fn test_bisection_polynomial_with_close_roots() {
356        let method = BisectionMethod::new(0.5, 1.5);
357        let config = RootFindingConfig::default();
358
359        // f(x) = (x-1)(x-2) has roots at 1 and 2
360        let result = method
361            .find_root(|x| (x - 1.0) * (x - 2.0), &config)
362            .unwrap();
363
364        assert!(result.converged);
365
366        // Primary: verify it's a root
367        let residual = result.function_value.abs();
368        assert!(residual < 1e-9, "Not a valid root: f(x) = {}", residual);
369
370        // Bracket [0.5, 1.5] should find root at x=1
371        assert!((result.root - 1.0).abs() < 1e-9);
372    }
373
374    #[test]
375    fn test_bisection_oscillatory_function() {
376        let method = BisectionMethod::new(0.1, 0.5);
377        let config = RootFindingConfig::default();
378
379        let result = method.find_root(|x| (10.0 * x).sin(), &config).unwrap();
380
381        assert!(result.converged);
382
383        // Primary: verify sin(10x) = 0
384        let residual = (10.0 * result.root).sin().abs();
385        assert!(
386            residual < 1e-9,
387            "Solution doesn't satisfy sin(10x) = 0: residual = {}",
388            residual
389        );
390
391        // Secondary: compare to π/10
392        assert!((result.root - std::f64::consts::PI / 10.0).abs() < 1e-9);
393    }
394
395    #[test]
396    fn test_bisection_tolerance_control() {
397        let method = BisectionMethod::new(0.0, 2.0);
398
399        let config_loose = RootFindingConfig {
400            tolerance: 1e-4,
401            ..Default::default()
402        };
403        let result_loose = method.find_root(|x| x * x - 2.0, &config_loose).unwrap();
404
405        let config_tight = RootFindingConfig {
406            tolerance: 1e-12,
407            ..Default::default()
408        };
409        let result_tight = method.find_root(|x| x * x - 2.0, &config_tight).unwrap();
410
411        // Tighter tolerance requires more iterations
412        assert!(result_loose.iterations < result_tight.iterations);
413
414        // Tighter tolerance produces more accurate result
415        assert!(result_tight.function_value.abs() < result_loose.function_value.abs());
416    }
417
418    #[test]
419    fn test_bisection_negative_interval() {
420        let method = BisectionMethod::new(-3.0, -1.0);
421        let config = RootFindingConfig::default();
422
423        let result = method.find_root(|x| x + 2.0, &config).unwrap();
424
425        // Primary: verify x + 2 = 0
426        let residual = (result.root + 2.0).abs();
427        assert!(
428            residual < 1e-9,
429            "Solution doesn't satisfy x = -2: residual = {}",
430            residual
431        );
432
433        assert!(result.converged);
434    }
435
436    #[test]
437    fn test_bisection_max_iterations_reached() {
438        let method = BisectionMethod::new(0.0, 2.0);
439        let config = RootFindingConfig {
440            tolerance: 1e-15,
441            max_iterations: 10, // Deliberately too few
442            ..Default::default()
443        };
444
445        let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
446
447        // Should return non-converged result
448        assert!(
449            !result.converged,
450            "Should not have converged with only 10 iterations"
451        );
452        assert_eq!(result.iterations, 10);
453
454        // But should still be getting closer to the root
455        assert!(result.root > 1.0 && result.root < 2.0);
456        assert!(result.function_value.abs() < 1.0); // Better than initial bracket
457    }
458
459    #[test]
460    fn test_bisection_function_value_convergence() {
461        let method = BisectionMethod::new(0.0, 2.0);
462        let config = RootFindingConfig {
463            tolerance: 1e-10,
464            ..Default::default()
465        };
466
467        let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
468
469        // When converged, function_value should be near zero
470        assert!(result.converged);
471        assert!(result.function_value.abs() < 1e-9);
472    }
473
474    #[test]
475    fn test_bisection_bracket_width_convergence() {
476        let method = BisectionMethod::new(1.0, 2.0);
477        let config = RootFindingConfig {
478            tolerance: 1e-6,
479            ..Default::default()
480        };
481
482        // Use a function where f(c) might not get small, but bracket does
483        let result = method.find_root(|x| x * x - 2.0, &config).unwrap();
484
485        assert!(result.converged);
486        // The bracket width criterion ensures root is within tolerance
487        let sqrt2 = 2.0_f64.sqrt();
488        assert!((result.root - sqrt2).abs() < config.tolerance);
489    }
490}