Skip to main content

lox_core/math/
optim.rs

1// SPDX-FileCopyrightText: 2026 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5//! Bracketed optimization algorithms.
6
7use lox_test_utils::approx_eq;
8
9use super::roots::{Callback, RootFinderError};
10
11/// Finds the minimum of a function within a bracket.
12pub trait FindBracketedMinimum<F>
13where
14    F: Callback,
15{
16    /// Finds the x value that minimizes `f` within the given `bracket`.
17    fn find_minimum_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError>;
18}
19
20/// Brent's method for finding the minimum of a unimodal function in a bracket.
21///
22/// Combines golden section search with parabolic interpolation.
23#[derive(Debug, Copy, Clone, PartialEq)]
24pub struct BrentMinimizer {
25    /// Maximum number of iterations.
26    pub max_iter: u32,
27    /// Absolute tolerance for convergence.
28    pub abs_tol: f64,
29}
30
31impl Default for BrentMinimizer {
32    fn default() -> Self {
33        Self {
34            max_iter: 500,
35            abs_tol: 1e-10,
36        }
37    }
38}
39
40/// Golden ratio constant used in Brent minimization.
41const GOLDEN: f64 = 0.381_966_011_250_105_1; // (3 - sqrt(5)) / 2
42
43impl<F> FindBracketedMinimum<F> for BrentMinimizer
44where
45    F: Callback,
46{
47    fn find_minimum_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
48        let (mut a, mut b) = bracket;
49        if a > b {
50            std::mem::swap(&mut a, &mut b);
51        }
52
53        // x is the point with the least function value found so far.
54        // w is the point with the second least value.
55        // v is the previous value of w.
56        let mut x = a + GOLDEN * (b - a);
57        let mut w = x;
58        let mut v = x;
59        let mut fx = f.call(x)?;
60        let mut fw = fx;
61        let mut fv = fx;
62
63        // e is the distance moved on the step before last.
64        // d is the distance moved on the last step.
65        let mut e = 0.0_f64;
66        let mut d = 0.0_f64;
67
68        for _ in 0..self.max_iter {
69            let midpoint = 0.5 * (a + b);
70            let tol1 = self.abs_tol * x.abs() + 1e-10;
71            let tol2 = 2.0 * tol1;
72
73            // Check convergence.
74            if (x - midpoint).abs() <= tol2 - 0.5 * (b - a) {
75                return Ok(x);
76            }
77
78            // Try parabolic interpolation.
79            let mut use_golden = true;
80            if e.abs() > tol1 {
81                // Fit parabola through x, v, w.
82                let r = (x - w) * (fx - fv);
83                let q = (x - v) * (fx - fw);
84                let p = (x - v) * q - (x - w) * r;
85                let q = 2.0 * (q - r);
86                let (p, q) = if q > 0.0 { (-p, q) } else { (p, -q) };
87
88                // Is the parabola acceptable?
89                if p.abs() < (0.5 * q * e).abs() && p > q * (a - x) && p < q * (b - x) {
90                    e = d;
91                    d = p / q;
92                    let u = x + d;
93
94                    // f must not be evaluated too close to a or b.
95                    if (u - a) < tol2 || (b - u) < tol2 {
96                        d = if x < midpoint { tol1 } else { -tol1 };
97                    }
98                    use_golden = false;
99                }
100            }
101
102            if use_golden {
103                // Golden section step.
104                e = if x < midpoint { b - x } else { a - x };
105                d = GOLDEN * e;
106            }
107
108            // f must not be evaluated too close to x.
109            let u = if d.abs() >= tol1 {
110                x + d
111            } else if d > 0.0 {
112                x + tol1
113            } else {
114                x - tol1
115            };
116
117            let fu = f.call(u)?;
118
119            // Update a, b, v, w, x.
120            if fu <= fx {
121                if u < x {
122                    b = x;
123                } else {
124                    a = x;
125                }
126                v = w;
127                fv = fw;
128                w = x;
129                fw = fx;
130                x = u;
131                fx = fu;
132            } else {
133                if u < x {
134                    a = u;
135                } else {
136                    b = u;
137                }
138                if fu <= fw || approx_eq!(w, x, atol <= 1e-15) {
139                    v = w;
140                    fv = fw;
141                    w = u;
142                    fw = fu;
143                } else if fu <= fv
144                    || approx_eq!(v, x, atol <= 1e-15)
145                    || approx_eq!(v, w, atol <= 1e-15)
146                {
147                    v = u;
148                    fv = fu;
149                }
150            }
151        }
152
153        Err(RootFinderError::NotConverged(self.max_iter, fx))
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use lox_test_utils::assert_approx_eq;
160    use std::f64::consts::PI;
161
162    use super::*;
163
164    type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
165    type Result = std::result::Result<f64, BoxedError>;
166
167    #[test]
168    fn test_brent_minimizer_quadratic() {
169        let minimizer = BrentMinimizer::default();
170        let x = minimizer
171            .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 3.0).powi(2)) }, (0.0, 5.0))
172            .expect("should converge");
173        assert_approx_eq!(x, 3.0, atol <= 1e-8);
174    }
175
176    #[test]
177    fn test_brent_minimizer_cosine() {
178        // cos(x) has a minimum at PI in [PI/2, 3*PI/2]
179        let minimizer = BrentMinimizer::default();
180        let x = minimizer
181            .find_minimum_in_bracket(
182                |x: f64| -> Result { Ok(x.cos()) },
183                (PI / 2.0, 3.0 * PI / 2.0),
184            )
185            .expect("should converge");
186        assert_approx_eq!(x, PI, atol <= 1e-8);
187    }
188
189    #[test]
190    fn test_brent_minimizer_reversed_bracket() {
191        let minimizer = BrentMinimizer::default();
192        let x = minimizer
193            .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 2.0).powi(2)) }, (5.0, 0.0))
194            .expect("should converge");
195        assert_approx_eq!(x, 2.0, atol <= 1e-8);
196    }
197
198    #[test]
199    fn test_brent_minimizer_custom_tolerance() {
200        let minimizer = BrentMinimizer {
201            max_iter: 100,
202            abs_tol: 1e-4,
203        };
204        let x = minimizer
205            .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 1.0).powi(2)) }, (-2.0, 5.0))
206            .expect("should converge");
207        assert_approx_eq!(x, 1.0, atol <= 1e-3);
208    }
209
210    #[test]
211    fn test_brent_minimizer_not_converged() {
212        let minimizer = BrentMinimizer {
213            max_iter: 0,
214            abs_tol: 1e-15,
215        };
216        let result = minimizer
217            .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 1.0).powi(2)) }, (0.0, 5.0));
218        assert!(result.is_err());
219    }
220}