1#[derive(Clone, Copy, Debug)]
16pub struct MonotoneRootSolution {
17 pub root: f64,
18 pub abs_deriv: f64,
19 pub residual: f64,
20 pub refine_iters: usize,
21}
22
23pub use gam_problem::MonotoneRootError;
24
25const NEWTON_DERIV_FLOOR: f64 = 1e-30;
32
33const WARMSTART_NEWTON_STEP_LIMIT: f64 = 8.0;
39
40const BRACKET_INITIAL_STEP_FRAC: f64 = 0.25;
43
44#[inline]
46fn map_eval_err(label: &str, a: f64, source: String) -> MonotoneRootError {
47 MonotoneRootError::EvalFailed {
48 label: label.to_string(),
49 a,
50 source,
51 }
52}
53
54pub fn solve_monotone_root(
55 eval: impl Fn(f64) -> Result<(f64, f64, f64), String>,
56 a_init: f64,
57 label: &str,
58 convergence_tol: f64,
59 max_bracket_iters: usize,
60 max_refine_iters: usize,
61) -> Result<(f64, f64, f64), MonotoneRootError> {
62 let solution = solve_monotone_root_detailed(
63 eval,
64 a_init,
65 label,
66 convergence_tol,
67 max_bracket_iters,
68 max_refine_iters,
69 )?;
70 Ok((solution.root, solution.abs_deriv, solution.residual))
71}
72
73pub fn solve_monotone_root_detailed(
74 eval: impl Fn(f64) -> Result<(f64, f64, f64), String>,
75 a_init: f64,
76 label: &str,
77 convergence_tol: f64,
78 max_bracket_iters: usize,
79 max_refine_iters: usize,
80) -> Result<MonotoneRootSolution, MonotoneRootError> {
81 solve_monotone_root_detailed_with_bracket(
82 eval,
83 a_init,
84 label,
85 convergence_tol,
86 max_bracket_iters,
87 max_refine_iters,
88 None,
89 )
90}
91
92pub fn solve_monotone_root_detailed_with_bracket(
93 eval: impl Fn(f64) -> Result<(f64, f64, f64), String>,
94 a_init: f64,
95 label: &str,
96 convergence_tol: f64,
97 max_bracket_iters: usize,
98 max_refine_iters: usize,
99 analytic_bracket: Option<(f64, f64)>,
100) -> Result<MonotoneRootSolution, MonotoneRootError> {
101 let (f_init, f_deriv_init, _) = eval(a_init).map_err(|e| map_eval_err(label, a_init, e))?;
102
103 if f_init.abs() <= convergence_tol {
105 let abs_d = f_deriv_init.abs();
106 if !abs_d.is_finite() || abs_d == 0.0 {
107 return Err(MonotoneRootError::exact_root_degenerate(label, a_init));
108 }
109 return Ok(MonotoneRootSolution {
110 root: a_init,
111 abs_deriv: abs_d,
112 residual: f_init,
113 refine_iters: 0,
114 });
115 }
116
117 if !f_deriv_init.is_finite() || f_deriv_init == 0.0 {
118 return Err(MonotoneRootError::DegenerateDerivative {
119 label: label.to_string(),
120 a: a_init,
121 fp: f_deriv_init,
122 });
123 }
124
125 let mut a = a_init;
130 let mut f = f_init;
131 let mut fp = f_deriv_init;
132 for probe_iter in 0..2 {
133 if f.abs() <= convergence_tol {
134 let abs_d = fp.abs();
135 if !abs_d.is_finite() || abs_d == 0.0 {
136 break;
137 }
138 return Ok(MonotoneRootSolution {
139 root: a,
140 abs_deriv: abs_d,
141 residual: f,
142 refine_iters: probe_iter,
143 });
144 }
145
146 if !fp.is_finite() || fp.abs() <= NEWTON_DERIV_FLOOR {
147 break;
148 }
149
150 let step = -f / fp;
151 if !step.is_finite() || step.abs() > WARMSTART_NEWTON_STEP_LIMIT * (1.0 + a.abs()) {
152 break;
153 }
154
155 let cand = a + step;
156 let (f_cand, fp_cand, _) = eval(cand).map_err(|e| map_eval_err(label, cand, e))?;
157 if f_cand.abs() <= convergence_tol {
158 let abs_d = fp_cand.abs();
159 if !abs_d.is_finite() || abs_d == 0.0 {
160 break;
161 }
162 return Ok(MonotoneRootSolution {
163 root: cand,
164 abs_deriv: abs_d,
165 residual: f_cand,
166 refine_iters: probe_iter + 1,
167 });
168 }
169
170 a = cand;
171 f = f_cand;
172 fp = fp_cand;
173 }
174
175 let (mut neg_pt, mut pos_pt) = if let Some((lo, hi)) = analytic_bracket {
177 if !lo.is_finite() || !hi.is_finite() || lo == hi {
178 return Err(MonotoneRootError::analytic_bracket_invalid(label, lo, hi));
179 }
180 let (f_lo, _, _) = eval(lo).map_err(|e| map_eval_err(label, lo, e))?;
181 let (f_hi, _, _) = eval(hi).map_err(|e| map_eval_err(label, hi, e))?;
182 if f_lo <= 0.0 && f_hi >= 0.0 {
183 (lo, hi)
184 } else if f_hi <= 0.0 && f_lo >= 0.0 {
185 (hi, lo)
186 } else {
187 return Err(MonotoneRootError::analytic_bracket_no_straddle(
188 label, f_lo, f_hi,
189 ));
190 }
191 } else {
192 let step_sign: f64 = if f_init * f_deriv_init < 0.0 {
203 1.0
204 } else {
205 -1.0
206 };
207
208 let f_init_negative = f_init < 0.0;
209 let mut same_side = a_init; let mut step_mag = (BRACKET_INITIAL_STEP_FRAC * (1.0 + a_init.abs())).max(1.0);
211 let step_cap = 1e6_f64.max(1024.0 * (1.0 + a_init.abs()));
222 let mut found_other: Option<(f64, f64)> = None;
223
224 for _ in 0..max_bracket_iters {
225 let probe = same_side + step_mag * step_sign;
226 let (f_probe, _, _) = eval(probe).map_err(|e| map_eval_err(label, probe, e))?;
227 let crossed = if f_init_negative {
228 f_probe >= 0.0
229 } else {
230 f_probe <= 0.0
231 };
232 if crossed {
233 found_other = Some((probe, f_probe));
234 break;
235 }
236 same_side = probe;
237 step_mag *= 2.0;
238 if step_mag > step_cap {
239 break;
240 }
241 }
242
243 let Some((other, _)) = found_other else {
244 return Err(MonotoneRootError::search_exhausted(
245 label, step_sign, a_init,
246 ));
247 };
248
249 if f_init_negative {
250 (same_side, other)
251 } else {
252 (other, same_side)
253 }
254 };
255
256 let mut best_a = a_init;
259 let mut best_f = f_init;
260 let mut best_abs_deriv = f_deriv_init.abs();
261
262 #[inline]
263 fn update_best(
264 best_a: &mut f64,
265 best_f: &mut f64,
266 best_abs_d: &mut f64,
267 a: f64,
268 f: f64,
269 f_d: f64,
270 ) {
271 if f.abs() < best_f.abs() {
272 *best_a = a;
273 *best_f = f;
274 *best_abs_d = f_d.abs();
275 }
276 }
277
278 let mut refine_iters = 0usize;
279 for _ in 0..max_refine_iters {
280 refine_iters += 1;
281 let (lo, hi) = if neg_pt <= pos_pt {
282 (neg_pt, pos_pt)
283 } else {
284 (pos_pt, neg_pt)
285 };
286 let mid = 0.5 * (lo + hi);
287 let (f_mid, f_a_mid, f_aa_mid) = eval(mid).map_err(|e| map_eval_err(label, mid, e))?;
288 update_best(
289 &mut best_a,
290 &mut best_f,
291 &mut best_abs_deriv,
292 mid,
293 f_mid,
294 f_a_mid,
295 );
296
297 if f_mid.abs() <= convergence_tol {
298 break;
299 }
300
301 let halley_probe = if f_a_mid.is_finite() && f_a_mid.abs() > NEWTON_DERIV_FLOOR {
306 let halley_denom = 2.0 * f_a_mid * f_a_mid - f_mid * f_aa_mid;
307 if halley_denom.is_finite() && halley_denom.abs() > NEWTON_DERIV_FLOOR {
308 let cand = mid - (2.0 * f_mid * f_a_mid) / halley_denom;
309 if cand > lo && cand < hi {
310 Some(cand)
311 } else {
312 None
313 }
314 } else {
315 None
316 }
317 } else {
318 None
319 };
320
321 let probe = if let Some(cand) = halley_probe {
324 cand
325 } else if f_a_mid.is_finite() && f_a_mid.abs() > NEWTON_DERIV_FLOOR {
326 let cand = mid - f_mid / f_a_mid;
327 if cand > lo && cand < hi { cand } else { mid }
328 } else {
329 mid
330 };
331
332 let (bracket_pt, f_bracket) = if (probe - mid).abs() > 0.0 {
334 let (f_p, f_a_p, _) = eval(probe).map_err(|e| map_eval_err(label, probe, e))?;
335 update_best(
336 &mut best_a,
337 &mut best_f,
338 &mut best_abs_deriv,
339 probe,
340 f_p,
341 f_a_p,
342 );
343 (probe, f_p)
344 } else {
345 (mid, f_mid)
346 };
347
348 if f_bracket <= 0.0 {
349 neg_pt = bracket_pt;
350 } else {
351 pos_pt = bracket_pt;
352 }
353
354 let (next_lo, next_hi) = if neg_pt <= pos_pt {
355 (neg_pt, pos_pt)
356 } else {
357 (pos_pt, neg_pt)
358 };
359 if (next_hi - next_lo).abs() <= convergence_tol * (1.0 + next_hi.abs() + next_lo.abs()) {
360 break;
361 }
362 }
363
364 if !best_abs_deriv.is_finite() || best_abs_deriv == 0.0 {
366 let (_, f_a_best, _) = eval(best_a).map_err(|e| map_eval_err(label, best_a, e))?;
367 best_abs_deriv = f_a_best.abs();
368 }
369 if !best_abs_deriv.is_finite() || best_abs_deriv == 0.0 {
370 return Err(MonotoneRootError::converged_root_degenerate(label, best_a));
371 }
372
373 Ok(MonotoneRootSolution {
374 root: best_a,
375 abs_deriv: best_abs_deriv,
376 residual: best_f,
377 refine_iters,
378 })
379}
380
381#[cfg(test)]
382mod tests {
383 use super::{
384 solve_monotone_root, solve_monotone_root_detailed,
385 solve_monotone_root_detailed_with_bracket, MonotoneRootError,
386 };
387 use std::cell::RefCell;
388
389 #[test]
390 fn solve_monotone_root_converges_for_increasing_function() {
391 let (root, abs_deriv, residual) = solve_monotone_root(
392 |a| {
393 let ea = a.exp();
394 Ok((ea - 2.0, ea, ea))
395 },
396 0.0,
397 "increasing",
398 1e-12,
399 32,
400 32,
401 )
402 .expect("root");
403
404 assert!((root - std::f64::consts::LN_2).abs() < 1e-10);
405 assert!((abs_deriv - 2.0).abs() < 1e-10);
406 assert!(residual.abs() < 1e-12);
407 }
408
409 #[test]
410 fn solve_monotone_root_accepts_halley_probe_for_decreasing_function() {
411 let eval_points = RefCell::new(Vec::new());
412 let (root, abs_deriv, residual) = solve_monotone_root(
413 |a| {
414 eval_points.borrow_mut().push(a);
415 let ea = (-a).exp();
416 Ok((ea - 0.5, -ea, ea))
417 },
418 0.0,
419 "decreasing",
420 1e-12,
421 32,
422 32,
423 )
424 .expect("root");
425
426 let f_mid = (-0.5f64).exp() - 0.5;
427 let f_a_mid = -(-0.5f64).exp();
428 let f_aa_mid = (-0.5f64).exp();
429 let expected_probe =
430 0.5 - (2.0 * f_mid * f_a_mid) / (2.0 * f_a_mid * f_a_mid - f_mid * f_aa_mid);
431 assert!((root - std::f64::consts::LN_2).abs() < 1e-10);
432 assert!((abs_deriv - 0.5).abs() < 1e-10);
433 assert!(residual.abs() < 1e-12);
434 assert!(
435 eval_points
436 .borrow()
437 .iter()
438 .copied()
439 .any(|a| (a - expected_probe).abs() < 1e-12)
440 );
441 }
442
443 #[test]
444 fn solve_linear_function_reaches_exact_root() {
445 let (root, abs_deriv, residual) = solve_monotone_root(
447 |a| Ok((2.0 * a - 7.0, 2.0, 0.0)),
448 0.0,
449 "linear",
450 1e-12,
451 32,
452 64,
453 )
454 .expect("root");
455 assert!((root - 3.5).abs() < 1e-10, "root={root}");
456 assert!((abs_deriv - 2.0).abs() < 1e-10, "abs_deriv={abs_deriv}");
457 assert!(residual.abs() < 1e-12, "residual={residual}");
458 }
459
460 #[test]
461 fn exact_root_at_init_returns_zero_iters() {
462 let result = solve_monotone_root_detailed(
464 |a| Ok((a, 1.0, 0.0)),
465 0.0,
466 "exact_at_init",
467 1e-12,
468 32,
469 32,
470 )
471 .expect("solution");
472 assert!(result.root.abs() < 1e-12, "root={}", result.root);
473 assert_eq!(result.refine_iters, 0);
474 }
475
476 #[test]
477 fn degenerate_derivative_returns_error() {
478 let err = solve_monotone_root(
481 |a| Ok((a - 5.0, 0.0, 0.0)),
482 0.0,
483 "degenerate_fp",
484 1e-12,
485 32,
486 32,
487 )
488 .unwrap_err();
489 match err {
490 MonotoneRootError::DegenerateDerivative { .. } => {}
491 other => panic!("expected DegenerateDerivative, got {other:?}"),
492 }
493 }
494
495 #[test]
496 fn analytic_bracket_is_honored() {
497 let sol = solve_monotone_root_detailed_with_bracket(
500 |a| Ok((a - 3.0, 1.0, 0.0)),
501 5.0,
502 "analytic_bracket",
503 1e-12,
504 32,
505 64,
506 Some((0.0, 10.0)),
507 )
508 .expect("solution");
509 assert!((sol.root - 3.0).abs() < 1e-10, "root={}", sol.root);
510 assert!(sol.residual.abs() < 1e-12, "residual={}", sol.residual);
511 }
512
513 #[test]
514 fn search_exhausted_with_zero_bracket_iters() {
515 let err = solve_monotone_root(
517 |a| Ok((a - 100.0, 1.0, 0.0)),
518 0.0,
519 "no_bracket",
520 1e-12,
521 0, 32,
523 )
524 .unwrap_err();
525 match err {
526 MonotoneRootError::BracketingExhausted { .. } => {}
527 other => panic!("expected BracketingExhausted, got {other:?}"),
528 }
529 }
530}