1use lox_test_utils::approx_eq;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum RootFinderError {
10 #[error("not converged after {0} iterations, residual {1}")]
11 NotConverged(u32, f64),
12 #[error("root not in bracket")]
13 NotInBracket,
14 #[error(transparent)]
15 Callback(#[from] CallbackError),
16}
17
18pub type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
19
20#[derive(Debug, Error)]
21#[error(transparent)]
22pub struct CallbackError(BoxedError);
23
24impl From<&str> for CallbackError {
25 fn from(s: &str) -> Self {
26 CallbackError(s.into())
27 }
28}
29
30impl From<BoxedError> for CallbackError {
31 fn from(e: BoxedError) -> Self {
32 CallbackError(e)
33 }
34}
35
36pub trait Callback {
37 fn call(&self, v: f64) -> Result<f64, CallbackError>;
38}
39
40impl<F> Callback for F
41where
42 F: Fn(f64) -> Result<f64, BoxedError>,
43{
44 fn call(&self, v: f64) -> Result<f64, CallbackError> {
45 self(v).map_err(CallbackError)
46 }
47}
48
49pub trait FindRoot<F>
50where
51 F: Callback,
52{
53 fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError>;
54}
55
56pub trait FindRootWithDerivative<F, D>
57where
58 F: Callback,
59 D: Callback,
60{
61 fn find_with_derivative(
62 &self,
63 f: F,
64 derivative: D,
65 initial_guess: f64,
66 ) -> Result<f64, RootFinderError>;
67}
68
69pub trait FindBracketedRoot<F>
70where
71 F: Callback,
72{
73 fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError>;
74}
75
76#[derive(Debug, Copy, Clone, PartialEq)]
77pub struct Steffensen {
78 max_iter: u32,
79 tolerance: f64,
80}
81
82impl Default for Steffensen {
83 fn default() -> Self {
84 Self {
85 max_iter: 1000,
86 tolerance: f64::EPSILON.sqrt(),
87 }
88 }
89}
90
91impl<F> FindRoot<F> for Steffensen
92where
93 F: Callback,
94{
95 fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
96 let mut p0 = initial_guess;
97 for _ in 0..self.max_iter {
98 let f1 = p0 + f.call(p0).map_err(RootFinderError::Callback)?;
99 let f2 = f1 + f.call(f1).map_err(RootFinderError::Callback)?;
100 let p = p0 - (f1 - p0).powi(2) / (f2 - 2.0 * f1 + p0);
101 if approx_eq!(p, p0, atol <= self.tolerance) {
102 return Ok(p);
103 }
104 p0 = p;
105 }
106 Err(RootFinderError::NotConverged(self.max_iter, p0))
107 }
108}
109
110#[derive(Debug, Copy, Clone, PartialEq)]
111pub struct Newton {
112 max_iter: u32,
113 tolerance: f64,
114}
115
116impl Default for Newton {
117 fn default() -> Self {
118 Self {
119 max_iter: 50,
120 tolerance: f64::EPSILON.sqrt(),
121 }
122 }
123}
124
125impl<F, D> FindRootWithDerivative<F, D> for Newton
126where
127 F: Callback,
128 D: Callback,
129{
130 fn find_with_derivative(
131 &self,
132 f: F,
133 derivative: D,
134 initial_guess: f64,
135 ) -> Result<f64, RootFinderError> {
136 let mut p0 = initial_guess;
137 for _ in 0..self.max_iter {
138 let p = p0
139 - f.call(p0).map_err(RootFinderError::Callback)?
140 / derivative.call(p0).map_err(RootFinderError::Callback)?;
141 if approx_eq!(p, p0, atol <= self.tolerance) {
142 return Ok(p);
143 }
144 p0 = p;
145 }
146 Err(RootFinderError::NotConverged(self.max_iter, p0))
147 }
148}
149
150#[derive(Debug, Copy, Clone, PartialEq)]
151pub struct Brent {
152 max_iter: u32,
153 abs_tol: f64,
154 rel_tol: f64,
155}
156
157impl Default for Brent {
158 fn default() -> Self {
159 Self {
160 max_iter: 100,
161 abs_tol: 1e-6,
162 rel_tol: f64::EPSILON.sqrt(),
163 }
164 }
165}
166
167impl<F> FindBracketedRoot<F> for Brent
168where
169 F: Callback,
170{
171 fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
172 let mut fblk = 0.0;
173 let mut xblk = 0.0;
174 let (mut xpre, mut xcur) = bracket;
175 let mut spre = 0.0;
176 let mut scur = 0.0;
177
178 let mut fpre = f.call(xpre).map_err(RootFinderError::Callback)?;
179 let mut fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
180
181 if fpre * fcur > 0.0 {
182 return Err(RootFinderError::NotInBracket);
183 }
184
185 if approx_eq!(fpre, 0.0, atol <= self.abs_tol) {
186 return Ok(xpre);
187 }
188
189 if approx_eq!(fcur, 0.0, atol <= self.abs_tol) {
190 return Ok(xcur);
191 }
192
193 for _ in 0..self.max_iter {
194 if fpre * fcur < 0.0 {
195 xblk = xpre;
196 fblk = fpre;
197 spre = xcur - xpre;
198 scur = xcur - xpre;
199 }
200
201 if fblk.abs() < fcur.abs() {
202 xpre = xcur;
203 xcur = xblk;
204 xblk = xpre;
205 fpre = fcur;
206 fcur = fblk;
207 fblk = fpre;
208 }
209
210 let delta = (self.abs_tol + self.rel_tol * xcur.abs()) / 2.0;
211 let sbis = (xblk - xcur) / 2.0;
212
213 if approx_eq!(fcur, 0.0, atol <= self.abs_tol) || sbis.abs() < delta {
214 return Ok(xcur);
215 }
216
217 if spre.abs() > delta && fcur.abs() < fpre.abs() {
218 let stry = if approx_eq!(xpre, xblk, rtol <= self.rel_tol) {
219 -fcur * (xcur - xpre) / (fcur - fpre)
221 } else {
222 let dpre = (fpre - fcur) / (xpre - xcur);
224 let dblk = (fblk - fcur) / (xblk - xcur);
225 -fcur * (fblk * dblk - fpre * dpre) / (dblk * dpre * (fblk - fpre))
226 };
227
228 if 2.0 * stry.abs() < spre.abs().min(3.0 * sbis.abs() - delta) {
229 spre = scur;
230 scur = stry;
231 } else {
232 spre = sbis;
234 scur = sbis;
235 }
236 } else {
237 spre = sbis;
239 scur = sbis;
240 }
241
242 xpre = xcur;
243 fpre = fcur;
244
245 if scur.abs() > delta {
246 xcur += scur
247 } else {
248 xcur += if sbis > 0.0 { delta } else { -delta };
249 }
250
251 fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
252 }
253
254 Err(RootFinderError::NotConverged(self.max_iter, fcur))
255 }
256}
257
258#[derive(Debug, Copy, Clone, PartialEq)]
259pub struct Secant {
260 max_iter: u32,
261 rel_tol: f64,
262 abs_tol: f64,
263}
264
265impl Default for Secant {
266 fn default() -> Self {
267 Self {
268 max_iter: 100,
269 rel_tol: f64::EPSILON.sqrt(),
270 abs_tol: 1e-6,
271 }
272 }
273}
274
275impl<F> FindBracketedRoot<F> for Secant
276where
277 F: Callback,
278{
279 fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
280 let (x0, x1) = bracket;
281 let mut p0 = x0;
282 let mut p1 = x1;
283 let mut q0 = f.call(p0).map_err(RootFinderError::Callback)?;
284 let mut q1 = f.call(p1).map_err(RootFinderError::Callback)?;
285 if q1.abs() < q0.abs() {
286 std::mem::swap(&mut p0, &mut p1);
287 std::mem::swap(&mut q0, &mut q1);
288 }
289 for i in 0..self.max_iter {
290 if q1 == q0 {
291 if p1 != p0 {
292 return Err(RootFinderError::NotConverged(i, q0));
293 }
294 return Ok((p1 + p0) / 2.0);
295 }
296 let p = if q1.abs() > q0.abs() {
297 (-q0 / q1 * p1 + p0) / (1.0 - q0 / q1)
298 } else {
299 (-q1 / q0 * p0 + p1) / (1.0 - q1 / q0)
300 };
301 if approx_eq!(p, p1, rtol <= self.rel_tol, atol <= self.abs_tol) {
302 return Ok(p);
303 }
304 p0 = p1;
305 q0 = q1;
306 p1 = p;
307 q1 = f.call(p).map_err(RootFinderError::Callback)?;
308 }
309 Err(RootFinderError::NotConverged(self.max_iter, p0))
310 }
311}
312
313impl<F> FindRoot<F> for Secant
314where
315 F: Callback,
316{
317 fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
318 let x0 = initial_guess;
319 let eps = 1e-4;
320 let mut x1 = x0 * (1.0 + eps);
321 x1 += if x1 > x0 { eps } else { -eps };
322 self.find_in_bracket(f, (x0, x1))
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use lox_test_utils::assert_approx_eq;
329 use std::f64::consts::PI;
330
331 use super::*;
332
333 type Result = std::result::Result<f64, BoxedError>;
334
335 #[test]
336 fn test_newton_kepler() {
337 fn mean_to_ecc(mean: f64, eccentricity: f64) -> std::result::Result<f64, RootFinderError> {
338 let newton = Newton::default();
339 newton.find_with_derivative(
340 |e: f64| -> Result { Ok(e - eccentricity * e.sin() - mean) },
341 |e: f64| -> Result { Ok(1.0 - eccentricity * e.cos()) },
342 mean,
343 )
344 }
345 let act = mean_to_ecc(PI / 2.0, 0.3).expect("should converge");
346 assert_approx_eq!(act, 1.85846841205333, rtol <= 1e-8);
347 }
348
349 #[test]
350 fn test_newton_cubic() {
351 let newton = Newton::default();
352 let act = newton
353 .find_with_derivative(
354 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
355 |x: f64| -> Result { Ok(2.0 * x.powi(2) + 8.0 * x) },
356 1.5,
357 )
358 .expect("should converge");
359 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
360 }
361
362 #[test]
363 fn test_steffensen_cubic() {
364 let steffensen = Steffensen::default();
365 let act = steffensen
366 .find(
367 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
368 1.5,
369 )
370 .expect("should converge");
371 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
372 }
373
374 #[test]
375 fn test_brent_cubic() {
376 let brent = Brent::default();
377 let act = brent
378 .find_in_bracket(
379 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
380 (1.0, 1.5),
381 )
382 .expect("should converge");
383 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
384 }
385
386 #[test]
387 fn test_secant_cubic() {
388 let secant = Secant::default();
389 let act = secant
390 .find_in_bracket(
391 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
392 (1.0, 1.5),
393 )
394 .expect("should converge");
395 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
396
397 let act = secant
398 .find(
399 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
400 1.0,
401 )
402 .expect("should converge");
403 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
404 }
405
406 #[test]
407 #[should_panic(expected = "derivative failed")]
408 fn test_newton_kepler_callback_error() {
409 let newton = Newton::default();
410 newton
411 .find_with_derivative(
412 |e: f64| -> Result { Ok(e) },
413 |_e: f64| -> Result { Err("derivative failed".into()) },
414 1.0,
415 )
416 .unwrap();
417 }
418
419 #[test]
420 #[should_panic(expected = "f failed")]
421 fn test_steffensen_cubic_error() {
422 let steffensen = Steffensen::default();
423 steffensen
425 .find(|_x| -> Result { Err("f failed".into()) }, 1.0)
426 .unwrap();
427 }
428
429 #[test]
430 #[should_panic(expected = "negative x")]
431 fn test_brent_cubic_error() {
432 let brent = Brent::default();
433 brent
435 .find_in_bracket(
436 |x: f64| -> Result {
437 if x.is_sign_negative() {
438 Err("negative x".into())
439 } else {
440 Ok(x * x - 2.0)
441 }
442 },
443 (-1.0, 2.0),
444 )
445 .unwrap();
446 }
447}