gam_solve/pirls/loop_driver.rs
1//! Outer driver for a single fixed-ρ PIRLS fit.
2//!
3//! Owns:
4//! - `fit_model_for_fixed_rho` and `fit_model_for_fixed_rho_with_adaptive_kkt`
5//! — build the working model, run the inner LM loop, assemble the final result.
6//! - `PirlsProblem`, `PenaltyConfig`, `PirlsConfig` — the configuration types.
7//! - Helper functions exclusive to the fixed-ρ fitting path: constraint
8//! transformation, sparse-native decision, reparam materialisation, prior
9//! shift assembly, initial-β guess, Gaussian short-circuit assembly, etc.
10//! - The two GPU dispatch blocks (Stage 3.3) that call into
11//! `crate::gpu::pirls_dispatch_wire`.
12
13use super::{
14 // state re-exports
15 AdaptiveKktTolerance,
16 ExportedLaplaceCurvature,
17 FirthDiagnostics,
18 GamWorkingModel,
19 HessianCurvatureKind,
20 // penalty types
21 KroneckerQsTransform,
22 LinearInequalityConstraints,
23 PirlsCoordinateFrame,
24 PirlsLinearSolvePath,
25 PirlsPenalty,
26 PirlsResult,
27 PirlsStatus,
28 PirlsWorkspace,
29 SparsePirlsDecision,
30 WorkingModelIterationInfo,
31 WorkingModelPirlsOptions,
32 WorkingModelPirlsResult,
33 WorkingReparamTransform,
34 WorkingState,
35 // misc helpers
36 array1_l2_norm,
37 attach_penalty_shift,
38 // compute functions
39 calculate_deviance,
40 // edf helpers
41 calculate_edf_with_penalty,
42 calculate_edfwithworkspace_with_penalty,
43 calculate_loglikelihood_omitting_constants,
44 compute_constraint_kkt_diagnostics,
45 computeworkingweight_derivatives_from_eta,
46 inf_norm,
47 runworking_model_pirls,
48 should_use_sparse_native_pirls,
49 solve_penalized_least_squares_implicit,
50 standard_inverse_link_jet,
51};
52use super::{
53 ArrowSchurInnerConfig, GamModelFinalState, effective_kkt_tolerance,
54 project_coefficients_to_lower_bounds,
55};
56use gam_terms::construction::{KroneckerReparamResult, ReparamResult};
57use crate::estimate::EstimationError;
58use gam_linalg::faer_ndarray::fast_ab;
59use gam_linalg::matrix::{DesignMatrix, LinearOperator, ReparamOperator, SymmetricMatrix};
60use crate::mixture_link::inverse_link_has_fisher_weight_jet;
61use gam_math::probability::standard_normal_quantile;
62use crate::active_set;
63use crate::gpu::pirls_host_dispatch::{try_gaussian_pls_gpu, try_pirls_loop_gpu};
64use gam_problem::{
65 Coefficients, GlmLikelihoodSpec, InverseLink, LinearPredictor, LinkFunction,
66 LogSmoothingParamsView, MixtureLinkState, ResponseFamily, RidgePassport, RidgePolicy,
67 SasLinkState, StandardLink,
68};
69use faer::sparse::{SparseColMat, Triplet};
70use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
71use std::borrow::Cow;
72use std::sync::Arc;
73
74pub(super) fn default_beta_guess_external(
75 p: usize,
76 link_function: LinkFunction,
77 y: ArrayView1<f64>,
78 priorweights: ArrayView1<f64>,
79 mixture_link_state: Option<&MixtureLinkState>,
80 sas_link_state: Option<&SasLinkState>,
81) -> Array1<f64> {
82 let mut beta = Array1::<f64>::zeros(p);
83 let intercept_col = 0usize;
84 match link_function {
85 LinkFunction::Logit
86 | LinkFunction::Probit
87 | LinkFunction::CLogLog
88 | LinkFunction::Sas
89 | LinkFunction::BetaLogistic => {
90 let mut weighted_sum = 0.0;
91 let mut totalweight = 0.0;
92 for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
93 weighted_sum += wi * yi;
94 totalweight += wi;
95 }
96 if totalweight > 0.0 {
97 let prevalence =
98 ((weighted_sum + 0.5) / (totalweight + 1.0)).clamp(1e-6, 1.0 - 1e-6);
99 beta[intercept_col] = match link_function {
100 LinkFunction::Logit => (prevalence / (1.0 - prevalence)).ln(),
101 LinkFunction::Probit => {
102 standard_normal_quantile(prevalence).unwrap_or_else(|_| {
103 // `prevalence` is clamped to (0, 1); this fallback is
104 // only for defensive robustness under non-finite upstream inputs.
105 (prevalence / (1.0 - prevalence)).ln()
106 })
107 }
108 LinkFunction::CLogLog => (-(1.0 - prevalence).ln()).ln(),
109 LinkFunction::Sas => solve_intercept_for_prevalence(
110 link_function,
111 prevalence,
112 mixture_link_state,
113 sas_link_state,
114 )
115 .unwrap_or_else(|| {
116 standard_normal_quantile(prevalence)
117 .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
118 }),
119 LinkFunction::BetaLogistic => solve_intercept_for_prevalence(
120 link_function,
121 prevalence,
122 mixture_link_state,
123 sas_link_state,
124 )
125 .unwrap_or_else(|| {
126 standard_normal_quantile(prevalence)
127 .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
128 }),
129 // Outer arm guard already filtered out Log/Identity; fall
130 // back to the canonical logit transform for defensive safety
131 // if these are ever reached unexpectedly.
132 LinkFunction::Log | LinkFunction::Identity => {
133 (prevalence / (1.0 - prevalence)).ln()
134 }
135 };
136 if mixture_link_state.is_some() {
137 beta[intercept_col] = solve_intercept_for_prevalence(
138 link_function,
139 prevalence,
140 mixture_link_state,
141 sas_link_state,
142 )
143 .unwrap_or(beta[intercept_col]);
144 }
145 }
146 }
147 LinkFunction::Identity => {
148 let mut weighted_sum = 0.0;
149 let mut totalweight = 0.0;
150 for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
151 weighted_sum += wi * yi;
152 totalweight += wi;
153 }
154 if totalweight > 0.0 {
155 beta[intercept_col] = weighted_sum / totalweight;
156 }
157 }
158 LinkFunction::Log => {
159 // For log link, intercept = ln(weighted mean of y)
160 let mut weighted_sum = 0.0;
161 let mut totalweight = 0.0;
162 for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
163 weighted_sum += wi * yi;
164 totalweight += wi;
165 }
166 if totalweight > 0.0 {
167 let mean_y = (weighted_sum / totalweight).max(1e-10);
168 beta[intercept_col] = mean_y.ln();
169 }
170 }
171 }
172 beta
173}
174
175pub(super) fn solve_intercept_for_prevalence(
176 link_function: LinkFunction,
177 prevalence: f64,
178 mixture_link_state: Option<&MixtureLinkState>,
179 sas_link_state: Option<&SasLinkState>,
180) -> Option<f64> {
181 #[inline]
182 fn f_eta(
183 link_function: LinkFunction,
184 eta: f64,
185 prevalence: f64,
186 mixture_link_state: Option<&MixtureLinkState>,
187 sas_link_state: Option<&SasLinkState>,
188 ) -> f64 {
189 let inverse_link = if let Some(state) = mixture_link_state {
190 InverseLink::Mixture(state.clone())
191 } else if let Some(state) = sas_link_state {
192 match link_function {
193 LinkFunction::BetaLogistic => InverseLink::BetaLogistic(*state),
194 _ => InverseLink::Sas(*state),
195 }
196 } else {
197 // SAFETY: when `sas_link_state` is None, `solve_intercept_for_prevalence`
198 // is only invoked with the five legal `StandardLink` variants (the
199 // dispatch site at pirls.rs:4203 routes Sas/BetaLogistic into the
200 // Some branch above with state).
201 InverseLink::Standard(StandardLink::try_from(link_function).expect(
202 "state-bearing link reached state-less arm in solve_intercept_for_prevalence",
203 ))
204 };
205 standard_inverse_link_jet(&inverse_link, eta)
206 .map(|jet| jet.mu - prevalence)
207 .unwrap_or(f64::NAN)
208 }
209
210 let mut lo = -40.0;
211 let mut hi = 40.0;
212 let mut f_lo = f_eta(
213 link_function,
214 lo,
215 prevalence,
216 mixture_link_state,
217 sas_link_state,
218 );
219 let mut f_hi = f_eta(
220 link_function,
221 hi,
222 prevalence,
223 mixture_link_state,
224 sas_link_state,
225 );
226 if !(f_lo.is_finite() && f_hi.is_finite()) {
227 return None;
228 }
229 for _ in 0..8 {
230 if f_lo <= 0.0 && f_hi >= 0.0 {
231 break;
232 }
233 lo *= 2.0;
234 hi *= 2.0;
235 f_lo = f_eta(
236 link_function,
237 lo,
238 prevalence,
239 mixture_link_state,
240 sas_link_state,
241 );
242 f_hi = f_eta(
243 link_function,
244 hi,
245 prevalence,
246 mixture_link_state,
247 sas_link_state,
248 );
249 if !(f_lo.is_finite() && f_hi.is_finite()) {
250 return None;
251 }
252 }
253 if f_lo > 0.0 {
254 return Some(lo);
255 }
256 if f_hi < 0.0 {
257 return Some(hi);
258 }
259 for _ in 0..80 {
260 let mid = 0.5 * (lo + hi);
261 let f_mid = f_eta(
262 link_function,
263 mid,
264 prevalence,
265 mixture_link_state,
266 sas_link_state,
267 );
268 if !f_mid.is_finite() {
269 return None;
270 }
271 if f_mid > 0.0 {
272 hi = mid;
273 } else {
274 lo = mid;
275 }
276 }
277 Some(0.5 * (lo + hi))
278}
279
280pub(super) fn assemble_pirls_result(
281 working_summary: &WorkingModelPirlsResult,
282 likelihood: GlmLikelihoodSpec,
283 offset: ArrayView1<'_, f64>,
284 penalized_hessian_transformed: SymmetricMatrix,
285 stabilizedhessian_transformed: SymmetricMatrix,
286 edf: f64,
287 penalty_term: f64,
288 finalmu: &Array1<f64>,
289 finalweights: &Array1<f64>,
290 scoreweights: &Array1<f64>,
291 finalz: &Array1<f64>,
292 final_c: &Array1<f64>,
293 final_d: &Array1<f64>,
294 final_dmu_deta: &Array1<f64>,
295 final_d2mu_deta2: &Array1<f64>,
296 final_d3mu_deta3: &Array1<f64>,
297 status: PirlsStatus,
298 reparam_result: ReparamResult,
299 x_transformed: DesignMatrix,
300 coordinate_frame: PirlsCoordinateFrame,
301 linear_constraints_transformed: Option<LinearInequalityConstraints>,
302) -> PirlsResult {
303 let final_eta_arr = working_summary.state.eta.as_ref().clone();
304 PirlsResult {
305 likelihood,
306 beta_transformed: working_summary.beta.clone(),
307 penalized_hessian_transformed,
308 stabilizedhessian_transformed,
309 ridge_passport: RidgePassport::scaled_identity(
310 working_summary.state.ridge_used,
311 RidgePolicy::explicit_stabilization_full(),
312 ),
313 ridge_used: working_summary.state.ridge_used,
314 deviance: working_summary.state.deviance,
315 edf,
316 stable_penalty_term: penalty_term,
317 firth: working_summary.state.firth.clone(),
318 finalweights: finalweights.clone(),
319 final_offset: offset.to_owned(),
320 final_eta: final_eta_arr,
321 finalmu: finalmu.clone(),
322 solveweights: scoreweights.clone(),
323 solveworking_response: finalz.clone(),
324 solvemu: finalmu.clone(),
325 solve_dmu_deta: final_dmu_deta.clone(),
326 solve_d2mu_deta2: final_d2mu_deta2.clone(),
327 solve_d3mu_deta3: final_d3mu_deta3.clone(),
328 solve_c_array: final_c.clone(),
329 solve_d_array: final_d.clone(),
330 derivatives_unsupported: false,
331 status,
332 iteration: working_summary.iterations,
333 max_abs_eta: working_summary.max_abs_eta,
334 lastgradient_norm: working_summary.lastgradient_norm,
335 gradient_natural_scale: working_summary.state.gradient_natural_scale,
336 last_deviance_change: working_summary.last_deviance_change,
337 last_step_halving: working_summary.last_step_halving,
338 hessian_curvature: working_summary.state.hessian_curvature,
339 exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
340 final_lm_lambda: working_summary.final_lm_lambda,
341 final_accept_rho: working_summary.final_accept_rho,
342 constraint_kkt: working_summary.constraint_kkt.clone(),
343 linear_constraints_transformed,
344 reparam_result,
345 x_transformed,
346 coordinate_frame,
347 used_device: false,
348 cache_compacted: false,
349 min_penalized_deviance: working_summary.min_penalized_deviance,
350 }
351}
352
353pub(super) fn detect_logit_instability(
354 link: LinkFunction,
355 response: &ResponseFamily,
356 has_penalty: bool,
357 firth_active: bool,
358 summary: &WorkingModelPirlsResult,
359 finalmu: &Array1<f64>,
360 finalweights: &Array1<f64>,
361 y: ArrayView1<'_, f64>,
362) -> bool {
363 // Perfect / quasi-perfect separation is a *Bernoulli/Binomial* pathology.
364 // Every heuristic below is binary-response–specific: saturation toward
365 // μ ∈ {0, 1}, the `yᵢ > 0.5` order-separation split, and working-weight
366 // collapse only carry meaning when each `yᵢ` is a 0/1 outcome (or a
367 // proportion of Bernoulli trials). The Beta family also fits through the
368 // logit link, but its response is *continuous* on (0, 1): a perfectly
369 // healthy monotone mean (μ increasing in a covariate ⇒ rows with y > 0.5
370 // sit at higher η than rows with y ≤ 0.5) trivially satisfies the
371 // `order_separated` test, so gating this detector on the logit link alone
372 // misclassifies well-behaved Beta fits as separated and forces a spurious
373 // inner-solve retreat at every smoothing-parameter seed (issue #499).
374 // Gate strictly on the Binomial response so only binary GLMs are screened.
375 if !matches!(response, ResponseFamily::Binomial) || link != LinkFunction::Logit || firth_active
376 {
377 return false;
378 }
379
380 // Separation-detection policy thresholds. Each is a heuristic cut-off, not
381 // a math identity: they decide when a binary-logit fit has drifted into the
382 // perfect/quasi-perfect separation regime and the inner solve must retreat.
383 //
384 // `ORDER_SEPARATION_ETA_GAP`: a strictly positive η-gap between the lowest
385 // η among y=1 rows and the highest among y=0 rows means the two classes
386 // are linearly separable on the linear predictor.
387 // `EXTREME_ETA`: |η| this large drives μ to within machine-ε of {0,1}.
388 // `SATURATION_FRACTION` / `SEVERE_SATURATION_FRACTION`: share of fitted μ
389 // pinned to the {0,1} boundary that flags (severe) saturation.
390 // `DEGENERATE_DEVIANCE_PER_SAMPLE` / `EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE`:
391 // near-zero per-sample deviance means the model fits the data perfectly.
392 // `EXTREME_BETA_NORM`: coefficient norm blow-up characteristic of the MLE
393 // escaping to infinity under separation.
394 // `WEIGHT_COLLAPSE_FRACTION`: share of working weights collapsed to ~0.
395 const ORDER_SEPARATION_ETA_GAP: f64 = 1e-3;
396 const EXTREME_ETA: f64 = 30.0;
397 const SATURATION_FRACTION: f64 = 0.98;
398 const SEVERE_SATURATION_FRACTION: f64 = 0.995;
399 const DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-3;
400 const EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-6;
401 const EXTREME_BETA_NORM: f64 = 1e4;
402 const WEIGHT_COLLAPSE_FRACTION: f64 = 0.98;
403
404 let n = y.len() as f64;
405 if n == 0.0 {
406 return false;
407 }
408
409 let max_abs_eta = summary.max_abs_eta;
410 let sat_fraction = {
411 const SAT_EPS: f64 = 1e-3;
412 finalmu
413 .iter()
414 .filter(|&&m| m <= SAT_EPS || m >= 1.0 - SAT_EPS)
415 .count() as f64
416 / n
417 };
418
419 let weight_collapse_fraction = {
420 const WEIGHT_EPS: f64 = 1e-8;
421 finalweights
422 .iter()
423 .filter(|&&w| w <= WEIGHT_EPS || !w.is_finite())
424 .count() as f64
425 / n
426 };
427
428 let beta_norm = summary.beta.as_ref().dot(summary.beta.as_ref()).sqrt();
429 let dev_per_sample = summary.state.deviance / n;
430
431 let mut has_pos = false;
432 let mut has_neg = false;
433 let mut min_eta_pos = f64::INFINITY;
434 let mut max_eta_neg = f64::NEG_INFINITY;
435 for (eta_i, &yi) in summary.state.eta.iter().zip(y.iter()) {
436 if yi > 0.5 {
437 has_pos = true;
438 if *eta_i < min_eta_pos {
439 min_eta_pos = *eta_i;
440 }
441 } else {
442 has_neg = true;
443 if *eta_i > max_eta_neg {
444 max_eta_neg = *eta_i;
445 }
446 }
447 }
448 let order_separated =
449 has_pos && has_neg && (min_eta_pos - max_eta_neg) > ORDER_SEPARATION_ETA_GAP;
450
451 let classic_signals = max_abs_eta > EXTREME_ETA
452 || sat_fraction > SATURATION_FRACTION
453 || dev_per_sample < DEGENERATE_DEVIANCE_PER_SAMPLE
454 || beta_norm > EXTREME_BETA_NORM;
455
456 if !has_penalty {
457 return classic_signals || order_separated;
458 }
459
460 let severe_saturation = sat_fraction > SEVERE_SATURATION_FRACTION && max_abs_eta > EXTREME_ETA;
461 let weights_collapsed = weight_collapse_fraction > WEIGHT_COLLAPSE_FRACTION;
462 let dev_extremely_small = dev_per_sample < EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE;
463
464 order_separated || severe_saturation || weights_collapsed || dev_extremely_small
465}
466
467/// Stack λ-weighted penalty roots from canonical penalties into a single
468/// `total_rank × p` matrix for PIRLS. Each block-local root is embedded
469/// into the full column space on-the-fly.
470pub(super) fn stack_lambdaweighted_penalty_root_canonical(
471 penalties: &[gam_terms::construction::CanonicalPenalty],
472 lambdas: &[f64],
473 p: usize,
474) -> Array2<f64> {
475 let totalrows: usize = penalties.iter().map(|cp| cp.rank()).sum();
476 if totalrows == 0 {
477 return Array2::zeros((0, p));
478 }
479 let mut e = Array2::<f64>::zeros((totalrows, p));
480 let mut row_start = 0usize;
481 for (k, cp) in penalties.iter().enumerate() {
482 let rows = cp.rank();
483 if rows == 0 {
484 continue;
485 }
486 let scale = lambdas.get(k).copied().unwrap_or(0.0).max(0.0).sqrt();
487 if scale != 0.0 {
488 // Embed block-local root (rank × block_dim) into full width (rank × p).
489 let r = &cp.col_range;
490 for row in 0..rows {
491 for col in 0..cp.block_dim() {
492 e[[row_start + row, r.start + col]] = scale * cp.root[[row, col]];
493 }
494 }
495 }
496 row_start += rows;
497 }
498 e
499}
500
501pub(super) fn build_sparse_native_reparam_result(
502 base: ReparamResult,
503 penalties: &[gam_terms::construction::CanonicalPenalty],
504 lambdas: &[f64],
505 p: usize,
506) -> ReparamResult {
507 // Map the engine penalty back into identity (original) coordinates. The
508 // engine returns `s_transformed = Qsᵀ S Qs` (and `e_transformed = E Qs`)
509 // with `S = S_λ + shrinkage·P_range` already folded in (so it matches the
510 // reported `log_det`/`det1`). With the sparse-native `qs = I` we need that
511 // SAME penalty expressed in original coordinates: `S_orig = Qs S_transformed
512 // Qsᵀ`. Rebuilding `S_orig` from the bare lambda-weighted canonical sum
513 // would DROP the shrinkage ridge and desync the inner penalized Hessian from
514 // the penalty log-determinant the REML criterion uses for this fit — the
515 // cross-backend λ-selection divergence (#1266 class). Round-tripping the
516 // engine penalty through `Qs` keeps the inner solve, EDF, and REML logdet on
517 // one penalty.
518 let qs = &base.qs;
519 let s_orig = if qs.nrows() == p && qs.ncols() == base.s_transformed.nrows() {
520 // S_orig = Qs · S_transformed · Qsᵀ
521 let qs_s = fast_ab(qs, &base.s_transformed);
522 qs_s.dot(&qs.t())
523 } else {
524 // Degenerate fallback (engine produced no transform): use the bare
525 // lambda-weighted sum. Shrinkage is zero in this branch by construction.
526 let mut s_original = Array2::<f64>::zeros((p, p));
527 for (k, cp) in penalties.iter().enumerate() {
528 let lambda_k = lambdas.get(k).copied().unwrap_or(0.0);
529 if lambda_k != 0.0 {
530 cp.accumulate_weighted(&mut s_original, lambda_k);
531 }
532 }
533 s_original
534 };
535 // E_orig = E_transformed · Qsᵀ (so that E_origᵀ E_orig = S_orig and the EDF
536 // augmented system matches the inner Hessian).
537 let e_orig = if qs.nrows() == p && base.e_transformed.ncols() == qs.ncols() {
538 base.e_transformed.dot(&qs.t())
539 } else {
540 stack_lambdaweighted_penalty_root_canonical(penalties, lambdas, p)
541 };
542 let u_original = if base.u_truncated.nrows() == p {
543 fast_ab(&base.qs, &base.u_truncated)
544 } else {
545 Array2::<f64>::eye(p)
546 };
547 // In the sparse-native path, qs = I, so the penalties are already in the
548 // right coordinate frame. We keep them as-is in canonical_transformed.
549 let canonical_transformed: Vec<gam_terms::construction::CanonicalPenalty> = penalties.to_vec();
550 ReparamResult {
551 penalty_shrinkage_ridge: base.penalty_shrinkage_ridge,
552 s_transformed: s_orig,
553 log_det: base.log_det,
554 det1: base.det1,
555 qs: Array2::<f64>::eye(p),
556 canonical_transformed,
557 e_transformed: e_orig,
558 u_truncated: u_original,
559 }
560}
561
562pub(super) fn build_diagonal_penalty_from_kronecker(
563 kron_result: &KroneckerReparamResult,
564 lambdas: &[f64],
565) -> PirlsPenalty {
566 let d = kron_result.marginal_dims.len();
567 let p: usize = kron_result.marginal_dims.iter().copied().product();
568 let mut diag = Array1::<f64>::zeros(p);
569 let mut positive_indices = Vec::new();
570
571 const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
572 let mut multi_idx = vec![0usize; d];
573 let mut flat = 0usize;
574 loop {
575 let mut sigma = 0.0;
576 let mut structural_sigma = 0.0;
577 for k in 0..d {
578 let marginal_eigenvalue = kron_result.marginal_eigenvalues[k][multi_idx[k]];
579 structural_sigma += marginal_eigenvalue;
580 sigma += lambdas[k] * marginal_eigenvalue;
581 }
582 let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
583 if kron_result.has_double_penalty && lambdas.len() > d && joint_null {
584 sigma += lambdas[d];
585 }
586 if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
587 sigma += kron_result.penalty_shrinkage_ridge;
588 }
589 diag[flat] = sigma;
590 if sigma > 0.0 {
591 positive_indices.push(flat);
592 }
593 flat += 1;
594
595 let mut carry = true;
596 for dim in (0..d).rev() {
597 if carry {
598 multi_idx[dim] += 1;
599 if multi_idx[dim] < kron_result.marginal_dims[dim] {
600 carry = false;
601 } else {
602 multi_idx[dim] = 0;
603 }
604 }
605 }
606 if carry {
607 break;
608 }
609 }
610
611 PirlsPenalty::Diagonal {
612 diag,
613 positive_indices,
614 linear_shift: Array1::zeros(p),
615 constant_shift: 0.0,
616 prior_mean_target: Array1::zeros(p),
617 }
618}
619
620pub(super) fn canonical_prior_shift(
621 penalties: &[gam_terms::construction::CanonicalPenalty],
622 lambdas: &[f64],
623 p: usize,
624) -> (Array1<f64>, f64) {
625 let mut linear = Array1::<f64>::zeros(p);
626 let mut constant = 0.0;
627 for (idx, cp) in penalties.iter().enumerate() {
628 let Some(&lambda) = lambdas.get(idx) else {
629 continue;
630 };
631 if lambda == 0.0 {
632 continue;
633 }
634 linear += &cp.prior_linear_shift(lambda);
635 constant += cp.prior_constant_shift(lambda);
636 }
637 (linear, constant)
638}
639
640/// Aggregate prior-mean target across canonical penalty blocks: the sum of
641/// each block's `full_width_prior_mean()`. Used by the PIRLS solve sites
642/// that add a fixed stabilization ridge `δI` to the penalized Hessian — they
643/// must also add `δ · prior_mean_target` to the RHS to keep `β = μ` recovery
644/// exact when the data carries no information (X'WX = 0). Equivalent to
645/// `canonical_prior_shift` with all λ = 1 and dropping `S_k` from the linear
646/// piece (i.e., raw μ rather than `S_k μ`). Returned in the *original*
647/// coordinates; callers transform if needed.
648pub(super) fn canonical_prior_mean_aggregate(
649 penalties: &[gam_terms::construction::CanonicalPenalty],
650 p: usize,
651) -> Array1<f64> {
652 let mut mean = Array1::<f64>::zeros(p);
653 for cp in penalties {
654 mean += &cp.full_width_prior_mean();
655 }
656 mean
657}
658
659pub struct PirlsProblem<'a, X> {
660 pub x: X,
661 pub offset: ArrayView1<'a, f64>,
662 pub y: ArrayView1<'a, f64>,
663 pub priorweights: ArrayView1<'a, f64>,
664 pub covariate_se: Option<ArrayView1<'a, f64>>,
665 /// When set, the inner PLS solver reuses the precomputed `XᵀWX` and
666 /// `XᵀW(y − offset)` in *original* coordinates instead of streaming the
667 /// O(N·p²) GEMM and the O(N·p) matvec on every outer REML iteration.
668 ///
669 /// Valid only when the family is Gaussian + Identity link, prior weights
670 /// are constant across outer iterations (always true in the REML outer
671 /// loop), no Firth bias reduction, and no inequality / lower-bound
672 /// constraints (matching the existing Identity short-circuit at
673 /// `pirls.rs:6237`). The penalty `λ·S` is still added per-λ on top of
674 /// the cached `XᵀWX`.
675 pub gaussian_fixed_cache: Option<&'a GaussianFixedCache>,
676 /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for a GLM
677 /// design-moving ψ-trial (#1111 / #1033 mechanism (c)), in *original*
678 /// (conditioned `x_fit`) coordinates. When set, the iterative GLM P-IRLS
679 /// serves its FIRST Fisher-scoring iteration's `XᵀWX` from this matrix
680 /// instead of streaming the O(N·p²) weighted cross-product; every later
681 /// iteration restreams the true moving `W`, so the converged β̂ is
682 /// unchanged. Mutually distinct from `gaussian_fixed_cache` (which is the
683 /// Gaussian-identity converged-objective short-circuit); this is the GLM
684 /// first-step lane and never short-circuits the iteration count.
685 pub glm_first_step_gram: Option<&'a Array2<f64>>,
686}
687
688// GaussianFixedCache is defined in pls_solver.
689pub use super::pls_solver::GaussianFixedCache;
690
691pub struct PenaltyConfig<'a> {
692 /// Block-local canonical penalties with precomputed roots and spectral data.
693 /// This is the single canonical penalty representation — no full-width
694 /// `rank × p` roots are stored. When the reparameterization engine needs
695 /// full-width roots, they are derived on-the-fly from these block-local roots.
696 pub canonical_penalties: &'a [gam_terms::construction::CanonicalPenalty],
697 pub balanced_penalty_root: Option<&'a Array2<f64>>,
698 pub reparam_invariant: Option<&'a gam_terms::construction::ReparamInvariant>,
699 pub p: usize,
700 pub coefficient_lower_bounds: Option<&'a Array1<f64>>,
701 pub linear_constraints_original: Option<&'a LinearInequalityConstraints>,
702 /// Relative shrinkage floor for eigenvalues of the penalized block.
703 /// If `Some(epsilon)`, a rho-independent ridge of `epsilon * max_balanced_eigenvalue`
704 /// is added to prevent barely-penalized directions from causing pathological
705 /// non-Gaussianity in the posterior. Typical value: `1e-6`. `None` disables.
706 pub penalty_shrinkage_floor: Option<f64>,
707 /// When set, the penalties have Kronecker (tensor-product) structure.
708 /// The reparameterization engine will use factored Qs = U_1 ⊗ ... ⊗ U_d
709 /// instead of eigendecomposing the full p×p balanced penalty.
710 pub kronecker_factored: Option<&'a gam_terms::basis::KroneckerFactoredBasis>,
711}
712
713/// P-IRLS solver that follows mgcv's architecture exactly
714///
715/// This function implements the complete algorithm from mgcv's gam.fit3 function
716/// for fitting a GAM model with a fixed set of smoothing parameters:
717///
718/// - Perform stable reparameterization ONCE at the beginning (mgcv's gam.reparam)
719/// - Transform the design matrix into this stable basis
720/// - Extract a single penalty square root from the transformed penalty
721/// - Run the P-IRLS loop entirely in the transformed basis
722/// - Transform the coefficients back to the original basis only when returning
723/// - Reuse a cached balanced penalty root when available to avoid repeated eigendecompositions
724///
725/// This architecture ensures optimal numerical stability throughout the entire
726/// fitting process by working in a well-conditioned parameter space.
727pub fn fit_model_for_fixed_rho<'a, X: Into<DesignMatrix> + Clone>(
728 rho: LogSmoothingParamsView<'_>,
729 problem: PirlsProblem<'a, X>,
730 penalty: PenaltyConfig<'_>,
731 config: &PirlsConfig,
732 warm_start_beta: Option<&Coefficients>,
733) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
734 fit_model_for_fixed_rho_with_adaptive_kkt(
735 rho,
736 problem,
737 penalty,
738 config,
739 warm_start_beta,
740 None,
741 false,
742 )
743}
744
745/// `refine_dispersion_at_converged_eta`: when `true`, after the inner P-IRLS
746/// solve converges, re-estimate the family's estimated dispersion nuisance — the
747/// Gamma shape ν = 1/φ or the Beta precision φ — at the *converged* linear
748/// predictor and iterate the (β, dispersion) pair to its joint fixed point at the
749/// current λ (see the in-body comments at each refresh loop). This is ON only for
750/// the single final, reported fit at the REML-selected λ (#678 for Gamma, #769
751/// for Beta). It is deliberately OFF for every REML cost / sigma-point evaluation:
752/// re-profiling the dispersion against each trial λ's converged residuals would
753/// couple the scale to the smoothing parameter (a flat over-smoothed μ inflates
754/// the deviance ⇒ a smaller effective precision ⇒ a smaller `deviance/(2φ)` REML
755/// term), perversely rewarding over-smoothing and biasing λ selection. mgcv
756/// likewise estimates the scale at the converged fit, not inside the λ search.
757///
758/// The Gamma and Beta cases differ in what the re-solve buys. For Gamma the shape
759/// is a pure nuisance — β̂ is essentially scale-free — so the re-solve only keeps
760/// the reported dispersion and SEs self-consistent. For Beta the precision φ
761/// enters the *mean* score through the digamma terms
762/// `μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ)`, so a φ measured at the cold null predictor
763/// (μ ≈ 0.5) attenuates every slope toward zero; here the fixed point is
764/// load-bearing — it is what recovers the correct mean coefficients (the betareg
765/// alternating mean-fit ↔ φ-estimate scheme).
766pub(crate) fn fit_model_for_fixed_rho_with_adaptive_kkt<'a, X: Into<DesignMatrix> + Clone>(
767 rho: LogSmoothingParamsView<'_>,
768 problem: PirlsProblem<'a, X>,
769 penalty: PenaltyConfig<'_>,
770 config: &PirlsConfig,
771 warm_start_beta: Option<&Coefficients>,
772 adaptive_kkt_tolerance: Option<AdaptiveKktTolerance>,
773 refine_dispersion_at_converged_eta: bool,
774) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
775 let PirlsProblem {
776 x,
777 offset,
778 y,
779 priorweights,
780 covariate_se,
781 gaussian_fixed_cache,
782 glm_first_step_gram,
783 } = problem;
784 let quadctx = crate::quadrature::QuadratureContext::new();
785 // gam#1379 — finite-ceiling λ = exp(ρ). When the outer REML / spatial-κ
786 // optimizer drives a redundant penalty direction's log-λ past ~709 (it does
787 // so deterministically on 1-D `matern(x)` / `bs="gp"` data whose kernel
788 // already controls the smoothness an operator block also penalizes, so REML
789 // wants λ → ∞), `exp(ρ)` overflows to `+∞`. A literal `+∞` λ then poisons
790 // every downstream consumer that forms `λ · S`: the range-penalty block
791 // assembled as `Σ λ_k S_k` hits `∞ · 0 = NaN` and the eigensolve aborts, and
792 // the final fit-result validation rejects the non-finite stored λ outright.
793 // `exp(709.78) ≈ 1.8e308` is already the largest finite f64; capping log-λ at
794 // a value whose `exp` stays finite pins the over-penalized direction exactly
795 // as hard as `+∞` would for every finite-arithmetic consumer (the penalized
796 // block is numerically a hard constraint at λ this large) while keeping
797 // `λ · 0 = 0`. Ordinary finite λ are untouched, so non-degenerate fits and
798 // their recorded λ̂ are bit-identical. `ln(1e300) ≈ 690.78` keeps this in lock
799 // step with the post-exp λ ceiling (`1e300`) used by the reparam range-block
800 // assembly and the stored fit result, so a fully-smoothed direction carries
801 // the SAME finite λ everywhere it is consumed.
802 const LOG_LAMBDA_CEILING: f64 = 690.0;
803 let lambdas = rho.mapv(|r| {
804 if r.is_nan() {
805 r
806 } else {
807 r.min(LOG_LAMBDA_CEILING).exp()
808 }
809 });
810 let lambdas_slice = lambdas.as_slice_memory_order().ok_or_else(|| {
811 EstimationError::InvalidInput("non-contiguous lambda storage".to_string())
812 })?;
813
814 let likelihood = &config.likelihood;
815 let link_function = config.link_function();
816
817 use gam_terms::construction::{
818 EngineDims, create_balanced_penalty_root_from_canonical,
819 stable_reparameterization_engine_canonical,
820 };
821
822 let eb_cow: Cow<'_, Array2<f64>> = if let Some(precomputed) = penalty.balanced_penalty_root {
823 Cow::Borrowed(precomputed)
824 } else {
825 Cow::Owned(create_balanced_penalty_root_from_canonical(
826 penalty.canonical_penalties,
827 penalty.p,
828 )?)
829 };
830 let eb: &Array2<f64> = eb_cow.as_ref();
831
832 // Build a cheap weighted penalty sum for the sparse-native decision
833 // WITHOUT running the expensive eigendecomposition engine.
834 // The full reparameterization is deferred until we know which path we need.
835 let cheap_s_lambda: Option<Array2<f64>> = if penalty.kronecker_factored.is_none() {
836 let mut s = Array2::<f64>::zeros((penalty.p, penalty.p));
837 for (k, cp) in penalty.canonical_penalties.iter().enumerate() {
838 let lam = lambdas_slice.get(k).copied().unwrap_or(0.0);
839 if lam != 0.0 {
840 cp.accumulate_weighted(&mut s, lam);
841 }
842 }
843 Some(s)
844 } else {
845 None
846 };
847 let kronecker_runtime = if let Some(kron) = penalty.kronecker_factored {
848 // The marginal eigensystems and reparameterized marginals depend only on
849 // the fixed marginal designs/penalties, not on λ = exp(ρ). Memoize them
850 // once per fit so each outer REML iterate reuses the eigendecomposition
851 // instead of recomputing `eigh()` + `B_k·U_k` every call; only the cheap
852 // λ-grid logdet/derivative sweep is redone here. Bit-identical to the
853 // unmemoized engine.
854 let invariant = kron.invariant_structure()?;
855 let kron_result = gam_terms::construction::kronecker_reparameterization_engine_with_invariant(
856 invariant.as_ref(),
857 &kron.marginal_dims,
858 lambdas_slice,
859 kron.has_double_penalty,
860 penalty.penalty_shrinkage_floor,
861 )?;
862 let transform = Arc::new(KroneckerQsTransform::new(&kron_result));
863 let penalty_diag = build_diagonal_penalty_from_kronecker(&kron_result, lambdas_slice);
864 Some((kron_result, transform, penalty_diag))
865 } else {
866 None
867 };
868 // Constraint transformation is deferred until after the sparse-native
869 // decision, because the dense reparameterization engine (which provides Qs)
870 // is now run lazily. Kronecker constraints can be built eagerly since
871 // the Kronecker transform is already available.
872 let kronecker_constraints = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
873 let tb = build_transformed_lower_bound_constraints_with_transform(
874 &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
875 penalty.coefficient_lower_bounds,
876 );
877 let tl = build_transformed_linear_constraints_with_transform(
878 &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
879 penalty.linear_constraints_original,
880 );
881 Some(merge_linear_constraints(tb, tl))
882 } else {
883 None
884 };
885
886 let x_original: DesignMatrix = x.into();
887 // Auto-detect sparse structure in dense designs so the sparse-native path
888 // can engage for structurally sparse models that happen to be stored dense.
889 let x_original = {
890 let auto_sparse = x_original
891 .as_dense()
892 .and_then(|dense| sparse_from_denseview(dense.view()));
893 auto_sparse.unwrap_or(x_original)
894 };
895 let ebrows = eb.nrows();
896 let erows = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
897 penalty_diag.rank()
898 } else {
899 // Compute penalty root rank cheaply from canonical penalties.
900 penalty
901 .canonical_penalties
902 .iter()
903 .map(|cp| cp.rank())
904 .sum::<usize>()
905 };
906 let mut workspace = PirlsWorkspace::new(x_original.nrows(), x_original.ncols(), ebrows, erows);
907 let solver_decision = if let Some((_, _, _)) = kronecker_runtime.as_ref() {
908 SparsePirlsDecision {
909 path: PirlsLinearSolvePath::DenseTransformed,
910 reason: "kronecker_runtime",
911 p: x_original.ncols(),
912 nnz_x: 0,
913 nnz_xtwx_symbolic: None,
914 nnz_s_lambda: 0,
915 nnz_h_est: None,
916 density_h_est: None,
917 }
918 } else {
919 should_use_sparse_native_pirls(
920 &mut workspace,
921 &x_original,
922 cheap_s_lambda
923 .as_ref()
924 .expect("cheap_s_lambda should be present outside Kronecker path"),
925 penalty.coefficient_lower_bounds,
926 penalty.linear_constraints_original,
927 )
928 };
929 solver_decision.log_once();
930
931 let use_sparse_native = matches!(solver_decision.path, PirlsLinearSolvePath::SparseNative);
932
933 // Run the eigendecomposition engine for the dense-transformed path. The
934 // sparse-native path also needs it, but only to obtain a penalty that is
935 // *consistent with the REML penalty log-determinant it reports* — see the
936 // sparse-native `reparam` below. The dense path keeps `qs ≠ I`; the
937 // sparse-native path discards `qs` (identity coords) and reuses only the
938 // shrinkage-folded `s_transformed`/`e_transformed`.
939 let dense_reparam_result = if !use_sparse_native && penalty.kronecker_factored.is_none() {
940 Some(stable_reparameterization_engine_canonical(
941 penalty.canonical_penalties,
942 lambdas_slice,
943 EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
944 penalty.reparam_invariant,
945 penalty.penalty_shrinkage_floor,
946 )?)
947 } else {
948 None
949 };
950 // Sparse-native reparam result, in identity (original) coordinates with the
951 // penalty shrinkage floor folded in. This MUST drive the inner penalized
952 // solve too: when `penalty_shrinkage_floor` is active (default `Some(1e-6)`)
953 // the dense engine adds `shrinkage·P_range` to every penalized range
954 // direction of `S_λ` and rebuilds `s_transformed = EᵀE` from the floored
955 // roots, so `base.log_det` (the REML penalty pseudo-logdet) is the
956 // determinant of `S_λ + shrinkage·P_range`, NOT of the bare `S_λ`. Building
957 // the inner Hessian from an UN-shrunk `S_λ` (the previous behaviour, via the
958 // `cheap_s_lambda` row-sum) while reporting the shrunk `log_det` made the
959 // sparse-native REML surface internally inconsistent — the penalty-logdet
960 // term and the inner H / EDF / β̂ lived on different penalties — which biased
961 // λ-selection relative to the dense and Kronecker backends for the SAME
962 // model (the #1266 cross-backend divergence class). Reusing the engine's
963 // shrinkage-folded penalty here makes all three backends solve the same
964 // penalized objective.
965 let sparse_native_reparam = if use_sparse_native && penalty.kronecker_factored.is_none() {
966 let base = stable_reparameterization_engine_canonical(
967 penalty.canonical_penalties,
968 lambdas_slice,
969 EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
970 penalty.reparam_invariant,
971 penalty.penalty_shrinkage_floor,
972 )?;
973 Some(build_sparse_native_reparam_result(
974 base,
975 penalty.canonical_penalties,
976 lambdas_slice,
977 penalty.p,
978 ))
979 } else {
980 None
981 };
982 let qs_arc = dense_reparam_result
983 .as_ref()
984 .map(|reparam_result| Arc::new(reparam_result.qs.clone()));
985 let transform_active = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
986 Some(WorkingReparamTransform::Kronecker(Arc::clone(transform)))
987 } else if use_sparse_native {
988 None
989 } else {
990 Some(WorkingReparamTransform::Dense(Arc::clone(
991 qs_arc
992 .as_ref()
993 .expect("dense Qs should exist for non-Kronecker transformed path"),
994 )))
995 };
996 let mut penalty_active = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
997 penalty_diag.clone()
998 } else if use_sparse_native {
999 // Sparse-native inner penalty in original (identity) coordinates. Use
1000 // the shrinkage-folded `s_transformed`/`e_transformed` from
1001 // `sparse_native_reparam` so the inner penalized Hessian
1002 // `H = XᵀWX + S` matches the penalty whose log-determinant the REML
1003 // criterion reports for this fit (`base.log_det`). Falling back to the
1004 // bare lambda-weighted sum here (the prior behaviour) omitted the
1005 // `penalty_shrinkage_floor` ridge and desynced the inner solve from the
1006 // REML logdet, biasing λ-selection vs the dense/Kronecker backends.
1007 let sparse_reparam = sparse_native_reparam
1008 .as_ref()
1009 .expect("sparse_native_reparam should be present for sparse-native path");
1010 PirlsPenalty::Dense {
1011 s_transformed: sparse_reparam.s_transformed.clone(),
1012 e_transformed: sparse_reparam.e_transformed.clone(),
1013 linear_shift: Array1::zeros(penalty.p),
1014 constant_shift: 0.0,
1015 prior_mean_target: Array1::zeros(penalty.p),
1016 }
1017 } else {
1018 let dense = dense_reparam_result
1019 .as_ref()
1020 .expect("dense reparam result should be present outside Kronecker path");
1021 PirlsPenalty::Dense {
1022 s_transformed: dense.s_transformed.clone(),
1023 e_transformed: dense.e_transformed.clone(),
1024 linear_shift: Array1::zeros(penalty.p),
1025 constant_shift: 0.0,
1026 prior_mean_target: Array1::zeros(penalty.p),
1027 }
1028 };
1029 let (shift_original, shift_constant) =
1030 canonical_prior_shift(penalty.canonical_penalties, lambdas_slice, penalty.p);
1031 let shift_active = transform_active
1032 .as_ref()
1033 .map(|transform| transform.apply_transpose(&shift_original))
1034 .unwrap_or(shift_original);
1035 let prior_mean_original =
1036 canonical_prior_mean_aggregate(penalty.canonical_penalties, penalty.p);
1037 let prior_mean_active = transform_active
1038 .as_ref()
1039 .map(|transform| transform.apply_transpose(&prior_mean_original))
1040 .unwrap_or(prior_mean_original);
1041 attach_penalty_shift(
1042 &mut penalty_active,
1043 shift_active,
1044 shift_constant,
1045 prior_mean_active,
1046 );
1047 // Build transformed constraints now that dense_reparam_result is available.
1048 let linear_constraints = if let Some(kc) = kronecker_constraints {
1049 kc
1050 } else if let Some(reparam) = dense_reparam_result.as_ref() {
1051 let tb = build_transformed_lower_bound_constraints(
1052 &reparam.qs,
1053 penalty.coefficient_lower_bounds,
1054 );
1055 let tl =
1056 build_transformed_linear_constraints(&reparam.qs, penalty.linear_constraints_original);
1057 merge_linear_constraints(tb, tl)
1058 } else {
1059 // Sparse-native without dense reparam: constraints stay in original
1060 // coordinates (identity Qs). Use an identity matrix of appropriate size.
1061 let p = penalty.p;
1062 let qs_identity = Array2::<f64>::eye(p);
1063 let tb = build_transformed_lower_bound_constraints(
1064 &qs_identity,
1065 penalty.coefficient_lower_bounds,
1066 );
1067 let tl =
1068 build_transformed_linear_constraints(&qs_identity, penalty.linear_constraints_original);
1069 merge_linear_constraints(tb, tl)
1070 };
1071
1072 let coordinate_frame = if use_sparse_native {
1073 PirlsCoordinateFrame::OriginalSparseNative
1074 } else {
1075 PirlsCoordinateFrame::TransformedQs
1076 };
1077 let materialize_final_reparam_result = || -> Result<ReparamResult, EstimationError> {
1078 if let Some((kron_result, _, _)) = kronecker_runtime.as_ref() {
1079 let rs_list: Vec<Array2<f64>> = penalty
1080 .canonical_penalties
1081 .iter()
1082 .map(|cp| cp.full_width_root())
1083 .collect();
1084 kron_result.materialize_dense_artifact_result(&rs_list, lambdas_slice, penalty.p)
1085 } else if use_sparse_native {
1086 // Sparse-native path: reuse the engine result already computed for
1087 // `penalty_active` (with the shrinkage floor folded in and mapped to
1088 // identity coordinates). This is both correct — the REML
1089 // log-determinant now matches the penalty the inner solve used — and
1090 // cheaper, since the eigendecomposition is no longer run twice.
1091 Ok(sparse_native_reparam
1092 .as_ref()
1093 .expect("sparse_native_reparam should be present for sparse-native path")
1094 .clone())
1095 } else {
1096 Ok(dense_reparam_result
1097 .as_ref()
1098 .expect("dense reparam result should be present outside Kronecker path")
1099 .clone())
1100 }
1101 };
1102
1103 // Stage 3.3-GI: GPU exact PLS dispatch — see pirls_host_dispatch::try_gaussian_pls_gpu.
1104 if let Some(result) = try_gaussian_pls_gpu(
1105 link_function,
1106 config,
1107 penalty.coefficient_lower_bounds,
1108 penalty.linear_constraints_original,
1109 gaussian_fixed_cache,
1110 &penalty_active,
1111 &qs_arc,
1112 &x_original,
1113 use_sparse_native,
1114 penalty.p,
1115 || materialize_final_reparam_result(),
1116 y,
1117 priorweights,
1118 offset,
1119 coordinate_frame,
1120 &linear_constraints,
1121 ) {
1122 return result;
1123 }
1124
1125 if matches!(link_function, LinkFunction::Identity) && linear_constraints.is_none() {
1126 // Gaussian-Identity zero-iteration exact solve. The unconstrained
1127 // penalized least-squares system is linear, so for an identity link a
1128 // single solve is the exact minimizer and no PIRLS iteration is needed.
1129 //
1130 // This shortcut is only valid in the *unconstrained* convex program.
1131 // When shape/box/linear inequality constraints are present (e.g. a
1132 // `shape=monotone_increasing` smooth, whose cumulative-sum box-reparam
1133 // bounds `γ_j ≥ 0` are folded into `linear_constraints` above), the
1134 // minimizer is the solution of an inequality-constrained QP, not the
1135 // plain normal-equations solve. Taking this branch then returns the
1136 // unconstrained β, which generically violates the constraints and is
1137 // rejected by the REML startup KKT gate (`enforce_constraint_kkt`),
1138 // aborting the whole fit. Gating on `linear_constraints.is_none()`
1139 // routes every constrained Identity fit to the iterative loop below,
1140 // which builds a feasible initial point and solves the exact QP via
1141 // the active-set solver — mirroring the gate already enforced on the
1142 // GPU Gaussian-PLS path in `try_gaussian_pls_gpu`.
1143 //
1144 // Apply the Gaussian-Identity fixed-data cache only when every
1145 // precondition for the short-circuit's exact reuse holds: the family
1146 // really is Gaussian (z = y), there is no Firth bias-reduction term,
1147 // no coefficient lower bounds, and no linear inequality constraints
1148 // — anything that would change the right-hand side or the system
1149 // beyond the additive penalty would invalidate the cache.
1150 let cache_eligible = gaussian_fixed_cache.is_some()
1151 && likelihood.spec.is_gaussian_identity()
1152 && !config.firth_bias_reduction
1153 && penalty.coefficient_lower_bounds.is_none()
1154 && penalty.linear_constraints_original.is_none();
1155 let cache_for_solve = if cache_eligible {
1156 gaussian_fixed_cache
1157 } else {
1158 None
1159 };
1160 let (pls_result, _) = solve_penalized_least_squares_implicit(
1161 &x_original,
1162 transform_active.as_ref(),
1163 y,
1164 priorweights,
1165 offset,
1166 &penalty_active,
1167 &mut workspace,
1168 y,
1169 link_function,
1170 cache_for_solve,
1171 )?;
1172
1173 let beta_transformed = pls_result.beta;
1174 let penalized_hessian = pls_result.penalized_hessian;
1175 let edf = pls_result.edf;
1176 let baseridge = pls_result.ridge_used;
1177
1178 let priorweights_owned = priorweights.to_owned();
1179 // eta = offset + X Qs beta (composed, no materialization) unless a
1180 // design-moving ψ tensor cache explicitly says the surface rows are a
1181 // stale reference. In that lane the Gaussian objective and gradient are
1182 // fully determined by (G, r, y'Wy), so applying `x_original` would both
1183 // reintroduce per-trial row work and evaluate the wrong ψ.
1184 let qbeta = transform_active
1185 .as_ref()
1186 .map(|transform| transform.apply(beta_transformed.as_ref()))
1187 .unwrap_or_else(|| beta_transformed.as_ref().clone());
1188 let stale_row_cache = cache_for_solve.filter(|cache| cache.row_prediction_is_stale);
1189 let (final_eta, finalmu, finalz, gradient_data, deviance, log_likelihood, max_abs_eta) =
1190 if let Some(cache) = stale_row_cache {
1191 let final_eta = offset.to_owned();
1192 let finalmu = final_eta.clone();
1193 let finalz = y.to_owned();
1194 let mut grad_orig = cache.xtwx_orig.dot(&qbeta);
1195 grad_orig -= &cache.xtwy_orig;
1196 let gradient_data = transform_active
1197 .as_ref()
1198 .map(|transform| transform.apply_transpose(&grad_orig))
1199 .unwrap_or(grad_orig);
1200 let weighted_rss = (cache.centered_weighted_y_sq
1201 - 2.0 * qbeta.dot(&cache.xtwy_orig)
1202 + qbeta.dot(&cache.xtwx_orig.dot(&qbeta)))
1203 .max(0.0);
1204 let phi = likelihood.scale.fixed_phi().unwrap_or(1.0);
1205 let deviance = if phi.is_finite() && phi > 0.0 {
1206 weighted_rss / phi
1207 } else {
1208 f64::NAN
1209 };
1210 let log_likelihood = -0.5 * deviance;
1211 let max_abs_eta = inf_norm(finalmu.iter().copied());
1212 (
1213 final_eta,
1214 finalmu,
1215 finalz,
1216 gradient_data,
1217 deviance,
1218 log_likelihood,
1219 max_abs_eta,
1220 )
1221 } else {
1222 let mut eta = offset.to_owned();
1223 eta += &x_original.apply(&qbeta);
1224 let final_eta = eta.clone();
1225 let finalmu = eta.clone();
1226 let finalz = y.to_owned();
1227
1228 let mut weighted_residual = finalmu.clone();
1229 weighted_residual -= &finalz;
1230 weighted_residual *= &priorweights_owned;
1231 // gradient = Qs^T X^T (w * residual) (composed)
1232 let xt_wr = x_original.apply_transpose(&weighted_residual);
1233 let gradient_data = transform_active
1234 .as_ref()
1235 .map(|transform| transform.apply_transpose(&xt_wr))
1236 .unwrap_or(xt_wr);
1237 let deviance = calculate_deviance(y, &finalmu, likelihood, priorweights);
1238 let log_likelihood = calculate_loglikelihood_omitting_constants(
1239 y,
1240 &finalmu,
1241 likelihood,
1242 priorweights,
1243 );
1244 let max_abs_eta = inf_norm(finalmu.iter().copied());
1245 (
1246 final_eta,
1247 finalmu,
1248 finalz,
1249 gradient_data,
1250 deviance,
1251 log_likelihood,
1252 max_abs_eta,
1253 )
1254 };
1255 let score_norm = array1_l2_norm(&gradient_data);
1256 let s_beta = penalty_active.shifted_gradient(beta_transformed.as_ref());
1257 let s_beta_norm = array1_l2_norm(&s_beta);
1258 let mut gradient = gradient_data;
1259 gradient += &s_beta;
1260 let mut penalty_term = penalty_active.shifted_quadratic(beta_transformed.as_ref());
1261 let ridge_used = baseridge;
1262 let stabilizedhessian = if ridge_used > 0.0 {
1263 penalized_hessian
1264 .addridge(ridge_used)
1265 .map_err(|e| EstimationError::InvalidInput(format!("ridge addition failed: {e}")))?
1266 } else {
1267 penalized_hessian.clone()
1268 };
1269 let mut ridge_grad_norm = 0.0;
1270 if ridge_used > 0.0 {
1271 let ridge_penalty =
1272 ridge_used * beta_transformed.as_ref().dot(beta_transformed.as_ref());
1273 penalty_term += ridge_penalty;
1274 gradient += &beta_transformed.as_ref().mapv(|v| ridge_used * v);
1275 ridge_grad_norm = ridge_used * array1_l2_norm(beta_transformed.as_ref());
1276 }
1277
1278 let gradient_norm = array1_l2_norm(&gradient);
1279 let working_state = WorkingState {
1280 eta: LinearPredictor::new(finalmu.clone()),
1281 gradient: gradient.clone(),
1282 hessian: penalized_hessian.clone(),
1283
1284 log_likelihood,
1285 deviance,
1286 penalty_term,
1287 firth: FirthDiagnostics::Inactive,
1288 ridge_used,
1289 hessian_curvature: HessianCurvatureKind::Fisher,
1290 gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1291 };
1292
1293 let zero_iter_penalized = deviance + penalty_term;
1294 let working_summary = WorkingModelPirlsResult {
1295 beta: beta_transformed.clone(),
1296 state: working_state,
1297 status: PirlsStatus::Converged,
1298 iterations: 1,
1299 lastgradient_norm: gradient_norm,
1300 last_deviance_change: 0.0,
1301 last_step_size: 1.0,
1302 last_step_halving: 0,
1303 max_abs_eta,
1304 constraint_kkt: linear_constraints.as_ref().map(|lin| {
1305 compute_constraint_kkt_diagnostics(beta_transformed.as_ref(), &gradient, lin)
1306 }),
1307 min_penalized_deviance: if zero_iter_penalized.is_finite() {
1308 zero_iter_penalized
1309 } else {
1310 f64::INFINITY
1311 },
1312 // Zero-iteration synthesis: no LM damping was exercised, so
1313 // hand the next solve the cold default.
1314 final_lm_lambda: 1e-6,
1315 // Zero-iteration synthesis: no LM gain ratio was measured.
1316 final_accept_rho: None,
1317 // Zero-iteration synthesis assembles the Hessian with prior
1318 // weights only; no observed-information re-evaluation has
1319 // happened. Label honestly as a Fisher-type surrogate so
1320 // outer Laplace consumers see the truth.
1321 exported_laplace_curvature: ExportedLaplaceCurvature::ExpectedInformationSurrogate,
1322 };
1323
1324 let (solve_c_array, solve_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
1325 computeworkingweight_derivatives_from_eta(
1326 &config.likelihood,
1327 &config.link_kind,
1328 &final_eta,
1329 priorweights_owned.view(),
1330 )?;
1331 let reparam_result = materialize_final_reparam_result()?;
1332 let qs_arc_final = Arc::new(reparam_result.qs.clone());
1333 let pirls_result = PirlsResult {
1334 likelihood: config.likelihood.clone(),
1335 beta_transformed,
1336 penalized_hessian_transformed: penalized_hessian,
1337 stabilizedhessian_transformed: stabilizedhessian,
1338 ridge_passport: RidgePassport::scaled_identity(
1339 ridge_used,
1340 RidgePolicy::explicit_stabilization_full(),
1341 ),
1342 ridge_used,
1343 deviance,
1344 edf,
1345 stable_penalty_term: penalty_term,
1346 firth: FirthDiagnostics::Inactive,
1347 finalweights: priorweights_owned.clone(),
1348 final_offset: offset.to_owned(),
1349 final_eta: final_eta.clone(),
1350 finalmu: finalmu.clone(),
1351 solveweights: priorweights_owned,
1352 solveworking_response: finalz.clone(),
1353 solvemu: finalmu.clone(),
1354 solve_dmu_deta,
1355 solve_d2mu_deta2,
1356 solve_d3mu_deta3,
1357 solve_c_array,
1358 solve_d_array,
1359 derivatives_unsupported: false,
1360 status: PirlsStatus::Converged,
1361 iteration: 1,
1362 max_abs_eta,
1363 lastgradient_norm: gradient_norm,
1364 gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1365 last_deviance_change: 0.0,
1366 last_step_halving: 0,
1367 hessian_curvature: HessianCurvatureKind::Fisher,
1368 exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
1369 final_lm_lambda: working_summary.final_lm_lambda,
1370 final_accept_rho: working_summary.final_accept_rho,
1371 constraint_kkt: working_summary.constraint_kkt.clone(),
1372 linear_constraints_transformed: linear_constraints.clone(),
1373 reparam_result,
1374 x_transformed: make_reparam_operator(&x_original, &qs_arc_final, use_sparse_native),
1375 coordinate_frame,
1376 used_device: false,
1377 cache_compacted: false,
1378 min_penalized_deviance: working_summary.min_penalized_deviance,
1379 };
1380
1381 return Ok((pirls_result, working_summary));
1382 }
1383
1384 let x_original_for_result = x_original.clone();
1385 let mut working_model = GamWorkingModel::new(
1386 None, // No pre-materialized x_transformed: use implicit Qs composition
1387 x_original.clone(),
1388 coordinate_frame,
1389 offset,
1390 y,
1391 priorweights,
1392 penalty_active.clone(),
1393 workspace,
1394 config.likelihood.clone(),
1395 config.link_kind.clone(),
1396 // Inner Firth/Jeffreys activation must agree with the caller-requested
1397 // mode. The REML *outer* analytic derivative assembly only carries the
1398 // Jeffreys score/curvature term when `firth_bias_reduction` is set
1399 // (`reml_robust_jeffreys_link` returns `None` otherwise), so arming the
1400 // inner penalty unconditionally would converge the inner mode to the
1401 // Firth-penalized stationary point while the outer H/u/IFT stayed
1402 // non-Firth — the two would then disagree by exactly the Jeffreys
1403 // contribution (broken τ-τ Hessian-vs-FD and stationarity-cancellation
1404 // identities, #825). Gate on `firth_bias_reduction` so inner and outer
1405 // are the same objective.
1406 config.firth_bias_reduction
1407 && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1408 && inverse_link_has_fisher_weight_jet(&config.link_kind),
1409 transform_active.clone(),
1410 quadctx,
1411 // #1111 / #1033 mechanism (c): frozen-W first-Fisher-step XᵀWX in the
1412 // original (conditioned x_fit) frame, served n-free on the first inner
1413 // iteration. Suppressed under Firth bias reduction, which shifts the
1414 // working response per iteration (the installer also gates Firth off).
1415 if config.firth_bias_reduction {
1416 None
1417 } else {
1418 glm_first_step_gram.cloned()
1419 },
1420 );
1421
1422 // Apply integrated (GHQ) likelihood if per-observation SE is provided.
1423 // This is used by the calibrator to coherently account for base prediction uncertainty.
1424 if let Some(se) = covariate_se {
1425 working_model = working_model.with_covariate_se(se.to_owned());
1426 }
1427
1428 let mut beta_guess_original = warm_start_beta
1429 .filter(|beta| beta.len() == penalty.p)
1430 .map(|beta| beta.to_owned())
1431 .unwrap_or_else(|| {
1432 Coefficients::new(default_beta_guess_external(
1433 penalty.p,
1434 link_function,
1435 y,
1436 priorweights,
1437 config.link_kind.mixture_state(),
1438 config.link_kind.sas_state(),
1439 ))
1440 });
1441 if let Some(lb) = penalty.coefficient_lower_bounds {
1442 project_coefficients_to_lower_bounds(&mut beta_guess_original.0, lb);
1443 }
1444 let initial_beta = transform_active
1445 .as_ref()
1446 .map(|transform| transform.apply_transpose(beta_guess_original.as_ref()))
1447 .unwrap_or_else(|| beta_guess_original.as_ref().clone());
1448 let initial_beta = if let Some(constraints) = linear_constraints.as_ref() {
1449 // Worst per-row *scaled* (geometric) slack of the current seed against the
1450 // constraint cone. Negative ⇒ the seed violates a row; ~0 ⇒ the seed sits
1451 // ON the boundary (for a homogeneous convex/concave second-difference
1452 // cone, `β = 0` — the unconstrained Gaussian seed — sits on EVERY row's
1453 // boundary, i.e. the cone vertex). Either way the seed must be pushed
1454 // strictly into the interior before P-IRLS starts.
1455 let mut min_scaled_slack = f64::INFINITY;
1456 for i in 0..constraints.a.nrows() {
1457 let norm = constraints.a.row(i).dot(&constraints.a.row(i)).sqrt();
1458 let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
1459 let slack = (constraints.a.row(i).dot(&initial_beta) - constraints.b[i]) * inv;
1460 min_scaled_slack = min_scaled_slack.min(slack);
1461 }
1462 // Push the seed to the nearest STRICTLY-INTERIOR feasible point whenever
1463 // any row is tight or violated. A seed on the cone boundary (most acutely
1464 // the vertex `β = 0`) hands the inner active-set QP an all-rows-active
1465 // working set, where it stalls on a degenerate, non-stationary face — so
1466 // the fit silently diverges (or aborts in release) between a cold and a
1467 // warm warm-start cache (#873). A strictly-interior seed makes the QP's
1468 // initial active set empty; it then adds only the genuinely binding rows
1469 // and converges to the certified constrained optimum regardless of cache
1470 // state. The projection keeps the data-driven curvature of `initial_beta`
1471 // and falls back to the min-norm feasible point only if it cannot certify
1472 // a strictly-interior solution.
1473 //
1474 // The min-norm fallback (`feasible_point_for_linear_constraints`) is only
1475 // used for a NON-homogeneous cone (`b ≠ 0`), where it returns a genuine
1476 // interior-of-the-offset-polyhedron point. For a HOMOGENEOUS shape cone
1477 // (`b ≈ 0` — the convex/concave second-difference rows) that function
1478 // returns the minimum-norm feasible point `β = 0`, which is the cone
1479 // *vertex*: the exact all-rows-tight degenerate seed #873 is about. Taking
1480 // it would silently reintroduce the #873 pathology whenever the strict
1481 // projection rarely fails to certify. So for a homogeneous cone we skip the
1482 // vertex fallback entirely and prefer the data-driven `initial_beta`: it
1483 // violates at most *some* rows (a lower-dimensional, non-degenerate face the
1484 // inner active-set QP can recover from), strictly better than the vertex
1485 // where *every* row is simultaneously tight.
1486 let cone_is_homogeneous = constraints.b.iter().all(|v| v.abs() <= 1e-14);
1487 if min_scaled_slack < active_set::interior_seed_margin() {
1488 let projected =
1489 active_set::project_point_strictly_into_feasible_cone(&initial_beta, constraints)
1490 .or_else(|| {
1491 if cone_is_homogeneous {
1492 None
1493 } else {
1494 active_set::feasible_point_for_linear_constraints(
1495 constraints,
1496 initial_beta.len(),
1497 )
1498 }
1499 });
1500 projected.unwrap_or(initial_beta)
1501 } else {
1502 initial_beta
1503 }
1504 } else {
1505 initial_beta
1506 };
1507 // Inner P-IRLS Firth activation. The inner penalized objective must match
1508 // the objective the REML outer derivatives are assembled against: the outer
1509 // path carries the Jeffreys/Firth score+curvature only when the caller set
1510 // `firth_bias_reduction` (`reml_robust_jeffreys_link` is `None` otherwise),
1511 // so the inner Firth term is armed iff the caller requested it AND the link
1512 // exposes a Fisher-weight jet (#825). Forcing it on unconditionally desynced
1513 // the Firth-penalized inner mode from the non-Firth outer assembly.
1514 let firth_active = config.firth_bias_reduction
1515 && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1516 && inverse_link_has_fisher_weight_jet(&config.link_kind);
1517 let base_max_step_halving = if firth_active { 60 } else { 30 };
1518 let options = WorkingModelPirlsOptions {
1519 // The Firth-penalized P-IRLS converges at the same iteration count as
1520 // the unpenalized fit — the Jeffreys term is a smooth, bounded addition
1521 // to a Newton system that is already well conditioned (the additional
1522 // per-iteration LM step-halving budget above absorbs the early-iteration
1523 // curvature change). Bumping the outer-iteration cap to mask a
1524 // mis-conditioned step would only hide non-convergence, so the cap stays
1525 // the caller's `max_iterations` and trips as a hard error if exceeded.
1526 max_iterations: config.max_iterations,
1527 convergence_tolerance: config.convergence_tolerance,
1528 adaptive_kkt_tolerance,
1529 // LM step-halving is a per-iteration damping retry budget; it is
1530 // independent of the total outer-iteration cap. Tying the two
1531 // together collapsed step halving to 3 under seed screening (where
1532 // max_iterations is intentionally capped low), turning recoverable
1533 // damping into spurious failures.
1534 max_step_halving: base_max_step_halving,
1535 min_step_size: if firth_active { 1e-12 } else { 1e-10 },
1536 firth_bias_reduction: firth_active,
1537 coefficient_lower_bounds: None,
1538 linear_constraints: linear_constraints.clone(),
1539 initial_lm_lambda: config.initial_lm_lambda,
1540 geodesic_acceleration: config.geodesic_acceleration,
1541 arrow_schur: config.arrow_schur.clone(),
1542 };
1543
1544 let mut iteration_logger = |info: &WorkingModelIterationInfo| {
1545 log::debug!(
1546 "[PIRLS] iter {:>3} | deviance {:.6e} | |grad| {:.3e} | step {:.3e} (halving {})",
1547 info.iteration,
1548 info.deviance,
1549 info.gradient_norm,
1550 info.step_size,
1551 info.step_halving
1552 );
1553 };
1554
1555 // Stage 3.3 GPU PIRLS-loop dispatch — see pirls_host_dispatch::try_pirls_loop_gpu.
1556 if let Some(result) = try_pirls_loop_gpu(
1557 config,
1558 &penalty_active,
1559 kronecker_runtime.is_none(),
1560 use_sparse_native,
1561 &linear_constraints,
1562 &x_original,
1563 &qs_arc,
1564 penalty.p,
1565 &x_original_for_result,
1566 || materialize_final_reparam_result(),
1567 y,
1568 priorweights,
1569 offset,
1570 &initial_beta,
1571 link_function,
1572 coordinate_frame,
1573 ) {
1574 return result;
1575 }
1576
1577 let mut working_summary = runworking_model_pirls(
1578 &mut working_model,
1579 Coefficients::new(initial_beta),
1580 &options,
1581 &mut iteration_logger,
1582 )?;
1583
1584 // ── Gamma dispersion: re-estimate the shape at the *converged* η (#678) ──
1585 //
1586 // The inner LM solve estimates the Gamma shape ν = 1/φ **once** from the
1587 // warm-start η and freezes it for the rest of the solve (see the
1588 // `gamma_shape_locked` doc on `GamWorkingModel`): holding ν fixed keeps the
1589 // product φ·λ — and hence the penalized argmin β̂ — a stationary LM target,
1590 // so the gain ratio compares one objective. That lock is correct *within* a
1591 // solve, but it pins ν to whatever η the solve started from. When the fit
1592 // cold-starts (the final dedicated fit at the converged ρ passes
1593 // `warm_start_beta = None`, and seed screening starts from a default guess),
1594 // that warm-start η has not yet captured the mean structure; the leftover
1595 // spread of μ inflates the Gamma deviance term `mean[y/μ − ln(y/μ) − 1]` and
1596 // biases ν **down** (φ up) by >2× whenever μ varies appreciably. The mean
1597 // surface still converges (β̂ is essentially scale-free here), but the frozen
1598 // ν that survives into `UnifiedFitResult::dispersion_phi()` — and from there
1599 // into every coefficient SE `Vb = H⁻¹·φ̂`, prediction interval, and
1600 // observation-noise interval — is the early, mean-spread-contaminated value.
1601 //
1602 // Fix: after the solve converges, re-estimate ν at the converged η. If it
1603 // moved, re-solve β (warm-started, ν held fixed at the refreshed value) and
1604 // repeat, driving the pair (β, ν) to their joint fixed point at the current
1605 // λ. At convergence the reported dispersion is the Gamma ML estimate at the
1606 // converged mean (mgcv's post-hoc Pearson/deviance scale), and the final
1607 // working state — `finalweights`, the penalized Hessian, the deviance, μ —
1608 // is rebuilt with that same ν, so `Vb = H⁻¹·φ̂` stays internally consistent.
1609 // Warm-started solves (every REML cost eval) already sit near the converged
1610 // η, so the first refresh check confirms ν and exits without a re-solve; the
1611 // added cost there is a single O(n) shape evaluation.
1612 if refine_dispersion_at_converged_eta
1613 && working_model.likelihood.scale.gamma_shape_is_estimated()
1614 {
1615 // A few passes suffice: the converged-η shape map is a strong
1616 // contraction (β̂ barely moves once the mean is captured), so cold
1617 // starts settle in 1–2 re-solves and warm starts in zero.
1618 const MAX_SHAPE_REFRESH: usize = 5;
1619 // Relative shape tolerance below which a re-solve cannot move any
1620 // reported quantity meaningfully (far under statistical resolution).
1621 const SHAPE_REFRESH_REL_TOL: f64 = 1e-4;
1622 for refresh_iter in 0..MAX_SHAPE_REFRESH {
1623 let refreshed_shape = super::estimate_gamma_shape_from_eta(
1624 y,
1625 working_summary.state.eta.as_ref(),
1626 priorweights,
1627 );
1628 let prior_shape = working_model.likelihood.gamma_shape().unwrap_or(1.0);
1629 let rel_change =
1630 (refreshed_shape - prior_shape).abs() / prior_shape.max(f64::MIN_POSITIVE);
1631 // Install the refreshed shape and hold it fixed for any re-solve so
1632 // the LM objective stays stationary (the lock is *re-armed*, not
1633 // released — the seed-from-warm-start branch in `update_with_curvature`
1634 // must not overwrite this deliberately chosen value). Because this
1635 // assignment evaluated the shape at the *current* converged η and no
1636 // re-solve follows it on the exit paths below, the reported shape
1637 // always equals `estimate_gamma_shape_from_eta(final_eta)` — the
1638 // self-consistency invariant the in-module Gamma unit test checks.
1639 working_model.likelihood = working_model
1640 .likelihood
1641 .clone()
1642 .with_gamma_shape(refreshed_shape);
1643 working_model.gamma_shape_locked = true;
1644 if rel_change <= SHAPE_REFRESH_REL_TOL {
1645 // Converged: the working-state buffers (weights, Hessian,
1646 // deviance) already reflect a shape within tolerance of
1647 // `refreshed_shape`, because the only way to reach here without
1648 // a re-solve is that the prior solve's shape already matched the
1649 // converged-η estimate. Nothing left to rebuild.
1650 break;
1651 }
1652 if refresh_iter + 1 == MAX_SHAPE_REFRESH {
1653 // Final allowed pass and the shape is still drifting (a
1654 // pathological non-contraction). Do NOT re-solve: re-solving
1655 // would advance `final_eta` past the η the just-installed shape
1656 // was evaluated at, breaking the stored-shape == estimate(final_eta)
1657 // invariant. Stopping here keeps the reported shape exactly the
1658 // ML estimate at the reported η; the residual weight/φ drift is
1659 // bounded by the last `rel_change` and never worse than the
1660 // pre-fix frozen-warm-start value.
1661 break;
1662 }
1663 // The shape moved: re-solve β at the corrected shape, warm-started
1664 // at the converged β, so the final working state is rebuilt with the
1665 // refreshed ν.
1666 working_summary = runworking_model_pirls(
1667 &mut working_model,
1668 working_summary.beta.clone(),
1669 &options,
1670 &mut iteration_logger,
1671 )?;
1672 }
1673 }
1674
1675 // ── Tweedie dispersion φ: re-estimate at the *converged* η (#771) ─────────
1676 //
1677 // Identical in spirit to the Gamma-shape refresh above: the inner LM solve
1678 // estimates φ **once** from the warm-start η and freezes it (the
1679 // `tweedie_phi_locked` lock), keeping the product φ·λ — and hence β̂ — a
1680 // stationary LM target. φ enters only the working weight `prior·μ^{2−p}/φ`
1681 // and not the working response, so (like the Gamma shape, and unlike the
1682 // Beta precision which couples through the digamma mean score) the mean
1683 // surface is essentially scale-free and β̂ barely moves when φ is corrected.
1684 // But the frozen warm-start φ is the value that survives into
1685 // `FitInference::dispersion` and the covariance `Vb = H⁻¹` (whose √φ scaling
1686 // lives in the weight); at a cold-started η ≈ 0 the Pearson residuals carry
1687 // the *marginal* spread of y, biasing the estimate. Re-estimating at the
1688 // converged η — re-solving β only if φ moved materially — drives (β, φ) to
1689 // their joint fixed point, so the reported φ is the converged-mean Pearson
1690 // estimate and the final weights/Hessian/SE are internally consistent with
1691 // it. Held OFF inside the REML λ search (the flag), φ is refreshed only at
1692 // the reported fit, so it cannot couple to the smoothing parameter.
1693 if refine_dispersion_at_converged_eta
1694 && working_model.likelihood.scale.tweedie_phi_is_estimated()
1695 {
1696 if let ResponseFamily::Tweedie { p } = working_model.likelihood.spec.response {
1697 // The converged-η Pearson map is a strong contraction (β̂ scale-free
1698 // here), so cold starts settle in 1–2 re-solves and warm starts in
1699 // zero.
1700 const MAX_PHI_REFRESH: usize = 5;
1701 // Relative φ tolerance below which a re-solve cannot move any reported
1702 // quantity meaningfully (far under statistical resolution).
1703 const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1704 for refresh_iter in 0..MAX_PHI_REFRESH {
1705 let refreshed_phi = super::estimate_tweedie_phi_from_eta(
1706 y,
1707 working_summary.state.eta.as_ref(),
1708 priorweights,
1709 p,
1710 );
1711 let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1712 let rel_change =
1713 (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1714 // Install the refreshed φ (the scale metadata the working weight
1715 // reads via `fixed_phi()`) and re-arm the lock so a following
1716 // re-solve does not overwrite this converged-η value. Because the
1717 // exit paths below evaluate φ at the *current* η with no following
1718 // re-solve, the reported φ always equals
1719 // `estimate_tweedie_phi_from_eta(final_eta)`.
1720 working_model.likelihood = working_model
1721 .likelihood
1722 .clone()
1723 .with_tweedie_phi(refreshed_phi);
1724 working_model.tweedie_phi_locked = true;
1725 if rel_change <= PHI_REFRESH_REL_TOL {
1726 // Converged: the working state already reflects a φ within
1727 // tolerance of `refreshed_phi`. Nothing left to rebuild.
1728 break;
1729 }
1730 if refresh_iter + 1 == MAX_PHI_REFRESH {
1731 // Final allowed pass and φ is still drifting. Do NOT re-solve:
1732 // re-solving would advance η past the point φ was evaluated at,
1733 // breaking the stored-φ == estimate(final_eta) invariant.
1734 break;
1735 }
1736 // φ moved materially: re-solve β at the corrected φ, warm-started
1737 // at the converged β, so the final working state is rebuilt with
1738 // the refreshed φ.
1739 working_summary = runworking_model_pirls(
1740 &mut working_model,
1741 working_summary.beta.clone(),
1742 &options,
1743 &mut iteration_logger,
1744 )?;
1745 }
1746 }
1747 }
1748
1749 // ── Beta precision φ: re-estimate at the *converged* η and drive (β, φ) to
1750 // their joint fixed point (#769) ──────────────────────────────────────
1751 //
1752 // Like the Gamma shape above, the inner LM solve estimates φ **once** from
1753 // the warm-start η and freezes it for the rest of the solve (the
1754 // `beta_phi_locked` doc on `GamWorkingModel`): holding φ fixed keeps the
1755 // penalized argmin β̂ a stationary LM target so the gain ratio compares one
1756 // objective. But that lock pins φ to whatever η the solve started from, and
1757 // for the final dedicated fit at the converged ρ the warm-start is the cold
1758 // default guess (η ≈ 0, μ ≈ 0.5 everywhere). At the null predictor the
1759 // Pearson residuals `(y−μ)²/(μ(1−μ))` capture the full *marginal* spread of
1760 // y rather than its *conditional* spread, so the moment estimator
1761 // `1+φ = Σw / Σ w·s` returns a precision far too small (≈3 when the truth is
1762 // ≈20 here).
1763 //
1764 // Crucially — and unlike the Gamma shape — φ does **not** factor out of the
1765 // Beta mean score. With the logit link the score for β is
1766 // ∂ℓ/∂β = φ · Σᵢ xᵢ (y*ᵢ − μ*ᵢ), y*ᵢ = logit(yᵢ),
1767 // μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ),
1768 // so the root β̂ depends on φ through the digamma terms. A φ that is too
1769 // small shrinks every fitted coefficient toward zero. So this refresh is not
1770 // cosmetic (as it is for Gamma): the re-solve is what *recovers the mean*.
1771 //
1772 // Fix: after the cold solve converges, re-estimate φ at the converged η,
1773 // re-solve β at the corrected φ (warm-started), and repeat. This is the
1774 // betareg alternating mean-fit ↔ φ-estimate scheme; the moment estimator is
1775 // a strong contraction once the mean has any structure, so the pair settles
1776 // in a handful of passes. Held OFF inside the REML λ search (see the flag
1777 // doc), φ is refreshed only here at the reported fit, so it cannot couple to
1778 // the smoothing parameter and reward over-smoothing. As with Gamma, every
1779 // exit path installs φ evaluated at the *current* η with no following
1780 // re-solve, so the reported φ (which flows into `EstimatedBetaPhi`, the
1781 // embedded `Beta { phi }`, `dispersion`, and every SE) always equals
1782 // `estimate_beta_phi_from_eta(final_eta)`.
1783 if refine_dispersion_at_converged_eta && working_model.likelihood.scale.beta_phi_is_estimated()
1784 {
1785 // The mean moves between passes (φ feeds back through the digamma
1786 // score), so allow a few more passes than the scale-free Gamma case;
1787 // the contraction is fast and warm-started re-solves are cheap.
1788 const MAX_PHI_REFRESH: usize = 30;
1789 // Relative φ tolerance below which a re-solve cannot move β̂ — and hence
1790 // any reported quantity — by a statistically meaningful amount.
1791 const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1792 for refresh_iter in 0..MAX_PHI_REFRESH {
1793 let refreshed_phi = super::estimate_beta_phi_from_eta(
1794 y,
1795 working_summary.state.eta.as_ref(),
1796 priorweights,
1797 );
1798 let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1799 let rel_change = (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1800 // Install the refreshed φ (updates BOTH the `Beta { phi }` family
1801 // variant every weight/deviance expression reads and the
1802 // `EstimatedBetaPhi` scale metadata) and re-arm the lock so a
1803 // following re-solve's `update_with_curvature` does not overwrite
1804 // this deliberately chosen value with a fresh cold estimate.
1805 working_model.likelihood = working_model
1806 .likelihood
1807 .clone()
1808 .with_beta_phi(refreshed_phi);
1809 working_model.beta_phi_locked = true;
1810 if rel_change <= PHI_REFRESH_REL_TOL {
1811 // Converged: the just-installed φ matches (to tolerance) the φ
1812 // the current working state was solved at, so β̂, the weights,
1813 // the Hessian and the deviance are already self-consistent with
1814 // the reported φ. Nothing left to rebuild.
1815 break;
1816 }
1817 if refresh_iter + 1 == MAX_PHI_REFRESH {
1818 // Final allowed pass and φ is still drifting. Do NOT re-solve:
1819 // re-solving would advance η past the point the just-installed φ
1820 // was evaluated at, breaking the stored-φ == estimate(final_eta)
1821 // invariant. Stop here so the reported φ is exactly the moment
1822 // estimate at the reported η.
1823 break;
1824 }
1825 // φ moved materially: re-solve β at the corrected φ, warm-started at
1826 // the converged β, so the mean is refit under the better precision
1827 // and the final working state is rebuilt consistently.
1828 working_summary = runworking_model_pirls(
1829 &mut working_model,
1830 working_summary.beta.clone(),
1831 &options,
1832 &mut iteration_logger,
1833 )?;
1834 }
1835 }
1836
1837 // ── Negative-Binomial overdispersion θ: re-estimate at the *converged* η and
1838 // drive (β, θ) to their joint fixed point (#802) ───────────────────────
1839 //
1840 // Identical in spirit to the Beta-precision refresh above. The inner LM solve
1841 // estimates θ **once** from the warm-start η and freezes it (the
1842 // `negbin_theta_locked` lock), keeping the penalized argmin β̂ a stationary LM
1843 // target. But that lock pins θ to whatever η the solve started from, and for
1844 // the final dedicated fit at the converged ρ the warm-start is the cold
1845 // default guess (η ≈ 0). At the null predictor the Pearson residuals carry
1846 // the *marginal* spread of y rather than its *conditional* spread, biasing
1847 // the moment seed — and the frozen θ is what survives into the working weight
1848 // `W = μθ/(θ+μ)`, the covariance `Vb = H⁻¹` (whose overdispersion scaling
1849 // lives in that weight, not a post-hoc multiply), and every reported SE /
1850 // interval / `generate` draw.
1851 //
1852 // Like the Beta precision — and unlike the scale-free Gamma shape / Tweedie φ
1853 // — θ enters the NB2 working *response*, not only the weight, so re-solving β
1854 // under the corrected θ is not cosmetic: it recovers the mean under the right
1855 // variance function. Re-estimating at the converged η, re-solving β
1856 // (warm-started), and repeating drives (β, θ) to their joint maximum-
1857 // likelihood fixed point. Held OFF inside the REML λ search (the flag), θ is
1858 // refreshed only here at the reported fit, so it cannot couple to the
1859 // smoothing parameter. Every exit path installs θ evaluated at the *current*
1860 // η with no following re-solve, so the reported θ (which flows into the
1861 // embedded `NegativeBinomial { theta }`, the `EstimatedNegBinTheta` scale
1862 // metadata, the predictive-interval variance, and every SE) always equals
1863 // `estimate_negbin_theta_from_eta(final_eta)`.
1864 if refine_dispersion_at_converged_eta
1865 && working_model.likelihood.scale.negbin_theta_is_estimated()
1866 {
1867 // θ feeds back through the working response, so allow a few more passes
1868 // than the scale-free Gamma case; the alternation is a strong contraction
1869 // and warm-started re-solves are cheap.
1870 const MAX_THETA_REFRESH: usize = 30;
1871 // Relative θ tolerance below which a re-solve cannot move β̂ — and hence
1872 // any reported quantity — by a statistically meaningful amount.
1873 const THETA_REFRESH_REL_TOL: f64 = 1e-4;
1874 for refresh_iter in 0..MAX_THETA_REFRESH {
1875 let refreshed_theta = super::estimate_negbin_theta_from_eta(
1876 y,
1877 working_summary.state.eta.as_ref(),
1878 priorweights,
1879 );
1880 let prior_theta = working_model.likelihood.negbin_theta().unwrap_or(1.0);
1881 let rel_change =
1882 (refreshed_theta - prior_theta).abs() / prior_theta.max(f64::MIN_POSITIVE);
1883 // Install the refreshed θ (updates BOTH the `NegativeBinomial { theta }`
1884 // family variant every weight/deviance expression reads and the
1885 // `EstimatedNegBinTheta` scale metadata) and re-arm the lock so a
1886 // following re-solve's `update_with_curvature` does not overwrite this
1887 // deliberately chosen value with a fresh cold estimate.
1888 working_model.likelihood = working_model
1889 .likelihood
1890 .clone()
1891 .with_negbin_theta(refreshed_theta);
1892 working_model.negbin_theta_locked = true;
1893 if rel_change <= THETA_REFRESH_REL_TOL {
1894 // Converged: the just-installed θ matches (to tolerance) the θ the
1895 // current working state was solved at, so β̂, the weights, the
1896 // Hessian and the deviance are already self-consistent with the
1897 // reported θ. Nothing left to rebuild.
1898 break;
1899 }
1900 if refresh_iter + 1 == MAX_THETA_REFRESH {
1901 // Final allowed pass and θ is still drifting. Do NOT re-solve:
1902 // re-solving would advance η past the point the just-installed θ
1903 // was evaluated at, breaking the stored-θ == estimate(final_eta)
1904 // invariant. Stop here so the reported θ is exactly the ML
1905 // estimate at the reported η.
1906 break;
1907 }
1908 // θ moved materially: re-solve β at the corrected θ, warm-started at
1909 // the converged β, so the mean is refit under the better variance
1910 // function and the final working state is rebuilt consistently.
1911 working_summary = runworking_model_pirls(
1912 &mut working_model,
1913 working_summary.beta.clone(),
1914 &options,
1915 &mut iteration_logger,
1916 )?;
1917 }
1918 }
1919
1920 // Extract workspace before consuming working_model so we can reuse
1921 // the pre-allocated buffers in calculate_edfwithworkspace_with_penalty.
1922 // into_final_state() drops the workspace field anyway (it uses `..` in
1923 // its destructure); we replace it with a zero-sized stub to satisfy the
1924 // borrow checker, then keep the real workspace alive for the EDF call.
1925 let mut saved_workspace = std::mem::replace(
1926 &mut working_model.workspace,
1927 PirlsWorkspace::new(0, 0, 0, 0),
1928 );
1929 let final_state = working_model.into_final_state();
1930 let GamModelFinalState {
1931 likelihood: final_likelihood,
1932 coordinate_frame,
1933 finalmu,
1934 finalweights,
1935 scoreweights,
1936 finalz,
1937 final_c,
1938 final_d,
1939 final_dmu_deta,
1940 final_d2mu_deta2,
1941 final_d3mu_deta3,
1942 penalty_term,
1943 ..
1944 } = final_state;
1945
1946 // Preserve the Hessian as-is (sparse or dense) — no densification.
1947 // P-IRLS already folded any stabilization ridge directly into the Hessian.
1948 // Keep that exact matrix so outer LAML derivatives stay consistent:
1949 // H_eff = X'W_H X + S_λ + ridge I (if ridge_used > 0).
1950 let penalized_hessian_transformed = working_summary.state.hessian.clone();
1951 let stabilizedhessian_transformed = penalized_hessian_transformed.clone();
1952 // Use the workspace-backed variant for the dense path to reuse the
1953 // `final_aug_matrix` allocation; the sparse path still allocates
1954 // internally because no pre-computed factor is available at this site.
1955 let mut edf = if let Some(dense_h) = penalized_hessian_transformed.as_dense() {
1956 calculate_edfwithworkspace_with_penalty(dense_h, &penalty_active, &mut saved_workspace)?
1957 } else {
1958 calculate_edf_with_penalty(&penalized_hessian_transformed, &penalty_active)?
1959 };
1960 if !edf.is_finite() || edf.is_nan() {
1961 let p = penalized_hessian_transformed.ncols() as f64;
1962 let r = penalty_active.rank() as f64;
1963 edf = (p - r).max(0.0);
1964 }
1965
1966 // Outer rescue: a fit that hit max-iterations may still be a usable
1967 // minimum if progress has effectively stopped (deviance plateaued or
1968 // step size collapsed to the floor) AND the projected gradient is in
1969 // the near-stationary band under the scale-invariant certificate.
1970 // Same logic for non-Firth and Firth paths; firth_active just gates
1971 // the second pass.
1972 let stalled_at_valid_minimum = |summary: &WorkingModelPirlsResult| -> bool {
1973 // Scale-equivariant deviance plateau band (issue #1127). The
1974 // `last_deviance_change` compared below and the deviance both scale as
1975 // `O(a²)` under a response rescaling `y → a·y` (the penalized normal
1976 // equations are linear in `y`, so `β → a·β` and the RSS-deviance
1977 // scales by `a²`). Keying the plateau band to the deviance's own
1978 // magnitude `+ |penalty|` makes the ratio `Δdev / dev_scale`
1979 // scale-invariant. The previous `.max(1.0)` absolute floor broke this:
1980 // for a micro-unit response (`a = 1e-6`) the deviance is `O(1e-12)`, so
1981 // the floor pinned the band at `1.0` — ~1e9× too loose — and this
1982 // max-iteration rescue declared `progress_stopped` at an over-smoothed
1983 // iterate, propagating an inflated `λ̂` to the outer REML loop. For a
1984 // well-scaled (`a ≳ 1`) or up-scaled (`a = 1e6`) objective the floor was
1985 // already a no-op, so those directions are byte-identical. A perfect
1986 // interpolating fit gives a `0` band, so the relative `Δdev` test cannot
1987 // fire spuriously and the scale-invariant `near_stationary_kkt`
1988 // certificate then governs acceptance.
1989 let dev_scale = summary.state.deviance.abs() + summary.state.penalty_term.abs();
1990 // Progress plateau uses the fixed solver tolerance; only the KKT band below adapts.
1991 let dev_tol = options.convergence_tolerance * dev_scale;
1992 let step_floor = options.min_step_size * 2.0;
1993 let progress_stopped =
1994 summary.last_deviance_change.abs() <= dev_tol || summary.last_step_size <= step_floor;
1995 let near_stationary = summary
1996 .state
1997 .near_stationary_kkt(summary.lastgradient_norm, effective_kkt_tolerance(&options));
1998 progress_stopped && near_stationary
1999 };
2000
2001 let mut status = working_summary.status;
2002 if status.is_failed_max_iterations() && stalled_at_valid_minimum(&working_summary) {
2003 status = PirlsStatus::StalledAtValidMinimum;
2004 working_summary.status = status;
2005 }
2006 if status.is_failed_max_iterations()
2007 && firth_active
2008 && stalled_at_valid_minimum(&working_summary)
2009 {
2010 // Firth-adjusted fits can stall; accept under the same dual-criterion
2011 // near-stationary band.
2012 status = PirlsStatus::StalledAtValidMinimum;
2013 working_summary.status = status;
2014 }
2015 let has_penalty = penalty_active.rank() > 0;
2016 let firth_active = options.firth_bias_reduction;
2017 if detect_logit_instability(
2018 link_function,
2019 &final_likelihood.spec.response,
2020 has_penalty,
2021 firth_active,
2022 &working_summary,
2023 &finalmu,
2024 &finalweights,
2025 y,
2026 ) {
2027 status = PirlsStatus::Unstable;
2028 working_summary.status = status;
2029 }
2030
2031 // Store a lazy ReparamOperator instead of materializing X·Qs.
2032 // Consumers that truly need dense access can call .to_dense() on demand.
2033 let reparam_result_final = materialize_final_reparam_result()?;
2034 let qs_arc_final = Arc::new(reparam_result_final.qs.clone());
2035 let x_transformed_final =
2036 make_reparam_operator(&x_original_for_result, &qs_arc_final, use_sparse_native);
2037
2038 let pirls_result = assemble_pirls_result(
2039 &working_summary,
2040 final_likelihood,
2041 offset,
2042 penalized_hessian_transformed,
2043 stabilizedhessian_transformed,
2044 edf,
2045 penalty_term,
2046 &finalmu,
2047 &finalweights,
2048 &scoreweights,
2049 &finalz,
2050 &final_c,
2051 &final_d,
2052 &final_dmu_deta,
2053 &final_d2mu_deta2,
2054 &final_d3mu_deta3,
2055 status,
2056 reparam_result_final,
2057 x_transformed_final,
2058 coordinate_frame,
2059 linear_constraints,
2060 );
2061
2062 Ok((pirls_result, working_summary))
2063}
2064
2065#[derive(Clone)]
2066pub struct PirlsConfig {
2067 pub likelihood: GlmLikelihoodSpec,
2068 pub link_kind: InverseLink,
2069 pub max_iterations: usize,
2070 pub convergence_tolerance: f64,
2071 pub firth_bias_reduction: bool,
2072 /// Optional warm-start hint for `WorkingModelPirlsOptions::initial_lm_lambda`.
2073 /// Forwarded directly when `fit_model_for_fixed_rho` builds its
2074 /// internal options. See the field doc on `WorkingModelPirlsOptions`
2075 /// for the seeding semantics.
2076 pub initial_lm_lambda: Option<f64>,
2077 /// Enable the Transtrum-Sethna geodesic-acceleration second-order
2078 /// correction on each accepted LM step. Forwarded to
2079 /// `WorkingModelPirlsOptions::geodesic_acceleration`; see that
2080 /// field's doc for the full semantics and cost model. Default
2081 /// `false`; opt-in until validated.
2082 pub geodesic_acceleration: bool,
2083 /// Optional arrow-Schur structured-inner-solve descriptor. When
2084 /// `Some`, forwarded to `WorkingModelPirlsOptions::arrow_schur` so
2085 /// each accepted LM step is solved by the per-observation
2086 /// arrow-Schur path
2087 /// ([`crate::arrow_schur::ArrowSchurSystem`]). When `None`
2088 /// (the default), the existing β-only path is used unchanged.
2089 ///
2090 /// See [`ArrowSchurInnerConfig`] for the closure contract.
2091 pub arrow_schur: Option<ArrowSchurInnerConfig>,
2092}
2093
2094impl PirlsConfig {
2095 #[inline]
2096 pub fn link_function(&self) -> LinkFunction {
2097 self.link_kind.link_function()
2098 }
2099}
2100
2101#[inline]
2102pub(super) fn max_symmetric_asymmetry(matrix: &Array2<f64>) -> f64 {
2103 let n = matrix.nrows().min(matrix.ncols());
2104 let mut max_asym = 0.0_f64;
2105 for i in 0..n {
2106 for j in 0..i {
2107 let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2108 if diff > max_asym {
2109 max_asym = diff;
2110 }
2111 }
2112 }
2113 max_asym
2114}
2115
2116#[inline]
2117pub(super) fn assert_symmetric_tol(matrix: &Array2<f64>, label: &str, tol: f64) {
2118 let max_asym = max_symmetric_asymmetry(matrix);
2119 assert!(
2120 max_asym <= tol,
2121 "{} asymmetry too large: {:.3e} (tol {:.3e})",
2122 label,
2123 max_asym,
2124 tol
2125 );
2126}
2127
2128/// Build a DesignMatrix wrapping a lazy ReparamOperator (or the original for sparse-native).
2129pub(crate) fn make_reparam_operator(
2130 x_original: &DesignMatrix,
2131 qs_arc: &Arc<Array2<f64>>,
2132 use_sparse_native: bool,
2133) -> DesignMatrix {
2134 if use_sparse_native {
2135 x_original.clone()
2136 } else {
2137 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
2138 ReparamOperator::new(x_original.clone(), Arc::clone(qs_arc)),
2139 )))
2140 }
2141}
2142
2143// solve_penalized_least_squares_implicit lives in pls_solver (imported above).
2144
2145pub(super) fn build_transformed_lower_bound_constraints(
2146 qs: &Array2<f64>,
2147 coefficient_lower_bounds: Option<&Array1<f64>>,
2148) -> Option<LinearInequalityConstraints> {
2149 let lb = coefficient_lower_bounds?;
2150 if lb.len() != qs.nrows() {
2151 return None;
2152 }
2153 let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2154 if activerows.is_empty() {
2155 return None;
2156 }
2157 let mut a = Array2::<f64>::zeros((activerows.len(), qs.ncols()));
2158 let mut b = Array1::<f64>::zeros(activerows.len());
2159 for (r, &idx) in activerows.iter().enumerate() {
2160 a.row_mut(r).assign(&qs.row(idx));
2161 b[r] = lb[idx];
2162 }
2163 Some(
2164 LinearInequalityConstraints::new(a, b)
2165 .expect("transformed lower-bound constraint shape invariant"),
2166 )
2167}
2168
2169pub(super) fn build_transformed_lower_bound_constraints_with_transform(
2170 transform: &WorkingReparamTransform,
2171 coefficient_lower_bounds: Option<&Array1<f64>>,
2172) -> Option<LinearInequalityConstraints> {
2173 let lb = coefficient_lower_bounds?;
2174 let p = match transform {
2175 WorkingReparamTransform::Dense(qs) => qs.nrows(),
2176 WorkingReparamTransform::Kronecker(kron) => kron.p,
2177 };
2178 if lb.len() != p {
2179 return None;
2180 }
2181 let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2182 if activerows.is_empty() {
2183 return None;
2184 }
2185 let mut a = Array2::<f64>::zeros((activerows.len(), p));
2186 let mut b = Array1::<f64>::zeros(activerows.len());
2187 for (r, &idx) in activerows.iter().enumerate() {
2188 let mut basis = Array1::<f64>::zeros(p);
2189 basis[idx] = 1.0;
2190 let row = transform.apply_transpose(&basis);
2191 a.row_mut(r).assign(&row);
2192 b[r] = lb[idx];
2193 }
2194 Some(
2195 LinearInequalityConstraints::new(a, b)
2196 .expect("transformed lower-bound constraint shape invariant"),
2197 )
2198}
2199
2200pub(super) fn build_transformed_linear_constraints(
2201 qs: &Array2<f64>,
2202 linear_constraints: Option<&LinearInequalityConstraints>,
2203) -> Option<LinearInequalityConstraints> {
2204 let lc = linear_constraints?;
2205 if lc.a.ncols() != qs.nrows() {
2206 return None;
2207 }
2208 Some(
2209 LinearInequalityConstraints::new(lc.a.dot(qs), lc.b.clone())
2210 .expect("transformed linear constraint shape invariant"),
2211 )
2212}
2213
2214pub(super) fn build_transformed_linear_constraints_with_transform(
2215 transform: &WorkingReparamTransform,
2216 linear_constraints: Option<&LinearInequalityConstraints>,
2217) -> Option<LinearInequalityConstraints> {
2218 let lc = linear_constraints?;
2219 let p = match transform {
2220 WorkingReparamTransform::Dense(qs) => qs.nrows(),
2221 WorkingReparamTransform::Kronecker(kron) => kron.p,
2222 };
2223 if lc.a.ncols() != p {
2224 return None;
2225 }
2226 let mut a = Array2::<f64>::zeros((lc.a.nrows(), p));
2227 for row in 0..lc.a.nrows() {
2228 let transformed = transform.apply_transpose(&lc.a.row(row).to_owned());
2229 a.row_mut(row).assign(&transformed);
2230 }
2231 Some(LinearInequalityConstraints { a, b: lc.b.clone() })
2232}
2233
2234pub(super) fn merge_linear_constraints(
2235 first: Option<LinearInequalityConstraints>,
2236 second: Option<LinearInequalityConstraints>,
2237) -> Option<LinearInequalityConstraints> {
2238 match (first, second) {
2239 (None, None) => None,
2240 (Some(c), None) | (None, Some(c)) => Some(c),
2241 (Some(c1), Some(c2)) => {
2242 if c1.a.ncols() != c2.a.ncols() {
2243 return None;
2244 }
2245 let rows = c1.a.nrows() + c2.a.nrows();
2246 let cols = c1.a.ncols();
2247 let mut a = Array2::<f64>::zeros((rows, cols));
2248 a.slice_mut(s![0..c1.a.nrows(), ..]).assign(&c1.a);
2249 a.slice_mut(s![c1.a.nrows()..rows, ..]).assign(&c2.a);
2250 let mut b = Array1::<f64>::zeros(rows);
2251 b.slice_mut(s![0..c1.b.len()]).assign(&c1.b);
2252 b.slice_mut(s![c1.b.len()..rows]).assign(&c2.b);
2253 Some(LinearInequalityConstraints { a, b })
2254 }
2255 }
2256}
2257
2258pub(super) fn sparse_from_denseview(x: ArrayView2<f64>) -> Option<DesignMatrix> {
2259 // Below this column count a dense factorization beats the sparse path even
2260 // at high sparsity, so skip the sparsity scan entirely for narrow designs.
2261 const DENSE_PREFERRED_MAX_COLS: usize = 32;
2262 // Sparse storage + sparse Cholesky only pays off below this density (nnz as
2263 // a fraction of all entries); denser matrices stay dense.
2264 const SPARSE_DENSITY_LIMIT: f64 = 0.20;
2265
2266 let nrows = x.nrows();
2267 let ncols = x.ncols();
2268 if nrows == 0 || ncols == 0 {
2269 return None;
2270 }
2271 // Narrow matrices are faster in dense form; avoid any sparsity scan overhead.
2272 if ncols <= DENSE_PREFERRED_MAX_COLS {
2273 return None;
2274 }
2275
2276 const ZERO_EPS: f64 = 1e-12;
2277 let total = nrows.saturating_mul(ncols);
2278 if total == 0 {
2279 return None;
2280 }
2281 // If a matrix exceeds this nnz count it is too dense for sparse path; bail early.
2282 let sparse_nnz_limit = ((total as f64) * SPARSE_DENSITY_LIMIT).floor() as usize;
2283 let mut nnz = 0usize;
2284 for &val in x.iter() {
2285 if val.abs() > ZERO_EPS {
2286 nnz += 1;
2287 if nnz > sparse_nnz_limit {
2288 return None;
2289 }
2290 }
2291 }
2292 let mut triplets = Vec::with_capacity(nnz);
2293 for (row_idx, row) in x.outer_iter().enumerate() {
2294 for (col_idx, &val) in row.iter().enumerate() {
2295 if val.abs() > ZERO_EPS {
2296 triplets.push(Triplet::new(row_idx, col_idx, val));
2297 }
2298 }
2299 }
2300 SparseColMat::try_new_from_triplets(nrows, ncols, &triplets)
2301 .ok()
2302 .map(DesignMatrix::from)
2303}
2304
2305#[cfg(test)]
2306mod tests {
2307 use super::{PirlsPenalty, build_diagonal_penalty_from_kronecker};
2308 use gam_terms::construction::KroneckerReparamResult;
2309 use ndarray::{Array1, Array2, array};
2310
2311 #[test]
2312 fn kronecker_diagonal_double_penalty_hits_only_joint_null_space() {
2313 let kron_result = KroneckerReparamResult {
2314 reparameterized_marginals: std::sync::Arc::new(Vec::new()),
2315 marginal_eigenvalues: std::sync::Arc::new(vec![array![0.0, 2.0], array![0.0, 3.0]]),
2316 marginal_qs: std::sync::Arc::new(Vec::new()),
2317 log_det: 0.0,
2318 det1: Array1::zeros(3),
2319 det2: Array2::zeros((3, 3)),
2320 penalty_shrinkage_ridge: 0.5,
2321 has_double_penalty: true,
2322 marginal_dims: vec![2usize, 2usize],
2323 };
2324 let penalty = build_diagonal_penalty_from_kronecker(&kron_result, &[5.0, 7.0, 11.0]);
2325
2326 let PirlsPenalty::Diagonal {
2327 diag,
2328 positive_indices,
2329 ..
2330 } = penalty
2331 else {
2332 panic!("expected diagonal Kronecker PIRLS penalty");
2333 };
2334 let expected = [11.0, 21.5, 10.5, 31.5];
2335 for (idx, expected_diag) in expected.iter().copied().enumerate() {
2336 assert!(
2337 (diag[idx] - expected_diag).abs() <= 1e-12,
2338 "diagonal {idx} got {}, expected {expected_diag}",
2339 diag[idx]
2340 );
2341 }
2342 assert_eq!(positive_indices, vec![0, 1, 2, 3]);
2343 }
2344}