Skip to main content

gam_solve/
latent_inner.rs

1//! Joint `(t, β)` inner Newton driver for [`gam_terms::latent::LatentCoordValues`]
2//! blocks.
3//!
4//! The arrow-Schur inner step is `O(N d³ + K³)`; the REML outer
5//! gradient w.r.t. `t` carries a shared `Schur⁻¹` factor and is handled
6//! at the REML driver level, not here.
7//!
8//! ## Scope
9//!
10//! This module wires together:
11//!
12//! * [`gam_terms::latent::LatentCoordValues`] — the per-row
13//!   latent field;
14//! * [`crate::arrow_schur::ArrowSchurSystem`] — the bordered
15//!   `(t, β)` Newton system with arrow structure;
16//! * [`crate::arrow_schur::ArrowFactorCache`] — the per-row
17//!   Cholesky factors saved between inner iterations and reused by the
18//!   evidence-side IFT delta-t predictor loop in
19//!   [`crate::evidence`].
20//!
21//! The driver is a thin coordinator: it expects the caller to supply a
22//! closure that, given the current `(β, t)`, assembles a fresh
23//! `ArrowSchurSystem` (per-row `H_tt^(i)`, `H_tβ^(i)`, `g_t^(i)`, and
24//! the β-block `H_ββ`, `g_β`). The closure is the natural home for
25//! "evaluate Φ(t), form Gauss–Newton blocks from the radial jet, fold
26//! in analytic penalties" — all of which depend on the basis and the
27//! analytic-penalty registry the outer-fit configuration owns.
28//!
29//! ## Convergence criterion
30//!
31//! The inner loop terminates when the relative joint gradient norm
32//! drops below `tol`:
33//!
34//! ```text
35//!   ‖[g_t; g_β]‖₂ / (1 + ‖[t; β]‖₂)  <  tol
36//! ```
37//!
38//! or when the LM damping cannot be lowered any further (the latter is
39//! interpreted as "we are at the steepest-descent floor; declare
40//! convergence and let the outer loop tighten ρ"). Failure to factor
41//! the per-row block or the Schur complement at the current ridge
42//! escalates `ridge_t` / `ridge_beta` by `lm_grow` and retries.
43
44use std::ops::Range;
45use std::sync::Arc;
46
47use ndarray::{Array1, ArrayView1};
48
49use crate::arrow_schur::{
50    ArrowFactorCache, ArrowSchurError, ArrowSchurSystem, ArrowSolveOptions, ArrowSolverMode,
51    arrow_bare_quadratic_model_reduction, solve_arrow_newton_step_with_options,
52};
53use gam_terms::latent::LatentCoordValues;
54
55/// Configuration for [`LatentInnerSolver::solve`].
56#[derive(Debug, Clone)]
57pub struct LatentInnerOptions {
58    /// Maximum joint `(t, β)` Newton iterations.
59    pub max_iterations: usize,
60    /// Relative-gradient convergence tolerance (see module docs).
61    pub convergence_tolerance: f64,
62    /// Initial LM-style ridge on the per-row latent blocks.
63    pub initial_ridge_t: f64,
64    /// Initial LM-style ridge on the β block.
65    pub initial_ridge_beta: f64,
66    /// Multiplicative growth factor for the LM ridges on a rejected step.
67    pub lm_grow: f64,
68    /// Multiplicative shrink factor for the LM ridges on an accepted step.
69    pub lm_shrink: f64,
70    /// Maximum ridge value before declaring failure.
71    pub max_ridge: f64,
72    /// BA Schur mode for the reduced shared system. `None` selects Direct for
73    /// `K <= 2000` and InexactPCG above, matching large-scale BA practice.
74    pub solver_mode: Option<ArrowSolverMode>,
75    /// Reduced-system trust-region radius for Steihaug-CG. This is the
76    /// Ceres/BA trust-region bound layered on top of the existing LM ridges.
77    pub trust_region_radius: f64,
78    /// Optional β-block column ranges for the block-Jacobi Schur preconditioner.
79    ///
80    /// When `Some`, the solver calls
81    /// [`crate::arrow_schur::ArrowSchurSystem::set_block_offsets`] on
82    /// every assembled system, engaging the block-Jacobi PCG preconditioner
83    /// (one dense Schur sub-block per term, max size 256 columns) instead of
84    /// the scalar-diagonal fallback.
85    ///
86    /// Derive from `ParameterBlockSpec` slices via
87    /// `gam_custom_family::block_offsets_from_specs` or an
88    /// equivalent family-owned block layout.
89    /// When `None`, the preconditioner falls back to scalar-diagonal Jacobi
90    /// (the pre-block-Jacobi behaviour); when `Some([])` (empty slice), the
91    /// same fallback applies.
92    pub block_offsets: Option<Arc<[Range<usize>]>>,
93}
94
95impl Default for LatentInnerOptions {
96    fn default() -> Self {
97        Self {
98            max_iterations: 50,
99            convergence_tolerance: 1e-8,
100            initial_ridge_t: 0.0,
101            initial_ridge_beta: 0.0,
102            lm_grow: 10.0,
103            lm_shrink: 0.5,
104            max_ridge: 1e12,
105            solver_mode: None,
106            trust_region_radius: 1.0,
107            block_offsets: None,
108        }
109    }
110}
111
112/// Outcome of a [`LatentInnerSolver::solve`] call.
113#[derive(Debug, Clone)]
114pub struct LatentInnerOutcome {
115    /// Final β coefficient vector.
116    pub beta: Array1<f64>,
117    /// Per-row Cholesky factor cache from the *last* accepted Newton
118    /// step. Consumed by the evidence-side IFT delta-t predictor loop in
119    /// [`crate::evidence`].
120    pub factor_cache: Option<ArrowFactorCache>,
121    /// Number of iterations executed.
122    pub iterations: usize,
123    /// Whether the convergence test was satisfied.
124    pub converged: bool,
125    /// Final ridge values (for warm-starting subsequent solves).
126    pub final_ridge_t: f64,
127    pub final_ridge_beta: f64,
128}
129
130/// Driver trait the caller implements to assemble the arrow system at
131/// the current iterate.
132///
133/// The driver owns the basis evaluation (`Φ(t)`), the radial jet
134/// (`∂Φ/∂t` via
135/// [`gam_terms::latent::LatentCoordValues::design_gradient_wrt_t`]),
136/// the Gauss–Newton block assembly
137/// (`H_tt^(i) ← (g_i β)(g_i β)^T`, `H_tβ^(i) ← (g_i β) ⊗ Φ_i`,
138/// `H_ββ ← Φ^T W Φ + Σ_k λ_k S_k`), and the analytic-penalty fold-in via
139/// [`crate::arrow_schur::ArrowSchurSystem::add_analytic_penalty_contributions`].
140pub trait ArrowSystemAssembler {
141    /// Build a freshly-populated arrow system at the current `(β, t)`.
142    ///
143    /// `latent` is the *current* latent field; the assembler should read
144    /// its values via `latent.as_flat()` / `latent.row(i)` and form Φ at
145    /// those coordinates. β is supplied as a view.
146    fn assemble(
147        &mut self,
148        beta: ArrayView1<'_, f64>,
149        latent: &LatentCoordValues,
150    ) -> Result<ArrowSchurSystem, String>;
151
152    /// Evaluate the true joint merit/objective at the current `(β, t)`.
153    ///
154    /// This is deliberately separate from [`Self::assemble`]: the Schur system
155    /// is a local quadratic model, but nonlinear latent retractions must be
156    /// accepted against the objective they actually change.
157    fn objective(
158        &mut self,
159        beta: ArrayView1<'_, f64>,
160        latent: &LatentCoordValues,
161    ) -> Result<f64, String>;
162}
163
164/// Joint `(t, β)` inner Newton solver exploiting arrow structure.
165///
166/// ## Usage
167///
168/// 1. Call [`LatentInnerSolver::new`] with the initial `β`, a mutable
169///    [`LatentCoordValues`], an [`ArrowSystemAssembler`], and options.
170/// 2. Call [`LatentInnerSolver::solve`] to run the inner Newton loop.
171///    Both `β` and the latent field are updated in place.
172/// 3. The returned [`LatentInnerOutcome::factor_cache`] is the
173///    artifact Piece 3's IFT warm-start consumes when the outer loop
174///    next perturbs `(β, ρ)`.
175pub struct LatentInnerSolver<'a, A: ArrowSystemAssembler> {
176    pub beta: Array1<f64>,
177    pub latent: &'a mut LatentCoordValues,
178    pub assembler: A,
179    pub options: LatentInnerOptions,
180}
181
182impl<'a, A: ArrowSystemAssembler> LatentInnerSolver<'a, A> {
183    #[must_use]
184    pub fn new(
185        beta: Array1<f64>,
186        latent: &'a mut LatentCoordValues,
187        assembler: A,
188        options: LatentInnerOptions,
189    ) -> Self {
190        Self {
191            beta,
192            latent,
193            assembler,
194            options,
195        }
196    }
197
198    /// Run the joint Newton loop.
199    ///
200    /// Numerical-stability invariants:
201    ///   * `initial_ridge_t`, `initial_ridge_beta` are clamped to `≥ 0`.
202    ///   * On a per-row or Schur PD failure, both ridges are escalated
203    ///     by `lm_grow` (or seeded at `1e-6` when starting from `0`).
204    ///   * `max_ridge` is the cold-restart trigger: if we exhaust the
205    ///     ramp and the Hessian is still non-PSD, the loop bails with a
206    ///     clear diagnostic citing the iteration index and both ridge
207    ///     levels reached — the outer driver should treat this as an
208    ///     identifiability failure (missing gauge-fixing penalty,
209    ///     collinear basis, etc.).
210    pub fn solve(&mut self) -> Result<LatentInnerOutcome, String> {
211        let opts = self.options.clone();
212        assert!(opts.lm_grow > 1.0, "LM ridge grow factor must exceed 1");
213        assert!(
214            opts.lm_shrink > 0.0 && opts.lm_shrink < 1.0,
215            "LM ridge shrink factor must lie in (0, 1)"
216        );
217        let mut ridge_t = opts.initial_ridge_t.max(0.0);
218        let mut ridge_beta = opts.initial_ridge_beta.max(0.0);
219        let mut last_cache: Option<ArrowFactorCache> = None;
220        let mut converged = false;
221        let mut iter = 0_usize;
222        let mut current_objective = self
223            .assembler
224            .objective(self.beta.view(), self.latent)
225            .map_err(|e| format!("LatentInnerSolver: objective failed at start: {e}"))?;
226        if !current_objective.is_finite() {
227            return Err("LatentInnerSolver: non-finite objective at start".to_string());
228        }
229
230        while iter < opts.max_iterations {
231            let mut system = self
232                .assembler
233                .assemble(self.beta.view(), self.latent)
234                .map_err(|e| format!("LatentInnerSolver: assembler failed at iter {iter}: {e}"))?;
235            system.apply_riemannian_latent_geometry(self.latent);
236            // Wire per-term β-block ranges so block-Jacobi engages in the PCG
237            // preconditioner. Mirrors the PIRLS-driver wiring at
238            // `pirls::runworking_model_pirls` line 5169–5171: the driver calls
239            // `set_block_offsets` from `ArrowSchurInnerConfig.block_offsets` on
240            // every system returned by the `build` closure. Here the assembler
241            // owns the system, so the LatentInnerSolver is the natural place to
242            // inject the offsets — one call per assembled system covers all
243            // families that supply block ranges via `LatentInnerOptions` rather
244            // than baking the call into each assembler impl.
245            if let Some(offsets) = opts.block_offsets.as_ref() {
246                system.set_block_offsets(offsets.clone());
247            }
248
249            // Convergence test: relative joint gradient norm.
250            let g_norm_sq = system_gradient_norm_sq(&system);
251            let scale = 1.0 + iterate_norm(self.beta.view(), self.latent.as_flat().view());
252            let rel = (g_norm_sq.sqrt()) / scale;
253            if rel < opts.convergence_tolerance {
254                converged = true;
255                // Build a final factor cache for the warm-start IFT
256                // predictor even though we didn't take a step. Best-
257                // effort — if the factorization fails (e.g. ill-
258                // conditioned at the very converged point), skip the
259                // cache; the predictor will then no-op.
260                let solve_options = latent_arrow_solve_options(
261                    &system,
262                    &opts,
263                    !self.latent.effective_is_all_euclidean(),
264                );
265                if let Ok((_, _, cache)) = solve_arrow_newton_step_with_options(
266                    &system,
267                    ridge_t.max(1e-12),
268                    ridge_beta.max(1e-12),
269                    &solve_options,
270                ) {
271                    last_cache = Some(cache);
272                }
273                break;
274            }
275
276            // Attempt the LM-damped arrow-Schur step. On failure (per-
277            // row PD violation or Schur PD violation), grow the ridge
278            // and retry without consuming an outer iteration.
279            let solve_options = latent_arrow_solve_options(
280                &system,
281                &opts,
282                !self.latent.effective_is_all_euclidean(),
283            );
284            let step_result =
285                solve_arrow_newton_step_with_options(&system, ridge_t, ridge_beta, &solve_options);
286            match step_result {
287                Ok((delta_t, delta_beta, cache)) => {
288                    let delta_t = limit_delta_to_riemannian_trust_region(
289                        delta_t,
290                        self.latent,
291                        solve_options.riemannian_trust_region,
292                        solve_options.trust_region.radius,
293                    );
294                    let predicted_reduction = arrow_bare_quadratic_model_reduction(
295                        &system,
296                        delta_t.view(),
297                        delta_beta.view(),
298                        ridge_t,
299                        ridge_beta,
300                    )
301                    .map_err(|e| {
302                        format!("LatentInnerSolver: predicted reduction failed at iter {iter}: {e}")
303                    })?;
304                    let beta_before = self.beta.clone();
305                    let t_before = self.latent.as_flat().clone();
306                    for (b, db) in self.beta.iter_mut().zip(delta_beta.iter()) {
307                        *b += *db;
308                    }
309                    self.latent.retract_flat_delta(delta_t.view());
310                    let trial_objective = self
311                        .assembler
312                        .objective(self.beta.view(), self.latent)
313                        .map_err(|e| {
314                        format!("LatentInnerSolver: objective failed at trial iter {iter}: {e}")
315                    })?;
316                    // Trust-region gain-ratio noise floor, keyed to the
317                    // objective's own magnitude so it is equivariant under a
318                    // response rescaling `y → a·y` (the penalized objective and
319                    // both the predicted and actual reductions all scale as
320                    // `O(a²)`). The previous `.max(1.0)` absolute floor broke
321                    // this for a micro-unit response: it pinned the floor at
322                    // `1e-14` while genuine reductions were `O(a²)`, treating a
323                    // real step as numerical noise and stalling the inner solve
324                    // at an over-smoothed iterate (issue #1127). A perfectly
325                    // converged objective (`current_objective == 0`) yields a
326                    // `0` floor, so the `predicted_reduction > 0` branch still
327                    // governs and no step is misclassified.
328                    let objective_scale = current_objective.abs();
329                    let noise_floor = objective_scale * 1e-14;
330                    let actual_reduction = current_objective - trial_objective;
331                    let rho = if predicted_reduction > noise_floor {
332                        actual_reduction / predicted_reduction
333                    } else if actual_reduction >= -noise_floor {
334                        1.0
335                    } else {
336                        -1.0
337                    };
338                    if rho > 0.0 && trial_objective.is_finite() {
339                        current_objective = trial_objective;
340                        ridge_t = (ridge_t * opts.lm_shrink).max(0.0);
341                        ridge_beta = (ridge_beta * opts.lm_shrink).max(0.0);
342                        last_cache = Some(cache);
343                        iter += 1;
344                    } else {
345                        self.beta = beta_before;
346                        self.latent.set_flat(t_before.view());
347                        ridge_t = if ridge_t == 0.0 {
348                            1e-6
349                        } else {
350                            ridge_t * opts.lm_grow
351                        };
352                        ridge_beta = if ridge_beta == 0.0 {
353                            1e-6
354                        } else {
355                            ridge_beta * opts.lm_grow
356                        };
357                        if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
358                            return Err(format!(
359                                "LatentInnerSolver: LM rejected nonlinear step until ridge \
360                                 exceeded max ({}) at iter {} \
361                                 (ridge_t={ridge_t:.3e}, ridge_beta={ridge_beta:.3e}, \
362                                 rho={rho:.3e}, predicted_reduction={predicted_reduction:.3e}, \
363                                 actual_reduction={actual_reduction:.3e})",
364                                opts.max_ridge, iter,
365                            ));
366                        }
367                    }
368                }
369                Err(err @ ArrowSchurError::PerRowFactorFailed { .. })
370                | Err(err @ ArrowSchurError::PerRowFactorIllConditioned { .. })
371                | Err(err @ ArrowSchurError::SchurFactorFailed { .. })
372                | Err(err @ ArrowSchurError::PcgFailed { .. })
373                | Err(err @ ArrowSchurError::UnboundedNegativeCurvature { .. })
374                | Err(err @ ArrowSchurError::AdaptiveCorrectionFailed { .. }) => {
375                    // Grow ridges; retry without burning an iteration.
376                    // The per-row `factor_blocks` already ran an internal
377                    // ridge-ramp before surfacing the error here — if we
378                    // see it at this layer, the row block is
379                    // genuinely under-regularized (gauge issue or
380                    // collinear basis under U_i). Escalate the LM ridge
381                    // and let the outer Newton step damp.
382                    ridge_t = if ridge_t == 0.0 {
383                        1e-6
384                    } else {
385                        ridge_t * opts.lm_grow
386                    };
387                    ridge_beta = if ridge_beta == 0.0 {
388                        1e-6
389                    } else {
390                        ridge_beta * opts.lm_grow
391                    };
392                    if ridge_t > opts.max_ridge || ridge_beta > opts.max_ridge {
393                        return Err(format!(
394                            "LatentInnerSolver: cold-restart condition — LM ridge \
395                             exceeded max ({}) at iter {} \
396                             (ridge_t={ridge_t:.3e}, ridge_beta={ridge_beta:.3e}); \
397                             root-cause arrow-Schur error: {err}",
398                            opts.max_ridge, iter,
399                        ));
400                    }
401                }
402            }
403        }
404
405        Ok(LatentInnerOutcome {
406            beta: self.beta.clone(),
407            factor_cache: last_cache,
408            iterations: iter,
409            converged,
410            final_ridge_t: ridge_t,
411            final_ridge_beta: ridge_beta,
412        })
413    }
414}
415
416fn latent_arrow_solve_options(
417    system: &ArrowSchurSystem,
418    opts: &LatentInnerOptions,
419    riemannian_trust_region: bool,
420) -> ArrowSolveOptions {
421    let mut solve_options = ArrowSolveOptions::automatic(system.k);
422    if let Some(mode) = opts.solver_mode {
423        solve_options.mode = mode;
424    }
425    solve_options.trust_region.radius = opts.trust_region_radius;
426    solve_options.riemannian_trust_region = riemannian_trust_region;
427    solve_options
428}
429
430fn limit_delta_to_riemannian_trust_region(
431    mut delta_t: Array1<f64>,
432    latent: &LatentCoordValues,
433    enabled: bool,
434    radius: f64,
435) -> Array1<f64> {
436    if !enabled || !radius.is_finite() || radius <= 0.0 {
437        return delta_t;
438    }
439    let row_weights = latent.effective_metric_weights();
440    assert_eq!(row_weights.len(), latent.latent_dim());
441    let mut norm_sq = 0.0_f64;
442    for n in 0..latent.n_obs() {
443        let start = n * latent.latent_dim();
444        for a in 0..latent.latent_dim() {
445            let value = delta_t[start + a];
446            norm_sq += row_weights[a] * value * value;
447        }
448    }
449    let norm = norm_sq.sqrt();
450    if norm <= radius || norm == 0.0 {
451        return delta_t;
452    }
453    let shrink = radius / norm;
454    for value in delta_t.iter_mut() {
455        *value *= shrink;
456    }
457    delta_t
458}
459
460fn system_gradient_norm_sq(sys: &ArrowSchurSystem) -> f64 {
461    let mut acc = 0.0_f64;
462    for j in 0..sys.k {
463        acc += sys.gb[j] * sys.gb[j];
464    }
465    for row in &sys.rows {
466        for &v in row.gt.iter() {
467            acc += v * v;
468        }
469    }
470    acc
471}
472
473fn iterate_norm(beta: ArrayView1<'_, f64>, t: ArrayView1<'_, f64>) -> f64 {
474    let mut acc = 0.0_f64;
475    for v in beta.iter() {
476        acc += v * v;
477    }
478    for v in t.iter() {
479        acc += v * v;
480    }
481    acc.sqrt()
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use gam_terms::latent::{LatentCoordValues, LatentIdMode};
488    use ndarray::array;
489
490    struct ZeroAssembler {
491        n: usize,
492        d: usize,
493        k: usize,
494    }
495
496    impl ArrowSystemAssembler for ZeroAssembler {
497        fn assemble(
498            &mut self,
499            arr: ArrayView1<'_, f64>,
500            latent_coords: &LatentCoordValues,
501        ) -> Result<ArrowSchurSystem, String> {
502            let _unused_latent_coords = latent_coords;
503            assert!(arr.iter().all(|v| !v.is_nan()));
504            let mut sys = ArrowSchurSystem::new(self.n, self.d, self.k);
505            for j in 0..self.k {
506                sys.hbb[[j, j]] = 1.0;
507            }
508            for row in sys.rows.iter_mut() {
509                for c in 0..self.d {
510                    row.htt[[c, c]] = 1.0;
511                }
512            }
513            Ok(sys)
514        }
515
516        fn objective(
517            &mut self,
518            arr: ArrayView1<'_, f64>,
519            latent_coords: &LatentCoordValues,
520        ) -> Result<f64, String> {
521            let _unused_latent_coords = latent_coords;
522            assert!(arr.iter().all(|v| !v.is_nan()));
523            Ok(0.0)
524        }
525    }
526
527    #[test]
528    fn zero_assembler_converges_immediately() {
529        let m = array![[0.0_f64, 0.0], [0.0, 0.0]];
530        let mut latent = LatentCoordValues::from_matrix(m.view(), LatentIdMode::None);
531        let beta = Array1::<f64>::zeros(3);
532        let assembler = ZeroAssembler { n: 2, d: 2, k: 3 };
533        let mut solver =
534            LatentInnerSolver::new(beta, &mut latent, assembler, LatentInnerOptions::default());
535        let outcome = solver.solve().expect("zero assembler always succeeds");
536        assert!(outcome.converged);
537        assert_eq!(outcome.iterations, 0);
538    }
539}