Skip to main content

gam_model_kernels/
monotone_root.rs

1/// Shared hybrid bracketing + Newton solver for strictly monotone calibration
2/// equations F(a) = 0.
3///
4/// `eval(a)` must return `(F(a), F'(a), F''(a))`.  The second derivative is
5/// carried through for the caller but is not used by the solver itself.
6///
7/// Returns `(root, |F'(root)|, F(root))`.  The absolute derivative is always
8/// positive and can be used directly as the density-normalising calibration
9/// derivative.  Callers must validate the residual against the scale of their
10/// calibration equation.
11///
12/// The monotone direction (increasing vs decreasing) is inferred from the
13/// sign of F'(a) at the initial point, so the same code handles both the
14/// Bernoulli case (F increasing) and the survival case (F decreasing).
15#[derive(Clone, Copy, Debug)]
16pub struct MonotoneRootSolution {
17    pub root: f64,
18    pub abs_deriv: f64,
19    pub residual: f64,
20    pub refine_iters: usize,
21}
22
23pub use gam_problem::MonotoneRootError;
24
25/// Smallest |F'| (or Halley denominator) treated as usable before a Newton /
26/// Halley step is abandoned in favour of the bracketing fallback. Below this
27/// the reciprocal `f / fp` overflows or loses all precision, so the step is
28/// meaningless; the bisection bracket still makes guaranteed progress. Set far
29/// under any derivative magnitude a well-posed monotone calibration equation
30/// produces, so it only trips on genuine flat spots.
31const NEWTON_DERIV_FLOOR: f64 = 1e-30;
32
33/// Trust-region cap on a warm-start Newton probe step, expressed as a multiple
34/// of the current iterate scale `1 + |a|`. A correction larger than this is
35/// taken as a sign the quadratic model is untrustworthy (the warm start is far
36/// from the basin), so the probe is abandoned and the globally convergent
37/// bracketed solver takes over.
38const WARMSTART_NEWTON_STEP_LIMIT: f64 = 8.0;
39
40/// Initial geometric bracketing step magnitude, as a fraction of the seed scale
41/// `1 + |a_init|` (floored at 1). Doubles each probe until the root is straddled.
42const BRACKET_INITIAL_STEP_FRAC: f64 = 0.25;
43
44/// Internal helper: wrap an `eval` closure error into `EvalFailed`.
45#[inline]
46fn map_eval_err(label: &str, a: f64, source: String) -> MonotoneRootError {
47    MonotoneRootError::EvalFailed {
48        label: label.to_string(),
49        a,
50        source,
51    }
52}
53
54pub fn solve_monotone_root(
55    eval: impl Fn(f64) -> Result<(f64, f64, f64), String>,
56    a_init: f64,
57    label: &str,
58    convergence_tol: f64,
59    max_bracket_iters: usize,
60    max_refine_iters: usize,
61) -> Result<(f64, f64, f64), MonotoneRootError> {
62    let solution = solve_monotone_root_detailed(
63        eval,
64        a_init,
65        label,
66        convergence_tol,
67        max_bracket_iters,
68        max_refine_iters,
69    )?;
70    Ok((solution.root, solution.abs_deriv, solution.residual))
71}
72
73pub fn solve_monotone_root_detailed(
74    eval: impl Fn(f64) -> Result<(f64, f64, f64), String>,
75    a_init: f64,
76    label: &str,
77    convergence_tol: f64,
78    max_bracket_iters: usize,
79    max_refine_iters: usize,
80) -> Result<MonotoneRootSolution, MonotoneRootError> {
81    solve_monotone_root_detailed_with_bracket(
82        eval,
83        a_init,
84        label,
85        convergence_tol,
86        max_bracket_iters,
87        max_refine_iters,
88        None,
89    )
90}
91
92pub fn solve_monotone_root_detailed_with_bracket(
93    eval: impl Fn(f64) -> Result<(f64, f64, f64), String>,
94    a_init: f64,
95    label: &str,
96    convergence_tol: f64,
97    max_bracket_iters: usize,
98    max_refine_iters: usize,
99    analytic_bracket: Option<(f64, f64)>,
100) -> Result<MonotoneRootSolution, MonotoneRootError> {
101    let (f_init, f_deriv_init, _) = eval(a_init).map_err(|e| map_eval_err(label, a_init, e))?;
102
103    // Exact root — rare but handle correctly.
104    if f_init.abs() <= convergence_tol {
105        let abs_d = f_deriv_init.abs();
106        if !abs_d.is_finite() || abs_d == 0.0 {
107            return Err(MonotoneRootError::exact_root_degenerate(label, a_init));
108        }
109        return Ok(MonotoneRootSolution {
110            root: a_init,
111            abs_deriv: abs_d,
112            residual: f_init,
113            refine_iters: 0,
114        });
115    }
116
117    if !f_deriv_init.is_finite() || f_deriv_init == 0.0 {
118        return Err(MonotoneRootError::DegenerateDerivative {
119            label: label.to_string(),
120            a: a_init,
121            fp: f_deriv_init,
122        });
123    }
124
125    // With a good warm start, the root is often within one or two Newton
126    // corrections. Try that local basin before spending evaluations on a
127    // global bracket; fall back to the bracketed solver unchanged if the
128    // probe is not decisive.
129    let mut a = a_init;
130    let mut f = f_init;
131    let mut fp = f_deriv_init;
132    for probe_iter in 0..2 {
133        if f.abs() <= convergence_tol {
134            let abs_d = fp.abs();
135            if !abs_d.is_finite() || abs_d == 0.0 {
136                break;
137            }
138            return Ok(MonotoneRootSolution {
139                root: a,
140                abs_deriv: abs_d,
141                residual: f,
142                refine_iters: probe_iter,
143            });
144        }
145
146        if !fp.is_finite() || fp.abs() <= NEWTON_DERIV_FLOOR {
147            break;
148        }
149
150        let step = -f / fp;
151        if !step.is_finite() || step.abs() > WARMSTART_NEWTON_STEP_LIMIT * (1.0 + a.abs()) {
152            break;
153        }
154
155        let cand = a + step;
156        let (f_cand, fp_cand, _) = eval(cand).map_err(|e| map_eval_err(label, cand, e))?;
157        if f_cand.abs() <= convergence_tol {
158            let abs_d = fp_cand.abs();
159            if !abs_d.is_finite() || abs_d == 0.0 {
160                break;
161            }
162            return Ok(MonotoneRootSolution {
163                root: cand,
164                abs_deriv: abs_d,
165                residual: f_cand,
166                refine_iters: probe_iter + 1,
167            });
168        }
169
170        a = cand;
171        f = f_cand;
172        fp = fp_cand;
173    }
174
175    // --- Phase 1: bracket the root -------------------------------------------
176    let (mut neg_pt, mut pos_pt) = if let Some((lo, hi)) = analytic_bracket {
177        if !lo.is_finite() || !hi.is_finite() || lo == hi {
178            return Err(MonotoneRootError::analytic_bracket_invalid(label, lo, hi));
179        }
180        let (f_lo, _, _) = eval(lo).map_err(|e| map_eval_err(label, lo, e))?;
181        let (f_hi, _, _) = eval(hi).map_err(|e| map_eval_err(label, hi, e))?;
182        if f_lo <= 0.0 && f_hi >= 0.0 {
183            (lo, hi)
184        } else if f_hi <= 0.0 && f_lo >= 0.0 {
185            (hi, lo)
186        } else {
187            return Err(MonotoneRootError::analytic_bracket_no_straddle(
188                label, f_lo, f_hi,
189            ));
190        }
191    } else {
192        // We need a point on the opposite side of zero from f_init.
193        // The correct search direction depends on both the sign of f_init and
194        // the monotonicity of F:
195        //
196        //   F increasing, f < 0 → root is to the right  (+)
197        //   F increasing, f > 0 → root is to the left   (−)
198        //   F decreasing, f < 0 → root is to the left   (−)
199        //   F decreasing, f > 0 → root is to the right  (+)
200        //
201        // Compactly:  step_sign = −sign(f · F')
202        let step_sign: f64 = if f_init * f_deriv_init < 0.0 {
203            1.0
204        } else {
205            -1.0
206        };
207
208        let f_init_negative = f_init < 0.0;
209        let mut same_side = a_init; // last point with same sign as f_init
210        let mut step_mag = (BRACKET_INITIAL_STEP_FRAC * (1.0 + a_init.abs())).max(1.0);
211        // Geometric step growth is unbounded mathematically, but in practice
212        // we cap to avoid runaway evaluations when F flatlines and never
213        // crosses (e.g. probit calibration where every probe saturates at
214        // ±∞). The cap scales with the magnitude of the seed: a huge
215        // `a_init` (say 1e6) needs proportional reach because a doubling
216        // schedule starting at 0.25·|a_init| only spans an `O(|a_init|)`
217        // window before a step would overshoot. An absolute 1e6 cap leaks
218        // when the seed itself sits near that bound; the scaled cap
219        // guarantees at least ~`max_bracket_iters` useful probes regardless
220        // of seed magnitude.
221        let step_cap = 1e6_f64.max(1024.0 * (1.0 + a_init.abs()));
222        let mut found_other: Option<(f64, f64)> = None;
223
224        for _ in 0..max_bracket_iters {
225            let probe = same_side + step_mag * step_sign;
226            let (f_probe, _, _) = eval(probe).map_err(|e| map_eval_err(label, probe, e))?;
227            let crossed = if f_init_negative {
228                f_probe >= 0.0
229            } else {
230                f_probe <= 0.0
231            };
232            if crossed {
233                found_other = Some((probe, f_probe));
234                break;
235            }
236            same_side = probe;
237            step_mag *= 2.0;
238            if step_mag > step_cap {
239                break;
240            }
241        }
242
243        let Some((other, _)) = found_other else {
244            return Err(MonotoneRootError::search_exhausted(
245                label, step_sign, a_init,
246            ));
247        };
248
249        if f_init_negative {
250            (same_side, other)
251        } else {
252            (other, same_side)
253        }
254    };
255
256    // --- Phase 2: hybrid bisection / Newton refinement -----------------------
257
258    let mut best_a = a_init;
259    let mut best_f = f_init;
260    let mut best_abs_deriv = f_deriv_init.abs();
261
262    #[inline]
263    fn update_best(
264        best_a: &mut f64,
265        best_f: &mut f64,
266        best_abs_d: &mut f64,
267        a: f64,
268        f: f64,
269        f_d: f64,
270    ) {
271        if f.abs() < best_f.abs() {
272            *best_a = a;
273            *best_f = f;
274            *best_abs_d = f_d.abs();
275        }
276    }
277
278    let mut refine_iters = 0usize;
279    for _ in 0..max_refine_iters {
280        refine_iters += 1;
281        let (lo, hi) = if neg_pt <= pos_pt {
282            (neg_pt, pos_pt)
283        } else {
284            (pos_pt, neg_pt)
285        };
286        let mid = 0.5 * (lo + hi);
287        let (f_mid, f_a_mid, f_aa_mid) = eval(mid).map_err(|e| map_eval_err(label, mid, e))?;
288        update_best(
289            &mut best_a,
290            &mut best_f,
291            &mut best_abs_deriv,
292            mid,
293            f_mid,
294            f_a_mid,
295        );
296
297        if f_mid.abs() <= convergence_tol {
298            break;
299        }
300
301        // Prefer a safeguarded Halley step when the second derivative is
302        // available and well-conditioned. The caller already computed F''(a),
303        // so using it here reduces expensive calibration evaluations for the
304        // exact denested likelihood paths without changing the objective.
305        let halley_probe = if f_a_mid.is_finite() && f_a_mid.abs() > NEWTON_DERIV_FLOOR {
306            let halley_denom = 2.0 * f_a_mid * f_a_mid - f_mid * f_aa_mid;
307            if halley_denom.is_finite() && halley_denom.abs() > NEWTON_DERIV_FLOOR {
308                let cand = mid - (2.0 * f_mid * f_a_mid) / halley_denom;
309                if cand > lo && cand < hi {
310                    Some(cand)
311                } else {
312                    None
313                }
314            } else {
315                None
316            }
317        } else {
318            None
319        };
320
321        // Fall back to the monotone Newton step if Halley is unavailable or
322        // would leave the current bracket.
323        let probe = if let Some(cand) = halley_probe {
324            cand
325        } else if f_a_mid.is_finite() && f_a_mid.abs() > NEWTON_DERIV_FLOOR {
326            let cand = mid - f_mid / f_a_mid;
327            if cand > lo && cand < hi { cand } else { mid }
328        } else {
329            mid
330        };
331
332        // Evaluate probe if it differs from midpoint.
333        let (bracket_pt, f_bracket) = if (probe - mid).abs() > 0.0 {
334            let (f_p, f_a_p, _) = eval(probe).map_err(|e| map_eval_err(label, probe, e))?;
335            update_best(
336                &mut best_a,
337                &mut best_f,
338                &mut best_abs_deriv,
339                probe,
340                f_p,
341                f_a_p,
342            );
343            (probe, f_p)
344        } else {
345            (mid, f_mid)
346        };
347
348        if f_bracket <= 0.0 {
349            neg_pt = bracket_pt;
350        } else {
351            pos_pt = bracket_pt;
352        }
353
354        let (next_lo, next_hi) = if neg_pt <= pos_pt {
355            (neg_pt, pos_pt)
356        } else {
357            (pos_pt, neg_pt)
358        };
359        if (next_hi - next_lo).abs() <= convergence_tol * (1.0 + next_hi.abs() + next_lo.abs()) {
360            break;
361        }
362    }
363
364    // Final validation: re-evaluate at best_a if the derivative is suspect.
365    if !best_abs_deriv.is_finite() || best_abs_deriv == 0.0 {
366        let (_, f_a_best, _) = eval(best_a).map_err(|e| map_eval_err(label, best_a, e))?;
367        best_abs_deriv = f_a_best.abs();
368    }
369    if !best_abs_deriv.is_finite() || best_abs_deriv == 0.0 {
370        return Err(MonotoneRootError::converged_root_degenerate(label, best_a));
371    }
372
373    Ok(MonotoneRootSolution {
374        root: best_a,
375        abs_deriv: best_abs_deriv,
376        residual: best_f,
377        refine_iters,
378    })
379}
380
381#[cfg(test)]
382mod tests {
383    use super::solve_monotone_root;
384    use std::cell::RefCell;
385
386    #[test]
387    fn solve_monotone_root_converges_for_increasing_function() {
388        let (root, abs_deriv, residual) = solve_monotone_root(
389            |a| {
390                let ea = a.exp();
391                Ok((ea - 2.0, ea, ea))
392            },
393            0.0,
394            "increasing",
395            1e-12,
396            32,
397            32,
398        )
399        .expect("root");
400
401        assert!((root - std::f64::consts::LN_2).abs() < 1e-10);
402        assert!((abs_deriv - 2.0).abs() < 1e-10);
403        assert!(residual.abs() < 1e-12);
404    }
405
406    #[test]
407    fn solve_monotone_root_accepts_halley_probe_for_decreasing_function() {
408        let eval_points = RefCell::new(Vec::new());
409        let (root, abs_deriv, residual) = solve_monotone_root(
410            |a| {
411                eval_points.borrow_mut().push(a);
412                let ea = (-a).exp();
413                Ok((ea - 0.5, -ea, ea))
414            },
415            0.0,
416            "decreasing",
417            1e-12,
418            32,
419            32,
420        )
421        .expect("root");
422
423        let f_mid = (-0.5f64).exp() - 0.5;
424        let f_a_mid = -(-0.5f64).exp();
425        let f_aa_mid = (-0.5f64).exp();
426        let expected_probe =
427            0.5 - (2.0 * f_mid * f_a_mid) / (2.0 * f_a_mid * f_a_mid - f_mid * f_aa_mid);
428        assert!((root - std::f64::consts::LN_2).abs() < 1e-10);
429        assert!((abs_deriv - 0.5).abs() < 1e-10);
430        assert!(residual.abs() < 1e-12);
431        assert!(
432            eval_points
433                .borrow()
434                .iter()
435                .copied()
436                .any(|a| (a - expected_probe).abs() < 1e-12)
437        );
438    }
439}