Skip to main content

gam_solve/pirls/
state.rs

1use gam_terms::construction::ReparamResult;
2use crate::estimate::EstimationError;
3use gam_linalg::matrix::{
4    DesignMatrix, PsdWeightsView, ReparamOperator, SignedWeightsView, SymmetricMatrix,
5};
6use crate::active_set::ConstraintKktDiagnostics;
7use gam_problem::{Coefficients, GlmLikelihoodSpec, InverseLink, LinearPredictor, RidgePassport};
8use gam_problem::LinearInequalityConstraints;
9use ndarray::{Array1, Array2, ArrayView1};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13use super::{compute_observed_hessian_curvature_arrays, computeworkingweight_derivatives_from_eta};
14
15/// Whether the solve operates in sparse-native or dense-transformed coordinates.
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum PirlsLinearSolvePath {
18    DenseTransformed,
19    SparseNative,
20}
21
22/// Coordinate frame for the PIRLS inner iteration.
23#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum PirlsCoordinateFrame {
25    TransformedQs,
26    OriginalSparseNative,
27}
28
29/// Firth bias-reduction diagnostics at convergence.
30#[derive(Debug, Clone, Default)]
31pub enum FirthDiagnostics {
32    #[default]
33    Inactive,
34    Active {
35        jeffreys_logdet: f64,
36        hat_diag: Array1<f64>,
37    },
38}
39
40impl FirthDiagnostics {
41    #[inline]
42    pub fn jeffreys_logdet(&self) -> Option<f64> {
43        match self {
44            Self::Inactive => None,
45            Self::Active {
46                jeffreys_logdet, ..
47            } => Some(*jeffreys_logdet),
48        }
49    }
50}
51
52/// Which information matrix the penalized Hessian carries at the current
53/// PIRLS iterate.
54///
55/// Canonical links (logit-Binomial, log-Poisson) have W_obs == W_Fisher, so
56/// the two choices coincide. Non-canonical links (probit, cloglog, mixture,
57/// flexible, Gamma-log, ...) need observed information W_obs = W_Fisher -
58/// (y - mu) * B for the outer REML/Laplace log|H| and trace terms to be
59/// exact; Fisher weights alone yield a PQL-type surrogate. We fall back to
60/// `Fisher` only when the observed-information Hessian fails the
61/// positive-definiteness check, since the inner Newton step must be SPD.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum HessianCurvatureKind {
64    /// Expected (Fisher) information: W_Fisher = h'^2 / (phi * V(mu)).
65    /// Used as the inner iteration matrix when observed curvature fails (non-SPD).
66    Fisher,
67    /// Observed information: W_obs = W_Fisher - (y - mu) * B.
68    /// Required for the outer REML log|H| and trace terms (exact Laplace).
69    Observed,
70}
71
72/// The exported Laplace curvature kind used for the outer REML criterion.
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub enum ExportedLaplaceCurvature {
75    ObservedExact,
76    ExpectedInformationSurrogate,
77    InvalidObservedCurvature {
78        min_eigenvalue: f64,
79        pd_tolerance: f64,
80        gradient_norm: f64,
81    },
82}
83
84/// Working state at a PIRLS iterate: gradient, Hessian, deviance, etc.
85#[derive(Debug, Clone)]
86pub struct WorkingState {
87    pub eta: LinearPredictor,
88    pub gradient: Array1<f64>,
89    pub hessian: gam_linalg::matrix::SymmetricMatrix,
90    pub log_likelihood: f64,
91    pub deviance: f64,
92    pub penalty_term: f64,
93    pub firth: FirthDiagnostics,
94    // Ridge added to ensure positive definiteness of the penalized Hessian.
95    // `penalty_term` stores the full quadratic form contribution
96    // ridge * ||beta||^2. The optimization objective uses
97    // 0.5 * (deviance + penalty_term), so this corresponds to
98    // 0.5 * ridge * ||beta||^2 on the log-likelihood scale.
99    pub ridge_used: f64,
100    pub hessian_curvature: HessianCurvatureKind,
101    // Natural scale of the penalized gradient, used to form a scale-invariant
102    // KKT certificate.  Equal to ||X'(weighted_residual)||_2 + ||S*beta||_2
103    // (+ ridge*||beta||_2 when a stabilizing ridge is active).  Under
104    // stochastic noise the score component scales as O(sqrt(n)), so an
105    // absolute ||g||_2 < tol test rejects fits whose normalized stationarity
106    // residual is already negligible. Convergence uses ||g||_2 / (1 + this).
107    pub gradient_natural_scale: f64,
108}
109
110impl WorkingState {
111    #[inline]
112    pub fn jeffreys_logdet(&self) -> Option<f64> {
113        self.firth.jeffreys_logdet()
114    }
115
116    /// Scale-invariant relative gradient residual.
117    ///
118    /// Returns ||g||_2 / (1 + ||score||_2 + ||S*beta||_2 + ridge*||beta||_2).
119    /// `g_norm` is the projected/constrained stationarity residual in the
120    /// current PIRLS basis; the denominator is the natural magnitude of the
121    /// penalized gradient and is invariant under uniform rescaling of the
122    /// objective.
123    #[inline]
124    pub fn relative_gradient_norm(&self, g_norm: f64) -> f64 {
125        g_norm / (1.0 + self.gradient_natural_scale)
126    }
127
128    /// Dimension-based scale `√n · max(1, √p)` for the structural KKT bound.
129    ///
130    /// Under standardized columns, the score `Xᵀ(μ − y)` has components of
131    /// order O(√n), so the absolute test ‖g‖ < τ becomes systematically too
132    /// tight at large n. Multiplying τ by this scale restores the advertised
133    /// per-observation meaning.
134    #[inline]
135    pub(crate) fn kkt_dimension_scale(&self) -> f64 {
136        let n = self.eta.len().max(1) as f64;
137        let p = (self.gradient.len() as f64).max(1.0);
138        n.sqrt() * p.sqrt()
139    }
140
141    /// Strict KKT acceptance: `g_norm` certifies stationarity under EITHER
142    /// scale-invariant criterion (dimension-based or data-driven natural-scale).
143    ///
144    /// Both certificates are invariant under uniform rescaling of the objective
145    /// `F → c·F` (in the limit where the natural scale dominates the additive
146    /// `1` floor). Acceptance under either is sufficient because:
147    ///   - the natural-scale bound is tighter when the data are well-scaled
148    ///     (it tracks actual gradient component magnitudes);
149    ///   - the dimension bound is tighter when the design matrix has unusual
150    ///     scaling (so the natural scale is dominated by a single component).
151    #[inline]
152    pub fn certifies_kkt(&self, g_norm: f64, tol: f64) -> bool {
153        g_norm < tol * self.kkt_dimension_scale() || self.relative_gradient_norm(g_norm) < tol
154    }
155
156    /// Near-stationary band (10× the strict KKT tolerance) under EITHER
157    /// scale-invariant criterion. Used as a "good-enough" plateau check
158    /// that classifies a fit as `StalledAtValidMinimum` rather than as a
159    /// hard non-convergence. The band is `10 · tol` without a
160    /// floor — a caller asking for `tol = 1e-12` gets a 1e-11 band, not
161    /// the 1e-5 the old `tol.max(1e-6) * 10` formula silently widened it
162    /// to. The 1e-6 floor was masking real convergence regressions
163    /// (e.g. `constant_prior_mean_centers_penalty`'s LM-ridge induced
164    /// 2.5e-8 bias visible only when the user asked for sub-1e-6
165    /// precision).
166    #[inline]
167    pub fn near_stationary_kkt(&self, g_norm: f64, tol: f64) -> bool {
168        let near_tol = tol * 10.0;
169        g_norm <= near_tol * self.kkt_dimension_scale()
170            || self.relative_gradient_norm(g_norm) <= near_tol
171    }
172}
173
174/// Numerically stable Euclidean norm of an `Array1<f64>`.
175///
176/// Used to assemble the penalized-gradient natural scale at every
177/// `WorkingState` construction site (main GAM, identity-link short circuit,
178/// survival, test mocks). Centralizing here avoids drift between sites and
179/// makes the convergence certificate's denominator a single source of truth.
180///
181/// One pass, no allocation, O(p). At p≈10⁴ the cost is ≪ the O(np²) PIRLS
182/// inner work, so this is free in any setting where it matters.
183#[inline]
184pub fn array1_l2_norm(v: &Array1<f64>) -> f64 {
185    v.iter().map(|x| x * x).sum::<f64>().sqrt()
186}
187
188/// Adaptive KKT tolerance parameters for the inner PIRLS convergence test.
189#[derive(Clone, Copy, Debug)]
190pub struct AdaptiveKktTolerance {
191    pub eta: f64,
192    pub floor: f64,
193    pub ceiling: f64,
194    pub outer_grad_norm: f64,
195}
196
197/// Per-iteration PIRLS diagnostic info reported to the callback.
198#[derive(Clone, Debug)]
199pub struct WorkingModelIterationInfo {
200    pub iteration: usize,
201    pub deviance: f64,
202    pub gradient_norm: f64,
203    pub step_size: f64,
204    pub step_halving: usize,
205}
206
207/// Result of the inner `runworking_model_pirls` loop.
208#[derive(Clone)]
209pub struct WorkingModelPirlsResult {
210    pub beta: Coefficients,
211    pub state: WorkingState,
212    pub status: PirlsStatus,
213    pub iterations: usize,
214    pub lastgradient_norm: f64,
215    pub last_deviance_change: f64,
216    pub last_step_size: f64,
217    pub last_step_halving: usize,
218    pub max_abs_eta: f64,
219    pub constraint_kkt: Option<ConstraintKktDiagnostics>,
220    /// Levenberg-Marquardt damping coefficient at the last accepted
221    /// inner iter. Used by the REML runtime to seed the next PIRLS call
222    /// at the same outer fit, avoiding 4-6 iters of damping rediscovery
223    /// when the geometry calls for `λ_LM > 1e-6`.
224    pub final_lm_lambda: f64,
225    /// Gain ratio (`actual_reduction / predicted_reduction`) at the
226    /// last accepted inner iter. `None` when no step was accepted
227    /// (rejection-exhausted, MaxIterationsReached without acceptance).
228    /// Programmatic counterpart to the per-iter `[PIRLS lm-trajectory]`
229    /// log line's `accept_rho` field — the log is grep-only, this
230    /// field is queryable by the outer schedule and convergence guard.
231    /// Values near 1.0 indicate the quadratic model is faithful;
232    /// values much smaller indicate the LM model is over-stating
233    /// predicted reduction and the inner Newton may benefit from
234    /// shorter steps.
235    pub final_accept_rho: Option<f64>,
236    /// Minimum penalized deviance (`state.deviance + state.penalty_term`)
237    /// observed across all iterations whose state was computed during the
238    /// inner P-IRLS loop. Penalized deviance is monotonically decreasing
239    /// along any descent path the inner solver takes, so this minimum is a
240    /// principled seed-screening proxy that remains meaningful even when the
241    /// solver hit its iteration cap before reaching the mode. `f64::INFINITY`
242    /// when no state was ever computed (paths that synthesize a result
243    /// without iterating, e.g. zero-iteration warm-only paths).
244    pub min_penalized_deviance: f64,
245    pub exported_laplace_curvature: ExportedLaplaceCurvature,
246}
247
248/// The status of the P-IRLS convergence.
249#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
250pub enum PirlsStatus {
251    /// Converged successfully within tolerance.
252    Converged,
253    /// Reached maximum iterations but the gradient and Hessian indicate a valid minimum.
254    StalledAtValidMinimum,
255    /// Reached maximum iterations without converging.
256    MaxIterationsReached,
257    /// Levenberg-Marquardt step search exhausted its retry budget (damping λ
258    /// reached its ceiling, attempts counter expired, or λ went non-finite)
259    /// before the projected gradient entered the near-stationary band. Distinct
260    /// from `MaxIterationsReached`, which means the outer iteration counter
261    /// itself ran out — that exhaustion is a "looped 100×, made progress each
262    /// time but never converged" signal, while this one is a "no acceptable
263    /// step direction even after damping" signal pointing at curvature trouble
264    /// or saturated likelihoods.
265    LmStepSearchExhausted,
266    /// Fitting process became unstable, likely due to perfect separation.
267    Unstable,
268}
269
270impl PirlsStatus {
271    /// Whether the inner loop concluded without producing a usable mode.
272    /// Both the iteration-cap and LM-exhausted exits should be treated the
273    /// same by callers that just want to know "did we get a valid solution?".
274    #[inline]
275    pub const fn is_failed_max_iterations(self) -> bool {
276        matches!(
277            self,
278            PirlsStatus::MaxIterationsReached | PirlsStatus::LmStepSearchExhausted
279        )
280    }
281
282    /// Short human-readable label for reports and diagnostics. Stable text
283    /// (not the `Debug` rendering) so report output does not silently change if
284    /// the variant identifiers are ever renamed.
285    #[inline]
286    pub const fn label(self) -> &'static str {
287        match self {
288            PirlsStatus::Converged => "Converged",
289            PirlsStatus::StalledAtValidMinimum => "Stalled at valid minimum",
290            PirlsStatus::MaxIterationsReached => "Max iterations reached",
291            PirlsStatus::LmStepSearchExhausted => "LM step search exhausted",
292            PirlsStatus::Unstable => "Unstable (possible separation)",
293        }
294    }
295
296    /// Whether this status represents a clean convergence to the mode. Only
297    /// `Converged` qualifies; every other state carries a caveat a reader
298    /// should see flagged.
299    #[inline]
300    pub const fn is_converged(self) -> bool {
301        matches!(self, PirlsStatus::Converged)
302    }
303}
304
305/// Holds the result of a converged P-IRLS inner loop for a fixed rho.
306///
307/// # Basis of Returned Tensors
308///
309/// **IMPORTANT:** All vector and matrix outputs in this struct (`beta_transformed`,
310/// `penalized_hessian_transformed`) are in the **stable, transformed basis**
311/// that was computed for the given set of smoothing parameters.
312///
313/// To obtain coefficients in the original, interpretable basis, the caller must
314/// back-transform them using the `qs` matrix from the `reparam_result` field:
315/// `beta_original = reparam_result.qs.dot(&beta_transformed)`
316///
317/// # Fields
318///
319/// * `beta_transformed`: The estimated coefficient vector in the STABLE, TRANSFORMED basis.
320/// * `penalized_hessian_transformed`: The penalized Hessian matrix at convergence
321///   (`X'W_H X + S_λ`, with `W_H` equal to Fisher or observed curvature,
322///   depending on the accepted PIRLS step) in the STABLE, TRANSFORMED basis.
323/// * `deviance`: The final deviance value. This is family-specific:
324///    - Gaussian identity: weighted residual sum of squares.
325///    - Binomial families: binomial deviance.
326///    - Poisson log: Poisson deviance.
327///    - Gamma log: Gamma unit deviance scaled by the fitted Gamma shape.
328/// * `finalweights`: The final Hessian-side working weights at convergence.
329/// * `solveweights`: The final score-side Fisher weights used in
330///   `X'W(z-eta) - S beta`.
331/// * `reparam_result`: Contains the transformation matrix (`qs`) and other reparameterization data.
332///
333/// # Point Estimate: Posterior Mode (MAP)
334///
335/// The coefficients returned by PIRLS are the **posterior mode** (Maximum A Posteriori estimate),
336/// not the posterior mean. For risk predictions, the posterior mean is theoretically preferable
337/// mode ≈ mean and it doesn't matter. For asymmetric posteriors (rare events, boundary effects),
338/// the mean would give more accurate calibrated probabilities. To obtain the posterior mean,
339/// one would need MCMC sampling from the posterior and average f(patient, β) over samples.
340#[derive(Clone)]
341pub struct PirlsResult {
342    pub likelihood: GlmLikelihoodSpec,
343    // Coefficients and Hessian are now in the STABLE, TRANSFORMED basis
344    pub beta_transformed: Coefficients,
345    pub penalized_hessian_transformed: SymmetricMatrix,
346    // Single stabilized Hessian for consistent cost/gradient computation
347    pub stabilizedhessian_transformed: SymmetricMatrix,
348    /// Canonical ridge metadata passport consumed by outer objective/gradient code.
349    pub ridge_passport: RidgePassport,
350    // Ridge added to make the stabilized Hessian positive definite. When > 0,
351    // `stable_penalty_term` includes ridge_used * ||beta||^2 (which contributes
352    // 0.5 * ridge_used * ||beta||^2 in -0.5 * (deviance + stable_penalty_term)).
353    // Backward-compatible mirror of `ridge_passport.delta`.
354    pub ridge_used: f64,
355
356    // The unpenalized deviance, calculated from mu and y
357    pub deviance: f64,
358
359    // Effective degrees of freedom at the solution
360    pub edf: f64,
361
362    // The penalty term, calculated stably within P-IRLS.
363    // This is beta_transformed' * S_transformed * beta_transformed, plus
364    // ridge_used * ||beta||^2 when stabilization is active so that the
365    // penalized deviance matches the stabilized Hessian.
366    pub stable_penalty_term: f64,
367
368    /// Firth diagnostics in the converged PIRLS state.
369    pub firth: FirthDiagnostics,
370
371    // Diagonal weights defining the Hessian surface returned to outer REML/LAML.
372    //
373    // For canonical links Fisher = Observed identically. For non-canonical links,
374    // PIRLS always recomputes observed weights at the accepted β̂ in a
375    // post-convergence finalization step (see "Post-convergence Laplace curvature
376    // finalization"), so `finalweights` carries the *observed-information* diagonal
377    // whenever the model supports it — even if the inner LM loop ended on Fisher
378    // due to a fallback. Exact label of what these represent is in
379    // `exported_laplace_curvature`; do not infer the kind from `hessian_curvature`
380    // (which records what the inner loop's last accepted step happened to use).
381    pub finalweights: Array1<f64>,
382    // Additional PIRLS state captured at the accepted step to support
383    // cost/gradient consistency in the outer optimization
384    pub final_offset: Array1<f64>,
385    pub final_eta: Array1<f64>,
386    pub finalmu: Array1<f64>,
387    /// Score-side Fisher weights used in `X'W(z-eta) - S beta`.
388    pub solveweights: Array1<f64>,
389    pub solveworking_response: Array1<f64>,
390    pub solvemu: Array1<f64>,
391    pub solve_dmu_deta: Array1<f64>,
392    pub solve_d2mu_deta2: Array1<f64>,
393    pub solve_d3mu_deta3: Array1<f64>,
394    /// First eta-derivative of the diagonal Hessian curvature W_H(eta):
395    /// c_i := dW_i/deta_i at the accepted PIRLS solution.
396    ///
397    /// This carries 3rd-order likelihood information used in exact dH/dρ
398    /// terms for outer LAML derivatives.
399    pub solve_c_array: Array1<f64>,
400    /// Second eta-derivative of the diagonal Hessian curvature W_H(eta):
401    /// d_i := d²W_i/deta_i² at the accepted PIRLS solution.
402    ///
403    /// This carries 4th-order likelihood information used in exact d²H/dρ²
404    /// terms for the outer LAML Hessian.
405    pub solve_d_array: Array1<f64>,
406    /// True when `solve_c_array` / `solve_d_array` are placeholders rather
407    /// than supported likelihood derivatives.
408    pub derivatives_unsupported: bool,
409
410    // Keep all other fields as they are
411    pub status: PirlsStatus,
412    pub iteration: usize,
413    pub max_abs_eta: f64,
414    pub lastgradient_norm: f64,
415    /// Natural scale of the penalized gradient at the accepted PIRLS state,
416    /// equal to ‖Xᵀ(weighted residual)‖₂ + ‖Sβ‖₂ (+ ridge·‖β‖₂ when active).
417    /// Mirrors `WorkingState::gradient_natural_scale` so that callers reading
418    /// `PirlsResult` directly (e.g. seed-screening cost augmentation) can form
419    /// the scale-invariant residual r_g = ‖g‖ / (1 + this) without rebuilding
420    /// the score and penalty norms.
421    pub gradient_natural_scale: f64,
422    pub last_deviance_change: f64,
423    pub last_step_halving: usize,
424    pub hessian_curvature: HessianCurvatureKind,
425    pub exported_laplace_curvature: ExportedLaplaceCurvature,
426    /// Levenberg-Marquardt damping coefficient at the converged inner
427    /// iter. Cached by the REML runtime so the next PIRLS call in the
428    /// same outer optimization can seed `λ_LM` to this value instead
429    /// of cold-starting at `1e-6`. Mirrors `WorkingModelPirlsResult::final_lm_lambda`.
430    pub final_lm_lambda: f64,
431    /// Gain ratio of the last accepted LM step inside this PIRLS solve,
432    /// `None` when no step was accepted (e.g. zero-iteration synthesis,
433    /// rejection-exhausted, MaxIterations without acceptance). Mirrors
434    /// `WorkingModelPirlsResult::final_accept_rho`. Programmatic
435    /// counterpart to the per-iter `[PIRLS lm-trajectory]` log line's
436    /// `accept_rho` field, queryable by outer consumers (cap schedule,
437    /// convergence guard) for inner-Newton model-fidelity decisions.
438    pub final_accept_rho: Option<f64>,
439    /// Optional KKT diagnostics when inequality constraints were active.
440    pub constraint_kkt: Option<ConstraintKktDiagnostics>,
441    /// Linear inequality system enforced in transformed PIRLS coordinates:
442    /// `A * beta_transformed >= b`.
443    pub linear_constraints_transformed: Option<LinearInequalityConstraints>,
444
445    // Pass through the entire reparameterization result for use in the gradient
446    pub reparam_result: ReparamResult,
447    // Cached X·Qs for this PIRLS result (transformed design matrix)
448    pub x_transformed: DesignMatrix,
449    pub coordinate_frame: PirlsCoordinateFrame,
450    /// True when this fixed-rho inner solve completed on a GPU path.
451    pub used_device: bool,
452    /// True when this result was compacted for REML LRU storage and needs
453    /// cold artifacts (for example `x_transformed`) rehydrated before exact
454    /// bundle construction.
455    pub cache_compacted: bool,
456    /// Minimum penalized deviance observed across the inner P-IRLS loop.
457    /// Mirrors `WorkingModelPirlsResult::min_penalized_deviance`. Used as the
458    /// seed-screening ranking proxy: penalized deviance descends monotonically
459    /// along any inner descent path, so the per-seed minimum tells the outer
460    /// cascade "how good a fit this rho's neighbourhood can support" even
461    /// when the inner solver was capped before reaching the mode.
462    pub min_penalized_deviance: f64,
463}
464
465impl PirlsResult {
466    /// Export the stabilized transformed Hessian as an exact dense matrix for
467    /// downstream solve paths that require explicit Hessians.
468    ///
469    /// The returned matrix is the convergence Hessian already used by PIRLS and
470    /// REML (`X'W_HX + S_λ`, plus the explicit stabilization ridge when active).
471    /// Sparse-native fits are materialized from their assembled sparse Hessian;
472    /// no numerical Hessian approximation or compatibility fallback is used.
473    pub fn dense_stabilizedhessian_transformed(
474        &self,
475        context: &str,
476    ) -> Result<Array2<f64>, EstimationError> {
477        self.stabilizedhessian_transformed
478            .try_to_dense_exact(context)
479            .map_err(EstimationError::InvalidInput)
480    }
481
482    #[inline]
483    pub fn jeffreys_logdet(&self) -> Option<f64> {
484        self.firth.jeffreys_logdet()
485    }
486
487    /// Typed view of the Hessian-side working weight diagonal stored on this
488    /// result, sign-honest. `finalweights` carries the observed-information
489    /// diagonal whenever the model supports it (see `exported_laplace_curvature`),
490    /// and observed weights `W_obs = W_F - (y - μ) · B` can be negative for
491    /// non-canonical links. Consumers feeding this into the asymmetric
492    /// `X_iᵀ W X_j` path, `weighted_crossprod_dense_rows`, or
493    /// `xt_diag_x_signed_op` must use this typed view rather than borrowing
494    /// the raw `Array1<f64>` so the function-boundary type contract from
495    /// `linalg/matrix.rs` is construction-enforced.
496    #[inline]
497    pub fn final_weights_signed(&self) -> SignedWeightsView<'_> {
498        SignedWeightsView::from_array(&self.finalweights)
499    }
500
501    /// Typed view of the score-side Fisher weights `W_F = h'²/(φ V(μ)) ≥ 0`
502    /// stored on this result, PSD-by-construction. Used by PSD-Gram kernels
503    /// (`dense_xtwx_view`, `sparse_csr_weighted_xtwx_*`, `xt_diag_x_psd_op`)
504    /// without a runtime sign scan; the PSD obligation is discharged
505    /// algebraically by the Fisher formula at the construction site in
506    /// `solver/pirls/mod.rs`. New callers that need the same diagonal under
507    /// a sign-honest API should route through `as_signed()` on the returned
508    /// view rather than reconstructing from the raw array.
509    #[inline]
510    pub fn solve_weights_psd(&self) -> PsdWeightsView<'_> {
511        PsdWeightsView::from_view_unchecked(self.solveweights.view())
512    }
513
514    /// Scale-invariant relative gradient residual at the accepted PIRLS state.
515    ///
516    /// Returns ‖g‖ / (1 + ‖score‖ + ‖Sβ‖ + ridge·‖β‖). Numerator is
517    /// `lastgradient_norm`; denominator is `1 + gradient_natural_scale`.
518    /// This is the "r_g" used by seed-screening cost augmentation.
519    #[inline]
520    pub fn relative_gradient_norm(&self) -> f64 {
521        self.lastgradient_norm / (1.0 + self.gradient_natural_scale)
522    }
523
524    pub(crate) fn compact_for_reml_cache(&self) -> Self {
525        Self {
526            likelihood: self.likelihood.clone(),
527            beta_transformed: self.beta_transformed.clone(),
528            penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
529            stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
530            ridge_passport: self.ridge_passport,
531            ridge_used: self.ridge_used,
532            deviance: self.deviance,
533            edf: self.edf,
534            stable_penalty_term: self.stable_penalty_term,
535            firth: self.firth.clone(),
536            finalweights: Array1::zeros(0),
537            final_offset: Array1::zeros(0),
538            final_eta: self.final_eta.clone(),
539            finalmu: Array1::zeros(0),
540            solveweights: self.solveweights.clone(),
541            solveworking_response: self.solveworking_response.clone(),
542            solvemu: self.solvemu.clone(),
543            solve_dmu_deta: Array1::zeros(0),
544            solve_d2mu_deta2: Array1::zeros(0),
545            solve_d3mu_deta3: Array1::zeros(0),
546            solve_c_array: self.solve_c_array.clone(),
547            solve_d_array: self.solve_d_array.clone(),
548            derivatives_unsupported: self.derivatives_unsupported,
549            status: self.status,
550            iteration: self.iteration,
551            max_abs_eta: self.max_abs_eta,
552            lastgradient_norm: self.lastgradient_norm,
553            gradient_natural_scale: self.gradient_natural_scale,
554            last_deviance_change: self.last_deviance_change,
555            last_step_halving: self.last_step_halving,
556            hessian_curvature: self.hessian_curvature,
557            exported_laplace_curvature: self.exported_laplace_curvature.clone(),
558            final_lm_lambda: self.final_lm_lambda,
559            final_accept_rho: self.final_accept_rho,
560            constraint_kkt: self.constraint_kkt.clone(),
561            linear_constraints_transformed: self.linear_constraints_transformed.clone(),
562            reparam_result: self.reparam_result.clone(),
563            x_transformed: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
564                Array2::zeros((0, 0)),
565            )),
566            coordinate_frame: self.coordinate_frame,
567            used_device: self.used_device,
568            cache_compacted: true,
569            min_penalized_deviance: self.min_penalized_deviance,
570        }
571    }
572
573    pub(crate) fn rehydrate_after_reml_cache(
574        &self,
575        x_original: &DesignMatrix,
576        y: ArrayView1<'_, f64>,
577        priorweights: ArrayView1<'_, f64>,
578        offset: ArrayView1<'_, f64>,
579        inverse_link: &InverseLink,
580    ) -> Result<Self, EstimationError> {
581        if !self.cache_compacted {
582            return Ok(self.clone());
583        }
584
585        let (score_c_array, score_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
586            computeworkingweight_derivatives_from_eta(
587                &self.likelihood,
588                inverse_link,
589                &self.final_eta,
590                priorweights,
591            )?;
592        let (finalweights, solve_c_array, solve_d_array) =
593            if self.hessian_curvature == HessianCurvatureKind::Observed {
594                compute_observed_hessian_curvature_arrays(
595                    &self.likelihood,
596                    inverse_link,
597                    &self.final_eta,
598                    y,
599                    &self.solveweights,
600                    priorweights,
601                )?
602            } else {
603                (
604                    self.solveweights.clone(),
605                    score_c_array.clone(),
606                    score_d_array.clone(),
607                )
608            };
609        // Lazy rehydration: wrap in ReparamOperator instead of materializing X·Qs.
610        let qs_arc = Arc::new(self.reparam_result.qs.clone());
611        Ok(Self {
612            likelihood: self.likelihood.clone(),
613            beta_transformed: self.beta_transformed.clone(),
614            penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
615            stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
616            ridge_passport: self.ridge_passport,
617            ridge_used: self.ridge_used,
618            used_device: self.used_device,
619            deviance: self.deviance,
620            edf: self.edf,
621            stable_penalty_term: self.stable_penalty_term,
622            firth: self.firth.clone(),
623            finalweights,
624            final_offset: offset.to_owned(),
625            final_eta: self.final_eta.clone(),
626            finalmu: self.solvemu.clone(),
627            solveweights: self.solveweights.clone(),
628            solveworking_response: self.solveworking_response.clone(),
629            solvemu: self.solvemu.clone(),
630            solve_dmu_deta,
631            solve_d2mu_deta2,
632            solve_d3mu_deta3,
633            solve_c_array,
634            solve_d_array,
635            derivatives_unsupported: self.derivatives_unsupported,
636            status: self.status,
637            iteration: self.iteration,
638            max_abs_eta: self.max_abs_eta,
639            lastgradient_norm: self.lastgradient_norm,
640            gradient_natural_scale: self.gradient_natural_scale,
641            last_deviance_change: self.last_deviance_change,
642            last_step_halving: self.last_step_halving,
643            hessian_curvature: self.hessian_curvature,
644            exported_laplace_curvature: self.exported_laplace_curvature.clone(),
645            final_lm_lambda: self.final_lm_lambda,
646            final_accept_rho: self.final_accept_rho,
647            constraint_kkt: self.constraint_kkt.clone(),
648            linear_constraints_transformed: self.linear_constraints_transformed.clone(),
649            reparam_result: self.reparam_result.clone(),
650            x_transformed: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
651                ReparamOperator::new(x_original.clone(), qs_arc),
652            ))),
653            coordinate_frame: self.coordinate_frame,
654            cache_compacted: false,
655            min_penalized_deviance: self.min_penalized_deviance,
656        })
657    }
658}