Skip to main content

lox_core/math/
roots.rs

1// SPDX-FileCopyrightText: 2024 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5use lox_test_utils::approx_eq;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum RootFinderError {
10    #[error("not converged after {0} iterations, residual {1}")]
11    NotConverged(u32, f64),
12    #[error("root not in bracket")]
13    NotInBracket,
14    #[error(transparent)]
15    Callback(#[from] CallbackError),
16}
17
18pub type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
19
20#[derive(Debug, Error)]
21#[error(transparent)]
22pub struct CallbackError(BoxedError);
23
24impl From<&str> for CallbackError {
25    fn from(s: &str) -> Self {
26        CallbackError(s.into())
27    }
28}
29
30impl From<BoxedError> for CallbackError {
31    fn from(e: BoxedError) -> Self {
32        CallbackError(e)
33    }
34}
35
36pub trait Callback {
37    fn call(&self, v: f64) -> Result<f64, CallbackError>;
38}
39
40impl<F> Callback for F
41where
42    F: Fn(f64) -> Result<f64, BoxedError>,
43{
44    fn call(&self, v: f64) -> Result<f64, CallbackError> {
45        self(v).map_err(CallbackError)
46    }
47}
48
49pub trait FindRoot<F>
50where
51    F: Callback,
52{
53    fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError>;
54}
55
56pub trait FindRootWithDerivative<F, D>
57where
58    F: Callback,
59    D: Callback,
60{
61    fn find_with_derivative(
62        &self,
63        f: F,
64        derivative: D,
65        initial_guess: f64,
66    ) -> Result<f64, RootFinderError>;
67}
68
69pub trait FindBracketedRoot<F>
70where
71    F: Callback,
72{
73    fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError>;
74}
75
76#[derive(Debug, Copy, Clone, PartialEq)]
77pub struct Steffensen {
78    max_iter: u32,
79    tolerance: f64,
80}
81
82impl Default for Steffensen {
83    fn default() -> Self {
84        Self {
85            max_iter: 1000,
86            tolerance: f64::EPSILON.sqrt(),
87        }
88    }
89}
90
91impl<F> FindRoot<F> for Steffensen
92where
93    F: Callback,
94{
95    fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
96        let mut p0 = initial_guess;
97        for _ in 0..self.max_iter {
98            let f1 = p0 + f.call(p0).map_err(RootFinderError::Callback)?;
99            let f2 = f1 + f.call(f1).map_err(RootFinderError::Callback)?;
100            let p = p0 - (f1 - p0).powi(2) / (f2 - 2.0 * f1 + p0);
101            if approx_eq!(p, p0, atol <= self.tolerance) {
102                return Ok(p);
103            }
104            p0 = p;
105        }
106        Err(RootFinderError::NotConverged(self.max_iter, p0))
107    }
108}
109
110#[derive(Debug, Copy, Clone, PartialEq)]
111pub struct Newton {
112    max_iter: u32,
113    tolerance: f64,
114}
115
116impl Default for Newton {
117    fn default() -> Self {
118        Self {
119            max_iter: 50,
120            tolerance: f64::EPSILON.sqrt(),
121        }
122    }
123}
124
125impl<F, D> FindRootWithDerivative<F, D> for Newton
126where
127    F: Callback,
128    D: Callback,
129{
130    fn find_with_derivative(
131        &self,
132        f: F,
133        derivative: D,
134        initial_guess: f64,
135    ) -> Result<f64, RootFinderError> {
136        let mut p0 = initial_guess;
137        for _ in 0..self.max_iter {
138            let p = p0
139                - f.call(p0).map_err(RootFinderError::Callback)?
140                    / derivative.call(p0).map_err(RootFinderError::Callback)?;
141            if approx_eq!(p, p0, atol <= self.tolerance) {
142                return Ok(p);
143            }
144            p0 = p;
145        }
146        Err(RootFinderError::NotConverged(self.max_iter, p0))
147    }
148}
149
150#[derive(Debug, Copy, Clone, PartialEq)]
151pub struct Brent {
152    max_iter: u32,
153    abs_tol: f64,
154    rel_tol: f64,
155}
156
157impl Default for Brent {
158    fn default() -> Self {
159        Self {
160            max_iter: 100,
161            abs_tol: 1e-6,
162            rel_tol: f64::EPSILON.sqrt(),
163        }
164    }
165}
166
167impl<F> FindBracketedRoot<F> for Brent
168where
169    F: Callback,
170{
171    fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
172        let mut fblk = 0.0;
173        let mut xblk = 0.0;
174        let (mut xpre, mut xcur) = bracket;
175        let mut spre = 0.0;
176        let mut scur = 0.0;
177
178        let mut fpre = f.call(xpre).map_err(RootFinderError::Callback)?;
179        let mut fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
180
181        if fpre * fcur > 0.0 {
182            return Err(RootFinderError::NotInBracket);
183        }
184
185        if approx_eq!(fpre, 0.0, atol <= self.abs_tol) {
186            return Ok(xpre);
187        }
188
189        if approx_eq!(fcur, 0.0, atol <= self.abs_tol) {
190            return Ok(xcur);
191        }
192
193        for _ in 0..self.max_iter {
194            if fpre * fcur < 0.0 {
195                xblk = xpre;
196                fblk = fpre;
197                spre = xcur - xpre;
198                scur = xcur - xpre;
199            }
200
201            if fblk.abs() < fcur.abs() {
202                xpre = xcur;
203                xcur = xblk;
204                xblk = xpre;
205                fpre = fcur;
206                fcur = fblk;
207                fblk = fpre;
208            }
209
210            let delta = (self.abs_tol + self.rel_tol * xcur.abs()) / 2.0;
211            let sbis = (xblk - xcur) / 2.0;
212
213            if approx_eq!(fcur, 0.0, atol <= self.abs_tol) || sbis.abs() < delta {
214                return Ok(xcur);
215            }
216
217            if spre.abs() > delta && fcur.abs() < fpre.abs() {
218                let stry = if approx_eq!(xpre, xblk, rtol <= self.rel_tol) {
219                    // interpolate
220                    -fcur * (xcur - xpre) / (fcur - fpre)
221                } else {
222                    // extrapolate
223                    let dpre = (fpre - fcur) / (xpre - xcur);
224                    let dblk = (fblk - fcur) / (xblk - xcur);
225                    -fcur * (fblk * dblk - fpre * dpre) / (dblk * dpre * (fblk - fpre))
226                };
227
228                if 2.0 * stry.abs() < spre.abs().min(3.0 * sbis.abs() - delta) {
229                    spre = scur;
230                    scur = stry;
231                } else {
232                    // bisect
233                    spre = sbis;
234                    scur = sbis;
235                }
236            } else {
237                // bisect
238                spre = sbis;
239                scur = sbis;
240            }
241
242            xpre = xcur;
243            fpre = fcur;
244
245            if scur.abs() > delta {
246                xcur += scur
247            } else {
248                xcur += if sbis > 0.0 { delta } else { -delta };
249            }
250
251            fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
252        }
253
254        Err(RootFinderError::NotConverged(self.max_iter, fcur))
255    }
256}
257
258#[derive(Debug, Copy, Clone, PartialEq)]
259pub struct Secant {
260    max_iter: u32,
261    rel_tol: f64,
262    abs_tol: f64,
263}
264
265impl Default for Secant {
266    fn default() -> Self {
267        Self {
268            max_iter: 100,
269            rel_tol: f64::EPSILON.sqrt(),
270            abs_tol: 1e-6,
271        }
272    }
273}
274
275impl<F> FindBracketedRoot<F> for Secant
276where
277    F: Callback,
278{
279    fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
280        let (x0, x1) = bracket;
281        let mut p0 = x0;
282        let mut p1 = x1;
283        let mut q0 = f.call(p0).map_err(RootFinderError::Callback)?;
284        let mut q1 = f.call(p1).map_err(RootFinderError::Callback)?;
285        if q1.abs() < q0.abs() {
286            std::mem::swap(&mut p0, &mut p1);
287            std::mem::swap(&mut q0, &mut q1);
288        }
289        for i in 0..self.max_iter {
290            if q1 == q0 {
291                if p1 != p0 {
292                    return Err(RootFinderError::NotConverged(i, q0));
293                }
294                return Ok((p1 + p0) / 2.0);
295            }
296            let p = if q1.abs() > q0.abs() {
297                (-q0 / q1 * p1 + p0) / (1.0 - q0 / q1)
298            } else {
299                (-q1 / q0 * p0 + p1) / (1.0 - q1 / q0)
300            };
301            if approx_eq!(p, p1, rtol <= self.rel_tol, atol <= self.abs_tol) {
302                return Ok(p);
303            }
304            p0 = p1;
305            q0 = q1;
306            p1 = p;
307            q1 = f.call(p).map_err(RootFinderError::Callback)?;
308        }
309        Err(RootFinderError::NotConverged(self.max_iter, p0))
310    }
311}
312
313impl<F> FindRoot<F> for Secant
314where
315    F: Callback,
316{
317    fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
318        let x0 = initial_guess;
319        let eps = 1e-4;
320        let mut x1 = x0 * (1.0 + eps);
321        x1 += if x1 > x0 { eps } else { -eps };
322        self.find_in_bracket(f, (x0, x1))
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use lox_test_utils::assert_approx_eq;
329    use std::f64::consts::PI;
330
331    use super::*;
332
333    type Result = std::result::Result<f64, BoxedError>;
334
335    #[test]
336    fn test_newton_kepler() {
337        fn mean_to_ecc(mean: f64, eccentricity: f64) -> std::result::Result<f64, RootFinderError> {
338            let newton = Newton::default();
339            newton.find_with_derivative(
340                |e: f64| -> Result { Ok(e - eccentricity * e.sin() - mean) },
341                |e: f64| -> Result { Ok(1.0 - eccentricity * e.cos()) },
342                mean,
343            )
344        }
345        let act = mean_to_ecc(PI / 2.0, 0.3).expect("should converge");
346        assert_approx_eq!(act, 1.85846841205333, rtol <= 1e-8);
347    }
348
349    #[test]
350    fn test_newton_cubic() {
351        let newton = Newton::default();
352        let act = newton
353            .find_with_derivative(
354                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
355                |x: f64| -> Result { Ok(2.0 * x.powi(2) + 8.0 * x) },
356                1.5,
357            )
358            .expect("should converge");
359        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
360    }
361
362    #[test]
363    fn test_steffensen_cubic() {
364        let steffensen = Steffensen::default();
365        let act = steffensen
366            .find(
367                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
368                1.5,
369            )
370            .expect("should converge");
371        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
372    }
373
374    #[test]
375    fn test_brent_cubic() {
376        let brent = Brent::default();
377        let act = brent
378            .find_in_bracket(
379                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
380                (1.0, 1.5),
381            )
382            .expect("should converge");
383        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
384    }
385
386    #[test]
387    fn test_secant_cubic() {
388        let secant = Secant::default();
389        let act = secant
390            .find_in_bracket(
391                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
392                (1.0, 1.5),
393            )
394            .expect("should converge");
395        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
396
397        let act = secant
398            .find(
399                |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
400                1.0,
401            )
402            .expect("should converge");
403        assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
404    }
405
406    #[test]
407    #[should_panic(expected = "derivative failed")]
408    fn test_newton_kepler_callback_error() {
409        let newton = Newton::default();
410        newton
411            .find_with_derivative(
412                |e: f64| -> Result { Ok(e) },
413                |_e: f64| -> Result { Err("derivative failed".into()) },
414                1.0,
415            )
416            .unwrap();
417    }
418
419    #[test]
420    #[should_panic(expected = "f failed")]
421    fn test_steffensen_cubic_error() {
422        let steffensen = Steffensen::default();
423        // function errors immediately
424        steffensen
425            .find(|_x| -> Result { Err("f failed".into()) }, 1.0)
426            .unwrap();
427    }
428
429    #[test]
430    #[should_panic(expected = "negative x")]
431    fn test_brent_cubic_error() {
432        let brent = Brent::default();
433        // error at bracket endpoint, then during iteration
434        brent
435            .find_in_bracket(
436                |x: f64| -> Result {
437                    if x.is_sign_negative() {
438                        Err("negative x".into())
439                    } else {
440                        Ok(x * x - 2.0)
441                    }
442                },
443                (-1.0, 2.0),
444            )
445            .unwrap();
446    }
447}