1use lox_test_utils::approx_eq;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum RootFinderError {
13 #[error("not converged after {0} iterations, residual {1}")]
15 NotConverged(u32, f64),
16 #[error("root not in bracket")]
18 NotInBracket,
19 #[error(transparent)]
21 Callback(#[from] CallbackError),
22}
23
24pub type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
26
27#[derive(Debug, Error)]
29#[error(transparent)]
30pub struct CallbackError(BoxedError);
31
32impl From<&str> for CallbackError {
33 fn from(s: &str) -> Self {
34 CallbackError(s.into())
35 }
36}
37
38impl From<BoxedError> for CallbackError {
39 fn from(e: BoxedError) -> Self {
40 CallbackError(e)
41 }
42}
43
44pub trait Callback {
46 fn call(&self, v: f64) -> Result<f64, CallbackError>;
48}
49
50impl<F> Callback for F
51where
52 F: Fn(f64) -> Result<f64, BoxedError>,
53{
54 fn call(&self, v: f64) -> Result<f64, CallbackError> {
55 self(v).map_err(CallbackError)
56 }
57}
58
59pub trait FindRoot<F>
61where
62 F: Callback,
63{
64 fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError>;
66}
67
68pub trait FindRootWithDerivative<F, D>
70where
71 F: Callback,
72 D: Callback,
73{
74 fn find_with_derivative(
76 &self,
77 f: F,
78 derivative: D,
79 initial_guess: f64,
80 ) -> Result<f64, RootFinderError>;
81}
82
83pub trait FindBracketedRoot<F>
85where
86 F: Callback,
87{
88 fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError>;
90}
91
92#[derive(Debug, Copy, Clone, PartialEq)]
94pub struct Steffensen {
95 max_iter: u32,
96 tolerance: f64,
97}
98
99impl Default for Steffensen {
100 fn default() -> Self {
101 Self {
102 max_iter: 1000,
103 tolerance: f64::EPSILON.sqrt(),
104 }
105 }
106}
107
108impl<F> FindRoot<F> for Steffensen
109where
110 F: Callback,
111{
112 fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
113 let mut p0 = initial_guess;
114 for _ in 0..self.max_iter {
115 let f1 = p0 + f.call(p0).map_err(RootFinderError::Callback)?;
116 let f2 = f1 + f.call(f1).map_err(RootFinderError::Callback)?;
117 let p = p0 - (f1 - p0).powi(2) / (f2 - 2.0 * f1 + p0);
118 if approx_eq!(p, p0, atol <= self.tolerance) {
119 return Ok(p);
120 }
121 p0 = p;
122 }
123 Err(RootFinderError::NotConverged(self.max_iter, p0))
124 }
125}
126
127#[derive(Debug, Copy, Clone, PartialEq)]
129pub struct Newton {
130 max_iter: u32,
131 tolerance: f64,
132}
133
134impl Default for Newton {
135 fn default() -> Self {
136 Self {
137 max_iter: 50,
138 tolerance: f64::EPSILON.sqrt(),
139 }
140 }
141}
142
143impl<F, D> FindRootWithDerivative<F, D> for Newton
144where
145 F: Callback,
146 D: Callback,
147{
148 fn find_with_derivative(
149 &self,
150 f: F,
151 derivative: D,
152 initial_guess: f64,
153 ) -> Result<f64, RootFinderError> {
154 let mut p0 = initial_guess;
155 for _ in 0..self.max_iter {
156 let p = p0
157 - f.call(p0).map_err(RootFinderError::Callback)?
158 / derivative.call(p0).map_err(RootFinderError::Callback)?;
159 if approx_eq!(p, p0, atol <= self.tolerance) {
160 return Ok(p);
161 }
162 p0 = p;
163 }
164 Err(RootFinderError::NotConverged(self.max_iter, p0))
165 }
166}
167
168#[derive(Debug, Copy, Clone, PartialEq)]
170pub struct Brent {
171 max_iter: u32,
172 abs_tol: f64,
173 rel_tol: f64,
174}
175
176impl Default for Brent {
177 fn default() -> Self {
178 Self {
179 max_iter: 100,
180 abs_tol: 1e-6,
181 rel_tol: f64::EPSILON.sqrt(),
182 }
183 }
184}
185
186impl<F> FindBracketedRoot<F> for Brent
187where
188 F: Callback,
189{
190 fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
191 let mut fblk = 0.0;
192 let mut xblk = 0.0;
193 let (mut xpre, mut xcur) = bracket;
194 let mut spre = 0.0;
195 let mut scur = 0.0;
196
197 let mut fpre = f.call(xpre).map_err(RootFinderError::Callback)?;
198 let mut fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
199
200 if fpre * fcur > 0.0 {
201 return Err(RootFinderError::NotInBracket);
202 }
203
204 if approx_eq!(fpre, 0.0, atol <= self.abs_tol) {
205 return Ok(xpre);
206 }
207
208 if approx_eq!(fcur, 0.0, atol <= self.abs_tol) {
209 return Ok(xcur);
210 }
211
212 for _ in 0..self.max_iter {
213 if fpre * fcur < 0.0 {
214 xblk = xpre;
215 fblk = fpre;
216 spre = xcur - xpre;
217 scur = xcur - xpre;
218 }
219
220 if fblk.abs() < fcur.abs() {
221 xpre = xcur;
222 xcur = xblk;
223 xblk = xpre;
224 fpre = fcur;
225 fcur = fblk;
226 fblk = fpre;
227 }
228
229 let delta = (self.abs_tol + self.rel_tol * xcur.abs()) / 2.0;
230 let sbis = (xblk - xcur) / 2.0;
231
232 if approx_eq!(fcur, 0.0, atol <= self.abs_tol) || sbis.abs() < delta {
233 return Ok(xcur);
234 }
235
236 if spre.abs() > delta && fcur.abs() < fpre.abs() {
237 let stry = if approx_eq!(xpre, xblk, rtol <= self.rel_tol) {
238 -fcur * (xcur - xpre) / (fcur - fpre)
240 } else {
241 let dpre = (fpre - fcur) / (xpre - xcur);
243 let dblk = (fblk - fcur) / (xblk - xcur);
244 -fcur * (fblk * dblk - fpre * dpre) / (dblk * dpre * (fblk - fpre))
245 };
246
247 if 2.0 * stry.abs() < spre.abs().min(3.0 * sbis.abs() - delta) {
248 spre = scur;
249 scur = stry;
250 } else {
251 spre = sbis;
253 scur = sbis;
254 }
255 } else {
256 spre = sbis;
258 scur = sbis;
259 }
260
261 xpre = xcur;
262 fpre = fcur;
263
264 if scur.abs() > delta {
265 xcur += scur
266 } else {
267 xcur += if sbis > 0.0 { delta } else { -delta };
268 }
269
270 fcur = f.call(xcur).map_err(RootFinderError::Callback)?;
271 }
272
273 Err(RootFinderError::NotConverged(self.max_iter, fcur))
274 }
275}
276
277#[derive(Debug, Copy, Clone, PartialEq)]
279pub struct Secant {
280 max_iter: u32,
281 rel_tol: f64,
282 abs_tol: f64,
283}
284
285impl Default for Secant {
286 fn default() -> Self {
287 Self {
288 max_iter: 100,
289 rel_tol: f64::EPSILON.sqrt(),
290 abs_tol: 1e-6,
291 }
292 }
293}
294
295impl<F> FindBracketedRoot<F> for Secant
296where
297 F: Callback,
298{
299 fn find_in_bracket(&self, f: F, bracket: (f64, f64)) -> Result<f64, RootFinderError> {
300 let (x0, x1) = bracket;
301 let mut p0 = x0;
302 let mut p1 = x1;
303 let mut q0 = f.call(p0).map_err(RootFinderError::Callback)?;
304 let mut q1 = f.call(p1).map_err(RootFinderError::Callback)?;
305 if q1.abs() < q0.abs() {
306 std::mem::swap(&mut p0, &mut p1);
307 std::mem::swap(&mut q0, &mut q1);
308 }
309 for i in 0..self.max_iter {
310 if q1 == q0 {
311 if p1 != p0 {
312 return Err(RootFinderError::NotConverged(i, q0));
313 }
314 return Ok((p1 + p0) / 2.0);
315 }
316 let p = if q1.abs() > q0.abs() {
317 (-q0 / q1 * p1 + p0) / (1.0 - q0 / q1)
318 } else {
319 (-q1 / q0 * p0 + p1) / (1.0 - q1 / q0)
320 };
321 if approx_eq!(p, p1, rtol <= self.rel_tol, atol <= self.abs_tol) {
322 return Ok(p);
323 }
324 p0 = p1;
325 q0 = q1;
326 p1 = p;
327 q1 = f.call(p).map_err(RootFinderError::Callback)?;
328 }
329 Err(RootFinderError::NotConverged(self.max_iter, p0))
330 }
331}
332
333impl<F> FindRoot<F> for Secant
334where
335 F: Callback,
336{
337 fn find(&self, f: F, initial_guess: f64) -> Result<f64, RootFinderError> {
338 let x0 = initial_guess;
339 let eps = 1e-4;
340 let mut x1 = x0 * (1.0 + eps);
341 x1 += if x1 > x0 { eps } else { -eps };
342 self.find_in_bracket(f, (x0, x1))
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use lox_test_utils::assert_approx_eq;
349 use std::f64::consts::PI;
350
351 use super::*;
352
353 type Result = std::result::Result<f64, BoxedError>;
354
355 #[test]
356 fn test_newton_kepler() {
357 fn mean_to_ecc(mean: f64, eccentricity: f64) -> std::result::Result<f64, RootFinderError> {
358 let newton = Newton::default();
359 newton.find_with_derivative(
360 |e: f64| -> Result { Ok(e - eccentricity * e.sin() - mean) },
361 |e: f64| -> Result { Ok(1.0 - eccentricity * e.cos()) },
362 mean,
363 )
364 }
365 let act = mean_to_ecc(PI / 2.0, 0.3).expect("should converge");
366 assert_approx_eq!(act, 1.85846841205333, rtol <= 1e-8);
367 }
368
369 #[test]
370 fn test_newton_cubic() {
371 let newton = Newton::default();
372 let act = newton
373 .find_with_derivative(
374 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
375 |x: f64| -> Result { Ok(2.0 * x.powi(2) + 8.0 * x) },
376 1.5,
377 )
378 .expect("should converge");
379 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
380 }
381
382 #[test]
383 fn test_steffensen_cubic() {
384 let steffensen = Steffensen::default();
385 let act = steffensen
386 .find(
387 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
388 1.5,
389 )
390 .expect("should converge");
391 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
392 }
393
394 #[test]
395 fn test_brent_cubic() {
396 let brent = Brent::default();
397 let act = brent
398 .find_in_bracket(
399 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
400 (1.0, 1.5),
401 )
402 .expect("should converge");
403 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
404 }
405
406 #[test]
407 fn test_secant_cubic() {
408 let secant = Secant::default();
409 let act = secant
410 .find_in_bracket(
411 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
412 (1.0, 1.5),
413 )
414 .expect("should converge");
415 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
416
417 let act = secant
418 .find(
419 |x: f64| -> Result { Ok(x.powi(3) + 4.0 * x.powi(2) - 10.0) },
420 1.0,
421 )
422 .expect("should converge");
423 assert_approx_eq!(act, 1.3652300134140969, rtol <= 1e-8);
424 }
425
426 #[test]
427 #[should_panic(expected = "derivative failed")]
428 fn test_newton_kepler_callback_error() {
429 let newton = Newton::default();
430 newton
431 .find_with_derivative(
432 |e: f64| -> Result { Ok(e) },
433 |_e: f64| -> Result { Err("derivative failed".into()) },
434 1.0,
435 )
436 .unwrap();
437 }
438
439 #[test]
440 #[should_panic(expected = "f failed")]
441 fn test_steffensen_cubic_error() {
442 let steffensen = Steffensen::default();
443 steffensen
445 .find(|_x| -> Result { Err("f failed".into()) }, 1.0)
446 .unwrap();
447 }
448
449 #[test]
450 #[should_panic(expected = "negative x")]
451 fn test_brent_cubic_error() {
452 let brent = Brent::default();
453 brent
455 .find_in_bracket(
456 |x: f64| -> Result {
457 if x.is_sign_negative() {
458 Err("negative x".into())
459 } else {
460 Ok(x * x - 2.0)
461 }
462 },
463 (-1.0, 2.0),
464 )
465 .unwrap();
466 }
467}