gam_solve/pirls/state.rs
1use gam_terms::construction::ReparamResult;
2use crate::estimate::EstimationError;
3use gam_linalg::matrix::{
4 DesignMatrix, PsdWeightsView, ReparamOperator, SignedWeightsView, SymmetricMatrix,
5};
6use crate::active_set::ConstraintKktDiagnostics;
7use gam_problem::{Coefficients, GlmLikelihoodSpec, InverseLink, LinearPredictor, RidgePassport};
8use gam_problem::LinearInequalityConstraints;
9use ndarray::{Array1, Array2, ArrayView1};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13use super::{compute_observed_hessian_curvature_arrays, computeworkingweight_derivatives_from_eta};
14
15/// Whether the solve operates in sparse-native or dense-transformed coordinates.
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum PirlsLinearSolvePath {
18 DenseTransformed,
19 SparseNative,
20}
21
22/// Coordinate frame for the PIRLS inner iteration.
23#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum PirlsCoordinateFrame {
25 TransformedQs,
26 OriginalSparseNative,
27}
28
29/// Firth bias-reduction diagnostics at convergence.
30#[derive(Debug, Clone, Default)]
31pub enum FirthDiagnostics {
32 #[default]
33 Inactive,
34 Active {
35 jeffreys_logdet: f64,
36 hat_diag: Array1<f64>,
37 },
38}
39
40impl FirthDiagnostics {
41 #[inline]
42 pub fn jeffreys_logdet(&self) -> Option<f64> {
43 match self {
44 Self::Inactive => None,
45 Self::Active {
46 jeffreys_logdet, ..
47 } => Some(*jeffreys_logdet),
48 }
49 }
50}
51
52/// Which information matrix the penalized Hessian carries at the current
53/// PIRLS iterate.
54///
55/// Canonical links (logit-Binomial, log-Poisson) have W_obs == W_Fisher, so
56/// the two choices coincide. Non-canonical links (probit, cloglog, mixture,
57/// flexible, Gamma-log, ...) need observed information W_obs = W_Fisher -
58/// (y - mu) * B for the outer REML/Laplace log|H| and trace terms to be
59/// exact; Fisher weights alone yield a PQL-type surrogate. We fall back to
60/// `Fisher` only when the observed-information Hessian fails the
61/// positive-definiteness check, since the inner Newton step must be SPD.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum HessianCurvatureKind {
64 /// Expected (Fisher) information: W_Fisher = h'^2 / (phi * V(mu)).
65 /// Used as the inner iteration matrix when observed curvature fails (non-SPD).
66 Fisher,
67 /// Observed information: W_obs = W_Fisher - (y - mu) * B.
68 /// Required for the outer REML log|H| and trace terms (exact Laplace).
69 Observed,
70}
71
72/// The exported Laplace curvature kind used for the outer REML criterion.
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub enum ExportedLaplaceCurvature {
75 ObservedExact,
76 ExpectedInformationSurrogate,
77 InvalidObservedCurvature {
78 min_eigenvalue: f64,
79 pd_tolerance: f64,
80 gradient_norm: f64,
81 },
82}
83
84/// Working state at a PIRLS iterate: gradient, Hessian, deviance, etc.
85#[derive(Debug, Clone)]
86pub struct WorkingState {
87 pub eta: LinearPredictor,
88 pub gradient: Array1<f64>,
89 pub hessian: gam_linalg::matrix::SymmetricMatrix,
90 pub log_likelihood: f64,
91 pub deviance: f64,
92 pub penalty_term: f64,
93 pub firth: FirthDiagnostics,
94 // Ridge added to ensure positive definiteness of the penalized Hessian.
95 // `penalty_term` stores the full quadratic form contribution
96 // ridge * ||beta||^2. The optimization objective uses
97 // 0.5 * (deviance + penalty_term), so this corresponds to
98 // 0.5 * ridge * ||beta||^2 on the log-likelihood scale.
99 pub ridge_used: f64,
100 pub hessian_curvature: HessianCurvatureKind,
101 // Natural scale of the penalized gradient, used to form a scale-invariant
102 // KKT certificate. Equal to ||X'(weighted_residual)||_2 + ||S*beta||_2
103 // (+ ridge*||beta||_2 when a stabilizing ridge is active). Under
104 // stochastic noise the score component scales as O(sqrt(n)), so an
105 // absolute ||g||_2 < tol test rejects fits whose normalized stationarity
106 // residual is already negligible. Convergence uses ||g||_2 / (1 + this).
107 pub gradient_natural_scale: f64,
108}
109
110impl WorkingState {
111 #[inline]
112 pub fn jeffreys_logdet(&self) -> Option<f64> {
113 self.firth.jeffreys_logdet()
114 }
115
116 /// Scale-invariant relative gradient residual.
117 ///
118 /// Returns ||g||_2 / (1 + ||score||_2 + ||S*beta||_2 + ridge*||beta||_2).
119 /// `g_norm` is the projected/constrained stationarity residual in the
120 /// current PIRLS basis; the denominator is the natural magnitude of the
121 /// penalized gradient and is invariant under uniform rescaling of the
122 /// objective.
123 #[inline]
124 pub fn relative_gradient_norm(&self, g_norm: f64) -> f64 {
125 g_norm / (1.0 + self.gradient_natural_scale)
126 }
127
128 /// Dimension-based scale `√n · max(1, √p)` for the structural KKT bound.
129 ///
130 /// Under standardized columns, the score `Xᵀ(μ − y)` has components of
131 /// order O(√n), so the absolute test ‖g‖ < τ becomes systematically too
132 /// tight at large n. Multiplying τ by this scale restores the advertised
133 /// per-observation meaning.
134 #[inline]
135 pub(crate) fn kkt_dimension_scale(&self) -> f64 {
136 let n = self.eta.len().max(1) as f64;
137 let p = (self.gradient.len() as f64).max(1.0);
138 n.sqrt() * p.sqrt()
139 }
140
141 /// Strict KKT acceptance: `g_norm` certifies stationarity under EITHER
142 /// scale-invariant criterion (dimension-based or data-driven natural-scale).
143 ///
144 /// Both certificates are invariant under uniform rescaling of the objective
145 /// `F → c·F` (in the limit where the natural scale dominates the additive
146 /// `1` floor). Acceptance under either is sufficient because:
147 /// - the natural-scale bound is tighter when the data are well-scaled
148 /// (it tracks actual gradient component magnitudes);
149 /// - the dimension bound is tighter when the design matrix has unusual
150 /// scaling (so the natural scale is dominated by a single component).
151 #[inline]
152 pub fn certifies_kkt(&self, g_norm: f64, tol: f64) -> bool {
153 g_norm < tol * self.kkt_dimension_scale() || self.relative_gradient_norm(g_norm) < tol
154 }
155
156 /// Near-stationary band (10× the strict KKT tolerance) under EITHER
157 /// scale-invariant criterion. Used as a "good-enough" plateau check
158 /// that classifies a fit as `StalledAtValidMinimum` rather than as a
159 /// hard non-convergence. The band is `10 · tol` without a
160 /// floor — a caller asking for `tol = 1e-12` gets a 1e-11 band, not
161 /// the 1e-5 the old `tol.max(1e-6) * 10` formula silently widened it
162 /// to. The 1e-6 floor was masking real convergence regressions
163 /// (e.g. `constant_prior_mean_centers_penalty`'s LM-ridge induced
164 /// 2.5e-8 bias visible only when the user asked for sub-1e-6
165 /// precision).
166 #[inline]
167 pub fn near_stationary_kkt(&self, g_norm: f64, tol: f64) -> bool {
168 let near_tol = tol * 10.0;
169 g_norm <= near_tol * self.kkt_dimension_scale()
170 || self.relative_gradient_norm(g_norm) <= near_tol
171 }
172}
173
174/// Numerically stable Euclidean norm of an `Array1<f64>`.
175///
176/// Used to assemble the penalized-gradient natural scale at every
177/// `WorkingState` construction site (main GAM, identity-link short circuit,
178/// survival, test mocks). Centralizing here avoids drift between sites and
179/// makes the convergence certificate's denominator a single source of truth.
180///
181/// One pass, no allocation, O(p). At p≈10⁴ the cost is ≪ the O(np²) PIRLS
182/// inner work, so this is free in any setting where it matters.
183#[inline]
184pub fn array1_l2_norm(v: &Array1<f64>) -> f64 {
185 v.iter().map(|x| x * x).sum::<f64>().sqrt()
186}
187
188/// Adaptive KKT tolerance parameters for the inner PIRLS convergence test.
189#[derive(Clone, Copy, Debug)]
190pub struct AdaptiveKktTolerance {
191 pub eta: f64,
192 pub floor: f64,
193 pub ceiling: f64,
194 pub outer_grad_norm: f64,
195}
196
197/// Per-iteration PIRLS diagnostic info reported to the callback.
198#[derive(Clone, Debug)]
199pub struct WorkingModelIterationInfo {
200 pub iteration: usize,
201 pub deviance: f64,
202 pub gradient_norm: f64,
203 pub step_size: f64,
204 pub step_halving: usize,
205}
206
207/// Result of the inner `runworking_model_pirls` loop.
208#[derive(Clone)]
209pub struct WorkingModelPirlsResult {
210 pub beta: Coefficients,
211 pub state: WorkingState,
212 pub status: PirlsStatus,
213 pub iterations: usize,
214 pub lastgradient_norm: f64,
215 pub last_deviance_change: f64,
216 pub last_step_size: f64,
217 pub last_step_halving: usize,
218 pub max_abs_eta: f64,
219 pub constraint_kkt: Option<ConstraintKktDiagnostics>,
220 /// Levenberg-Marquardt damping coefficient at the last accepted
221 /// inner iter. Used by the REML runtime to seed the next PIRLS call
222 /// at the same outer fit, avoiding 4-6 iters of damping rediscovery
223 /// when the geometry calls for `λ_LM > 1e-6`.
224 pub final_lm_lambda: f64,
225 /// Gain ratio (`actual_reduction / predicted_reduction`) at the
226 /// last accepted inner iter. `None` when no step was accepted
227 /// (rejection-exhausted, MaxIterationsReached without acceptance).
228 /// Programmatic counterpart to the per-iter `[PIRLS lm-trajectory]`
229 /// log line's `accept_rho` field — the log is grep-only, this
230 /// field is queryable by the outer schedule and convergence guard.
231 /// Values near 1.0 indicate the quadratic model is faithful;
232 /// values much smaller indicate the LM model is over-stating
233 /// predicted reduction and the inner Newton may benefit from
234 /// shorter steps.
235 pub final_accept_rho: Option<f64>,
236 /// Minimum penalized deviance (`state.deviance + state.penalty_term`)
237 /// observed across all iterations whose state was computed during the
238 /// inner P-IRLS loop. Penalized deviance is monotonically decreasing
239 /// along any descent path the inner solver takes, so this minimum is a
240 /// principled seed-screening proxy that remains meaningful even when the
241 /// solver hit its iteration cap before reaching the mode. `f64::INFINITY`
242 /// when no state was ever computed (paths that synthesize a result
243 /// without iterating, e.g. zero-iteration warm-only paths).
244 pub min_penalized_deviance: f64,
245 pub exported_laplace_curvature: ExportedLaplaceCurvature,
246}
247
248/// The status of the P-IRLS convergence.
249#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
250pub enum PirlsStatus {
251 /// Converged successfully within tolerance.
252 Converged,
253 /// Reached maximum iterations but the gradient and Hessian indicate a valid minimum.
254 StalledAtValidMinimum,
255 /// Reached maximum iterations without converging.
256 MaxIterationsReached,
257 /// Levenberg-Marquardt step search exhausted its retry budget (damping λ
258 /// reached its ceiling, attempts counter expired, or λ went non-finite)
259 /// before the projected gradient entered the near-stationary band. Distinct
260 /// from `MaxIterationsReached`, which means the outer iteration counter
261 /// itself ran out — that exhaustion is a "looped 100×, made progress each
262 /// time but never converged" signal, while this one is a "no acceptable
263 /// step direction even after damping" signal pointing at curvature trouble
264 /// or saturated likelihoods.
265 LmStepSearchExhausted,
266 /// Fitting process became unstable, likely due to perfect separation.
267 Unstable,
268}
269
270impl PirlsStatus {
271 /// Whether the inner loop concluded without producing a usable mode.
272 /// Both the iteration-cap and LM-exhausted exits should be treated the
273 /// same by callers that just want to know "did we get a valid solution?".
274 #[inline]
275 pub const fn is_failed_max_iterations(self) -> bool {
276 matches!(
277 self,
278 PirlsStatus::MaxIterationsReached | PirlsStatus::LmStepSearchExhausted
279 )
280 }
281
282 /// Short human-readable label for reports and diagnostics. Stable text
283 /// (not the `Debug` rendering) so report output does not silently change if
284 /// the variant identifiers are ever renamed.
285 #[inline]
286 pub const fn label(self) -> &'static str {
287 match self {
288 PirlsStatus::Converged => "Converged",
289 PirlsStatus::StalledAtValidMinimum => "Stalled at valid minimum",
290 PirlsStatus::MaxIterationsReached => "Max iterations reached",
291 PirlsStatus::LmStepSearchExhausted => "LM step search exhausted",
292 PirlsStatus::Unstable => "Unstable (possible separation)",
293 }
294 }
295
296 /// Whether this status represents a clean convergence to the mode. Only
297 /// `Converged` qualifies; every other state carries a caveat a reader
298 /// should see flagged.
299 #[inline]
300 pub const fn is_converged(self) -> bool {
301 matches!(self, PirlsStatus::Converged)
302 }
303}
304
305/// Holds the result of a converged P-IRLS inner loop for a fixed rho.
306///
307/// # Basis of Returned Tensors
308///
309/// **IMPORTANT:** All vector and matrix outputs in this struct (`beta_transformed`,
310/// `penalized_hessian_transformed`) are in the **stable, transformed basis**
311/// that was computed for the given set of smoothing parameters.
312///
313/// To obtain coefficients in the original, interpretable basis, the caller must
314/// back-transform them using the `qs` matrix from the `reparam_result` field:
315/// `beta_original = reparam_result.qs.dot(&beta_transformed)`
316///
317/// # Fields
318///
319/// * `beta_transformed`: The estimated coefficient vector in the STABLE, TRANSFORMED basis.
320/// * `penalized_hessian_transformed`: The penalized Hessian matrix at convergence
321/// (`X'W_H X + S_λ`, with `W_H` equal to Fisher or observed curvature,
322/// depending on the accepted PIRLS step) in the STABLE, TRANSFORMED basis.
323/// * `deviance`: The final deviance value. This is family-specific:
324/// - Gaussian identity: weighted residual sum of squares.
325/// - Binomial families: binomial deviance.
326/// - Poisson log: Poisson deviance.
327/// - Gamma log: Gamma unit deviance scaled by the fitted Gamma shape.
328/// * `finalweights`: The final Hessian-side working weights at convergence.
329/// * `solveweights`: The final score-side Fisher weights used in
330/// `X'W(z-eta) - S beta`.
331/// * `reparam_result`: Contains the transformation matrix (`qs`) and other reparameterization data.
332///
333/// # Point Estimate: Posterior Mode (MAP)
334///
335/// The coefficients returned by PIRLS are the **posterior mode** (Maximum A Posteriori estimate),
336/// not the posterior mean. For risk predictions, the posterior mean is theoretically preferable
337/// mode ≈ mean and it doesn't matter. For asymmetric posteriors (rare events, boundary effects),
338/// the mean would give more accurate calibrated probabilities. To obtain the posterior mean,
339/// one would need MCMC sampling from the posterior and average f(patient, β) over samples.
340#[derive(Clone)]
341pub struct PirlsResult {
342 pub likelihood: GlmLikelihoodSpec,
343 // Coefficients and Hessian are now in the STABLE, TRANSFORMED basis
344 pub beta_transformed: Coefficients,
345 pub penalized_hessian_transformed: SymmetricMatrix,
346 // Single stabilized Hessian for consistent cost/gradient computation
347 pub stabilizedhessian_transformed: SymmetricMatrix,
348 /// Canonical ridge metadata passport consumed by outer objective/gradient code.
349 pub ridge_passport: RidgePassport,
350 // Ridge added to make the stabilized Hessian positive definite. When > 0,
351 // `stable_penalty_term` includes ridge_used * ||beta||^2 (which contributes
352 // 0.5 * ridge_used * ||beta||^2 in -0.5 * (deviance + stable_penalty_term)).
353 // Backward-compatible mirror of `ridge_passport.delta`.
354 pub ridge_used: f64,
355
356 // The unpenalized deviance, calculated from mu and y
357 pub deviance: f64,
358
359 // Effective degrees of freedom at the solution
360 pub edf: f64,
361
362 // The penalty term, calculated stably within P-IRLS.
363 // This is beta_transformed' * S_transformed * beta_transformed, plus
364 // ridge_used * ||beta||^2 when stabilization is active so that the
365 // penalized deviance matches the stabilized Hessian.
366 pub stable_penalty_term: f64,
367
368 /// Firth diagnostics in the converged PIRLS state.
369 pub firth: FirthDiagnostics,
370
371 // Diagonal weights defining the Hessian surface returned to outer REML/LAML.
372 //
373 // For canonical links Fisher = Observed identically. For non-canonical links,
374 // PIRLS always recomputes observed weights at the accepted β̂ in a
375 // post-convergence finalization step (see "Post-convergence Laplace curvature
376 // finalization"), so `finalweights` carries the *observed-information* diagonal
377 // whenever the model supports it — even if the inner LM loop ended on Fisher
378 // due to a fallback. Exact label of what these represent is in
379 // `exported_laplace_curvature`; do not infer the kind from `hessian_curvature`
380 // (which records what the inner loop's last accepted step happened to use).
381 pub finalweights: Array1<f64>,
382 // Additional PIRLS state captured at the accepted step to support
383 // cost/gradient consistency in the outer optimization
384 pub final_offset: Array1<f64>,
385 pub final_eta: Array1<f64>,
386 pub finalmu: Array1<f64>,
387 /// Score-side Fisher weights used in `X'W(z-eta) - S beta`.
388 pub solveweights: Array1<f64>,
389 pub solveworking_response: Array1<f64>,
390 pub solvemu: Array1<f64>,
391 pub solve_dmu_deta: Array1<f64>,
392 pub solve_d2mu_deta2: Array1<f64>,
393 pub solve_d3mu_deta3: Array1<f64>,
394 /// First eta-derivative of the diagonal Hessian curvature W_H(eta):
395 /// c_i := dW_i/deta_i at the accepted PIRLS solution.
396 ///
397 /// This carries 3rd-order likelihood information used in exact dH/dρ
398 /// terms for outer LAML derivatives.
399 pub solve_c_array: Array1<f64>,
400 /// Second eta-derivative of the diagonal Hessian curvature W_H(eta):
401 /// d_i := d²W_i/deta_i² at the accepted PIRLS solution.
402 ///
403 /// This carries 4th-order likelihood information used in exact d²H/dρ²
404 /// terms for the outer LAML Hessian.
405 pub solve_d_array: Array1<f64>,
406 /// True when `solve_c_array` / `solve_d_array` are placeholders rather
407 /// than supported likelihood derivatives.
408 pub derivatives_unsupported: bool,
409
410 // Keep all other fields as they are
411 pub status: PirlsStatus,
412 pub iteration: usize,
413 pub max_abs_eta: f64,
414 pub lastgradient_norm: f64,
415 /// Natural scale of the penalized gradient at the accepted PIRLS state,
416 /// equal to ‖Xᵀ(weighted residual)‖₂ + ‖Sβ‖₂ (+ ridge·‖β‖₂ when active).
417 /// Mirrors `WorkingState::gradient_natural_scale` so that callers reading
418 /// `PirlsResult` directly (e.g. seed-screening cost augmentation) can form
419 /// the scale-invariant residual r_g = ‖g‖ / (1 + this) without rebuilding
420 /// the score and penalty norms.
421 pub gradient_natural_scale: f64,
422 pub last_deviance_change: f64,
423 pub last_step_halving: usize,
424 pub hessian_curvature: HessianCurvatureKind,
425 pub exported_laplace_curvature: ExportedLaplaceCurvature,
426 /// Levenberg-Marquardt damping coefficient at the converged inner
427 /// iter. Cached by the REML runtime so the next PIRLS call in the
428 /// same outer optimization can seed `λ_LM` to this value instead
429 /// of cold-starting at `1e-6`. Mirrors `WorkingModelPirlsResult::final_lm_lambda`.
430 pub final_lm_lambda: f64,
431 /// Gain ratio of the last accepted LM step inside this PIRLS solve,
432 /// `None` when no step was accepted (e.g. zero-iteration synthesis,
433 /// rejection-exhausted, MaxIterations without acceptance). Mirrors
434 /// `WorkingModelPirlsResult::final_accept_rho`. Programmatic
435 /// counterpart to the per-iter `[PIRLS lm-trajectory]` log line's
436 /// `accept_rho` field, queryable by outer consumers (cap schedule,
437 /// convergence guard) for inner-Newton model-fidelity decisions.
438 pub final_accept_rho: Option<f64>,
439 /// Optional KKT diagnostics when inequality constraints were active.
440 pub constraint_kkt: Option<ConstraintKktDiagnostics>,
441 /// Linear inequality system enforced in transformed PIRLS coordinates:
442 /// `A * beta_transformed >= b`.
443 pub linear_constraints_transformed: Option<LinearInequalityConstraints>,
444
445 // Pass through the entire reparameterization result for use in the gradient
446 pub reparam_result: ReparamResult,
447 // Cached X·Qs for this PIRLS result (transformed design matrix)
448 pub x_transformed: DesignMatrix,
449 pub coordinate_frame: PirlsCoordinateFrame,
450 /// True when this fixed-rho inner solve completed on a GPU path.
451 pub used_device: bool,
452 /// True when this result was compacted for REML LRU storage and needs
453 /// cold artifacts (for example `x_transformed`) rehydrated before exact
454 /// bundle construction.
455 pub cache_compacted: bool,
456 /// Minimum penalized deviance observed across the inner P-IRLS loop.
457 /// Mirrors `WorkingModelPirlsResult::min_penalized_deviance`. Used as the
458 /// seed-screening ranking proxy: penalized deviance descends monotonically
459 /// along any inner descent path, so the per-seed minimum tells the outer
460 /// cascade "how good a fit this rho's neighbourhood can support" even
461 /// when the inner solver was capped before reaching the mode.
462 pub min_penalized_deviance: f64,
463}
464
465impl PirlsResult {
466 /// Export the stabilized transformed Hessian as an exact dense matrix for
467 /// downstream solve paths that require explicit Hessians.
468 ///
469 /// The returned matrix is the convergence Hessian already used by PIRLS and
470 /// REML (`X'W_HX + S_λ`, plus the explicit stabilization ridge when active).
471 /// Sparse-native fits are materialized from their assembled sparse Hessian;
472 /// no numerical Hessian approximation or compatibility fallback is used.
473 pub fn dense_stabilizedhessian_transformed(
474 &self,
475 context: &str,
476 ) -> Result<Array2<f64>, EstimationError> {
477 self.stabilizedhessian_transformed
478 .try_to_dense_exact(context)
479 .map_err(EstimationError::InvalidInput)
480 }
481
482 #[inline]
483 pub fn jeffreys_logdet(&self) -> Option<f64> {
484 self.firth.jeffreys_logdet()
485 }
486
487 /// Typed view of the Hessian-side working weight diagonal stored on this
488 /// result, sign-honest. `finalweights` carries the observed-information
489 /// diagonal whenever the model supports it (see `exported_laplace_curvature`),
490 /// and observed weights `W_obs = W_F - (y - μ) · B` can be negative for
491 /// non-canonical links. Consumers feeding this into the asymmetric
492 /// `X_iᵀ W X_j` path, `weighted_crossprod_dense_rows`, or
493 /// `xt_diag_x_signed_op` must use this typed view rather than borrowing
494 /// the raw `Array1<f64>` so the function-boundary type contract from
495 /// `linalg/matrix.rs` is construction-enforced.
496 #[inline]
497 pub fn final_weights_signed(&self) -> SignedWeightsView<'_> {
498 SignedWeightsView::from_array(&self.finalweights)
499 }
500
501 /// Typed view of the score-side Fisher weights `W_F = h'²/(φ V(μ)) ≥ 0`
502 /// stored on this result, PSD-by-construction. Used by PSD-Gram kernels
503 /// (`dense_xtwx_view`, `sparse_csr_weighted_xtwx_*`, `xt_diag_x_psd_op`)
504 /// without a runtime sign scan; the PSD obligation is discharged
505 /// algebraically by the Fisher formula at the construction site in
506 /// `solver/pirls/mod.rs`. New callers that need the same diagonal under
507 /// a sign-honest API should route through `as_signed()` on the returned
508 /// view rather than reconstructing from the raw array.
509 #[inline]
510 pub fn solve_weights_psd(&self) -> PsdWeightsView<'_> {
511 PsdWeightsView::from_view_unchecked(self.solveweights.view())
512 }
513
514 /// Scale-invariant relative gradient residual at the accepted PIRLS state.
515 ///
516 /// Returns ‖g‖ / (1 + ‖score‖ + ‖Sβ‖ + ridge·‖β‖). Numerator is
517 /// `lastgradient_norm`; denominator is `1 + gradient_natural_scale`.
518 /// This is the "r_g" used by seed-screening cost augmentation.
519 #[inline]
520 pub fn relative_gradient_norm(&self) -> f64 {
521 self.lastgradient_norm / (1.0 + self.gradient_natural_scale)
522 }
523
524 pub(crate) fn compact_for_reml_cache(&self) -> Self {
525 Self {
526 likelihood: self.likelihood.clone(),
527 beta_transformed: self.beta_transformed.clone(),
528 penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
529 stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
530 ridge_passport: self.ridge_passport,
531 ridge_used: self.ridge_used,
532 deviance: self.deviance,
533 edf: self.edf,
534 stable_penalty_term: self.stable_penalty_term,
535 firth: self.firth.clone(),
536 finalweights: Array1::zeros(0),
537 final_offset: Array1::zeros(0),
538 final_eta: self.final_eta.clone(),
539 finalmu: Array1::zeros(0),
540 solveweights: self.solveweights.clone(),
541 solveworking_response: self.solveworking_response.clone(),
542 solvemu: self.solvemu.clone(),
543 solve_dmu_deta: Array1::zeros(0),
544 solve_d2mu_deta2: Array1::zeros(0),
545 solve_d3mu_deta3: Array1::zeros(0),
546 solve_c_array: self.solve_c_array.clone(),
547 solve_d_array: self.solve_d_array.clone(),
548 derivatives_unsupported: self.derivatives_unsupported,
549 status: self.status,
550 iteration: self.iteration,
551 max_abs_eta: self.max_abs_eta,
552 lastgradient_norm: self.lastgradient_norm,
553 gradient_natural_scale: self.gradient_natural_scale,
554 last_deviance_change: self.last_deviance_change,
555 last_step_halving: self.last_step_halving,
556 hessian_curvature: self.hessian_curvature,
557 exported_laplace_curvature: self.exported_laplace_curvature.clone(),
558 final_lm_lambda: self.final_lm_lambda,
559 final_accept_rho: self.final_accept_rho,
560 constraint_kkt: self.constraint_kkt.clone(),
561 linear_constraints_transformed: self.linear_constraints_transformed.clone(),
562 reparam_result: self.reparam_result.clone(),
563 x_transformed: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
564 Array2::zeros((0, 0)),
565 )),
566 coordinate_frame: self.coordinate_frame,
567 used_device: self.used_device,
568 cache_compacted: true,
569 min_penalized_deviance: self.min_penalized_deviance,
570 }
571 }
572
573 pub(crate) fn rehydrate_after_reml_cache(
574 &self,
575 x_original: &DesignMatrix,
576 y: ArrayView1<'_, f64>,
577 priorweights: ArrayView1<'_, f64>,
578 offset: ArrayView1<'_, f64>,
579 inverse_link: &InverseLink,
580 ) -> Result<Self, EstimationError> {
581 if !self.cache_compacted {
582 return Ok(self.clone());
583 }
584
585 let (score_c_array, score_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
586 computeworkingweight_derivatives_from_eta(
587 &self.likelihood,
588 inverse_link,
589 &self.final_eta,
590 priorweights,
591 )?;
592 let (finalweights, solve_c_array, solve_d_array) =
593 if self.hessian_curvature == HessianCurvatureKind::Observed {
594 compute_observed_hessian_curvature_arrays(
595 &self.likelihood,
596 inverse_link,
597 &self.final_eta,
598 y,
599 &self.solveweights,
600 priorweights,
601 )?
602 } else {
603 (
604 self.solveweights.clone(),
605 score_c_array.clone(),
606 score_d_array.clone(),
607 )
608 };
609 // Lazy rehydration: wrap in ReparamOperator instead of materializing X·Qs.
610 let qs_arc = Arc::new(self.reparam_result.qs.clone());
611 Ok(Self {
612 likelihood: self.likelihood.clone(),
613 beta_transformed: self.beta_transformed.clone(),
614 penalized_hessian_transformed: self.penalized_hessian_transformed.clone(),
615 stabilizedhessian_transformed: self.stabilizedhessian_transformed.clone(),
616 ridge_passport: self.ridge_passport,
617 ridge_used: self.ridge_used,
618 used_device: self.used_device,
619 deviance: self.deviance,
620 edf: self.edf,
621 stable_penalty_term: self.stable_penalty_term,
622 firth: self.firth.clone(),
623 finalweights,
624 final_offset: offset.to_owned(),
625 final_eta: self.final_eta.clone(),
626 finalmu: self.solvemu.clone(),
627 solveweights: self.solveweights.clone(),
628 solveworking_response: self.solveworking_response.clone(),
629 solvemu: self.solvemu.clone(),
630 solve_dmu_deta,
631 solve_d2mu_deta2,
632 solve_d3mu_deta3,
633 solve_c_array,
634 solve_d_array,
635 derivatives_unsupported: self.derivatives_unsupported,
636 status: self.status,
637 iteration: self.iteration,
638 max_abs_eta: self.max_abs_eta,
639 lastgradient_norm: self.lastgradient_norm,
640 gradient_natural_scale: self.gradient_natural_scale,
641 last_deviance_change: self.last_deviance_change,
642 last_step_halving: self.last_step_halving,
643 hessian_curvature: self.hessian_curvature,
644 exported_laplace_curvature: self.exported_laplace_curvature.clone(),
645 final_lm_lambda: self.final_lm_lambda,
646 final_accept_rho: self.final_accept_rho,
647 constraint_kkt: self.constraint_kkt.clone(),
648 linear_constraints_transformed: self.linear_constraints_transformed.clone(),
649 reparam_result: self.reparam_result.clone(),
650 x_transformed: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
651 ReparamOperator::new(x_original.clone(), qs_arc),
652 ))),
653 coordinate_frame: self.coordinate_frame,
654 cache_compacted: false,
655 min_penalized_deviance: self.min_penalized_deviance,
656 })
657 }
658}