Skip to main content

lox_core/math/
roots.rs

1// SPDX-FileCopyrightText: 2024 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5//! Root-finding algorithms: Steffensen, Newton, Brent, and Secant methods.
6
7use lox_test_utils::approx_eq;
8use thiserror::Error;
9
10/// Error returned by root-finding algorithms.
11#[derive(Debug, Error)]
12pub enum RootFinderError {
13    /// The algorithm did not converge within the maximum number of iterations.
14    #[error("not converged after {0} iterations, residual {1}")]
15    NotConverged(u32, f64),
16    /// The root is not within the given bracket.
17    #[error("root not in bracket")]
18    NotInBracket,
19    /// The objective function returned an error.
20    #[error(transparent)]
21    Callback(#[from] CallbackError),
22}
23
24/// A boxed error type for use in root-finding callbacks.
25pub type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
26
27/// An error returned by a root-finding callback function.
28#[derive(Debug, Error)]
29#[error(transparent)]
30pub struct CallbackError(BoxedError);
31
32impl From<&str> for CallbackError {
33    fn from(s: &str) -> Self {
34        CallbackError(s.into())
35    }
36}
37
38impl From<BoxedError> for CallbackError {
39    fn from(e: BoxedError) -> Self {
40        CallbackError(e)
41    }
42}
43
44/// A callable function for root-finding algorithms.
45pub trait Callback {
46    /// Evaluates the function at `v`.
47    fn call(&self, v: f64) -> Result<f64, CallbackError>;
48}
49
50impl<F> Callback for F
51where
52    F: Fn(f64) -> Result<f64, BoxedError>,
53{
54    fn call(&self, v: f64) -> Result<f64, CallbackError> {
55        self(v).map_err(CallbackError)
56    }
57}
58
59/// Finds a root of `f` starting from an initial guess.
60pub trait FindRoot<F>
61where
62    F: Callback,
63{
64    /// Finds a root of `f` starting from `initial_guess`.
65    fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError>;
66}
67
68/// Finds a root of `f` using both the function and its derivative.
69pub trait FindRootWithDerivative<F, D>
70where
71    F: Callback,
72    D: Callback,
73{
74    /// Finds a root of `f` using `derivative`, starting from `initial_guess`.
75    fn find_with_derivative(
76        &self,
77        f: F,
78        derivative: D,
79        initial_guess: f64,
80    ) -> Result<f64, RootFinderError>;
81}
82
83/// Finds a root of `f` within a bracket `(a, b)`.
84pub trait FindBracketedRoot<F>
85where
86    F: Callback,
87{
88    /// Finds a root of `f` within the given `bracket`.
89    fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError>;
90}
91
92/// Steffensen's method for root-finding (derivative-free).
93#[derive(Debug, Copy, Clone, PartialEq)]
94pub struct Steffensen {
95    max_iter: u32,
96    tolerance: f64,
97}
98
99impl Default for Steffensen {
100    fn default() -> Self {
101        Self {
102            max_iter: 1000,
103            tolerance: f64::EPSILON.sqrt(),
104        }
105    }
106}
107
108impl<F> FindRoot<F> for Steffensen
109where
110    F: Callback,
111{
112    fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
113        let mut p0 = initial_guess;
114        for _ in 0..self.max_iter {
115            let f1 = p0 + f.call(p0).map_err(RootFinderError::Callback)?;
116            let f2 = f1 + f.call(f1).map_err(RootFinderError::Callback)?;
117            let p = p0 - (f1 - p0).powi(2) / (f2 - 2.0 * f1 + p0);
118            if approx_eq!(p, p0, atol <= self.tolerance) {
119                return Ok(p);
120            }
121            p0 = p;
122        }
123        Err(RootFinderError::NotConverged(self.max_iter, p0))
124    }
125}
126
127/// Newton-Raphson method for root-finding (requires derivative).
128#[derive(Debug, Copy, Clone, PartialEq)]
129pub struct Newton {
130    max_iter: u32,
131    tolerance: f64,
132}
133
134impl Default for Newton {
135    fn default() -> Self {
136        Self {
137            max_iter: 50,
138            tolerance: f64::EPSILON.sqrt(),
139        }
140    }
141}
142
143impl<F, D> FindRootWithDerivative<F, D> for Newton
144where
145    F: Callback,
146    D: Callback,
147{
148    fn find_with_derivative(
149        &self,
150        f: F,
151        derivative: D,
152        initial_guess: f64,
153    ) -> Result<f64, RootFinderError> {
154        let mut p0 = initial_guess;
155        for _ in 0..self.max_iter {
156            let p = p0
157                - f.call(p0).map_err(RootFinderError::Callback)?
158                    / derivative.call(p0).map_err(RootFinderError::Callback)?;
159            if approx_eq!(p, p0, atol <= self.tolerance) {
160                return Ok(p);
161            }
162            p0 = p;
163        }
164        Err(RootFinderError::NotConverged(self.max_iter, p0))
165    }
166}
167
168/// Brent's method for bracketed root-finding.
169#[derive(Debug, Copy, Clone, PartialEq)]
170pub struct Brent {
171    max_iter: u32,
172    abs_tol: f64,
173    rel_tol: f64,
174}
175
176impl Default for Brent {
177    fn default() -> Self {
178        Self {
179            max_iter: 100,
180            abs_tol: 1e-6,
181            rel_tol: f64::EPSILON.sqrt(),
182        }
183    }
184}
185
186impl<F> FindBracketedRoot<F> for Brent
187where
188    F: Callback,
189{
190    fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
191        let mut fblk = 0.0;
192        let mut xblk = 0.0;
193        let (mut xpre, mut xcur) = bracket;
194        let mut spre = 0.0;
195        let mut scur = 0.0;
196
197        let mut fpre = f.call(xpre).map_err(RootFinderError::Callback)?;
198        let mut fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
199
200        if fpre * fcur > 0.0 {
201            return Err(RootFinderError::NotInBracket);
202        }
203
204        if approx_eq!(fpre, 0.0, atol <= self.abs_tol) {
205            return Ok(xpre);
206        }
207
208        if approx_eq!(fcur, 0.0, atol <= self.abs_tol) {
209            return Ok(xcur);
210        }
211
212        for _ in 0..self.max_iter {
213            if fpre * fcur < 0.0 {
214                xblk = xpre;
215                fblk = fpre;
216                spre = xcur - xpre;
217                scur = xcur - xpre;
218            }
219
220            if fblk.abs() < fcur.abs() {
221                xpre = xcur;
222                xcur = xblk;
223                xblk = xpre;
224                fpre = fcur;
225                fcur = fblk;
226                fblk = fpre;
227            }
228
229            let delta = (self.abs_tol + self.rel_tol * xcur.abs()) / 2.0;
230            let sbis = (xblk - xcur) / 2.0;
231
232            if approx_eq!(fcur, 0.0, atol <= self.abs_tol) || sbis.abs() < delta {
233                return Ok(xcur);
234            }
235
236            if spre.abs() > delta && fcur.abs() < fpre.abs() {
237                let stry = if approx_eq!(xpre, xblk, rtol <= self.rel_tol) {
238                    // interpolate
239                    -fcur * (xcur - xpre) / (fcur - fpre)
240                } else {
241                    // extrapolate
242                    let dpre = (fpre - fcur) / (xpre - xcur);
243                    let dblk = (fblk - fcur) / (xblk - xcur);
244                    -fcur * (fblk * dblk - fpre * dpre) / (dblk * dpre * (fblk - fpre))
245                };
246
247                if 2.0 * stry.abs() < spre.abs().min(3.0 * sbis.abs() - delta) {
248                    spre = scur;
249                    scur = stry;
250                } else {
251                    // bisect
252                    spre = sbis;
253                    scur = sbis;
254                }
255            } else {
256                // bisect
257                spre = sbis;
258                scur = sbis;
259            }
260
261            xpre = xcur;
262            fpre = fcur;
263
264            if scur.abs() > delta {
265                xcur += scur
266            } else {
267                xcur += if sbis > 0.0 { delta } else { -delta };
268            }
269
270            fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
271        }
272
273        Err(RootFinderError::NotConverged(self.max_iter, fcur))
274    }
275}
276
277/// Secant method for root-finding.
278#[derive(Debug, Copy, Clone, PartialEq)]
279pub struct Secant {
280    max_iter: u32,
281    rel_tol: f64,
282    abs_tol: f64,
283}
284
285impl Default for Secant {
286    fn default() -> Self {
287        Self {
288            max_iter: 100,
289            rel_tol: f64::EPSILON.sqrt(),
290            abs_tol: 1e-6,
291        }
292    }
293}
294
295impl<F> FindBracketedRoot<F> for Secant
296where
297    F: Callback,
298{
299    fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
300        let (x0, x1) = bracket;
301        let mut p0 = x0;
302        let mut p1 = x1;
303        let mut q0 = f.call(p0).map_err(RootFinderError::Callback)?;
304        let mut q1 = f.call(p1).map_err(RootFinderError::Callback)?;
305        if q1.abs() < q0.abs() {
306            std::mem::swap(&mut p0, &mut p1);
307            std::mem::swap(&mut q0, &mut q1);
308        }
309        for i in 0..self.max_iter {
310            if q1 == q0 {
311                if p1 != p0 {
312                    return Err(RootFinderError::NotConverged(i, q0));
313                }
314                return Ok((p1 + p0) / 2.0);
315            }
316            let p = if q1.abs() > q0.abs() {
317                (-q0 / q1 * p1 + p0) / (1.0 - q0 / q1)
318            } else {
319                (-q1 / q0 * p0 + p1) / (1.0 - q1 / q0)
320            };
321            if approx_eq!(p, p1, rtol <= self.rel_tol, atol <= self.abs_tol) {
322                return Ok(p);
323            }
324            p0 = p1;
325            q0 = q1;
326            p1 = p;
327            q1 = f.call(p).map_err(RootFinderError::Callback)?;
328        }
329        Err(RootFinderError::NotConverged(self.max_iter, p0))
330    }
331}
332
333impl<F> FindRoot<F> for Secant
334where
335    F: Callback,
336{
337    fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
338        let x0 = initial_guess;
339        let eps = 1e-4;
340        let mut x1 = x0 * (1.0 + eps);
341        x1 += if x1 > x0 { eps } else { -eps };
342        self.find_in_bracket(f, (x0, x1))
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use lox_test_utils::assert_approx_eq;
349    use std::f64::consts::PI;
350
351    use super::*;
352
353    type Result = std::result::Result<f64, BoxedError>;
354
355    #[test]
356    fn test_newton_kepler() {
357        fn mean_to_ecc(mean: f64, eccentricity: f64) -> std::result::Result<f64, RootFinderError> {
358            let newton = Newton::default();
359            newton.find_with_derivative(
360                |e: f64| -> Result { Ok(e - eccentricity * e.sin() - mean) },
361                |e: f64| -> Result { Ok(1.0 - eccentricity * e.cos()) },
362                mean,
363            )
364        }
365        let act = mean_to_ecc(PI / 2.0, 0.3).expect("should converge");
366        assert_approx_eq!(act, 1.85846841205333, rtol <= 1e-8);
367    }
368
369    #[test]
370    fn test_newton_cubic() {
371        let newton = Newton::default();
372        let act = newton
373            .find_with_derivative(
374                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
375                |x: f64| -> Result { Ok(2.0 * x.powi(2) + 8.0 * x) },
376                1.5,
377            )
378            .expect("should converge");
379        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
380    }
381
382    #[test]
383    fn test_steffensen_cubic() {
384        let steffensen = Steffensen::default();
385        let act = steffensen
386            .find(
387                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
388                1.5,
389            )
390            .expect("should converge");
391        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
392    }
393
394    #[test]
395    fn test_brent_cubic() {
396        let brent = Brent::default();
397        let act = brent
398            .find_in_bracket(
399                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
400                (1.0, 1.5),
401            )
402            .expect("should converge");
403        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
404    }
405
406    #[test]
407    fn test_secant_cubic() {
408        let secant = Secant::default();
409        let act = secant
410            .find_in_bracket(
411                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
412                (1.0, 1.5),
413            )
414            .expect("should converge");
415        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
416
417        let act = secant
418            .find(
419                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
420                1.0,
421            )
422            .expect("should converge");
423        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
424    }
425
426    #[test]
427    #[should_panic(expected = "derivative failed")]
428    fn test_newton_kepler_callback_error() {
429        let newton = Newton::default();
430        newton
431            .find_with_derivative(
432                |e: f64| -> Result { Ok(e) },
433                |_e: f64| -> Result { Err("derivative failed".into()) },
434                1.0,
435            )
436            .unwrap();
437    }
438
439    #[test]
440    #[should_panic(expected = "f failed")]
441    fn test_steffensen_cubic_error() {
442        let steffensen = Steffensen::default();
443        // function errors immediately
444        steffensen
445            .find(|_x| -> Result { Err("f failed".into()) }, 1.0)
446            .unwrap();
447    }
448
449    #[test]
450    #[should_panic(expected = "negative x")]
451    fn test_brent_cubic_error() {
452        let brent = Brent::default();
453        // error at bracket endpoint, then during iteration
454        brent
455            .find_in_bracket(
456                |x: f64| -> Result {
457                    if x.is_sign_negative() {
458                        Err("negative x".into())
459                    } else {
460                        Ok(x * x - 2.0)
461                    }
462                },
463                (-1.0, 2.0),
464            )
465            .unwrap();
466    }
467}