1use lox_test_utils::approx_eq;
8
9use super::roots::{Callback, RootFinderError};
10
11pub trait FindBracketedMinimum<F>
13where
14 F: Callback,
15{
16 fn find_minimum_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError>;
18}
19
20#[derive(Debug, Copy, Clone, PartialEq)]
24pub struct BrentMinimizer {
25 pub max_iter: u32,
27 pub abs_tol: f64,
29}
30
31impl Default for BrentMinimizer {
32 fn default() -> Self {
33 Self {
34 max_iter: 500,
35 abs_tol: 1e-10,
36 }
37 }
38}
39
40const GOLDEN: f64 = 0.381_966_011_250_105_1; impl<F> FindBracketedMinimum<F> for BrentMinimizer
44where
45 F: Callback,
46{
47 fn find_minimum_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
48 let (mut a, mut b) = bracket;
49 if a > b {
50 std::mem::swap(&mut a, &mut b);
51 }
52
53 let mut x = a + GOLDEN * (b - a);
57 let mut w = x;
58 let mut v = x;
59 let mut fx = f.call(x)?;
60 let mut fw = fx;
61 let mut fv = fx;
62
63 let mut e = 0.0_f64;
66 let mut d = 0.0_f64;
67
68 for _ in 0..self.max_iter {
69 let midpoint = 0.5 * (a + b);
70 let tol1 = self.abs_tol * x.abs() + 1e-10;
71 let tol2 = 2.0 * tol1;
72
73 if (x - midpoint).abs() <= tol2 - 0.5 * (b - a) {
75 return Ok(x);
76 }
77
78 let mut use_golden = true;
80 if e.abs() > tol1 {
81 let r = (x - w) * (fx - fv);
83 let q = (x - v) * (fx - fw);
84 let p = (x - v) * q - (x - w) * r;
85 let q = 2.0 * (q - r);
86 let (p, q) = if q > 0.0 { (-p, q) } else { (p, -q) };
87
88 if p.abs() < (0.5 * q * e).abs() && p > q * (a - x) && p < q * (b - x) {
90 e = d;
91 d = p / q;
92 let u = x + d;
93
94 if (u - a) < tol2 || (b - u) < tol2 {
96 d = if x < midpoint { tol1 } else { -tol1 };
97 }
98 use_golden = false;
99 }
100 }
101
102 if use_golden {
103 e = if x < midpoint { b - x } else { a - x };
105 d = GOLDEN * e;
106 }
107
108 let u = if d.abs() >= tol1 {
110 x + d
111 } else if d > 0.0 {
112 x + tol1
113 } else {
114 x - tol1
115 };
116
117 let fu = f.call(u)?;
118
119 if fu <= fx {
121 if u < x {
122 b = x;
123 } else {
124 a = x;
125 }
126 v = w;
127 fv = fw;
128 w = x;
129 fw = fx;
130 x = u;
131 fx = fu;
132 } else {
133 if u < x {
134 a = u;
135 } else {
136 b = u;
137 }
138 if fu <= fw || approx_eq!(w, x, atol <= 1e-15) {
139 v = w;
140 fv = fw;
141 w = u;
142 fw = fu;
143 } else if fu <= fv
144 || approx_eq!(v, x, atol <= 1e-15)
145 || approx_eq!(v, w, atol <= 1e-15)
146 {
147 v = u;
148 fv = fu;
149 }
150 }
151 }
152
153 Err(RootFinderError::NotConverged(self.max_iter, fx))
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use lox_test_utils::assert_approx_eq;
160 use std::f64::consts::PI;
161
162 use super::*;
163
164 type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
165 type Result = std::result::Result<f64, BoxedError>;
166
167 #[test]
168 fn test_brent_minimizer_quadratic() {
169 let minimizer = BrentMinimizer::default();
170 let x = minimizer
171 .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 3.0).powi(2)) }, (0.0, 5.0))
172 .expect("should converge");
173 assert_approx_eq!(x, 3.0, atol <= 1e-8);
174 }
175
176 #[test]
177 fn test_brent_minimizer_cosine() {
178 let minimizer = BrentMinimizer::default();
180 let x = minimizer
181 .find_minimum_in_bracket(
182 |x: f64| -> Result { Ok(x.cos()) },
183 (PI / 2.0, 3.0 * PI / 2.0),
184 )
185 .expect("should converge");
186 assert_approx_eq!(x, PI, atol <= 1e-8);
187 }
188
189 #[test]
190 fn test_brent_minimizer_reversed_bracket() {
191 let minimizer = BrentMinimizer::default();
192 let x = minimizer
193 .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 2.0).powi(2)) }, (5.0, 0.0))
194 .expect("should converge");
195 assert_approx_eq!(x, 2.0, atol <= 1e-8);
196 }
197
198 #[test]
199 fn test_brent_minimizer_custom_tolerance() {
200 let minimizer = BrentMinimizer {
201 max_iter: 100,
202 abs_tol: 1e-4,
203 };
204 let x = minimizer
205 .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 1.0).powi(2)) }, (-2.0, 5.0))
206 .expect("should converge");
207 assert_approx_eq!(x, 1.0, atol <= 1e-3);
208 }
209
210 #[test]
211 fn test_brent_minimizer_not_converged() {
212 let minimizer = BrentMinimizer {
213 max_iter: 0,
214 abs_tol: 1e-15,
215 };
216 let result = minimizer
217 .find_minimum_in_bracket(|x: f64| -> Result { Ok((x - 1.0).powi(2)) }, (0.0, 5.0));
218 assert!(result.is_err());
219 }
220}