gam_solve/pirls/loop_driver.rs
1//! Outer driver for a single fixed-ρ PIRLS fit.
2//!
3//! Owns:
4//! - `fit_model_for_fixed_rho` and `fit_model_for_fixed_rho_with_adaptive_kkt`
5//! — build the working model, run the inner LM loop, assemble the final result.
6//! - `PirlsProblem`, `PenaltyConfig`, `PirlsConfig` — the configuration types.
7//! - Helper functions exclusive to the fixed-ρ fitting path: constraint
8//! transformation, sparse-native decision, reparam materialisation, prior
9//! shift assembly, initial-β guess, Gaussian short-circuit assembly, etc.
10//! - The two GPU dispatch blocks (Stage 3.3) that call into
11//! `crate::gpu::pirls_dispatch_wire`.
12
13use super::{
14 // state re-exports
15 AdaptiveKktTolerance,
16 ExportedLaplaceCurvature,
17 FirthDiagnostics,
18 GamWorkingModel,
19 HessianCurvatureKind,
20 // penalty types
21 KroneckerQsTransform,
22 LinearInequalityConstraints,
23 PirlsCoordinateFrame,
24 PirlsLinearSolvePath,
25 PirlsPenalty,
26 PirlsResult,
27 PirlsStatus,
28 PirlsWorkspace,
29 SparsePirlsDecision,
30 WorkingModelIterationInfo,
31 WorkingModelPirlsOptions,
32 WorkingModelPirlsResult,
33 WorkingReparamTransform,
34 WorkingState,
35 // misc helpers
36 array1_l2_norm,
37 attach_penalty_shift,
38 // compute functions
39 calculate_deviance,
40 // edf helpers
41 calculate_edf_with_penalty,
42 calculate_edfwithworkspace_with_penalty,
43 calculate_loglikelihood,
44 compute_constraint_kkt_diagnostics,
45 computeworkingweight_derivatives_from_eta,
46 inf_norm,
47 runworking_model_pirls,
48 should_use_sparse_native_pirls,
49 solve_penalized_least_squares_implicit,
50 standard_inverse_link_jet,
51};
52use super::{
53 ArrowSchurInnerConfig, GamModelFinalState, effective_kkt_tolerance,
54 project_coefficients_to_lower_bounds,
55};
56use crate::active_set;
57use crate::estimate::EstimationError;
58use crate::gpu::pirls_host_dispatch::{try_gaussian_pls_gpu, try_pirls_loop_gpu};
59use crate::mixture_link::inverse_link_has_fisher_weight_jet;
60use faer::sparse::{SparseColMat, Triplet};
61use gam_linalg::faer_ndarray::fast_ab;
62use gam_linalg::matrix::{DesignMatrix, LinearOperator, ReparamOperator, SymmetricMatrix};
63use gam_math::probability::standard_normal_quantile;
64use gam_problem::{
65 Coefficients, GlmLikelihoodSpec, InverseLink, LinearPredictor, LinkFunction,
66 LogSmoothingParamsView, MixtureLinkState, ResponseFamily, RidgePassport, RidgePolicy,
67 SasLinkState, StandardLink,
68};
69use gam_terms::construction::{KroneckerReparamResult, ReparamResult};
70use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
71use std::borrow::Cow;
72use std::sync::Arc;
73
74pub(super) fn default_beta_guess_external(
75 p: usize,
76 link_function: LinkFunction,
77 y: ArrayView1<f64>,
78 priorweights: ArrayView1<f64>,
79 mixture_link_state: Option<&MixtureLinkState>,
80 sas_link_state: Option<&SasLinkState>,
81) -> Array1<f64> {
82 let mut beta = Array1::<f64>::zeros(p);
83 let intercept_col = 0usize;
84 match link_function {
85 LinkFunction::Logit
86 | LinkFunction::Probit
87 | LinkFunction::CLogLog
88 | LinkFunction::LogLog
89 | LinkFunction::Cauchit
90 | LinkFunction::Sas
91 | LinkFunction::BetaLogistic => {
92 let mut weighted_sum = 0.0;
93 let mut totalweight = 0.0;
94 for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
95 weighted_sum += wi * yi;
96 totalweight += wi;
97 }
98 if totalweight > 0.0 {
99 let prevalence =
100 ((weighted_sum + 0.5) / (totalweight + 1.0)).clamp(1e-6, 1.0 - 1e-6);
101 beta[intercept_col] = match link_function {
102 LinkFunction::Logit => (prevalence / (1.0 - prevalence)).ln(),
103 LinkFunction::Probit => {
104 standard_normal_quantile(prevalence).unwrap_or_else(|_| {
105 // `prevalence` is clamped to (0, 1); this fallback is
106 // only for defensive robustness under non-finite upstream inputs.
107 (prevalence / (1.0 - prevalence)).ln()
108 })
109 }
110 LinkFunction::CLogLog => (-(1.0 - prevalence).ln()).ln(),
111 LinkFunction::LogLog => -(-prevalence.ln()).ln(),
112 LinkFunction::Cauchit => (std::f64::consts::PI * (prevalence - 0.5)).tan(),
113 LinkFunction::Sas => solve_intercept_for_prevalence(
114 link_function,
115 prevalence,
116 mixture_link_state,
117 sas_link_state,
118 )
119 .unwrap_or_else(|| {
120 standard_normal_quantile(prevalence)
121 .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
122 }),
123 LinkFunction::BetaLogistic => solve_intercept_for_prevalence(
124 link_function,
125 prevalence,
126 mixture_link_state,
127 sas_link_state,
128 )
129 .unwrap_or_else(|| {
130 standard_normal_quantile(prevalence)
131 .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
132 }),
133 // Outer arm guard already filtered out Log/Identity; fall
134 // back to the canonical logit transform for defensive safety
135 // if these are ever reached unexpectedly.
136 LinkFunction::Log | LinkFunction::Identity => {
137 (prevalence / (1.0 - prevalence)).ln()
138 }
139 };
140 if mixture_link_state.is_some() {
141 beta[intercept_col] = solve_intercept_for_prevalence(
142 link_function,
143 prevalence,
144 mixture_link_state,
145 sas_link_state,
146 )
147 .unwrap_or(beta[intercept_col]);
148 }
149 }
150 }
151 LinkFunction::Identity => {
152 let mut weighted_sum = 0.0;
153 let mut totalweight = 0.0;
154 for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
155 weighted_sum += wi * yi;
156 totalweight += wi;
157 }
158 if totalweight > 0.0 {
159 beta[intercept_col] = weighted_sum / totalweight;
160 }
161 }
162 LinkFunction::Log => {
163 // For log link, intercept = ln(weighted mean of y)
164 let mut weighted_sum = 0.0;
165 let mut totalweight = 0.0;
166 for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
167 weighted_sum += wi * yi;
168 totalweight += wi;
169 }
170 if totalweight > 0.0 {
171 let mean_y = (weighted_sum / totalweight).max(1e-10);
172 beta[intercept_col] = mean_y.ln();
173 }
174 }
175 }
176 beta
177}
178
179pub(super) fn solve_intercept_for_prevalence(
180 link_function: LinkFunction,
181 prevalence: f64,
182 mixture_link_state: Option<&MixtureLinkState>,
183 sas_link_state: Option<&SasLinkState>,
184) -> Option<f64> {
185 #[inline]
186 fn f_eta(
187 link_function: LinkFunction,
188 eta: f64,
189 prevalence: f64,
190 mixture_link_state: Option<&MixtureLinkState>,
191 sas_link_state: Option<&SasLinkState>,
192 ) -> f64 {
193 let inverse_link = if let Some(state) = mixture_link_state {
194 InverseLink::Mixture(state.clone())
195 } else if let Some(state) = sas_link_state {
196 match link_function {
197 LinkFunction::BetaLogistic => InverseLink::BetaLogistic(*state),
198 _ => InverseLink::Sas(*state),
199 }
200 } else {
201 // SAFETY: when `sas_link_state` is None, `solve_intercept_for_prevalence`
202 // is only invoked with the five legal `StandardLink` variants (the
203 // dispatch site at pirls.rs:4203 routes Sas/BetaLogistic into the
204 // Some branch above with state).
205 InverseLink::Standard(StandardLink::try_from(link_function).expect(
206 "state-bearing link reached state-less arm in solve_intercept_for_prevalence",
207 ))
208 };
209 standard_inverse_link_jet(&inverse_link, eta)
210 .map(|jet| jet.mu - prevalence)
211 .unwrap_or(f64::NAN)
212 }
213
214 let mut lo = -40.0;
215 let mut hi = 40.0;
216 let mut f_lo = f_eta(
217 link_function,
218 lo,
219 prevalence,
220 mixture_link_state,
221 sas_link_state,
222 );
223 let mut f_hi = f_eta(
224 link_function,
225 hi,
226 prevalence,
227 mixture_link_state,
228 sas_link_state,
229 );
230 if !(f_lo.is_finite() && f_hi.is_finite()) {
231 return None;
232 }
233 for _ in 0..8 {
234 if f_lo <= 0.0 && f_hi >= 0.0 {
235 break;
236 }
237 lo *= 2.0;
238 hi *= 2.0;
239 f_lo = f_eta(
240 link_function,
241 lo,
242 prevalence,
243 mixture_link_state,
244 sas_link_state,
245 );
246 f_hi = f_eta(
247 link_function,
248 hi,
249 prevalence,
250 mixture_link_state,
251 sas_link_state,
252 );
253 if !(f_lo.is_finite() && f_hi.is_finite()) {
254 return None;
255 }
256 }
257 if f_lo > 0.0 {
258 return Some(lo);
259 }
260 if f_hi < 0.0 {
261 return Some(hi);
262 }
263 for _ in 0..80 {
264 let mid = 0.5 * (lo + hi);
265 let f_mid = f_eta(
266 link_function,
267 mid,
268 prevalence,
269 mixture_link_state,
270 sas_link_state,
271 );
272 if !f_mid.is_finite() {
273 return None;
274 }
275 if f_mid > 0.0 {
276 hi = mid;
277 } else {
278 lo = mid;
279 }
280 }
281 Some(0.5 * (lo + hi))
282}
283
284pub(super) fn assemble_pirls_result(
285 working_summary: &WorkingModelPirlsResult,
286 likelihood: GlmLikelihoodSpec,
287 offset: ArrayView1<'_, f64>,
288 penalized_hessian_transformed: SymmetricMatrix,
289 stabilizedhessian_transformed: SymmetricMatrix,
290 edf: f64,
291 penalty_term: f64,
292 finalmu: &Array1<f64>,
293 finalweights: &Array1<f64>,
294 scoreweights: &Array1<f64>,
295 finalz: &Array1<f64>,
296 final_c: &Array1<f64>,
297 final_d: &Array1<f64>,
298 final_dmu_deta: &Array1<f64>,
299 final_d2mu_deta2: &Array1<f64>,
300 final_d3mu_deta3: &Array1<f64>,
301 status: PirlsStatus,
302 reparam_result: ReparamResult,
303 x_transformed: DesignMatrix,
304 coordinate_frame: PirlsCoordinateFrame,
305 linear_constraints_transformed: Option<LinearInequalityConstraints>,
306) -> PirlsResult {
307 let final_eta_arr = working_summary.state.eta.as_ref().clone();
308 PirlsResult {
309 likelihood,
310 beta_transformed: working_summary.beta.clone(),
311 penalized_hessian_transformed,
312 stabilizedhessian_transformed,
313 ridge_passport: RidgePassport::scaled_identity(
314 working_summary.state.ridge_used,
315 RidgePolicy::explicit_stabilization_full(),
316 ),
317 ridge_used: working_summary.state.ridge_used,
318 deviance: working_summary.state.deviance,
319 edf,
320 stable_penalty_term: penalty_term,
321 firth: working_summary.state.firth.clone(),
322 finalweights: finalweights.clone(),
323 final_offset: offset.to_owned(),
324 final_eta: final_eta_arr,
325 finalmu: finalmu.clone(),
326 solveweights: scoreweights.clone(),
327 solveworking_response: finalz.clone(),
328 solvemu: finalmu.clone(),
329 solve_dmu_deta: final_dmu_deta.clone(),
330 solve_d2mu_deta2: final_d2mu_deta2.clone(),
331 solve_d3mu_deta3: final_d3mu_deta3.clone(),
332 solve_c_array: final_c.clone(),
333 solve_d_array: final_d.clone(),
334 derivatives_unsupported: false,
335 status,
336 iteration: working_summary.iterations,
337 max_abs_eta: working_summary.max_abs_eta,
338 lastgradient_norm: working_summary.lastgradient_norm,
339 gradient_natural_scale: working_summary.state.gradient_natural_scale,
340 last_deviance_change: working_summary.last_deviance_change,
341 last_step_halving: working_summary.last_step_halving,
342 hessian_curvature: working_summary.state.hessian_curvature,
343 exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
344 final_lm_lambda: working_summary.final_lm_lambda,
345 final_accept_rho: working_summary.final_accept_rho,
346 constraint_kkt: working_summary.constraint_kkt.clone(),
347 linear_constraints_transformed,
348 reparam_result,
349 x_transformed,
350 coordinate_frame,
351 used_device: false,
352 cache_compacted: false,
353 min_penalized_deviance: working_summary.min_penalized_deviance,
354 }
355}
356
357pub(super) fn detect_logit_instability(
358 link: LinkFunction,
359 response: &ResponseFamily,
360 has_penalty: bool,
361 firth_active: bool,
362 summary: &WorkingModelPirlsResult,
363 finalmu: &Array1<f64>,
364 finalweights: &Array1<f64>,
365 y: ArrayView1<'_, f64>,
366) -> bool {
367 // Perfect / quasi-perfect separation is a *Bernoulli/Binomial* pathology.
368 // Every heuristic below is binary-response–specific: saturation toward
369 // μ ∈ {0, 1}, the `yᵢ > 0.5` order-separation split, and working-weight
370 // collapse only carry meaning when each `yᵢ` is a 0/1 outcome (or a
371 // proportion of Bernoulli trials). The Beta family also fits through the
372 // logit link, but its response is *continuous* on (0, 1): a perfectly
373 // healthy monotone mean (μ increasing in a covariate ⇒ rows with y > 0.5
374 // sit at higher η than rows with y ≤ 0.5) trivially satisfies the
375 // `order_separated` test, so gating this detector on the logit link alone
376 // misclassifies well-behaved Beta fits as separated and forces a spurious
377 // inner-solve retreat at every smoothing-parameter seed (issue #499).
378 // Gate strictly on the Binomial response so only binary GLMs are screened.
379 if !matches!(response, ResponseFamily::Binomial) || link != LinkFunction::Logit || firth_active
380 {
381 return false;
382 }
383
384 // Separation-detection policy thresholds. Each is a heuristic cut-off, not
385 // a math identity: they decide when a binary-logit fit has drifted into the
386 // perfect/quasi-perfect separation regime and the inner solve must retreat.
387 //
388 // `ORDER_SEPARATION_ETA_GAP`: a strictly positive η-gap between the lowest
389 // η among y=1 rows and the highest among y=0 rows means the two classes
390 // are linearly separable on the linear predictor.
391 // `EXTREME_ETA`: |η| this large drives μ to within machine-ε of {0,1}.
392 // `SATURATION_FRACTION` / `SEVERE_SATURATION_FRACTION`: share of fitted μ
393 // pinned to the {0,1} boundary that flags (severe) saturation.
394 // `DEGENERATE_DEVIANCE_PER_SAMPLE` / `EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE`:
395 // near-zero per-sample deviance means the model fits the data perfectly.
396 // `EXTREME_BETA_NORM`: coefficient norm blow-up characteristic of the MLE
397 // escaping to infinity under separation.
398 // `WEIGHT_COLLAPSE_FRACTION`: share of working weights collapsed to ~0.
399 const ORDER_SEPARATION_ETA_GAP: f64 = 1e-3;
400 const EXTREME_ETA: f64 = 30.0;
401 const SATURATION_FRACTION: f64 = 0.98;
402 const SEVERE_SATURATION_FRACTION: f64 = 0.995;
403 const DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-3;
404 const EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-6;
405 const EXTREME_BETA_NORM: f64 = 1e4;
406 const WEIGHT_COLLAPSE_FRACTION: f64 = 0.98;
407
408 let n = y.len() as f64;
409 if n == 0.0 {
410 return false;
411 }
412
413 let max_abs_eta = summary.max_abs_eta;
414 let sat_fraction = {
415 const SAT_EPS: f64 = 1e-3;
416 finalmu
417 .iter()
418 .filter(|&&m| m <= SAT_EPS || m >= 1.0 - SAT_EPS)
419 .count() as f64
420 / n
421 };
422
423 let weight_collapse_fraction = {
424 const WEIGHT_EPS: f64 = 1e-8;
425 finalweights
426 .iter()
427 .filter(|&&w| w <= WEIGHT_EPS || !w.is_finite())
428 .count() as f64
429 / n
430 };
431
432 let beta_norm = summary.beta.as_ref().dot(summary.beta.as_ref()).sqrt();
433 let dev_per_sample = summary.state.deviance / n;
434
435 let mut has_pos = false;
436 let mut has_neg = false;
437 let mut min_eta_pos = f64::INFINITY;
438 let mut max_eta_neg = f64::NEG_INFINITY;
439 for (eta_i, &yi) in summary.state.eta.iter().zip(y.iter()) {
440 if yi > 0.5 {
441 has_pos = true;
442 if *eta_i < min_eta_pos {
443 min_eta_pos = *eta_i;
444 }
445 } else {
446 has_neg = true;
447 if *eta_i > max_eta_neg {
448 max_eta_neg = *eta_i;
449 }
450 }
451 }
452 let order_separated =
453 has_pos && has_neg && (min_eta_pos - max_eta_neg) > ORDER_SEPARATION_ETA_GAP;
454
455 let classic_signals = max_abs_eta > EXTREME_ETA
456 || sat_fraction > SATURATION_FRACTION
457 || dev_per_sample < DEGENERATE_DEVIANCE_PER_SAMPLE
458 || beta_norm > EXTREME_BETA_NORM;
459
460 if !has_penalty {
461 return classic_signals || order_separated;
462 }
463
464 let severe_saturation = sat_fraction > SEVERE_SATURATION_FRACTION && max_abs_eta > EXTREME_ETA;
465 let weights_collapsed = weight_collapse_fraction > WEIGHT_COLLAPSE_FRACTION;
466 let dev_extremely_small = dev_per_sample < EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE;
467
468 order_separated || severe_saturation || weights_collapsed || dev_extremely_small
469}
470
471/// Stack λ-weighted penalty roots from canonical penalties into a single
472/// `total_rank × p` matrix for PIRLS. Each block-local root is embedded
473/// into the full column space on-the-fly.
474pub(super) fn stack_lambdaweighted_penalty_root_canonical(
475 penalties: &[gam_terms::construction::CanonicalPenalty],
476 lambdas: &[f64],
477 p: usize,
478) -> Array2<f64> {
479 let totalrows: usize = penalties.iter().map(|cp| cp.rank()).sum();
480 if totalrows == 0 {
481 return Array2::zeros((0, p));
482 }
483 let mut e = Array2::<f64>::zeros((totalrows, p));
484 let mut row_start = 0usize;
485 for (k, cp) in penalties.iter().enumerate() {
486 let rows = cp.rank();
487 if rows == 0 {
488 continue;
489 }
490 let scale = lambdas.get(k).copied().unwrap_or(0.0).max(0.0).sqrt();
491 if scale != 0.0 {
492 // Embed block-local root (rank × block_dim) into full width (rank × p).
493 let r = &cp.col_range;
494 for row in 0..rows {
495 for col in 0..cp.block_dim() {
496 e[[row_start + row, r.start + col]] = scale * cp.root[[row, col]];
497 }
498 }
499 }
500 row_start += rows;
501 }
502 e
503}
504
505pub(super) fn build_sparse_native_reparam_result(
506 base: ReparamResult,
507 penalties: &[gam_terms::construction::CanonicalPenalty],
508 lambdas: &[f64],
509 p: usize,
510) -> ReparamResult {
511 // Map the engine penalty back into identity (original) coordinates. The
512 // engine returns `s_transformed = Qsᵀ S Qs` (and `e_transformed = E Qs`)
513 // with `S = S_λ + shrinkage·P_range` already folded in (so it matches the
514 // reported `log_det`/`det1`). With the sparse-native `qs = I` we need that
515 // SAME penalty expressed in original coordinates: `S_orig = Qs S_transformed
516 // Qsᵀ`. Rebuilding `S_orig` from the bare lambda-weighted canonical sum
517 // would DROP the shrinkage ridge and desync the inner penalized Hessian from
518 // the penalty log-determinant the REML criterion uses for this fit — the
519 // cross-backend λ-selection divergence (#1266 class). Round-tripping the
520 // engine penalty through `Qs` keeps the inner solve, EDF, and REML logdet on
521 // one penalty.
522 let qs = &base.qs;
523 let s_orig = if qs.nrows() == p && qs.ncols() == base.s_transformed.nrows() {
524 // S_orig = Qs · S_transformed · Qsᵀ
525 let qs_s = fast_ab(qs, &base.s_transformed);
526 qs_s.dot(&qs.t())
527 } else {
528 // Degenerate fallback (engine produced no transform): use the bare
529 // lambda-weighted sum. Shrinkage is zero in this branch by construction.
530 let mut s_original = Array2::<f64>::zeros((p, p));
531 for (k, cp) in penalties.iter().enumerate() {
532 let lambda_k = lambdas.get(k).copied().unwrap_or(0.0);
533 if lambda_k != 0.0 {
534 cp.accumulate_weighted(&mut s_original, lambda_k);
535 }
536 }
537 s_original
538 };
539 // E_orig = E_transformed · Qsᵀ (so that E_origᵀ E_orig = S_orig and the EDF
540 // augmented system matches the inner Hessian).
541 let e_orig = if qs.nrows() == p && base.e_transformed.ncols() == qs.ncols() {
542 base.e_transformed.dot(&qs.t())
543 } else {
544 stack_lambdaweighted_penalty_root_canonical(penalties, lambdas, p)
545 };
546 let u_original = if base.u_truncated.nrows() == p {
547 fast_ab(&base.qs, &base.u_truncated)
548 } else {
549 Array2::<f64>::eye(p)
550 };
551 // In the sparse-native path, qs = I, so the penalties are already in the
552 // right coordinate frame. We keep them as-is in canonical_transformed.
553 let canonical_transformed: Vec<gam_terms::construction::CanonicalPenalty> = penalties.to_vec();
554 ReparamResult {
555 penalty_shrinkage_ridge: base.penalty_shrinkage_ridge,
556 s_transformed: s_orig,
557 log_det: base.log_det,
558 det1: base.det1,
559 qs: Array2::<f64>::eye(p),
560 canonical_transformed,
561 e_transformed: e_orig,
562 u_truncated: u_original,
563 }
564}
565
566pub(super) fn build_diagonal_penalty_from_kronecker(
567 kron_result: &KroneckerReparamResult,
568 lambdas: &[f64],
569) -> PirlsPenalty {
570 let d = kron_result.marginal_dims.len();
571 let p: usize = kron_result.marginal_dims.iter().copied().product();
572 let mut diag = Array1::<f64>::zeros(p);
573 let mut positive_indices = Vec::new();
574
575 const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
576 let mut multi_idx = vec![0usize; d];
577 let mut flat = 0usize;
578 loop {
579 let mut sigma = 0.0;
580 let mut structural_sigma = 0.0;
581 for k in 0..d {
582 let marginal_eigenvalue = kron_result.marginal_eigenvalues[k][multi_idx[k]];
583 structural_sigma += marginal_eigenvalue;
584 sigma += lambdas[k] * marginal_eigenvalue;
585 }
586 let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
587 if kron_result.has_double_penalty && lambdas.len() > d && joint_null {
588 sigma += lambdas[d];
589 }
590 if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
591 sigma += kron_result.penalty_shrinkage_ridge;
592 }
593 diag[flat] = sigma;
594 if sigma > 0.0 {
595 positive_indices.push(flat);
596 }
597 flat += 1;
598
599 let mut carry = true;
600 for dim in (0..d).rev() {
601 if carry {
602 multi_idx[dim] += 1;
603 if multi_idx[dim] < kron_result.marginal_dims[dim] {
604 carry = false;
605 } else {
606 multi_idx[dim] = 0;
607 }
608 }
609 }
610 if carry {
611 break;
612 }
613 }
614
615 PirlsPenalty::Diagonal {
616 diag,
617 positive_indices,
618 linear_shift: Array1::zeros(p),
619 constant_shift: 0.0,
620 prior_mean_target: Array1::zeros(p),
621 }
622}
623
624pub(super) fn canonical_prior_shift(
625 penalties: &[gam_terms::construction::CanonicalPenalty],
626 lambdas: &[f64],
627 p: usize,
628) -> (Array1<f64>, f64) {
629 let mut linear = Array1::<f64>::zeros(p);
630 let mut constant = 0.0;
631 for (idx, cp) in penalties.iter().enumerate() {
632 let Some(&lambda) = lambdas.get(idx) else {
633 continue;
634 };
635 if lambda == 0.0 {
636 continue;
637 }
638 linear += &cp.prior_linear_shift(lambda);
639 constant += cp.prior_constant_shift(lambda);
640 }
641 (linear, constant)
642}
643
644/// Aggregate prior-mean target across canonical penalty blocks: the sum of
645/// each block's `full_width_prior_mean()`. Used by the PIRLS solve sites
646/// that add a fixed stabilization ridge `δI` to the penalized Hessian — they
647/// must also add `δ · prior_mean_target` to the RHS to keep `β = μ` recovery
648/// exact when the data carries no information (X'WX = 0). Equivalent to
649/// `canonical_prior_shift` with all λ = 1 and dropping `S_k` from the linear
650/// piece (i.e., raw μ rather than `S_k μ`). Returned in the *original*
651/// coordinates; callers transform if needed.
652pub(super) fn canonical_prior_mean_aggregate(
653 penalties: &[gam_terms::construction::CanonicalPenalty],
654 p: usize,
655) -> Array1<f64> {
656 let mut mean = Array1::<f64>::zeros(p);
657 for cp in penalties {
658 mean += &cp.full_width_prior_mean();
659 }
660 mean
661}
662
663pub struct PirlsProblem<'a, X> {
664 pub x: X,
665 pub offset: ArrayView1<'a, f64>,
666 pub y: ArrayView1<'a, f64>,
667 pub priorweights: ArrayView1<'a, f64>,
668 pub covariate_se: Option<ArrayView1<'a, f64>>,
669 /// When set, the inner PLS solver reuses the precomputed `XᵀWX` and
670 /// `XᵀW(y − offset)` in *original* coordinates instead of streaming the
671 /// O(N·p²) GEMM and the O(N·p) matvec on every outer REML iteration.
672 ///
673 /// Valid only when the family is Gaussian + Identity link, prior weights
674 /// are constant across outer iterations (always true in the REML outer
675 /// loop), no Firth bias reduction, and no inequality / lower-bound
676 /// constraints (matching the existing Identity short-circuit at
677 /// `pirls.rs:6237`). The penalty `λ·S` is still added per-λ on top of
678 /// the cached `XᵀWX`.
679 pub gaussian_fixed_cache: Option<&'a GaussianFixedCache>,
680 /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for a GLM
681 /// design-moving ψ-trial (#1111 / #1033 mechanism (c)), in *original*
682 /// (conditioned `x_fit`) coordinates. When set, the iterative GLM P-IRLS
683 /// serves its FIRST Fisher-scoring iteration's `XᵀWX` from this matrix
684 /// instead of streaming the O(N·p²) weighted cross-product; every later
685 /// iteration restreams the true moving `W`, so the converged β̂ is
686 /// unchanged. Mutually distinct from `gaussian_fixed_cache` (which is the
687 /// Gaussian-identity converged-objective short-circuit); this is the GLM
688 /// first-step lane and never short-circuits the iteration count.
689 pub glm_first_step_gram: Option<&'a Array2<f64>>,
690}
691
692// GaussianFixedCache is defined in pls_solver.
693pub use super::pls_solver::GaussianFixedCache;
694
695pub struct PenaltyConfig<'a> {
696 /// Block-local canonical penalties with precomputed roots and spectral data.
697 /// This is the single canonical penalty representation — no full-width
698 /// `rank × p` roots are stored. When the reparameterization engine needs
699 /// full-width roots, they are derived on-the-fly from these block-local roots.
700 pub canonical_penalties: &'a [gam_terms::construction::CanonicalPenalty],
701 pub balanced_penalty_root: Option<&'a Array2<f64>>,
702 pub reparam_invariant: Option<&'a gam_terms::construction::ReparamInvariant>,
703 pub p: usize,
704 pub coefficient_lower_bounds: Option<&'a Array1<f64>>,
705 pub linear_constraints_original: Option<&'a LinearInequalityConstraints>,
706 /// Relative shrinkage floor for eigenvalues of the penalized block.
707 /// If `Some(epsilon)`, a rho-independent ridge of `epsilon * max_balanced_eigenvalue`
708 /// is added to prevent barely-penalized directions from causing pathological
709 /// non-Gaussianity in the posterior. Typical value: `1e-6`. `None` disables.
710 pub penalty_shrinkage_floor: Option<f64>,
711 /// When set, the penalties have Kronecker (tensor-product) structure.
712 /// The reparameterization engine will use factored Qs = U_1 ⊗ ... ⊗ U_d
713 /// instead of eigendecomposing the full p×p balanced penalty.
714 pub kronecker_factored: Option<&'a gam_terms::basis::KroneckerFactoredBasis>,
715}
716
717/// P-IRLS solver that follows mgcv's architecture exactly
718///
719/// This function implements the complete algorithm from mgcv's gam.fit3 function
720/// for fitting a GAM model with a fixed set of smoothing parameters:
721///
722/// - Perform stable reparameterization ONCE at the beginning (mgcv's gam.reparam)
723/// - Transform the design matrix into this stable basis
724/// - Extract a single penalty square root from the transformed penalty
725/// - Run the P-IRLS loop entirely in the transformed basis
726/// - Transform the coefficients back to the original basis only when returning
727/// - Reuse a cached balanced penalty root when available to avoid repeated eigendecompositions
728///
729/// This architecture ensures optimal numerical stability throughout the entire
730/// fitting process by working in a well-conditioned parameter space.
731pub fn fit_model_for_fixed_rho<'a, X: Into<DesignMatrix> + Clone>(
732 rho: LogSmoothingParamsView<'_>,
733 problem: PirlsProblem<'a, X>,
734 penalty: PenaltyConfig<'_>,
735 config: &PirlsConfig,
736 warm_start_beta: Option<&Coefficients>,
737) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
738 fit_model_for_fixed_rho_with_adaptive_kkt(
739 rho,
740 problem,
741 penalty,
742 config,
743 warm_start_beta,
744 None,
745 false,
746 )
747}
748
749/// `refine_dispersion_at_converged_eta`: when `true`, after the inner P-IRLS
750/// solve converges, re-estimate the family's estimated dispersion nuisance — the
751/// Gamma shape ν = 1/φ or the Beta precision φ — at the *converged* linear
752/// predictor and iterate the (β, dispersion) pair to its joint fixed point at the
753/// current λ (see the in-body comments at each refresh loop). This is ON only for
754/// the single final, reported fit at the REML-selected λ (#678 for Gamma, #769
755/// for Beta). It is deliberately OFF for every REML cost / sigma-point evaluation:
756/// re-profiling the dispersion against each trial λ's converged residuals would
757/// couple the scale to the smoothing parameter (a flat over-smoothed μ inflates
758/// the deviance ⇒ a smaller effective precision ⇒ a smaller `deviance/(2φ)` REML
759/// term), perversely rewarding over-smoothing and biasing λ selection. mgcv
760/// likewise estimates the scale at the converged fit, not inside the λ search.
761///
762/// The Gamma and Beta cases differ in what the re-solve buys. For Gamma the shape
763/// is a pure nuisance — β̂ is essentially scale-free — so the re-solve only keeps
764/// the reported dispersion and SEs self-consistent. For Beta the precision φ
765/// enters the *mean* score through the digamma terms
766/// `μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ)`, so a φ measured at the cold null predictor
767/// (μ ≈ 0.5) attenuates every slope toward zero; here the fixed point is
768/// load-bearing — it is what recovers the correct mean coefficients (the betareg
769/// alternating mean-fit ↔ φ-estimate scheme).
770pub(crate) fn fit_model_for_fixed_rho_with_adaptive_kkt<'a, X: Into<DesignMatrix> + Clone>(
771 rho: LogSmoothingParamsView<'_>,
772 problem: PirlsProblem<'a, X>,
773 penalty: PenaltyConfig<'_>,
774 config: &PirlsConfig,
775 warm_start_beta: Option<&Coefficients>,
776 adaptive_kkt_tolerance: Option<AdaptiveKktTolerance>,
777 refine_dispersion_at_converged_eta: bool,
778) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
779 let PirlsProblem {
780 x,
781 offset,
782 y,
783 priorweights,
784 covariate_se,
785 gaussian_fixed_cache,
786 glm_first_step_gram,
787 } = problem;
788 let quadctx = crate::quadrature::QuadratureContext::new();
789 // gam#1379 — finite-ceiling λ = exp(ρ). When the outer REML / spatial-κ
790 // optimizer drives a redundant penalty direction's log-λ past ~709 (it does
791 // so deterministically on 1-D `matern(x)` / `bs="gp"` data whose kernel
792 // already controls the smoothness an operator block also penalizes, so REML
793 // wants λ → ∞), `exp(ρ)` overflows to `+∞`. A literal `+∞` λ then poisons
794 // every downstream consumer that forms `λ · S`: the range-penalty block
795 // assembled as `Σ λ_k S_k` hits `∞ · 0 = NaN` and the eigensolve aborts, and
796 // the final fit-result validation rejects the non-finite stored λ outright.
797 // `exp(709.78) ≈ 1.8e308` is already the largest finite f64; capping log-λ at
798 // a value whose `exp` stays finite pins the over-penalized direction exactly
799 // as hard as `+∞` would for every finite-arithmetic consumer (the penalized
800 // block is numerically a hard constraint at λ this large) while keeping
801 // `λ · 0 = 0`. Ordinary finite λ are untouched, so non-degenerate fits and
802 // their recorded λ̂ are bit-identical. `ln(1e300) ≈ 690.78` keeps this in lock
803 // step with the post-exp λ ceiling (`1e300`) used by the reparam range-block
804 // assembly and the stored fit result, so a fully-smoothed direction carries
805 // the SAME finite λ everywhere it is consumed.
806 const LOG_LAMBDA_CEILING: f64 = 690.0;
807 let lambdas = rho.mapv(|r| {
808 if r.is_nan() {
809 r
810 } else {
811 r.min(LOG_LAMBDA_CEILING).exp()
812 }
813 });
814 let lambdas_slice = lambdas.as_slice_memory_order().ok_or_else(|| {
815 EstimationError::InvalidInput("non-contiguous lambda storage".to_string())
816 })?;
817
818 let likelihood = &config.likelihood;
819 let link_function = config.link_function();
820
821 use gam_terms::construction::{
822 EngineDims, create_balanced_penalty_root_from_canonical,
823 stable_reparameterization_engine_canonical,
824 };
825
826 let eb_cow: Cow<'_, Array2<f64>> = if let Some(precomputed) = penalty.balanced_penalty_root {
827 Cow::Borrowed(precomputed)
828 } else {
829 Cow::Owned(create_balanced_penalty_root_from_canonical(
830 penalty.canonical_penalties,
831 penalty.p,
832 )?)
833 };
834 let eb: &Array2<f64> = eb_cow.as_ref();
835
836 // Build a cheap weighted penalty sum for the sparse-native decision
837 // WITHOUT running the expensive eigendecomposition engine.
838 // The full reparameterization is deferred until we know which path we need.
839 let cheap_s_lambda: Option<Array2<f64>> = if penalty.kronecker_factored.is_none() {
840 let mut s = Array2::<f64>::zeros((penalty.p, penalty.p));
841 for (k, cp) in penalty.canonical_penalties.iter().enumerate() {
842 let lam = lambdas_slice.get(k).copied().unwrap_or(0.0);
843 if lam != 0.0 {
844 cp.accumulate_weighted(&mut s, lam);
845 }
846 }
847 Some(s)
848 } else {
849 None
850 };
851 let kronecker_runtime = if let Some(kron) = penalty.kronecker_factored {
852 // The marginal eigensystems and reparameterized marginals depend only on
853 // the fixed marginal designs/penalties, not on λ = exp(ρ). Memoize them
854 // once per fit so each outer REML iterate reuses the eigendecomposition
855 // instead of recomputing `eigh()` + `B_k·U_k` every call; only the cheap
856 // λ-grid logdet/derivative sweep is redone here. Bit-identical to the
857 // unmemoized engine.
858 let invariant = kron.invariant_structure()?;
859 let kron_result =
860 gam_terms::construction::kronecker_reparameterization_engine_with_invariant(
861 invariant.as_ref(),
862 &kron.marginal_dims,
863 lambdas_slice,
864 kron.has_double_penalty,
865 penalty.penalty_shrinkage_floor,
866 )?;
867 let transform = Arc::new(KroneckerQsTransform::new(&kron_result));
868 let penalty_diag = build_diagonal_penalty_from_kronecker(&kron_result, lambdas_slice);
869 Some((kron_result, transform, penalty_diag))
870 } else {
871 None
872 };
873 // Constraint transformation is deferred until after the sparse-native
874 // decision, because the dense reparameterization engine (which provides Qs)
875 // is now run lazily. Kronecker constraints can be built eagerly since
876 // the Kronecker transform is already available.
877 let kronecker_constraints = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
878 let tb = build_transformed_lower_bound_constraints_with_transform(
879 &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
880 penalty.coefficient_lower_bounds,
881 );
882 let tl = build_transformed_linear_constraints_with_transform(
883 &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
884 penalty.linear_constraints_original,
885 );
886 Some(merge_linear_constraints(tb, tl))
887 } else {
888 None
889 };
890
891 let x_original: DesignMatrix = x.into();
892 // Auto-detect sparse structure in dense designs so the sparse-native path
893 // can engage for structurally sparse models that happen to be stored dense.
894 let x_original = {
895 let auto_sparse = x_original
896 .as_dense()
897 .and_then(|dense| sparse_from_denseview(dense.view()));
898 auto_sparse.unwrap_or(x_original)
899 };
900 let ebrows = eb.nrows();
901 let erows = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
902 penalty_diag.rank()
903 } else {
904 // Compute penalty root rank cheaply from canonical penalties.
905 penalty
906 .canonical_penalties
907 .iter()
908 .map(|cp| cp.rank())
909 .sum::<usize>()
910 };
911 let mut workspace = PirlsWorkspace::new(x_original.nrows(), x_original.ncols(), ebrows, erows);
912 let solver_decision = if let Some((_, _, _)) = kronecker_runtime.as_ref() {
913 SparsePirlsDecision {
914 path: PirlsLinearSolvePath::DenseTransformed,
915 reason: "kronecker_runtime",
916 p: x_original.ncols(),
917 nnz_x: 0,
918 nnz_xtwx_symbolic: None,
919 nnz_s_lambda: 0,
920 nnz_h_est: None,
921 density_h_est: None,
922 }
923 } else {
924 should_use_sparse_native_pirls(
925 &mut workspace,
926 &x_original,
927 cheap_s_lambda
928 .as_ref()
929 .expect("cheap_s_lambda should be present outside Kronecker path"),
930 penalty.coefficient_lower_bounds,
931 penalty.linear_constraints_original,
932 )
933 };
934 solver_decision.log_once();
935
936 let use_sparse_native = matches!(solver_decision.path, PirlsLinearSolvePath::SparseNative);
937
938 // Run the eigendecomposition engine for the dense-transformed path. The
939 // sparse-native path also needs it, but only to obtain a penalty that is
940 // *consistent with the REML penalty log-determinant it reports* — see the
941 // sparse-native `reparam` below. The dense path keeps `qs ≠ I`; the
942 // sparse-native path discards `qs` (identity coords) and reuses only the
943 // shrinkage-folded `s_transformed`/`e_transformed`.
944 let dense_reparam_result = if !use_sparse_native && penalty.kronecker_factored.is_none() {
945 Some(stable_reparameterization_engine_canonical(
946 penalty.canonical_penalties,
947 lambdas_slice,
948 EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
949 penalty.reparam_invariant,
950 penalty.penalty_shrinkage_floor,
951 )?)
952 } else {
953 None
954 };
955 // Sparse-native reparam result, in identity (original) coordinates with the
956 // penalty shrinkage floor folded in. This MUST drive the inner penalized
957 // solve too: when `penalty_shrinkage_floor` is active (default `Some(1e-6)`)
958 // the dense engine adds `shrinkage·P_range` to every penalized range
959 // direction of `S_λ` and rebuilds `s_transformed = EᵀE` from the floored
960 // roots, so `base.log_det` (the REML penalty pseudo-logdet) is the
961 // determinant of `S_λ + shrinkage·P_range`, NOT of the bare `S_λ`. Building
962 // the inner Hessian from an UN-shrunk `S_λ` (the previous behaviour, via the
963 // `cheap_s_lambda` row-sum) while reporting the shrunk `log_det` made the
964 // sparse-native REML surface internally inconsistent — the penalty-logdet
965 // term and the inner H / EDF / β̂ lived on different penalties — which biased
966 // λ-selection relative to the dense and Kronecker backends for the SAME
967 // model (the #1266 cross-backend divergence class). Reusing the engine's
968 // shrinkage-folded penalty here makes all three backends solve the same
969 // penalized objective.
970 let sparse_native_reparam = if use_sparse_native && penalty.kronecker_factored.is_none() {
971 let base = stable_reparameterization_engine_canonical(
972 penalty.canonical_penalties,
973 lambdas_slice,
974 EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
975 penalty.reparam_invariant,
976 penalty.penalty_shrinkage_floor,
977 )?;
978 Some(build_sparse_native_reparam_result(
979 base,
980 penalty.canonical_penalties,
981 lambdas_slice,
982 penalty.p,
983 ))
984 } else {
985 None
986 };
987 let qs_arc = dense_reparam_result
988 .as_ref()
989 .map(|reparam_result| Arc::new(reparam_result.qs.clone()));
990 let transform_active = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
991 Some(WorkingReparamTransform::Kronecker(Arc::clone(transform)))
992 } else if use_sparse_native {
993 None
994 } else {
995 Some(WorkingReparamTransform::Dense(Arc::clone(
996 qs_arc
997 .as_ref()
998 .expect("dense Qs should exist for non-Kronecker transformed path"),
999 )))
1000 };
1001 let mut penalty_active = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
1002 penalty_diag.clone()
1003 } else if use_sparse_native {
1004 // Sparse-native inner penalty in original (identity) coordinates. Use
1005 // the shrinkage-folded `s_transformed`/`e_transformed` from
1006 // `sparse_native_reparam` so the inner penalized Hessian
1007 // `H = XᵀWX + S` matches the penalty whose log-determinant the REML
1008 // criterion reports for this fit (`base.log_det`). Falling back to the
1009 // bare lambda-weighted sum here (the prior behaviour) omitted the
1010 // `penalty_shrinkage_floor` ridge and desynced the inner solve from the
1011 // REML logdet, biasing λ-selection vs the dense/Kronecker backends.
1012 let sparse_reparam = sparse_native_reparam
1013 .as_ref()
1014 .expect("sparse_native_reparam should be present for sparse-native path");
1015 PirlsPenalty::Dense {
1016 s_transformed: sparse_reparam.s_transformed.clone(),
1017 e_transformed: sparse_reparam.e_transformed.clone(),
1018 linear_shift: Array1::zeros(penalty.p),
1019 constant_shift: 0.0,
1020 prior_mean_target: Array1::zeros(penalty.p),
1021 }
1022 } else {
1023 let dense = dense_reparam_result
1024 .as_ref()
1025 .expect("dense reparam result should be present outside Kronecker path");
1026 PirlsPenalty::Dense {
1027 s_transformed: dense.s_transformed.clone(),
1028 e_transformed: dense.e_transformed.clone(),
1029 linear_shift: Array1::zeros(penalty.p),
1030 constant_shift: 0.0,
1031 prior_mean_target: Array1::zeros(penalty.p),
1032 }
1033 };
1034 let (shift_original, shift_constant) =
1035 canonical_prior_shift(penalty.canonical_penalties, lambdas_slice, penalty.p);
1036 let shift_active = transform_active
1037 .as_ref()
1038 .map(|transform| transform.apply_transpose(&shift_original))
1039 .unwrap_or(shift_original);
1040 let prior_mean_original =
1041 canonical_prior_mean_aggregate(penalty.canonical_penalties, penalty.p);
1042 let prior_mean_active = transform_active
1043 .as_ref()
1044 .map(|transform| transform.apply_transpose(&prior_mean_original))
1045 .unwrap_or(prior_mean_original);
1046 attach_penalty_shift(
1047 &mut penalty_active,
1048 shift_active,
1049 shift_constant,
1050 prior_mean_active,
1051 );
1052 // Build transformed constraints now that dense_reparam_result is available.
1053 let linear_constraints = if let Some(kc) = kronecker_constraints {
1054 kc
1055 } else if let Some(reparam) = dense_reparam_result.as_ref() {
1056 let tb = build_transformed_lower_bound_constraints(
1057 &reparam.qs,
1058 penalty.coefficient_lower_bounds,
1059 );
1060 let tl =
1061 build_transformed_linear_constraints(&reparam.qs, penalty.linear_constraints_original);
1062 merge_linear_constraints(tb, tl)
1063 } else {
1064 // Sparse-native without dense reparam: constraints stay in original
1065 // coordinates (identity Qs). Use an identity matrix of appropriate size.
1066 let p = penalty.p;
1067 let qs_identity = Array2::<f64>::eye(p);
1068 let tb = build_transformed_lower_bound_constraints(
1069 &qs_identity,
1070 penalty.coefficient_lower_bounds,
1071 );
1072 let tl =
1073 build_transformed_linear_constraints(&qs_identity, penalty.linear_constraints_original);
1074 merge_linear_constraints(tb, tl)
1075 };
1076
1077 let coordinate_frame = if use_sparse_native {
1078 PirlsCoordinateFrame::OriginalSparseNative
1079 } else {
1080 PirlsCoordinateFrame::TransformedQs
1081 };
1082 let materialize_final_reparam_result = || -> Result<ReparamResult, EstimationError> {
1083 if let Some((kron_result, _, _)) = kronecker_runtime.as_ref() {
1084 let rs_list: Vec<Array2<f64>> = penalty
1085 .canonical_penalties
1086 .iter()
1087 .map(|cp| cp.full_width_root())
1088 .collect();
1089 kron_result.materialize_dense_artifact_result(&rs_list, lambdas_slice, penalty.p)
1090 } else if use_sparse_native {
1091 // Sparse-native path: reuse the engine result already computed for
1092 // `penalty_active` (with the shrinkage floor folded in and mapped to
1093 // identity coordinates). This is both correct — the REML
1094 // log-determinant now matches the penalty the inner solve used — and
1095 // cheaper, since the eigendecomposition is no longer run twice.
1096 Ok(sparse_native_reparam
1097 .as_ref()
1098 .expect("sparse_native_reparam should be present for sparse-native path")
1099 .clone())
1100 } else {
1101 Ok(dense_reparam_result
1102 .as_ref()
1103 .expect("dense reparam result should be present outside Kronecker path")
1104 .clone())
1105 }
1106 };
1107
1108 // Stage 3.3-GI: GPU exact PLS dispatch — see pirls_host_dispatch::try_gaussian_pls_gpu.
1109 if let Some(result) = try_gaussian_pls_gpu(
1110 link_function,
1111 config,
1112 penalty.coefficient_lower_bounds,
1113 penalty.linear_constraints_original,
1114 gaussian_fixed_cache,
1115 &penalty_active,
1116 &qs_arc,
1117 &x_original,
1118 use_sparse_native,
1119 penalty.p,
1120 || materialize_final_reparam_result(),
1121 y,
1122 priorweights,
1123 offset,
1124 coordinate_frame,
1125 &linear_constraints,
1126 ) {
1127 return result;
1128 }
1129
1130 if matches!(link_function, LinkFunction::Identity) && linear_constraints.is_none() {
1131 // Gaussian-Identity zero-iteration exact solve. The unconstrained
1132 // penalized least-squares system is linear, so for an identity link a
1133 // single solve is the exact minimizer and no PIRLS iteration is needed.
1134 //
1135 // This shortcut is only valid in the *unconstrained* convex program.
1136 // When shape/box/linear inequality constraints are present (e.g. a
1137 // `shape=monotone_increasing` smooth, whose cumulative-sum box-reparam
1138 // bounds `γ_j ≥ 0` are folded into `linear_constraints` above), the
1139 // minimizer is the solution of an inequality-constrained QP, not the
1140 // plain normal-equations solve. Taking this branch then returns the
1141 // unconstrained β, which generically violates the constraints and is
1142 // rejected by the REML startup KKT gate (`enforce_constraint_kkt`),
1143 // aborting the whole fit. Gating on `linear_constraints.is_none()`
1144 // routes every constrained Identity fit to the iterative loop below,
1145 // which builds a feasible initial point and solves the exact QP via
1146 // the active-set solver — mirroring the gate already enforced on the
1147 // GPU Gaussian-PLS path in `try_gaussian_pls_gpu`.
1148 //
1149 // Apply the Gaussian-Identity fixed-data cache only when every
1150 // precondition for the short-circuit's exact reuse holds: the family
1151 // really is Gaussian (z = y), there is no Firth bias-reduction term,
1152 // no coefficient lower bounds, and no linear inequality constraints
1153 // — anything that would change the right-hand side or the system
1154 // beyond the additive penalty would invalidate the cache.
1155 let cache_eligible = gaussian_fixed_cache.is_some()
1156 && likelihood.spec.is_gaussian_identity()
1157 && !config.firth_bias_reduction
1158 && penalty.coefficient_lower_bounds.is_none()
1159 && penalty.linear_constraints_original.is_none();
1160 let cache_for_solve = if cache_eligible {
1161 gaussian_fixed_cache
1162 } else {
1163 None
1164 };
1165 let (pls_result, _) = solve_penalized_least_squares_implicit(
1166 &x_original,
1167 transform_active.as_ref(),
1168 y,
1169 priorweights,
1170 offset,
1171 &penalty_active,
1172 &mut workspace,
1173 y,
1174 link_function,
1175 cache_for_solve,
1176 )?;
1177
1178 let beta_transformed = pls_result.beta;
1179 let penalized_hessian = pls_result.penalized_hessian;
1180 let edf = pls_result.edf;
1181 let baseridge = pls_result.ridge_used;
1182
1183 let priorweights_owned = priorweights.to_owned();
1184 // eta = offset + X Qs beta (composed, no materialization) unless a
1185 // design-moving ψ tensor cache explicitly says the surface rows are a
1186 // stale reference. In that lane the Gaussian objective and gradient are
1187 // fully determined by (G, r, y'Wy), so applying `x_original` would both
1188 // reintroduce per-trial row work and evaluate the wrong ψ.
1189 let qbeta = transform_active
1190 .as_ref()
1191 .map(|transform| transform.apply(beta_transformed.as_ref()))
1192 .unwrap_or_else(|| beta_transformed.as_ref().clone());
1193 let stale_row_cache = cache_for_solve.filter(|cache| cache.row_prediction_is_stale);
1194 let (final_eta, finalmu, finalz, gradient_data, deviance, log_likelihood, max_abs_eta) =
1195 if let Some(cache) = stale_row_cache {
1196 let final_eta = offset.to_owned();
1197 let finalmu = final_eta.clone();
1198 let finalz = y.to_owned();
1199 let mut grad_orig = cache.xtwx_orig.dot(&qbeta);
1200 grad_orig -= &cache.xtwy_orig;
1201 let gradient_data = transform_active
1202 .as_ref()
1203 .map(|transform| transform.apply_transpose(&grad_orig))
1204 .unwrap_or(grad_orig);
1205 let weighted_rss = (cache.centered_weighted_y_sq
1206 - 2.0 * qbeta.dot(&cache.xtwy_orig)
1207 + qbeta.dot(&cache.xtwx_orig.dot(&qbeta)))
1208 .max(0.0);
1209 let phi = likelihood.scale.fixed_phi().unwrap_or(1.0);
1210 let deviance = if phi.is_finite() && phi > 0.0 {
1211 weighted_rss / phi
1212 } else {
1213 f64::NAN
1214 };
1215 let log_likelihood = calculate_loglikelihood(y, &finalmu, likelihood, priorweights);
1216 let max_abs_eta = inf_norm(finalmu.iter().copied());
1217 (
1218 final_eta,
1219 finalmu,
1220 finalz,
1221 gradient_data,
1222 deviance,
1223 log_likelihood,
1224 max_abs_eta,
1225 )
1226 } else {
1227 let mut eta = offset.to_owned();
1228 eta += &x_original.apply(&qbeta);
1229 let final_eta = eta.clone();
1230 let finalmu = eta.clone();
1231 let finalz = y.to_owned();
1232
1233 let mut weighted_residual = finalmu.clone();
1234 weighted_residual -= &finalz;
1235 weighted_residual *= &priorweights_owned;
1236 // gradient = Qs^T X^T (w * residual) (composed)
1237 let xt_wr = x_original.apply_transpose(&weighted_residual);
1238 let gradient_data = transform_active
1239 .as_ref()
1240 .map(|transform| transform.apply_transpose(&xt_wr))
1241 .unwrap_or(xt_wr);
1242 let deviance = calculate_deviance(y, &finalmu, likelihood, priorweights);
1243 let log_likelihood = calculate_loglikelihood(y, &finalmu, likelihood, priorweights);
1244 let max_abs_eta = inf_norm(finalmu.iter().copied());
1245 (
1246 final_eta,
1247 finalmu,
1248 finalz,
1249 gradient_data,
1250 deviance,
1251 log_likelihood,
1252 max_abs_eta,
1253 )
1254 };
1255 let score_norm = array1_l2_norm(&gradient_data);
1256 let s_beta = penalty_active.shifted_gradient(beta_transformed.as_ref());
1257 let s_beta_norm = array1_l2_norm(&s_beta);
1258 let mut gradient = gradient_data;
1259 gradient += &s_beta;
1260 let mut penalty_term = penalty_active.shifted_quadratic(beta_transformed.as_ref());
1261 let ridge_used = baseridge;
1262 let stabilizedhessian = if ridge_used > 0.0 {
1263 penalized_hessian
1264 .addridge(ridge_used)
1265 .map_err(|e| EstimationError::InvalidInput(format!("ridge addition failed: {e}")))?
1266 } else {
1267 penalized_hessian.clone()
1268 };
1269 let mut ridge_grad_norm = 0.0;
1270 if ridge_used > 0.0 {
1271 let ridge_penalty =
1272 ridge_used * beta_transformed.as_ref().dot(beta_transformed.as_ref());
1273 penalty_term += ridge_penalty;
1274 gradient += &beta_transformed.as_ref().mapv(|v| ridge_used * v);
1275 ridge_grad_norm = ridge_used * array1_l2_norm(beta_transformed.as_ref());
1276 }
1277
1278 let gradient_norm = array1_l2_norm(&gradient);
1279 let working_state = WorkingState {
1280 eta: LinearPredictor::new(finalmu.clone()),
1281 gradient: gradient.clone(),
1282 hessian: penalized_hessian.clone(),
1283
1284 log_likelihood,
1285 deviance,
1286 penalty_term,
1287 firth: FirthDiagnostics::Inactive,
1288 ridge_used,
1289 hessian_curvature: HessianCurvatureKind::Fisher,
1290 gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1291 };
1292
1293 let zero_iter_penalized = deviance + penalty_term;
1294 let working_summary = WorkingModelPirlsResult {
1295 beta: beta_transformed.clone(),
1296 state: working_state,
1297 status: PirlsStatus::Converged,
1298 iterations: 1,
1299 lastgradient_norm: gradient_norm,
1300 last_deviance_change: 0.0,
1301 last_step_size: 1.0,
1302 last_step_halving: 0,
1303 max_abs_eta,
1304 constraint_kkt: linear_constraints.as_ref().map(|lin| {
1305 compute_constraint_kkt_diagnostics(beta_transformed.as_ref(), &gradient, lin)
1306 }),
1307 min_penalized_deviance: if zero_iter_penalized.is_finite() {
1308 zero_iter_penalized
1309 } else {
1310 f64::INFINITY
1311 },
1312 // Zero-iteration synthesis: no LM damping was exercised, so
1313 // hand the next solve the cold default.
1314 final_lm_lambda: 1e-6,
1315 // Zero-iteration synthesis: no LM gain ratio was measured.
1316 final_accept_rho: None,
1317 // Zero-iteration synthesis assembles the Hessian with prior
1318 // weights only; no observed-information re-evaluation has
1319 // happened. Label honestly as a Fisher-type surrogate so
1320 // outer Laplace consumers see the truth.
1321 exported_laplace_curvature: ExportedLaplaceCurvature::ExpectedInformationSurrogate,
1322 };
1323
1324 let (solve_c_array, solve_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
1325 computeworkingweight_derivatives_from_eta(
1326 &config.likelihood,
1327 &config.link_kind,
1328 &final_eta,
1329 priorweights_owned.view(),
1330 )?;
1331 let reparam_result = materialize_final_reparam_result()?;
1332 let qs_arc_final = Arc::new(reparam_result.qs.clone());
1333 let pirls_result = PirlsResult {
1334 likelihood: config.likelihood.clone(),
1335 beta_transformed,
1336 penalized_hessian_transformed: penalized_hessian,
1337 stabilizedhessian_transformed: stabilizedhessian,
1338 ridge_passport: RidgePassport::scaled_identity(
1339 ridge_used,
1340 RidgePolicy::explicit_stabilization_full(),
1341 ),
1342 ridge_used,
1343 deviance,
1344 edf,
1345 stable_penalty_term: penalty_term,
1346 firth: FirthDiagnostics::Inactive,
1347 finalweights: priorweights_owned.clone(),
1348 final_offset: offset.to_owned(),
1349 final_eta: final_eta.clone(),
1350 finalmu: finalmu.clone(),
1351 solveweights: priorweights_owned,
1352 solveworking_response: finalz.clone(),
1353 solvemu: finalmu.clone(),
1354 solve_dmu_deta,
1355 solve_d2mu_deta2,
1356 solve_d3mu_deta3,
1357 solve_c_array,
1358 solve_d_array,
1359 derivatives_unsupported: false,
1360 status: PirlsStatus::Converged,
1361 iteration: 1,
1362 max_abs_eta,
1363 lastgradient_norm: gradient_norm,
1364 gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1365 last_deviance_change: 0.0,
1366 last_step_halving: 0,
1367 hessian_curvature: HessianCurvatureKind::Fisher,
1368 exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
1369 final_lm_lambda: working_summary.final_lm_lambda,
1370 final_accept_rho: working_summary.final_accept_rho,
1371 constraint_kkt: working_summary.constraint_kkt.clone(),
1372 linear_constraints_transformed: linear_constraints.clone(),
1373 reparam_result,
1374 x_transformed: make_reparam_operator(&x_original, &qs_arc_final, use_sparse_native),
1375 coordinate_frame,
1376 used_device: false,
1377 cache_compacted: false,
1378 min_penalized_deviance: working_summary.min_penalized_deviance,
1379 };
1380
1381 return Ok((pirls_result, working_summary));
1382 }
1383
1384 let x_original_for_result = x_original.clone();
1385 let mut working_model = GamWorkingModel::new(
1386 None, // No pre-materialized x_transformed: use implicit Qs composition
1387 x_original.clone(),
1388 coordinate_frame,
1389 offset,
1390 y,
1391 priorweights,
1392 penalty_active.clone(),
1393 workspace,
1394 config.likelihood.clone(),
1395 config.link_kind.clone(),
1396 // Inner Firth/Jeffreys activation must agree with the caller-requested
1397 // mode. The REML *outer* analytic derivative assembly only carries the
1398 // Jeffreys score/curvature term when `firth_bias_reduction` is set
1399 // (`reml_robust_jeffreys_link` returns `None` otherwise), so arming the
1400 // inner penalty unconditionally would converge the inner mode to the
1401 // Firth-penalized stationary point while the outer H/u/IFT stayed
1402 // non-Firth — the two would then disagree by exactly the Jeffreys
1403 // contribution (broken τ-τ Hessian-vs-FD and stationarity-cancellation
1404 // identities, #825). Gate on `firth_bias_reduction` so inner and outer
1405 // are the same objective.
1406 config.firth_bias_reduction
1407 && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1408 && inverse_link_has_fisher_weight_jet(&config.link_kind),
1409 transform_active.clone(),
1410 quadctx,
1411 // #1111 / #1033 mechanism (c): frozen-W first-Fisher-step XᵀWX in the
1412 // original (conditioned x_fit) frame, served n-free on the first inner
1413 // iteration. Suppressed under Firth bias reduction, which shifts the
1414 // working response per iteration (the installer also gates Firth off).
1415 if config.firth_bias_reduction {
1416 None
1417 } else {
1418 glm_first_step_gram.cloned()
1419 },
1420 );
1421
1422 // Apply integrated (GHQ) likelihood if per-observation SE is provided.
1423 // This is used by the calibrator to coherently account for base prediction uncertainty.
1424 if let Some(se) = covariate_se {
1425 working_model = working_model.with_covariate_se(se.to_owned());
1426 }
1427
1428 let mut beta_guess_original = warm_start_beta
1429 .filter(|beta| beta.len() == penalty.p)
1430 .map(|beta| beta.to_owned())
1431 .unwrap_or_else(|| {
1432 Coefficients::new(default_beta_guess_external(
1433 penalty.p,
1434 link_function,
1435 y,
1436 priorweights,
1437 config.link_kind.mixture_state(),
1438 config.link_kind.sas_state(),
1439 ))
1440 });
1441 if let Some(lb) = penalty.coefficient_lower_bounds {
1442 project_coefficients_to_lower_bounds(&mut beta_guess_original.0, lb);
1443 }
1444 let initial_beta = transform_active
1445 .as_ref()
1446 .map(|transform| transform.apply_transpose(beta_guess_original.as_ref()))
1447 .unwrap_or_else(|| beta_guess_original.as_ref().clone());
1448 let initial_beta = if let Some(constraints) = linear_constraints.as_ref() {
1449 // Worst per-row *scaled* (geometric) slack of the current seed against the
1450 // constraint cone. Negative ⇒ the seed violates a row; ~0 ⇒ the seed sits
1451 // ON the boundary (for a homogeneous convex/concave second-difference
1452 // cone, `β = 0` — the unconstrained Gaussian seed — sits on EVERY row's
1453 // boundary, i.e. the cone vertex). Either way the seed must be pushed
1454 // strictly into the interior before P-IRLS starts.
1455 let mut min_scaled_slack = f64::INFINITY;
1456 for i in 0..constraints.a.nrows() {
1457 let norm = constraints.a.row(i).dot(&constraints.a.row(i)).sqrt();
1458 let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
1459 let slack = (constraints.a.row(i).dot(&initial_beta) - constraints.b[i]) * inv;
1460 min_scaled_slack = min_scaled_slack.min(slack);
1461 }
1462 // Push the seed to the nearest STRICTLY-INTERIOR feasible point whenever
1463 // any row is tight or violated. A seed on the cone boundary (most acutely
1464 // the vertex `β = 0`) hands the inner active-set QP an all-rows-active
1465 // working set, where it stalls on a degenerate, non-stationary face — so
1466 // the fit silently diverges (or aborts in release) between a cold and a
1467 // warm warm-start cache (#873). A strictly-interior seed makes the QP's
1468 // initial active set empty; it then adds only the genuinely binding rows
1469 // and converges to the certified constrained optimum regardless of cache
1470 // state. The projection keeps the data-driven curvature of `initial_beta`
1471 // and falls back to the min-norm feasible point only if it cannot certify
1472 // a strictly-interior solution.
1473 //
1474 // The min-norm fallback (`feasible_point_for_linear_constraints`) is only
1475 // used for a NON-homogeneous cone (`b ≠ 0`), where it returns a genuine
1476 // interior-of-the-offset-polyhedron point. For a HOMOGENEOUS shape cone
1477 // (`b ≈ 0` — the convex/concave second-difference rows) that function
1478 // returns the minimum-norm feasible point `β = 0`, which is the cone
1479 // *vertex*: the exact all-rows-tight degenerate seed #873 is about. Taking
1480 // it would silently reintroduce the #873 pathology whenever the strict
1481 // projection rarely fails to certify. So for a homogeneous cone we skip the
1482 // vertex fallback entirely and prefer the data-driven `initial_beta`: it
1483 // violates at most *some* rows (a lower-dimensional, non-degenerate face the
1484 // inner active-set QP can recover from), strictly better than the vertex
1485 // where *every* row is simultaneously tight.
1486 let cone_is_homogeneous = constraints.b.iter().all(|v| v.abs() <= 1e-14);
1487 if min_scaled_slack < active_set::interior_seed_margin() {
1488 let projected =
1489 active_set::project_point_strictly_into_feasible_cone(&initial_beta, constraints)
1490 .or_else(|| {
1491 if cone_is_homogeneous {
1492 None
1493 } else {
1494 active_set::feasible_point_for_linear_constraints(
1495 constraints,
1496 initial_beta.len(),
1497 )
1498 }
1499 });
1500 projected.unwrap_or(initial_beta)
1501 } else {
1502 initial_beta
1503 }
1504 } else {
1505 initial_beta
1506 };
1507 // Inner P-IRLS Firth activation. The inner penalized objective must match
1508 // the objective the REML outer derivatives are assembled against: the outer
1509 // path carries the Jeffreys/Firth score+curvature only when the caller set
1510 // `firth_bias_reduction` (`reml_robust_jeffreys_link` is `None` otherwise),
1511 // so the inner Firth term is armed iff the caller requested it AND the link
1512 // exposes a Fisher-weight jet (#825). Forcing it on unconditionally desynced
1513 // the Firth-penalized inner mode from the non-Firth outer assembly.
1514 let firth_active = config.firth_bias_reduction
1515 && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1516 && inverse_link_has_fisher_weight_jet(&config.link_kind);
1517 let base_max_step_halving = if firth_active { 60 } else { 30 };
1518 let options = WorkingModelPirlsOptions {
1519 // The Firth-penalized P-IRLS converges at the same iteration count as
1520 // the unpenalized fit — the Jeffreys term is a smooth, bounded addition
1521 // to a Newton system that is already well conditioned (the additional
1522 // per-iteration LM step-halving budget above absorbs the early-iteration
1523 // curvature change). Bumping the outer-iteration cap to mask a
1524 // mis-conditioned step would only hide non-convergence, so the cap stays
1525 // the caller's `max_iterations` and trips as a hard error if exceeded.
1526 max_iterations: config.max_iterations,
1527 convergence_tolerance: config.convergence_tolerance,
1528 adaptive_kkt_tolerance,
1529 // LM step-halving is a per-iteration damping retry budget; it is
1530 // independent of the total outer-iteration cap. Tying the two
1531 // together collapsed step halving to 3 under seed screening (where
1532 // max_iterations is intentionally capped low), turning recoverable
1533 // damping into spurious failures.
1534 max_step_halving: base_max_step_halving,
1535 min_step_size: if firth_active { 1e-12 } else { 1e-10 },
1536 firth_bias_reduction: firth_active,
1537 coefficient_lower_bounds: None,
1538 linear_constraints: linear_constraints.clone(),
1539 initial_lm_lambda: config.initial_lm_lambda,
1540 geodesic_acceleration: config.geodesic_acceleration,
1541 arrow_schur: config.arrow_schur.clone(),
1542 };
1543
1544 let mut iteration_logger = |info: &WorkingModelIterationInfo| {
1545 log::debug!(
1546 "[PIRLS] iter {:>3} | deviance {:.6e} | |grad| {:.3e} | step {:.3e} (halving {})",
1547 info.iteration,
1548 info.deviance,
1549 info.gradient_norm,
1550 info.step_size,
1551 info.step_halving
1552 );
1553 };
1554
1555 // Stage 3.3 GPU PIRLS-loop dispatch — see pirls_host_dispatch::try_pirls_loop_gpu.
1556 if let Some(result) = try_pirls_loop_gpu(
1557 config,
1558 &penalty_active,
1559 kronecker_runtime.is_none(),
1560 use_sparse_native,
1561 &linear_constraints,
1562 &x_original,
1563 &qs_arc,
1564 penalty.p,
1565 &x_original_for_result,
1566 || materialize_final_reparam_result(),
1567 y,
1568 priorweights,
1569 offset,
1570 &initial_beta,
1571 link_function,
1572 coordinate_frame,
1573 ) {
1574 return result;
1575 }
1576
1577 let mut working_summary = runworking_model_pirls(
1578 &mut working_model,
1579 Coefficients::new(initial_beta),
1580 &options,
1581 &mut iteration_logger,
1582 )?;
1583
1584 // ── Gamma dispersion: re-estimate the shape at the *converged* η (#678) ──
1585 //
1586 // The inner LM solve estimates the Gamma shape ν = 1/φ **once** from the
1587 // warm-start η and freezes it for the rest of the solve (see the
1588 // `gamma_shape_locked` doc on `GamWorkingModel`): holding ν fixed keeps the
1589 // product φ·λ — and hence the penalized argmin β̂ — a stationary LM target,
1590 // so the gain ratio compares one objective. That lock is correct *within* a
1591 // solve, but it pins ν to whatever η the solve started from. When the fit
1592 // cold-starts (the final dedicated fit at the converged ρ passes
1593 // `warm_start_beta = None`, and seed screening starts from a default guess),
1594 // that warm-start η has not yet captured the mean structure; the leftover
1595 // spread of μ inflates the Gamma deviance term `mean[y/μ − ln(y/μ) − 1]` and
1596 // biases ν **down** (φ up) by >2× whenever μ varies appreciably. The mean
1597 // surface still converges (β̂ is essentially scale-free here), but the frozen
1598 // ν that survives into `UnifiedFitResult::dispersion_phi()` — and from there
1599 // into every coefficient SE `Vb = H⁻¹·φ̂`, prediction interval, and
1600 // observation-noise interval — is the early, mean-spread-contaminated value.
1601 //
1602 // Fix: after the solve converges, re-estimate ν at the converged η. If it
1603 // moved, re-solve β (warm-started, ν held fixed at the refreshed value) and
1604 // repeat, driving the pair (β, ν) to their joint fixed point at the current
1605 // λ. At convergence the reported dispersion is the Gamma ML estimate at the
1606 // converged mean (mgcv's post-hoc Pearson/deviance scale), and the final
1607 // working state — `finalweights`, the penalized Hessian, the deviance, μ —
1608 // is rebuilt with that same ν, so `Vb = H⁻¹·φ̂` stays internally consistent.
1609 // Warm-started solves (every REML cost eval) already sit near the converged
1610 // η, so the first refresh check confirms ν and exits without a re-solve; the
1611 // added cost there is a single O(n) shape evaluation.
1612 if refine_dispersion_at_converged_eta
1613 && working_model.likelihood.scale.gamma_shape_is_estimated()
1614 {
1615 // A few passes suffice: the converged-η shape map is a strong
1616 // contraction (β̂ barely moves once the mean is captured), so cold
1617 // starts settle in 1–2 re-solves and warm starts in zero.
1618 const MAX_SHAPE_REFRESH: usize = 5;
1619 // Relative shape tolerance below which a re-solve cannot move any
1620 // reported quantity meaningfully (far under statistical resolution).
1621 const SHAPE_REFRESH_REL_TOL: f64 = 1e-4;
1622 for refresh_iter in 0..MAX_SHAPE_REFRESH {
1623 let refreshed_shape = super::estimate_gamma_shape_from_eta(
1624 y,
1625 working_summary.state.eta.as_ref(),
1626 priorweights,
1627 );
1628 let prior_shape = working_model.likelihood.gamma_shape().unwrap_or(1.0);
1629 let rel_change =
1630 (refreshed_shape - prior_shape).abs() / prior_shape.max(f64::MIN_POSITIVE);
1631 // Install the refreshed shape and hold it fixed for any re-solve so
1632 // the LM objective stays stationary (the lock is *re-armed*, not
1633 // released — the seed-from-warm-start branch in `update_with_curvature`
1634 // must not overwrite this deliberately chosen value). Because this
1635 // assignment evaluated the shape at the *current* converged η and no
1636 // re-solve follows it on the exit paths below, the reported shape
1637 // always equals `estimate_gamma_shape_from_eta(final_eta)` — the
1638 // self-consistency invariant the in-module Gamma unit test checks.
1639 working_model.likelihood = working_model
1640 .likelihood
1641 .clone()
1642 .with_gamma_shape(refreshed_shape);
1643 working_model.gamma_shape_locked = true;
1644 if rel_change <= SHAPE_REFRESH_REL_TOL {
1645 // Converged: the working-state buffers (weights, Hessian,
1646 // deviance) already reflect a shape within tolerance of
1647 // `refreshed_shape`, because the only way to reach here without
1648 // a re-solve is that the prior solve's shape already matched the
1649 // converged-η estimate. Nothing left to rebuild.
1650 break;
1651 }
1652 if refresh_iter + 1 == MAX_SHAPE_REFRESH {
1653 // Final allowed pass and the shape is still drifting (a
1654 // pathological non-contraction). Do NOT re-solve: re-solving
1655 // would advance `final_eta` past the η the just-installed shape
1656 // was evaluated at, breaking the stored-shape == estimate(final_eta)
1657 // invariant. Stopping here keeps the reported shape exactly the
1658 // ML estimate at the reported η; the residual weight/φ drift is
1659 // bounded by the last `rel_change` and never worse than the
1660 // pre-fix frozen-warm-start value.
1661 break;
1662 }
1663 // The shape moved: re-solve β at the corrected shape, warm-started
1664 // at the converged β, so the final working state is rebuilt with the
1665 // refreshed ν.
1666 working_summary = runworking_model_pirls(
1667 &mut working_model,
1668 working_summary.beta.clone(),
1669 &options,
1670 &mut iteration_logger,
1671 )?;
1672 }
1673 }
1674
1675 // ── Tweedie dispersion φ: re-estimate at the *converged* η (#771) ─────────
1676 //
1677 // Identical in spirit to the Gamma-shape refresh above: the inner LM solve
1678 // estimates φ **once** from the warm-start η and freezes it (the
1679 // `tweedie_phi_locked` lock), keeping the product φ·λ — and hence β̂ — a
1680 // stationary LM target. φ enters only the working weight `prior·μ^{2−p}/φ`
1681 // and not the working response, so (like the Gamma shape, and unlike the
1682 // Beta precision which couples through the digamma mean score) the mean
1683 // surface is essentially scale-free and β̂ barely moves when φ is corrected.
1684 // But the frozen warm-start φ is the value that survives into
1685 // `FitInference::dispersion` and the covariance `Vb = H⁻¹` (whose √φ scaling
1686 // lives in the weight); at a cold-started η ≈ 0 the Pearson residuals carry
1687 // the *marginal* spread of y, biasing the estimate. Re-estimating at the
1688 // converged η — re-solving β only if φ moved materially — drives (β, φ) to
1689 // their joint fixed point, so the reported φ is the converged-mean Pearson
1690 // estimate and the final weights/Hessian/SE are internally consistent with
1691 // it. Held OFF inside the REML λ search (the flag), φ is refreshed only at
1692 // the reported fit, so it cannot couple to the smoothing parameter.
1693 if refine_dispersion_at_converged_eta
1694 && working_model.likelihood.scale.tweedie_phi_is_estimated()
1695 {
1696 if let ResponseFamily::Tweedie { p } = working_model.likelihood.spec.response {
1697 // The converged-η Pearson map is a strong contraction (β̂ scale-free
1698 // here), so cold starts settle in 1–2 re-solves and warm starts in
1699 // zero.
1700 const MAX_PHI_REFRESH: usize = 5;
1701 // Relative φ tolerance below which a re-solve cannot move any reported
1702 // quantity meaningfully (far under statistical resolution).
1703 const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1704 for refresh_iter in 0..MAX_PHI_REFRESH {
1705 let refreshed_phi = super::estimate_tweedie_phi_from_eta(
1706 y,
1707 working_summary.state.eta.as_ref(),
1708 priorweights,
1709 p,
1710 );
1711 let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1712 let rel_change =
1713 (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1714 // Install the refreshed φ (the scale metadata the working weight
1715 // reads via `fixed_phi()`) and re-arm the lock so a following
1716 // re-solve does not overwrite this converged-η value. Because the
1717 // exit paths below evaluate φ at the *current* η with no following
1718 // re-solve, the reported φ always equals
1719 // `estimate_tweedie_phi_from_eta(final_eta)`.
1720 working_model.likelihood = working_model
1721 .likelihood
1722 .clone()
1723 .with_tweedie_phi(refreshed_phi);
1724 working_model.tweedie_phi_locked = true;
1725 if rel_change <= PHI_REFRESH_REL_TOL {
1726 // Converged: the working state already reflects a φ within
1727 // tolerance of `refreshed_phi`. Nothing left to rebuild.
1728 break;
1729 }
1730 if refresh_iter + 1 == MAX_PHI_REFRESH {
1731 // Final allowed pass and φ is still drifting. Do NOT re-solve:
1732 // re-solving would advance η past the point φ was evaluated at,
1733 // breaking the stored-φ == estimate(final_eta) invariant.
1734 break;
1735 }
1736 // φ moved materially: re-solve β at the corrected φ, warm-started
1737 // at the converged β, so the final working state is rebuilt with
1738 // the refreshed φ.
1739 working_summary = runworking_model_pirls(
1740 &mut working_model,
1741 working_summary.beta.clone(),
1742 &options,
1743 &mut iteration_logger,
1744 )?;
1745 }
1746 }
1747 }
1748
1749 // ── Beta precision φ: re-estimate at the *converged* η and drive (β, φ) to
1750 // their joint fixed point (#769) ──────────────────────────────────────
1751 //
1752 // Like the Gamma shape above, the inner LM solve estimates φ **once** from
1753 // the warm-start η and freezes it for the rest of the solve (the
1754 // `beta_phi_locked` doc on `GamWorkingModel`): holding φ fixed keeps the
1755 // penalized argmin β̂ a stationary LM target so the gain ratio compares one
1756 // objective. But that lock pins φ to whatever η the solve started from, and
1757 // for the final dedicated fit at the converged ρ the warm-start is the cold
1758 // default guess (η ≈ 0, μ ≈ 0.5 everywhere). At the null predictor the
1759 // Pearson residuals `(y−μ)²/(μ(1−μ))` capture the full *marginal* spread of
1760 // y rather than its *conditional* spread, so the moment estimator
1761 // `1+φ = Σw / Σ w·s` returns a precision far too small (≈3 when the truth is
1762 // ≈20 here).
1763 //
1764 // Crucially — and unlike the Gamma shape — φ does **not** factor out of the
1765 // Beta mean score. With the logit link the score for β is
1766 // ∂ℓ/∂β = φ · Σᵢ xᵢ (y*ᵢ − μ*ᵢ), y*ᵢ = logit(yᵢ),
1767 // μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ),
1768 // so the root β̂ depends on φ through the digamma terms. A φ that is too
1769 // small shrinks every fitted coefficient toward zero. So this refresh is not
1770 // cosmetic (as it is for Gamma): the re-solve is what *recovers the mean*.
1771 //
1772 // Fix: after the cold solve converges, re-estimate φ at the converged η,
1773 // re-solve β at the corrected φ (warm-started), and repeat. This is the
1774 // betareg alternating mean-fit ↔ φ-estimate scheme; the moment estimator is
1775 // a strong contraction once the mean has any structure, so the pair settles
1776 // in a handful of passes. Held OFF inside the REML λ search (see the flag
1777 // doc), φ is refreshed only here at the reported fit, so it cannot couple to
1778 // the smoothing parameter and reward over-smoothing. As with Gamma, every
1779 // exit path installs φ evaluated at the *current* η with no following
1780 // re-solve, so the reported φ (which flows into `EstimatedBetaPhi`, the
1781 // embedded `Beta { phi }`, `dispersion`, and every SE) always equals
1782 // `estimate_beta_phi_from_eta(final_eta)`.
1783 if refine_dispersion_at_converged_eta && working_model.likelihood.scale.beta_phi_is_estimated()
1784 {
1785 // The mean moves between passes (φ feeds back through the digamma
1786 // score), so allow a few more passes than the scale-free Gamma case;
1787 // the contraction is fast and warm-started re-solves are cheap.
1788 const MAX_PHI_REFRESH: usize = 30;
1789 // Relative φ tolerance below which a re-solve cannot move β̂ — and hence
1790 // any reported quantity — by a statistically meaningful amount.
1791 const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1792 for refresh_iter in 0..MAX_PHI_REFRESH {
1793 let refreshed_phi = super::estimate_beta_phi_from_eta(
1794 y,
1795 working_summary.state.eta.as_ref(),
1796 priorweights,
1797 );
1798 let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1799 let rel_change = (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1800 // Install the refreshed φ (updates BOTH the `Beta { phi }` family
1801 // variant every weight/deviance expression reads and the
1802 // `EstimatedBetaPhi` scale metadata) and re-arm the lock so a
1803 // following re-solve's `update_with_curvature` does not overwrite
1804 // this deliberately chosen value with a fresh cold estimate.
1805 working_model.likelihood = working_model
1806 .likelihood
1807 .clone()
1808 .with_beta_phi(refreshed_phi);
1809 working_model.beta_phi_locked = true;
1810 if rel_change <= PHI_REFRESH_REL_TOL {
1811 // Converged: the just-installed φ matches (to tolerance) the φ
1812 // the current working state was solved at, so β̂, the weights,
1813 // the Hessian and the deviance are already self-consistent with
1814 // the reported φ. Nothing left to rebuild.
1815 break;
1816 }
1817 if refresh_iter + 1 == MAX_PHI_REFRESH {
1818 // Final allowed pass and φ is still drifting. Do NOT re-solve:
1819 // re-solving would advance η past the point the just-installed φ
1820 // was evaluated at, breaking the stored-φ == estimate(final_eta)
1821 // invariant. Stop here so the reported φ is exactly the moment
1822 // estimate at the reported η.
1823 break;
1824 }
1825 // φ moved materially: re-solve β at the corrected φ, warm-started at
1826 // the converged β, so the mean is refit under the better precision
1827 // and the final working state is rebuilt consistently.
1828 working_summary = runworking_model_pirls(
1829 &mut working_model,
1830 working_summary.beta.clone(),
1831 &options,
1832 &mut iteration_logger,
1833 )?;
1834 }
1835 }
1836
1837 // ── Negative-Binomial overdispersion θ: re-estimate at the *converged* η and
1838 // drive (β, θ) to their joint fixed point (#802) ───────────────────────
1839 //
1840 // Identical in spirit to the Beta-precision refresh above. The inner LM solve
1841 // estimates θ **once** from the warm-start η and freezes it (the
1842 // `negbin_theta_locked` lock), keeping the penalized argmin β̂ a stationary LM
1843 // target. But that lock pins θ to whatever η the solve started from, and for
1844 // the final dedicated fit at the converged ρ the warm-start is the cold
1845 // default guess (η ≈ 0). At the null predictor the Pearson residuals carry
1846 // the *marginal* spread of y rather than its *conditional* spread, biasing
1847 // the moment seed — and the frozen θ is what survives into the working weight
1848 // `W = μθ/(θ+μ)`, the covariance `Vb = H⁻¹` (whose overdispersion scaling
1849 // lives in that weight, not a post-hoc multiply), and every reported SE /
1850 // interval / `generate` draw.
1851 //
1852 // Like the Beta precision — and unlike the scale-free Gamma shape / Tweedie φ
1853 // — θ enters the NB2 working *response*, not only the weight, so re-solving β
1854 // under the corrected θ is not cosmetic: it recovers the mean under the right
1855 // variance function. Re-estimating at the converged η, re-solving β
1856 // (warm-started), and repeating drives (β, θ) to their joint maximum-
1857 // likelihood fixed point. Held OFF inside the REML λ search (the flag), θ is
1858 // refreshed only here at the reported fit, so it cannot couple to the
1859 // smoothing parameter. Every exit path installs θ evaluated at the *current*
1860 // η with no following re-solve, so the reported θ (which flows into the
1861 // embedded `NegativeBinomial { theta }`, the `EstimatedNegBinTheta` scale
1862 // metadata, the predictive-interval variance, and every SE) always equals
1863 // `estimate_negbin_theta_from_eta(final_eta)`.
1864 if refine_dispersion_at_converged_eta
1865 && working_model.likelihood.scale.negbin_theta_is_estimated()
1866 {
1867 // θ feeds back through the working response, so allow a few more passes
1868 // than the scale-free Gamma case; the alternation is a strong contraction
1869 // and warm-started re-solves are cheap.
1870 const MAX_THETA_REFRESH: usize = 30;
1871 // Relative θ tolerance below which a re-solve cannot move β̂ — and hence
1872 // any reported quantity — by a statistically meaningful amount.
1873 const THETA_REFRESH_REL_TOL: f64 = 1e-4;
1874 for refresh_iter in 0..MAX_THETA_REFRESH {
1875 let refreshed_theta = super::estimate_negbin_theta_from_eta(
1876 y,
1877 working_summary.state.eta.as_ref(),
1878 priorweights,
1879 );
1880 let prior_theta = working_model.likelihood.negbin_theta().unwrap_or(1.0);
1881 let rel_change =
1882 (refreshed_theta - prior_theta).abs() / prior_theta.max(f64::MIN_POSITIVE);
1883 // Install the refreshed θ (updates BOTH the `NegativeBinomial { theta }`
1884 // family variant every weight/deviance expression reads and the
1885 // `EstimatedNegBinTheta` scale metadata) and re-arm the lock so a
1886 // following re-solve's `update_with_curvature` does not overwrite this
1887 // deliberately chosen value with a fresh cold estimate.
1888 working_model.likelihood = working_model
1889 .likelihood
1890 .clone()
1891 .with_negbin_theta(refreshed_theta);
1892 working_model.negbin_theta_locked = true;
1893 if rel_change <= THETA_REFRESH_REL_TOL {
1894 // Converged: the just-installed θ matches (to tolerance) the θ the
1895 // current working state was solved at, so β̂, the weights, the
1896 // Hessian and the deviance are already self-consistent with the
1897 // reported θ. Nothing left to rebuild.
1898 break;
1899 }
1900 if refresh_iter + 1 == MAX_THETA_REFRESH {
1901 // Final allowed pass and θ is still drifting. Do NOT re-solve:
1902 // re-solving would advance η past the point the just-installed θ
1903 // was evaluated at, breaking the stored-θ == estimate(final_eta)
1904 // invariant. Stop here so the reported θ is exactly the ML
1905 // estimate at the reported η.
1906 break;
1907 }
1908 // θ moved materially: re-solve β at the corrected θ, warm-started at
1909 // the converged β, so the mean is refit under the better variance
1910 // function and the final working state is rebuilt consistently.
1911 working_summary = runworking_model_pirls(
1912 &mut working_model,
1913 working_summary.beta.clone(),
1914 &options,
1915 &mut iteration_logger,
1916 )?;
1917 }
1918 }
1919
1920 // Extract workspace before consuming working_model so we can reuse
1921 // the pre-allocated buffers in calculate_edfwithworkspace_with_penalty.
1922 // into_final_state() drops the workspace field anyway (it uses `..` in
1923 // its destructure); we replace it with a zero-sized stub to satisfy the
1924 // borrow checker, then keep the real workspace alive for the EDF call.
1925 let mut saved_workspace = std::mem::replace(
1926 &mut working_model.workspace,
1927 PirlsWorkspace::new(0, 0, 0, 0),
1928 );
1929 let final_state = working_model.into_final_state();
1930 let GamModelFinalState {
1931 likelihood: final_likelihood,
1932 coordinate_frame,
1933 finalmu,
1934 finalweights,
1935 scoreweights,
1936 finalz,
1937 final_c,
1938 final_d,
1939 final_dmu_deta,
1940 final_d2mu_deta2,
1941 final_d3mu_deta3,
1942 penalty_term,
1943 ..
1944 } = final_state;
1945
1946 // Preserve the Hessian as-is (sparse or dense) — no densification.
1947 // P-IRLS already folded any stabilization ridge directly into the Hessian.
1948 // Keep that exact matrix so outer LAML derivatives stay consistent:
1949 // H_eff = X'W_H X + S_λ + ridge I (if ridge_used > 0).
1950 let penalized_hessian_transformed = working_summary.state.hessian.clone();
1951 let stabilizedhessian_transformed = penalized_hessian_transformed.clone();
1952 // Use the workspace-backed variant for the dense path to reuse the
1953 // `final_aug_matrix` allocation; the sparse path still allocates
1954 // internally because no pre-computed factor is available at this site.
1955 let mut edf = if let Some(dense_h) = penalized_hessian_transformed.as_dense() {
1956 calculate_edfwithworkspace_with_penalty(dense_h, &penalty_active, &mut saved_workspace)?
1957 } else {
1958 calculate_edf_with_penalty(&penalized_hessian_transformed, &penalty_active)?
1959 };
1960 if !edf.is_finite() || edf.is_nan() {
1961 let p = penalized_hessian_transformed.ncols() as f64;
1962 let r = penalty_active.rank() as f64;
1963 edf = (p - r).max(0.0);
1964 }
1965
1966 // Outer rescue: a fit that hit max-iterations may still be a usable
1967 // minimum if progress has effectively stopped (deviance plateaued or
1968 // step size collapsed to the floor) AND the projected gradient is in
1969 // the near-stationary band under the scale-invariant certificate.
1970 // Same logic for non-Firth and Firth paths; firth_active just gates
1971 // the second pass.
1972 let stalled_at_valid_minimum = |summary: &WorkingModelPirlsResult| -> bool {
1973 // Scale-equivariant deviance plateau band (issue #1127). The
1974 // `last_deviance_change` compared below and the deviance both scale as
1975 // `O(a²)` under a response rescaling `y → a·y` (the penalized normal
1976 // equations are linear in `y`, so `β → a·β` and the RSS-deviance
1977 // scales by `a²`). Keying the plateau band to the deviance's own
1978 // magnitude `+ |penalty|` makes the ratio `Δdev / dev_scale`
1979 // scale-invariant. The previous `.max(1.0)` absolute floor broke this:
1980 // for a micro-unit response (`a = 1e-6`) the deviance is `O(1e-12)`, so
1981 // the floor pinned the band at `1.0` — ~1e9× too loose — and this
1982 // max-iteration rescue declared `progress_stopped` at an over-smoothed
1983 // iterate, propagating an inflated `λ̂` to the outer REML loop. For a
1984 // well-scaled (`a ≳ 1`) or up-scaled (`a = 1e6`) objective the floor was
1985 // already a no-op, so those directions are byte-identical. A perfect
1986 // interpolating fit gives a `0` band, so the relative `Δdev` test cannot
1987 // fire spuriously and the scale-invariant `near_stationary_kkt`
1988 // certificate then governs acceptance.
1989 let dev_scale = summary.state.deviance.abs() + summary.state.penalty_term.abs();
1990 // Progress plateau uses the fixed solver tolerance; only the KKT band below adapts.
1991 let dev_tol = options.convergence_tolerance * dev_scale;
1992 let step_floor = options.min_step_size * 2.0;
1993 let progress_stopped =
1994 summary.last_deviance_change.abs() <= dev_tol || summary.last_step_size <= step_floor;
1995 let near_stationary = summary
1996 .state
1997 .near_stationary_kkt(summary.lastgradient_norm, effective_kkt_tolerance(&options));
1998 progress_stopped && near_stationary
1999 };
2000
2001 let mut status = working_summary.status;
2002 if status.is_failed_max_iterations() && stalled_at_valid_minimum(&working_summary) {
2003 status = PirlsStatus::StalledAtValidMinimum;
2004 working_summary.status = status;
2005 }
2006 if status.is_failed_max_iterations()
2007 && firth_active
2008 && stalled_at_valid_minimum(&working_summary)
2009 {
2010 // Firth-adjusted fits can stall; accept under the same dual-criterion
2011 // near-stationary band.
2012 status = PirlsStatus::StalledAtValidMinimum;
2013 working_summary.status = status;
2014 }
2015 let has_penalty = penalty_active.rank() > 0;
2016 let firth_active = options.firth_bias_reduction;
2017 if detect_logit_instability(
2018 link_function,
2019 &final_likelihood.spec.response,
2020 has_penalty,
2021 firth_active,
2022 &working_summary,
2023 &finalmu,
2024 &finalweights,
2025 y,
2026 ) {
2027 status = PirlsStatus::Unstable;
2028 working_summary.status = status;
2029 }
2030
2031 // Store a lazy ReparamOperator instead of materializing X·Qs.
2032 // Consumers that truly need dense access can call .to_dense() on demand.
2033 let reparam_result_final = materialize_final_reparam_result()?;
2034 let qs_arc_final = Arc::new(reparam_result_final.qs.clone());
2035 let x_transformed_final =
2036 make_reparam_operator(&x_original_for_result, &qs_arc_final, use_sparse_native);
2037
2038 let pirls_result = assemble_pirls_result(
2039 &working_summary,
2040 final_likelihood,
2041 offset,
2042 penalized_hessian_transformed,
2043 stabilizedhessian_transformed,
2044 edf,
2045 penalty_term,
2046 &finalmu,
2047 &finalweights,
2048 &scoreweights,
2049 &finalz,
2050 &final_c,
2051 &final_d,
2052 &final_dmu_deta,
2053 &final_d2mu_deta2,
2054 &final_d3mu_deta3,
2055 status,
2056 reparam_result_final,
2057 x_transformed_final,
2058 coordinate_frame,
2059 linear_constraints,
2060 );
2061
2062 Ok((pirls_result, working_summary))
2063}
2064
2065#[derive(Clone)]
2066pub struct PirlsConfig {
2067 pub likelihood: GlmLikelihoodSpec,
2068 pub link_kind: InverseLink,
2069 pub max_iterations: usize,
2070 pub convergence_tolerance: f64,
2071 pub firth_bias_reduction: bool,
2072 /// Optional warm-start hint for `WorkingModelPirlsOptions::initial_lm_lambda`.
2073 /// Forwarded directly when `fit_model_for_fixed_rho` builds its
2074 /// internal options. See the field doc on `WorkingModelPirlsOptions`
2075 /// for the seeding semantics.
2076 pub initial_lm_lambda: Option<f64>,
2077 /// Enable the Transtrum-Sethna geodesic-acceleration second-order
2078 /// correction on each accepted LM step. Forwarded to
2079 /// `WorkingModelPirlsOptions::geodesic_acceleration`; see that
2080 /// field's doc for the full semantics and cost model. Default
2081 /// `false`; opt-in until validated.
2082 pub geodesic_acceleration: bool,
2083 /// Optional arrow-Schur structured-inner-solve descriptor. When
2084 /// `Some`, forwarded to `WorkingModelPirlsOptions::arrow_schur` so
2085 /// each accepted LM step is solved by the per-observation
2086 /// arrow-Schur path
2087 /// ([`crate::arrow_schur::ArrowSchurSystem`]). When `None`
2088 /// (the default), the existing β-only path is used unchanged.
2089 ///
2090 /// See [`ArrowSchurInnerConfig`] for the closure contract.
2091 pub arrow_schur: Option<ArrowSchurInnerConfig>,
2092}
2093
2094impl PirlsConfig {
2095 #[inline]
2096 pub fn link_function(&self) -> LinkFunction {
2097 self.link_kind.link_function()
2098 }
2099}
2100
2101#[inline]
2102pub(super) fn max_symmetric_asymmetry(matrix: &Array2<f64>) -> f64 {
2103 let n = matrix.nrows().min(matrix.ncols());
2104 let mut max_asym = 0.0_f64;
2105 for i in 0..n {
2106 for j in 0..i {
2107 let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2108 if diff > max_asym {
2109 max_asym = diff;
2110 }
2111 }
2112 }
2113 max_asym
2114}
2115
2116#[inline]
2117pub(super) fn assert_symmetric_tol(matrix: &Array2<f64>, label: &str, tol: f64) {
2118 let max_asym = max_symmetric_asymmetry(matrix);
2119 assert!(
2120 max_asym <= tol,
2121 "{} asymmetry too large: {:.3e} (tol {:.3e})",
2122 label,
2123 max_asym,
2124 tol
2125 );
2126}
2127
2128/// Build a DesignMatrix wrapping a lazy ReparamOperator (or the original for sparse-native).
2129pub(crate) fn make_reparam_operator(
2130 x_original: &DesignMatrix,
2131 qs_arc: &Arc<Array2<f64>>,
2132 use_sparse_native: bool,
2133) -> DesignMatrix {
2134 if use_sparse_native {
2135 x_original.clone()
2136 } else {
2137 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
2138 ReparamOperator::new(x_original.clone(), Arc::clone(qs_arc)),
2139 )))
2140 }
2141}
2142
2143// solve_penalized_least_squares_implicit lives in pls_solver (imported above).
2144
2145pub(super) fn build_transformed_lower_bound_constraints(
2146 qs: &Array2<f64>,
2147 coefficient_lower_bounds: Option<&Array1<f64>>,
2148) -> Option<LinearInequalityConstraints> {
2149 let lb = coefficient_lower_bounds?;
2150 if lb.len() != qs.nrows() {
2151 return None;
2152 }
2153 let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2154 if activerows.is_empty() {
2155 return None;
2156 }
2157 let mut a = Array2::<f64>::zeros((activerows.len(), qs.ncols()));
2158 let mut b = Array1::<f64>::zeros(activerows.len());
2159 for (r, &idx) in activerows.iter().enumerate() {
2160 a.row_mut(r).assign(&qs.row(idx));
2161 b[r] = lb[idx];
2162 }
2163 Some(
2164 LinearInequalityConstraints::new(a, b)
2165 .expect("transformed lower-bound constraint shape invariant"),
2166 )
2167}
2168
2169pub(super) fn build_transformed_lower_bound_constraints_with_transform(
2170 transform: &WorkingReparamTransform,
2171 coefficient_lower_bounds: Option<&Array1<f64>>,
2172) -> Option<LinearInequalityConstraints> {
2173 let lb = coefficient_lower_bounds?;
2174 let p = match transform {
2175 WorkingReparamTransform::Dense(qs) => qs.nrows(),
2176 WorkingReparamTransform::Kronecker(kron) => kron.p,
2177 };
2178 if lb.len() != p {
2179 return None;
2180 }
2181 let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2182 if activerows.is_empty() {
2183 return None;
2184 }
2185 let mut a = Array2::<f64>::zeros((activerows.len(), p));
2186 let mut b = Array1::<f64>::zeros(activerows.len());
2187 for (r, &idx) in activerows.iter().enumerate() {
2188 let mut basis = Array1::<f64>::zeros(p);
2189 basis[idx] = 1.0;
2190 let row = transform.apply_transpose(&basis);
2191 a.row_mut(r).assign(&row);
2192 b[r] = lb[idx];
2193 }
2194 Some(
2195 LinearInequalityConstraints::new(a, b)
2196 .expect("transformed lower-bound constraint shape invariant"),
2197 )
2198}
2199
2200pub(super) fn build_transformed_linear_constraints(
2201 qs: &Array2<f64>,
2202 linear_constraints: Option<&LinearInequalityConstraints>,
2203) -> Option<LinearInequalityConstraints> {
2204 let lc = linear_constraints?;
2205 if lc.a.ncols() != qs.nrows() {
2206 return None;
2207 }
2208 Some(
2209 LinearInequalityConstraints::new(lc.a.dot(qs), lc.b.clone())
2210 .expect("transformed linear constraint shape invariant"),
2211 )
2212}
2213
2214pub(super) fn build_transformed_linear_constraints_with_transform(
2215 transform: &WorkingReparamTransform,
2216 linear_constraints: Option<&LinearInequalityConstraints>,
2217) -> Option<LinearInequalityConstraints> {
2218 let lc = linear_constraints?;
2219 let p = match transform {
2220 WorkingReparamTransform::Dense(qs) => qs.nrows(),
2221 WorkingReparamTransform::Kronecker(kron) => kron.p,
2222 };
2223 if lc.a.ncols() != p {
2224 return None;
2225 }
2226 let mut a = Array2::<f64>::zeros((lc.a.nrows(), p));
2227 for row in 0..lc.a.nrows() {
2228 let transformed = transform.apply_transpose(&lc.a.row(row).to_owned());
2229 a.row_mut(row).assign(&transformed);
2230 }
2231 Some(LinearInequalityConstraints { a, b: lc.b.clone() })
2232}
2233
2234pub(super) fn merge_linear_constraints(
2235 first: Option<LinearInequalityConstraints>,
2236 second: Option<LinearInequalityConstraints>,
2237) -> Option<LinearInequalityConstraints> {
2238 match (first, second) {
2239 (None, None) => None,
2240 (Some(c), None) | (None, Some(c)) => Some(c),
2241 (Some(c1), Some(c2)) => {
2242 if c1.a.ncols() != c2.a.ncols() {
2243 return None;
2244 }
2245 let rows = c1.a.nrows() + c2.a.nrows();
2246 let cols = c1.a.ncols();
2247 let mut a = Array2::<f64>::zeros((rows, cols));
2248 a.slice_mut(s![0..c1.a.nrows(), ..]).assign(&c1.a);
2249 a.slice_mut(s![c1.a.nrows()..rows, ..]).assign(&c2.a);
2250 let mut b = Array1::<f64>::zeros(rows);
2251 b.slice_mut(s![0..c1.b.len()]).assign(&c1.b);
2252 b.slice_mut(s![c1.b.len()..rows]).assign(&c2.b);
2253 Some(LinearInequalityConstraints { a, b })
2254 }
2255 }
2256}
2257
2258pub(super) fn sparse_from_denseview(x: ArrayView2<f64>) -> Option<DesignMatrix> {
2259 // Below this column count a dense factorization beats the sparse path even
2260 // at high sparsity, so skip the sparsity scan entirely for narrow designs.
2261 const DENSE_PREFERRED_MAX_COLS: usize = 32;
2262 // Sparse storage + sparse Cholesky only pays off below this density (nnz as
2263 // a fraction of all entries); denser matrices stay dense.
2264 const SPARSE_DENSITY_LIMIT: f64 = 0.20;
2265
2266 let nrows = x.nrows();
2267 let ncols = x.ncols();
2268 if nrows == 0 || ncols == 0 {
2269 return None;
2270 }
2271 // Narrow matrices are faster in dense form; avoid any sparsity scan overhead.
2272 if ncols <= DENSE_PREFERRED_MAX_COLS {
2273 return None;
2274 }
2275
2276 const ZERO_EPS: f64 = 1e-12;
2277 let total = nrows.saturating_mul(ncols);
2278 if total == 0 {
2279 return None;
2280 }
2281 // If a matrix exceeds this nnz count it is too dense for sparse path; bail early.
2282 let sparse_nnz_limit = ((total as f64) * SPARSE_DENSITY_LIMIT).floor() as usize;
2283 let mut nnz = 0usize;
2284 for &val in x.iter() {
2285 if val.abs() > ZERO_EPS {
2286 nnz += 1;
2287 if nnz > sparse_nnz_limit {
2288 return None;
2289 }
2290 }
2291 }
2292 let mut triplets = Vec::with_capacity(nnz);
2293 for (row_idx, row) in x.outer_iter().enumerate() {
2294 for (col_idx, &val) in row.iter().enumerate() {
2295 if val.abs() > ZERO_EPS {
2296 triplets.push(Triplet::new(row_idx, col_idx, val));
2297 }
2298 }
2299 }
2300 SparseColMat::try_new_from_triplets(nrows, ncols, &triplets)
2301 .ok()
2302 .map(DesignMatrix::from)
2303}
2304
2305#[cfg(test)]
2306mod tests {
2307 use super::{PirlsPenalty, build_diagonal_penalty_from_kronecker};
2308 use gam_terms::construction::KroneckerReparamResult;
2309 use ndarray::{Array1, Array2, array};
2310
2311 #[test]
2312 fn kronecker_diagonal_double_penalty_hits_only_joint_null_space() {
2313 let kron_result = KroneckerReparamResult {
2314 reparameterized_marginals: std::sync::Arc::new(Vec::new()),
2315 marginal_eigenvalues: std::sync::Arc::new(vec![array![0.0, 2.0], array![0.0, 3.0]]),
2316 marginal_qs: std::sync::Arc::new(Vec::new()),
2317 log_det: 0.0,
2318 det1: Array1::zeros(3),
2319 det2: Array2::zeros((3, 3)),
2320 penalty_shrinkage_ridge: 0.5,
2321 has_double_penalty: true,
2322 marginal_dims: vec![2usize, 2usize],
2323 };
2324 let penalty = build_diagonal_penalty_from_kronecker(&kron_result, &[5.0, 7.0, 11.0]);
2325
2326 let PirlsPenalty::Diagonal {
2327 diag,
2328 positive_indices,
2329 ..
2330 } = penalty
2331 else {
2332 panic!("expected diagonal Kronecker PIRLS penalty");
2333 };
2334 let expected = [11.0, 21.5, 10.5, 31.5];
2335 for (idx, expected_diag) in expected.iter().copied().enumerate() {
2336 assert!(
2337 (diag[idx] - expected_diag).abs() <= 1e-12,
2338 "diagonal {idx} got {}, expected {expected_diag}",
2339 diag[idx]
2340 );
2341 }
2342 assert_eq!(positive_indices, vec![0, 1, 2, 3]);
2343 }
2344}