Skip to main content

gam_solve/pirls/
loop_driver.rs

1//! Outer driver for a single fixed-ρ PIRLS fit.
2//!
3//! Owns:
4//! - `fit_model_for_fixed_rho` and `fit_model_for_fixed_rho_with_adaptive_kkt`
5//!   — build the working model, run the inner LM loop, assemble the final result.
6//! - `PirlsProblem`, `PenaltyConfig`, `PirlsConfig` — the configuration types.
7//! - Helper functions exclusive to the fixed-ρ fitting path: constraint
8//!   transformation, sparse-native decision, reparam materialisation, prior
9//!   shift assembly, initial-β guess, Gaussian short-circuit assembly, etc.
10//! - The two GPU dispatch blocks (Stage 3.3) that call into
11//!   `crate::gpu::pirls_dispatch_wire`.
12
13use super::{
14    // state re-exports
15    AdaptiveKktTolerance,
16    ExportedLaplaceCurvature,
17    FirthDiagnostics,
18    GamWorkingModel,
19    HessianCurvatureKind,
20    // penalty types
21    KroneckerQsTransform,
22    LinearInequalityConstraints,
23    PirlsCoordinateFrame,
24    PirlsLinearSolvePath,
25    PirlsPenalty,
26    PirlsResult,
27    PirlsStatus,
28    PirlsWorkspace,
29    SparsePirlsDecision,
30    WorkingModelIterationInfo,
31    WorkingModelPirlsOptions,
32    WorkingModelPirlsResult,
33    WorkingReparamTransform,
34    WorkingState,
35    // misc helpers
36    array1_l2_norm,
37    attach_penalty_shift,
38    // compute functions
39    calculate_deviance,
40    // edf helpers
41    calculate_edf_with_penalty,
42    calculate_edfwithworkspace_with_penalty,
43    calculate_loglikelihood,
44    compute_constraint_kkt_diagnostics,
45    computeworkingweight_derivatives_from_eta,
46    inf_norm,
47    runworking_model_pirls,
48    should_use_sparse_native_pirls,
49    solve_penalized_least_squares_implicit,
50    standard_inverse_link_jet,
51};
52use super::{
53    ArrowSchurInnerConfig, GamModelFinalState, effective_kkt_tolerance,
54    project_coefficients_to_lower_bounds,
55};
56use crate::active_set;
57use crate::estimate::EstimationError;
58use crate::gpu::pirls_host_dispatch::{try_gaussian_pls_gpu, try_pirls_loop_gpu};
59use crate::mixture_link::inverse_link_has_fisher_weight_jet;
60use faer::sparse::{SparseColMat, Triplet};
61use gam_linalg::faer_ndarray::fast_ab;
62use gam_linalg::matrix::{DesignMatrix, LinearOperator, ReparamOperator, SymmetricMatrix};
63use gam_math::probability::standard_normal_quantile;
64use gam_problem::{
65    Coefficients, GlmLikelihoodSpec, InverseLink, LinearPredictor, LinkFunction,
66    LogSmoothingParamsView, MixtureLinkState, ResponseFamily, RidgePassport, RidgePolicy,
67    SasLinkState, StandardLink,
68};
69use gam_terms::construction::{KroneckerReparamResult, ReparamResult};
70use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
71use std::borrow::Cow;
72use std::sync::Arc;
73
74pub(super) fn default_beta_guess_external(
75    p: usize,
76    link_function: LinkFunction,
77    y: ArrayView1<f64>,
78    priorweights: ArrayView1<f64>,
79    mixture_link_state: Option<&MixtureLinkState>,
80    sas_link_state: Option<&SasLinkState>,
81) -> Array1<f64> {
82    let mut beta = Array1::<f64>::zeros(p);
83    let intercept_col = 0usize;
84    match link_function {
85        LinkFunction::Logit
86        | LinkFunction::Probit
87        | LinkFunction::CLogLog
88        | LinkFunction::LogLog
89        | LinkFunction::Cauchit
90        | LinkFunction::Sas
91        | LinkFunction::BetaLogistic => {
92            let mut weighted_sum = 0.0;
93            let mut totalweight = 0.0;
94            for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
95                weighted_sum += wi * yi;
96                totalweight += wi;
97            }
98            if totalweight > 0.0 {
99                let prevalence =
100                    ((weighted_sum + 0.5) / (totalweight + 1.0)).clamp(1e-6, 1.0 - 1e-6);
101                beta[intercept_col] = match link_function {
102                    LinkFunction::Logit => (prevalence / (1.0 - prevalence)).ln(),
103                    LinkFunction::Probit => {
104                        standard_normal_quantile(prevalence).unwrap_or_else(|_| {
105                            // `prevalence` is clamped to (0, 1); this fallback is
106                            // only for defensive robustness under non-finite upstream inputs.
107                            (prevalence / (1.0 - prevalence)).ln()
108                        })
109                    }
110                    LinkFunction::CLogLog => (-(1.0 - prevalence).ln()).ln(),
111                    LinkFunction::LogLog => -(-prevalence.ln()).ln(),
112                    LinkFunction::Cauchit => (std::f64::consts::PI * (prevalence - 0.5)).tan(),
113                    LinkFunction::Sas => solve_intercept_for_prevalence(
114                        link_function,
115                        prevalence,
116                        mixture_link_state,
117                        sas_link_state,
118                    )
119                    .unwrap_or_else(|| {
120                        standard_normal_quantile(prevalence)
121                            .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
122                    }),
123                    LinkFunction::BetaLogistic => solve_intercept_for_prevalence(
124                        link_function,
125                        prevalence,
126                        mixture_link_state,
127                        sas_link_state,
128                    )
129                    .unwrap_or_else(|| {
130                        standard_normal_quantile(prevalence)
131                            .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
132                    }),
133                    // Outer arm guard already filtered out Log/Identity; fall
134                    // back to the canonical logit transform for defensive safety
135                    // if these are ever reached unexpectedly.
136                    LinkFunction::Log | LinkFunction::Identity => {
137                        (prevalence / (1.0 - prevalence)).ln()
138                    }
139                };
140                if mixture_link_state.is_some() {
141                    beta[intercept_col] = solve_intercept_for_prevalence(
142                        link_function,
143                        prevalence,
144                        mixture_link_state,
145                        sas_link_state,
146                    )
147                    .unwrap_or(beta[intercept_col]);
148                }
149            }
150        }
151        LinkFunction::Identity => {
152            let mut weighted_sum = 0.0;
153            let mut totalweight = 0.0;
154            for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
155                weighted_sum += wi * yi;
156                totalweight += wi;
157            }
158            if totalweight > 0.0 {
159                beta[intercept_col] = weighted_sum / totalweight;
160            }
161        }
162        LinkFunction::Log => {
163            // For log link, intercept = ln(weighted mean of y)
164            let mut weighted_sum = 0.0;
165            let mut totalweight = 0.0;
166            for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
167                weighted_sum += wi * yi;
168                totalweight += wi;
169            }
170            if totalweight > 0.0 {
171                let mean_y = (weighted_sum / totalweight).max(1e-10);
172                beta[intercept_col] = mean_y.ln();
173            }
174        }
175    }
176    beta
177}
178
179pub(super) fn solve_intercept_for_prevalence(
180    link_function: LinkFunction,
181    prevalence: f64,
182    mixture_link_state: Option<&MixtureLinkState>,
183    sas_link_state: Option<&SasLinkState>,
184) -> Option<f64> {
185    #[inline]
186    fn f_eta(
187        link_function: LinkFunction,
188        eta: f64,
189        prevalence: f64,
190        mixture_link_state: Option<&MixtureLinkState>,
191        sas_link_state: Option<&SasLinkState>,
192    ) -> f64 {
193        let inverse_link = if let Some(state) = mixture_link_state {
194            InverseLink::Mixture(state.clone())
195        } else if let Some(state) = sas_link_state {
196            match link_function {
197                LinkFunction::BetaLogistic => InverseLink::BetaLogistic(*state),
198                _ => InverseLink::Sas(*state),
199            }
200        } else {
201            // SAFETY: when `sas_link_state` is None, `solve_intercept_for_prevalence`
202            // is only invoked with the five legal `StandardLink` variants (the
203            // dispatch site at pirls.rs:4203 routes Sas/BetaLogistic into the
204            // Some branch above with state).
205            InverseLink::Standard(StandardLink::try_from(link_function).expect(
206                "state-bearing link reached state-less arm in solve_intercept_for_prevalence",
207            ))
208        };
209        standard_inverse_link_jet(&inverse_link, eta)
210            .map(|jet| jet.mu - prevalence)
211            .unwrap_or(f64::NAN)
212    }
213
214    let mut lo = -40.0;
215    let mut hi = 40.0;
216    let mut f_lo = f_eta(
217        link_function,
218        lo,
219        prevalence,
220        mixture_link_state,
221        sas_link_state,
222    );
223    let mut f_hi = f_eta(
224        link_function,
225        hi,
226        prevalence,
227        mixture_link_state,
228        sas_link_state,
229    );
230    if !(f_lo.is_finite() && f_hi.is_finite()) {
231        return None;
232    }
233    for _ in 0..8 {
234        if f_lo <= 0.0 && f_hi >= 0.0 {
235            break;
236        }
237        lo *= 2.0;
238        hi *= 2.0;
239        f_lo = f_eta(
240            link_function,
241            lo,
242            prevalence,
243            mixture_link_state,
244            sas_link_state,
245        );
246        f_hi = f_eta(
247            link_function,
248            hi,
249            prevalence,
250            mixture_link_state,
251            sas_link_state,
252        );
253        if !(f_lo.is_finite() && f_hi.is_finite()) {
254            return None;
255        }
256    }
257    if f_lo > 0.0 {
258        return Some(lo);
259    }
260    if f_hi < 0.0 {
261        return Some(hi);
262    }
263    for _ in 0..80 {
264        let mid = 0.5 * (lo + hi);
265        let f_mid = f_eta(
266            link_function,
267            mid,
268            prevalence,
269            mixture_link_state,
270            sas_link_state,
271        );
272        if !f_mid.is_finite() {
273            return None;
274        }
275        if f_mid > 0.0 {
276            hi = mid;
277        } else {
278            lo = mid;
279        }
280    }
281    Some(0.5 * (lo + hi))
282}
283
284pub(super) fn assemble_pirls_result(
285    working_summary: &WorkingModelPirlsResult,
286    likelihood: GlmLikelihoodSpec,
287    offset: ArrayView1<'_, f64>,
288    penalized_hessian_transformed: SymmetricMatrix,
289    stabilizedhessian_transformed: SymmetricMatrix,
290    edf: f64,
291    penalty_term: f64,
292    finalmu: &Array1<f64>,
293    finalweights: &Array1<f64>,
294    scoreweights: &Array1<f64>,
295    finalz: &Array1<f64>,
296    final_c: &Array1<f64>,
297    final_d: &Array1<f64>,
298    final_dmu_deta: &Array1<f64>,
299    final_d2mu_deta2: &Array1<f64>,
300    final_d3mu_deta3: &Array1<f64>,
301    status: PirlsStatus,
302    reparam_result: ReparamResult,
303    x_transformed: DesignMatrix,
304    coordinate_frame: PirlsCoordinateFrame,
305    linear_constraints_transformed: Option<LinearInequalityConstraints>,
306) -> PirlsResult {
307    let final_eta_arr = working_summary.state.eta.as_ref().clone();
308    PirlsResult {
309        likelihood,
310        beta_transformed: working_summary.beta.clone(),
311        penalized_hessian_transformed,
312        stabilizedhessian_transformed,
313        ridge_passport: RidgePassport::scaled_identity(
314            working_summary.state.ridge_used,
315            RidgePolicy::explicit_stabilization_full(),
316        ),
317        ridge_used: working_summary.state.ridge_used,
318        deviance: working_summary.state.deviance,
319        edf,
320        stable_penalty_term: penalty_term,
321        firth: working_summary.state.firth.clone(),
322        finalweights: finalweights.clone(),
323        final_offset: offset.to_owned(),
324        final_eta: final_eta_arr,
325        finalmu: finalmu.clone(),
326        solveweights: scoreweights.clone(),
327        solveworking_response: finalz.clone(),
328        solvemu: finalmu.clone(),
329        solve_dmu_deta: final_dmu_deta.clone(),
330        solve_d2mu_deta2: final_d2mu_deta2.clone(),
331        solve_d3mu_deta3: final_d3mu_deta3.clone(),
332        solve_c_array: final_c.clone(),
333        solve_d_array: final_d.clone(),
334        derivatives_unsupported: false,
335        status,
336        iteration: working_summary.iterations,
337        max_abs_eta: working_summary.max_abs_eta,
338        lastgradient_norm: working_summary.lastgradient_norm,
339        gradient_natural_scale: working_summary.state.gradient_natural_scale,
340        last_deviance_change: working_summary.last_deviance_change,
341        last_step_halving: working_summary.last_step_halving,
342        hessian_curvature: working_summary.state.hessian_curvature,
343        exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
344        final_lm_lambda: working_summary.final_lm_lambda,
345        final_accept_rho: working_summary.final_accept_rho,
346        constraint_kkt: working_summary.constraint_kkt.clone(),
347        linear_constraints_transformed,
348        reparam_result,
349        x_transformed,
350        coordinate_frame,
351        used_device: false,
352        cache_compacted: false,
353        min_penalized_deviance: working_summary.min_penalized_deviance,
354    }
355}
356
357pub(super) fn detect_logit_instability(
358    link: LinkFunction,
359    response: &ResponseFamily,
360    has_penalty: bool,
361    firth_active: bool,
362    summary: &WorkingModelPirlsResult,
363    finalmu: &Array1<f64>,
364    finalweights: &Array1<f64>,
365    y: ArrayView1<'_, f64>,
366) -> bool {
367    // Perfect / quasi-perfect separation is a *Bernoulli/Binomial* pathology.
368    // Every heuristic below is binary-response–specific: saturation toward
369    // μ ∈ {0, 1}, the `yᵢ > 0.5` order-separation split, and working-weight
370    // collapse only carry meaning when each `yᵢ` is a 0/1 outcome (or a
371    // proportion of Bernoulli trials). The Beta family also fits through the
372    // logit link, but its response is *continuous* on (0, 1): a perfectly
373    // healthy monotone mean (μ increasing in a covariate ⇒ rows with y > 0.5
374    // sit at higher η than rows with y ≤ 0.5) trivially satisfies the
375    // `order_separated` test, so gating this detector on the logit link alone
376    // misclassifies well-behaved Beta fits as separated and forces a spurious
377    // inner-solve retreat at every smoothing-parameter seed (issue #499).
378    // Gate strictly on the Binomial response so only binary GLMs are screened.
379    if !matches!(response, ResponseFamily::Binomial) || link != LinkFunction::Logit || firth_active
380    {
381        return false;
382    }
383
384    // Separation-detection policy thresholds. Each is a heuristic cut-off, not
385    // a math identity: they decide when a binary-logit fit has drifted into the
386    // perfect/quasi-perfect separation regime and the inner solve must retreat.
387    //
388    // `ORDER_SEPARATION_ETA_GAP`: a strictly positive η-gap between the lowest
389    //   η among y=1 rows and the highest among y=0 rows means the two classes
390    //   are linearly separable on the linear predictor.
391    // `EXTREME_ETA`: |η| this large drives μ to within machine-ε of {0,1}.
392    // `SATURATION_FRACTION` / `SEVERE_SATURATION_FRACTION`: share of fitted μ
393    //   pinned to the {0,1} boundary that flags (severe) saturation.
394    // `DEGENERATE_DEVIANCE_PER_SAMPLE` / `EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE`:
395    //   near-zero per-sample deviance means the model fits the data perfectly.
396    // `EXTREME_BETA_NORM`: coefficient norm blow-up characteristic of the MLE
397    //   escaping to infinity under separation.
398    // `WEIGHT_COLLAPSE_FRACTION`: share of working weights collapsed to ~0.
399    const ORDER_SEPARATION_ETA_GAP: f64 = 1e-3;
400    const EXTREME_ETA: f64 = 30.0;
401    const SATURATION_FRACTION: f64 = 0.98;
402    const SEVERE_SATURATION_FRACTION: f64 = 0.995;
403    const DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-3;
404    const EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-6;
405    const EXTREME_BETA_NORM: f64 = 1e4;
406    const WEIGHT_COLLAPSE_FRACTION: f64 = 0.98;
407
408    let n = y.len() as f64;
409    if n == 0.0 {
410        return false;
411    }
412
413    let max_abs_eta = summary.max_abs_eta;
414    let sat_fraction = {
415        const SAT_EPS: f64 = 1e-3;
416        finalmu
417            .iter()
418            .filter(|&&m| m <= SAT_EPS || m >= 1.0 - SAT_EPS)
419            .count() as f64
420            / n
421    };
422
423    let weight_collapse_fraction = {
424        const WEIGHT_EPS: f64 = 1e-8;
425        finalweights
426            .iter()
427            .filter(|&&w| w <= WEIGHT_EPS || !w.is_finite())
428            .count() as f64
429            / n
430    };
431
432    let beta_norm = summary.beta.as_ref().dot(summary.beta.as_ref()).sqrt();
433    let dev_per_sample = summary.state.deviance / n;
434
435    let mut has_pos = false;
436    let mut has_neg = false;
437    let mut min_eta_pos = f64::INFINITY;
438    let mut max_eta_neg = f64::NEG_INFINITY;
439    for (eta_i, &yi) in summary.state.eta.iter().zip(y.iter()) {
440        if yi > 0.5 {
441            has_pos = true;
442            if *eta_i < min_eta_pos {
443                min_eta_pos = *eta_i;
444            }
445        } else {
446            has_neg = true;
447            if *eta_i > max_eta_neg {
448                max_eta_neg = *eta_i;
449            }
450        }
451    }
452    let order_separated =
453        has_pos && has_neg && (min_eta_pos - max_eta_neg) > ORDER_SEPARATION_ETA_GAP;
454
455    let classic_signals = max_abs_eta > EXTREME_ETA
456        || sat_fraction > SATURATION_FRACTION
457        || dev_per_sample < DEGENERATE_DEVIANCE_PER_SAMPLE
458        || beta_norm > EXTREME_BETA_NORM;
459
460    if !has_penalty {
461        return classic_signals || order_separated;
462    }
463
464    let severe_saturation = sat_fraction > SEVERE_SATURATION_FRACTION && max_abs_eta > EXTREME_ETA;
465    let weights_collapsed = weight_collapse_fraction > WEIGHT_COLLAPSE_FRACTION;
466    let dev_extremely_small = dev_per_sample < EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE;
467
468    order_separated || severe_saturation || weights_collapsed || dev_extremely_small
469}
470
471/// Stack λ-weighted penalty roots from canonical penalties into a single
472/// `total_rank × p` matrix for PIRLS. Each block-local root is embedded
473/// into the full column space on-the-fly.
474pub(super) fn stack_lambdaweighted_penalty_root_canonical(
475    penalties: &[gam_terms::construction::CanonicalPenalty],
476    lambdas: &[f64],
477    p: usize,
478) -> Array2<f64> {
479    let totalrows: usize = penalties.iter().map(|cp| cp.rank()).sum();
480    if totalrows == 0 {
481        return Array2::zeros((0, p));
482    }
483    let mut e = Array2::<f64>::zeros((totalrows, p));
484    let mut row_start = 0usize;
485    for (k, cp) in penalties.iter().enumerate() {
486        let rows = cp.rank();
487        if rows == 0 {
488            continue;
489        }
490        let scale = lambdas.get(k).copied().unwrap_or(0.0).max(0.0).sqrt();
491        if scale != 0.0 {
492            // Embed block-local root (rank × block_dim) into full width (rank × p).
493            let r = &cp.col_range;
494            for row in 0..rows {
495                for col in 0..cp.block_dim() {
496                    e[[row_start + row, r.start + col]] = scale * cp.root[[row, col]];
497                }
498            }
499        }
500        row_start += rows;
501    }
502    e
503}
504
505pub(super) fn build_sparse_native_reparam_result(
506    base: ReparamResult,
507    penalties: &[gam_terms::construction::CanonicalPenalty],
508    lambdas: &[f64],
509    p: usize,
510) -> ReparamResult {
511    // Map the engine penalty back into identity (original) coordinates. The
512    // engine returns `s_transformed = Qsᵀ S Qs` (and `e_transformed = E Qs`)
513    // with `S = S_λ + shrinkage·P_range` already folded in (so it matches the
514    // reported `log_det`/`det1`). With the sparse-native `qs = I` we need that
515    // SAME penalty expressed in original coordinates: `S_orig = Qs S_transformed
516    // Qsᵀ`. Rebuilding `S_orig` from the bare lambda-weighted canonical sum
517    // would DROP the shrinkage ridge and desync the inner penalized Hessian from
518    // the penalty log-determinant the REML criterion uses for this fit — the
519    // cross-backend λ-selection divergence (#1266 class). Round-tripping the
520    // engine penalty through `Qs` keeps the inner solve, EDF, and REML logdet on
521    // one penalty.
522    let qs = &base.qs;
523    let s_orig = if qs.nrows() == p && qs.ncols() == base.s_transformed.nrows() {
524        // S_orig = Qs · S_transformed · Qsᵀ
525        let qs_s = fast_ab(qs, &base.s_transformed);
526        qs_s.dot(&qs.t())
527    } else {
528        // Degenerate fallback (engine produced no transform): use the bare
529        // lambda-weighted sum. Shrinkage is zero in this branch by construction.
530        let mut s_original = Array2::<f64>::zeros((p, p));
531        for (k, cp) in penalties.iter().enumerate() {
532            let lambda_k = lambdas.get(k).copied().unwrap_or(0.0);
533            if lambda_k != 0.0 {
534                cp.accumulate_weighted(&mut s_original, lambda_k);
535            }
536        }
537        s_original
538    };
539    // E_orig = E_transformed · Qsᵀ  (so that E_origᵀ E_orig = S_orig and the EDF
540    // augmented system matches the inner Hessian).
541    let e_orig = if qs.nrows() == p && base.e_transformed.ncols() == qs.ncols() {
542        base.e_transformed.dot(&qs.t())
543    } else {
544        stack_lambdaweighted_penalty_root_canonical(penalties, lambdas, p)
545    };
546    let u_original = if base.u_truncated.nrows() == p {
547        fast_ab(&base.qs, &base.u_truncated)
548    } else {
549        Array2::<f64>::eye(p)
550    };
551    // In the sparse-native path, qs = I, so the penalties are already in the
552    // right coordinate frame. We keep them as-is in canonical_transformed.
553    let canonical_transformed: Vec<gam_terms::construction::CanonicalPenalty> = penalties.to_vec();
554    ReparamResult {
555        penalty_shrinkage_ridge: base.penalty_shrinkage_ridge,
556        s_transformed: s_orig,
557        log_det: base.log_det,
558        det1: base.det1,
559        qs: Array2::<f64>::eye(p),
560        canonical_transformed,
561        e_transformed: e_orig,
562        u_truncated: u_original,
563    }
564}
565
566pub(super) fn build_diagonal_penalty_from_kronecker(
567    kron_result: &KroneckerReparamResult,
568    lambdas: &[f64],
569) -> PirlsPenalty {
570    let d = kron_result.marginal_dims.len();
571    let p: usize = kron_result.marginal_dims.iter().copied().product();
572    let mut diag = Array1::<f64>::zeros(p);
573    let mut positive_indices = Vec::new();
574
575    const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
576    let mut multi_idx = vec![0usize; d];
577    let mut flat = 0usize;
578    loop {
579        let mut sigma = 0.0;
580        let mut structural_sigma = 0.0;
581        for k in 0..d {
582            let marginal_eigenvalue = kron_result.marginal_eigenvalues[k][multi_idx[k]];
583            structural_sigma += marginal_eigenvalue;
584            sigma += lambdas[k] * marginal_eigenvalue;
585        }
586        let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
587        if kron_result.has_double_penalty && lambdas.len() > d && joint_null {
588            sigma += lambdas[d];
589        }
590        if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
591            sigma += kron_result.penalty_shrinkage_ridge;
592        }
593        diag[flat] = sigma;
594        if sigma > 0.0 {
595            positive_indices.push(flat);
596        }
597        flat += 1;
598
599        let mut carry = true;
600        for dim in (0..d).rev() {
601            if carry {
602                multi_idx[dim] += 1;
603                if multi_idx[dim] < kron_result.marginal_dims[dim] {
604                    carry = false;
605                } else {
606                    multi_idx[dim] = 0;
607                }
608            }
609        }
610        if carry {
611            break;
612        }
613    }
614
615    PirlsPenalty::Diagonal {
616        diag,
617        positive_indices,
618        linear_shift: Array1::zeros(p),
619        constant_shift: 0.0,
620        prior_mean_target: Array1::zeros(p),
621    }
622}
623
624pub(super) fn canonical_prior_shift(
625    penalties: &[gam_terms::construction::CanonicalPenalty],
626    lambdas: &[f64],
627    p: usize,
628) -> (Array1<f64>, f64) {
629    let mut linear = Array1::<f64>::zeros(p);
630    let mut constant = 0.0;
631    for (idx, cp) in penalties.iter().enumerate() {
632        let Some(&lambda) = lambdas.get(idx) else {
633            continue;
634        };
635        if lambda == 0.0 {
636            continue;
637        }
638        linear += &cp.prior_linear_shift(lambda);
639        constant += cp.prior_constant_shift(lambda);
640    }
641    (linear, constant)
642}
643
644/// Aggregate prior-mean target across canonical penalty blocks: the sum of
645/// each block's `full_width_prior_mean()`. Used by the PIRLS solve sites
646/// that add a fixed stabilization ridge `δI` to the penalized Hessian — they
647/// must also add `δ · prior_mean_target` to the RHS to keep `β = μ` recovery
648/// exact when the data carries no information (X'WX = 0). Equivalent to
649/// `canonical_prior_shift` with all λ = 1 and dropping `S_k` from the linear
650/// piece (i.e., raw μ rather than `S_k μ`). Returned in the *original*
651/// coordinates; callers transform if needed.
652pub(super) fn canonical_prior_mean_aggregate(
653    penalties: &[gam_terms::construction::CanonicalPenalty],
654    p: usize,
655) -> Array1<f64> {
656    let mut mean = Array1::<f64>::zeros(p);
657    for cp in penalties {
658        mean += &cp.full_width_prior_mean();
659    }
660    mean
661}
662
663pub struct PirlsProblem<'a, X> {
664    pub x: X,
665    pub offset: ArrayView1<'a, f64>,
666    pub y: ArrayView1<'a, f64>,
667    pub priorweights: ArrayView1<'a, f64>,
668    pub covariate_se: Option<ArrayView1<'a, f64>>,
669    /// When set, the inner PLS solver reuses the precomputed `XᵀWX` and
670    /// `XᵀW(y − offset)` in *original* coordinates instead of streaming the
671    /// O(N·p²) GEMM and the O(N·p) matvec on every outer REML iteration.
672    ///
673    /// Valid only when the family is Gaussian + Identity link, prior weights
674    /// are constant across outer iterations (always true in the REML outer
675    /// loop), no Firth bias reduction, and no inequality / lower-bound
676    /// constraints (matching the existing Identity short-circuit at
677    /// `pirls.rs:6237`). The penalty `λ·S` is still added per-λ on top of
678    /// the cached `XᵀWX`.
679    pub gaussian_fixed_cache: Option<&'a GaussianFixedCache>,
680    /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for a GLM
681    /// design-moving ψ-trial (#1111 / #1033 mechanism (c)), in *original*
682    /// (conditioned `x_fit`) coordinates. When set, the iterative GLM P-IRLS
683    /// serves its FIRST Fisher-scoring iteration's `XᵀWX` from this matrix
684    /// instead of streaming the O(N·p²) weighted cross-product; every later
685    /// iteration restreams the true moving `W`, so the converged β̂ is
686    /// unchanged. Mutually distinct from `gaussian_fixed_cache` (which is the
687    /// Gaussian-identity converged-objective short-circuit); this is the GLM
688    /// first-step lane and never short-circuits the iteration count.
689    pub glm_first_step_gram: Option<&'a Array2<f64>>,
690}
691
692// GaussianFixedCache is defined in pls_solver.
693pub use super::pls_solver::GaussianFixedCache;
694
695pub struct PenaltyConfig<'a> {
696    /// Block-local canonical penalties with precomputed roots and spectral data.
697    /// This is the single canonical penalty representation — no full-width
698    /// `rank × p` roots are stored. When the reparameterization engine needs
699    /// full-width roots, they are derived on-the-fly from these block-local roots.
700    pub canonical_penalties: &'a [gam_terms::construction::CanonicalPenalty],
701    pub balanced_penalty_root: Option<&'a Array2<f64>>,
702    pub reparam_invariant: Option<&'a gam_terms::construction::ReparamInvariant>,
703    pub p: usize,
704    pub coefficient_lower_bounds: Option<&'a Array1<f64>>,
705    pub linear_constraints_original: Option<&'a LinearInequalityConstraints>,
706    /// Relative shrinkage floor for eigenvalues of the penalized block.
707    /// If `Some(epsilon)`, a rho-independent ridge of `epsilon * max_balanced_eigenvalue`
708    /// is added to prevent barely-penalized directions from causing pathological
709    /// non-Gaussianity in the posterior. Typical value: `1e-6`. `None` disables.
710    pub penalty_shrinkage_floor: Option<f64>,
711    /// When set, the penalties have Kronecker (tensor-product) structure.
712    /// The reparameterization engine will use factored Qs = U_1 ⊗ ... ⊗ U_d
713    /// instead of eigendecomposing the full p×p balanced penalty.
714    pub kronecker_factored: Option<&'a gam_terms::basis::KroneckerFactoredBasis>,
715}
716
717/// P-IRLS solver that follows mgcv's architecture exactly
718///
719/// This function implements the complete algorithm from mgcv's gam.fit3 function
720/// for fitting a GAM model with a fixed set of smoothing parameters:
721///
722/// - Perform stable reparameterization ONCE at the beginning (mgcv's gam.reparam)
723/// - Transform the design matrix into this stable basis
724/// - Extract a single penalty square root from the transformed penalty
725/// - Run the P-IRLS loop entirely in the transformed basis
726/// - Transform the coefficients back to the original basis only when returning
727/// - Reuse a cached balanced penalty root when available to avoid repeated eigendecompositions
728///
729/// This architecture ensures optimal numerical stability throughout the entire
730/// fitting process by working in a well-conditioned parameter space.
731pub fn fit_model_for_fixed_rho<'a, X: Into<DesignMatrix> + Clone>(
732    rho: LogSmoothingParamsView<'_>,
733    problem: PirlsProblem<'a, X>,
734    penalty: PenaltyConfig<'_>,
735    config: &PirlsConfig,
736    warm_start_beta: Option<&Coefficients>,
737) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
738    fit_model_for_fixed_rho_with_adaptive_kkt(
739        rho,
740        problem,
741        penalty,
742        config,
743        warm_start_beta,
744        None,
745        false,
746    )
747}
748
749/// `refine_dispersion_at_converged_eta`: when `true`, after the inner P-IRLS
750/// solve converges, re-estimate the family's estimated dispersion nuisance — the
751/// Gamma shape ν = 1/φ or the Beta precision φ — at the *converged* linear
752/// predictor and iterate the (β, dispersion) pair to its joint fixed point at the
753/// current λ (see the in-body comments at each refresh loop). This is ON only for
754/// the single final, reported fit at the REML-selected λ (#678 for Gamma, #769
755/// for Beta). It is deliberately OFF for every REML cost / sigma-point evaluation:
756/// re-profiling the dispersion against each trial λ's converged residuals would
757/// couple the scale to the smoothing parameter (a flat over-smoothed μ inflates
758/// the deviance ⇒ a smaller effective precision ⇒ a smaller `deviance/(2φ)` REML
759/// term), perversely rewarding over-smoothing and biasing λ selection. mgcv
760/// likewise estimates the scale at the converged fit, not inside the λ search.
761///
762/// The Gamma and Beta cases differ in what the re-solve buys. For Gamma the shape
763/// is a pure nuisance — β̂ is essentially scale-free — so the re-solve only keeps
764/// the reported dispersion and SEs self-consistent. For Beta the precision φ
765/// enters the *mean* score through the digamma terms
766/// `μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ)`, so a φ measured at the cold null predictor
767/// (μ ≈ 0.5) attenuates every slope toward zero; here the fixed point is
768/// load-bearing — it is what recovers the correct mean coefficients (the betareg
769/// alternating mean-fit ↔ φ-estimate scheme).
770pub(crate) fn fit_model_for_fixed_rho_with_adaptive_kkt<'a, X: Into<DesignMatrix> + Clone>(
771    rho: LogSmoothingParamsView<'_>,
772    problem: PirlsProblem<'a, X>,
773    penalty: PenaltyConfig<'_>,
774    config: &PirlsConfig,
775    warm_start_beta: Option<&Coefficients>,
776    adaptive_kkt_tolerance: Option<AdaptiveKktTolerance>,
777    refine_dispersion_at_converged_eta: bool,
778) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
779    let PirlsProblem {
780        x,
781        offset,
782        y,
783        priorweights,
784        covariate_se,
785        gaussian_fixed_cache,
786        glm_first_step_gram,
787    } = problem;
788    let quadctx = crate::quadrature::QuadratureContext::new();
789    // gam#1379 — finite-ceiling λ = exp(ρ). When the outer REML / spatial-κ
790    // optimizer drives a redundant penalty direction's log-λ past ~709 (it does
791    // so deterministically on 1-D `matern(x)` / `bs="gp"` data whose kernel
792    // already controls the smoothness an operator block also penalizes, so REML
793    // wants λ → ∞), `exp(ρ)` overflows to `+∞`. A literal `+∞` λ then poisons
794    // every downstream consumer that forms `λ · S`: the range-penalty block
795    // assembled as `Σ λ_k S_k` hits `∞ · 0 = NaN` and the eigensolve aborts, and
796    // the final fit-result validation rejects the non-finite stored λ outright.
797    // `exp(709.78) ≈ 1.8e308` is already the largest finite f64; capping log-λ at
798    // a value whose `exp` stays finite pins the over-penalized direction exactly
799    // as hard as `+∞` would for every finite-arithmetic consumer (the penalized
800    // block is numerically a hard constraint at λ this large) while keeping
801    // `λ · 0 = 0`. Ordinary finite λ are untouched, so non-degenerate fits and
802    // their recorded λ̂ are bit-identical. `ln(1e300) ≈ 690.78` keeps this in lock
803    // step with the post-exp λ ceiling (`1e300`) used by the reparam range-block
804    // assembly and the stored fit result, so a fully-smoothed direction carries
805    // the SAME finite λ everywhere it is consumed.
806    const LOG_LAMBDA_CEILING: f64 = 690.0;
807    let lambdas = rho.mapv(|r| {
808        if r.is_nan() {
809            r
810        } else {
811            r.min(LOG_LAMBDA_CEILING).exp()
812        }
813    });
814    let lambdas_slice = lambdas.as_slice_memory_order().ok_or_else(|| {
815        EstimationError::InvalidInput("non-contiguous lambda storage".to_string())
816    })?;
817
818    let likelihood = &config.likelihood;
819    let link_function = config.link_function();
820
821    use gam_terms::construction::{
822        EngineDims, create_balanced_penalty_root_from_canonical,
823        stable_reparameterization_engine_canonical,
824    };
825
826    let eb_cow: Cow<'_, Array2<f64>> = if let Some(precomputed) = penalty.balanced_penalty_root {
827        Cow::Borrowed(precomputed)
828    } else {
829        Cow::Owned(create_balanced_penalty_root_from_canonical(
830            penalty.canonical_penalties,
831            penalty.p,
832        )?)
833    };
834    let eb: &Array2<f64> = eb_cow.as_ref();
835
836    // Build a cheap weighted penalty sum for the sparse-native decision
837    // WITHOUT running the expensive eigendecomposition engine.
838    // The full reparameterization is deferred until we know which path we need.
839    let cheap_s_lambda: Option<Array2<f64>> = if penalty.kronecker_factored.is_none() {
840        let mut s = Array2::<f64>::zeros((penalty.p, penalty.p));
841        for (k, cp) in penalty.canonical_penalties.iter().enumerate() {
842            let lam = lambdas_slice.get(k).copied().unwrap_or(0.0);
843            if lam != 0.0 {
844                cp.accumulate_weighted(&mut s, lam);
845            }
846        }
847        Some(s)
848    } else {
849        None
850    };
851    let kronecker_runtime = if let Some(kron) = penalty.kronecker_factored {
852        // The marginal eigensystems and reparameterized marginals depend only on
853        // the fixed marginal designs/penalties, not on λ = exp(ρ). Memoize them
854        // once per fit so each outer REML iterate reuses the eigendecomposition
855        // instead of recomputing `eigh()` + `B_k·U_k` every call; only the cheap
856        // λ-grid logdet/derivative sweep is redone here. Bit-identical to the
857        // unmemoized engine.
858        let invariant = kron.invariant_structure()?;
859        let kron_result =
860            gam_terms::construction::kronecker_reparameterization_engine_with_invariant(
861                invariant.as_ref(),
862                &kron.marginal_dims,
863                lambdas_slice,
864                kron.has_double_penalty,
865                penalty.penalty_shrinkage_floor,
866            )?;
867        let transform = Arc::new(KroneckerQsTransform::new(&kron_result));
868        let penalty_diag = build_diagonal_penalty_from_kronecker(&kron_result, lambdas_slice);
869        Some((kron_result, transform, penalty_diag))
870    } else {
871        None
872    };
873    // Constraint transformation is deferred until after the sparse-native
874    // decision, because the dense reparameterization engine (which provides Qs)
875    // is now run lazily.  Kronecker constraints can be built eagerly since
876    // the Kronecker transform is already available.
877    let kronecker_constraints = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
878        let tb = build_transformed_lower_bound_constraints_with_transform(
879            &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
880            penalty.coefficient_lower_bounds,
881        );
882        let tl = build_transformed_linear_constraints_with_transform(
883            &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
884            penalty.linear_constraints_original,
885        );
886        Some(merge_linear_constraints(tb, tl))
887    } else {
888        None
889    };
890
891    let x_original: DesignMatrix = x.into();
892    // Auto-detect sparse structure in dense designs so the sparse-native path
893    // can engage for structurally sparse models that happen to be stored dense.
894    let x_original = {
895        let auto_sparse = x_original
896            .as_dense()
897            .and_then(|dense| sparse_from_denseview(dense.view()));
898        auto_sparse.unwrap_or(x_original)
899    };
900    let ebrows = eb.nrows();
901    let erows = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
902        penalty_diag.rank()
903    } else {
904        // Compute penalty root rank cheaply from canonical penalties.
905        penalty
906            .canonical_penalties
907            .iter()
908            .map(|cp| cp.rank())
909            .sum::<usize>()
910    };
911    let mut workspace = PirlsWorkspace::new(x_original.nrows(), x_original.ncols(), ebrows, erows);
912    let solver_decision = if let Some((_, _, _)) = kronecker_runtime.as_ref() {
913        SparsePirlsDecision {
914            path: PirlsLinearSolvePath::DenseTransformed,
915            reason: "kronecker_runtime",
916            p: x_original.ncols(),
917            nnz_x: 0,
918            nnz_xtwx_symbolic: None,
919            nnz_s_lambda: 0,
920            nnz_h_est: None,
921            density_h_est: None,
922        }
923    } else {
924        should_use_sparse_native_pirls(
925            &mut workspace,
926            &x_original,
927            cheap_s_lambda
928                .as_ref()
929                .expect("cheap_s_lambda should be present outside Kronecker path"),
930            penalty.coefficient_lower_bounds,
931            penalty.linear_constraints_original,
932        )
933    };
934    solver_decision.log_once();
935
936    let use_sparse_native = matches!(solver_decision.path, PirlsLinearSolvePath::SparseNative);
937
938    // Run the eigendecomposition engine for the dense-transformed path. The
939    // sparse-native path also needs it, but only to obtain a penalty that is
940    // *consistent with the REML penalty log-determinant it reports* — see the
941    // sparse-native `reparam` below. The dense path keeps `qs ≠ I`; the
942    // sparse-native path discards `qs` (identity coords) and reuses only the
943    // shrinkage-folded `s_transformed`/`e_transformed`.
944    let dense_reparam_result = if !use_sparse_native && penalty.kronecker_factored.is_none() {
945        Some(stable_reparameterization_engine_canonical(
946            penalty.canonical_penalties,
947            lambdas_slice,
948            EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
949            penalty.reparam_invariant,
950            penalty.penalty_shrinkage_floor,
951        )?)
952    } else {
953        None
954    };
955    // Sparse-native reparam result, in identity (original) coordinates with the
956    // penalty shrinkage floor folded in. This MUST drive the inner penalized
957    // solve too: when `penalty_shrinkage_floor` is active (default `Some(1e-6)`)
958    // the dense engine adds `shrinkage·P_range` to every penalized range
959    // direction of `S_λ` and rebuilds `s_transformed = EᵀE` from the floored
960    // roots, so `base.log_det` (the REML penalty pseudo-logdet) is the
961    // determinant of `S_λ + shrinkage·P_range`, NOT of the bare `S_λ`. Building
962    // the inner Hessian from an UN-shrunk `S_λ` (the previous behaviour, via the
963    // `cheap_s_lambda` row-sum) while reporting the shrunk `log_det` made the
964    // sparse-native REML surface internally inconsistent — the penalty-logdet
965    // term and the inner H / EDF / β̂ lived on different penalties — which biased
966    // λ-selection relative to the dense and Kronecker backends for the SAME
967    // model (the #1266 cross-backend divergence class). Reusing the engine's
968    // shrinkage-folded penalty here makes all three backends solve the same
969    // penalized objective.
970    let sparse_native_reparam = if use_sparse_native && penalty.kronecker_factored.is_none() {
971        let base = stable_reparameterization_engine_canonical(
972            penalty.canonical_penalties,
973            lambdas_slice,
974            EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
975            penalty.reparam_invariant,
976            penalty.penalty_shrinkage_floor,
977        )?;
978        Some(build_sparse_native_reparam_result(
979            base,
980            penalty.canonical_penalties,
981            lambdas_slice,
982            penalty.p,
983        ))
984    } else {
985        None
986    };
987    let qs_arc = dense_reparam_result
988        .as_ref()
989        .map(|reparam_result| Arc::new(reparam_result.qs.clone()));
990    let transform_active = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
991        Some(WorkingReparamTransform::Kronecker(Arc::clone(transform)))
992    } else if use_sparse_native {
993        None
994    } else {
995        Some(WorkingReparamTransform::Dense(Arc::clone(
996            qs_arc
997                .as_ref()
998                .expect("dense Qs should exist for non-Kronecker transformed path"),
999        )))
1000    };
1001    let mut penalty_active = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
1002        penalty_diag.clone()
1003    } else if use_sparse_native {
1004        // Sparse-native inner penalty in original (identity) coordinates. Use
1005        // the shrinkage-folded `s_transformed`/`e_transformed` from
1006        // `sparse_native_reparam` so the inner penalized Hessian
1007        // `H = XᵀWX + S` matches the penalty whose log-determinant the REML
1008        // criterion reports for this fit (`base.log_det`). Falling back to the
1009        // bare lambda-weighted sum here (the prior behaviour) omitted the
1010        // `penalty_shrinkage_floor` ridge and desynced the inner solve from the
1011        // REML logdet, biasing λ-selection vs the dense/Kronecker backends.
1012        let sparse_reparam = sparse_native_reparam
1013            .as_ref()
1014            .expect("sparse_native_reparam should be present for sparse-native path");
1015        PirlsPenalty::Dense {
1016            s_transformed: sparse_reparam.s_transformed.clone(),
1017            e_transformed: sparse_reparam.e_transformed.clone(),
1018            linear_shift: Array1::zeros(penalty.p),
1019            constant_shift: 0.0,
1020            prior_mean_target: Array1::zeros(penalty.p),
1021        }
1022    } else {
1023        let dense = dense_reparam_result
1024            .as_ref()
1025            .expect("dense reparam result should be present outside Kronecker path");
1026        PirlsPenalty::Dense {
1027            s_transformed: dense.s_transformed.clone(),
1028            e_transformed: dense.e_transformed.clone(),
1029            linear_shift: Array1::zeros(penalty.p),
1030            constant_shift: 0.0,
1031            prior_mean_target: Array1::zeros(penalty.p),
1032        }
1033    };
1034    let (shift_original, shift_constant) =
1035        canonical_prior_shift(penalty.canonical_penalties, lambdas_slice, penalty.p);
1036    let shift_active = transform_active
1037        .as_ref()
1038        .map(|transform| transform.apply_transpose(&shift_original))
1039        .unwrap_or(shift_original);
1040    let prior_mean_original =
1041        canonical_prior_mean_aggregate(penalty.canonical_penalties, penalty.p);
1042    let prior_mean_active = transform_active
1043        .as_ref()
1044        .map(|transform| transform.apply_transpose(&prior_mean_original))
1045        .unwrap_or(prior_mean_original);
1046    attach_penalty_shift(
1047        &mut penalty_active,
1048        shift_active,
1049        shift_constant,
1050        prior_mean_active,
1051    );
1052    // Build transformed constraints now that dense_reparam_result is available.
1053    let linear_constraints = if let Some(kc) = kronecker_constraints {
1054        kc
1055    } else if let Some(reparam) = dense_reparam_result.as_ref() {
1056        let tb = build_transformed_lower_bound_constraints(
1057            &reparam.qs,
1058            penalty.coefficient_lower_bounds,
1059        );
1060        let tl =
1061            build_transformed_linear_constraints(&reparam.qs, penalty.linear_constraints_original);
1062        merge_linear_constraints(tb, tl)
1063    } else {
1064        // Sparse-native without dense reparam: constraints stay in original
1065        // coordinates (identity Qs).  Use an identity matrix of appropriate size.
1066        let p = penalty.p;
1067        let qs_identity = Array2::<f64>::eye(p);
1068        let tb = build_transformed_lower_bound_constraints(
1069            &qs_identity,
1070            penalty.coefficient_lower_bounds,
1071        );
1072        let tl =
1073            build_transformed_linear_constraints(&qs_identity, penalty.linear_constraints_original);
1074        merge_linear_constraints(tb, tl)
1075    };
1076
1077    let coordinate_frame = if use_sparse_native {
1078        PirlsCoordinateFrame::OriginalSparseNative
1079    } else {
1080        PirlsCoordinateFrame::TransformedQs
1081    };
1082    let materialize_final_reparam_result = || -> Result<ReparamResult, EstimationError> {
1083        if let Some((kron_result, _, _)) = kronecker_runtime.as_ref() {
1084            let rs_list: Vec<Array2<f64>> = penalty
1085                .canonical_penalties
1086                .iter()
1087                .map(|cp| cp.full_width_root())
1088                .collect();
1089            kron_result.materialize_dense_artifact_result(&rs_list, lambdas_slice, penalty.p)
1090        } else if use_sparse_native {
1091            // Sparse-native path: reuse the engine result already computed for
1092            // `penalty_active` (with the shrinkage floor folded in and mapped to
1093            // identity coordinates). This is both correct — the REML
1094            // log-determinant now matches the penalty the inner solve used — and
1095            // cheaper, since the eigendecomposition is no longer run twice.
1096            Ok(sparse_native_reparam
1097                .as_ref()
1098                .expect("sparse_native_reparam should be present for sparse-native path")
1099                .clone())
1100        } else {
1101            Ok(dense_reparam_result
1102                .as_ref()
1103                .expect("dense reparam result should be present outside Kronecker path")
1104                .clone())
1105        }
1106    };
1107
1108    // Stage 3.3-GI: GPU exact PLS dispatch — see pirls_host_dispatch::try_gaussian_pls_gpu.
1109    if let Some(result) = try_gaussian_pls_gpu(
1110        link_function,
1111        config,
1112        penalty.coefficient_lower_bounds,
1113        penalty.linear_constraints_original,
1114        gaussian_fixed_cache,
1115        &penalty_active,
1116        &qs_arc,
1117        &x_original,
1118        use_sparse_native,
1119        penalty.p,
1120        || materialize_final_reparam_result(),
1121        y,
1122        priorweights,
1123        offset,
1124        coordinate_frame,
1125        &linear_constraints,
1126    ) {
1127        return result;
1128    }
1129
1130    if matches!(link_function, LinkFunction::Identity) && linear_constraints.is_none() {
1131        // Gaussian-Identity zero-iteration exact solve. The unconstrained
1132        // penalized least-squares system is linear, so for an identity link a
1133        // single solve is the exact minimizer and no PIRLS iteration is needed.
1134        //
1135        // This shortcut is only valid in the *unconstrained* convex program.
1136        // When shape/box/linear inequality constraints are present (e.g. a
1137        // `shape=monotone_increasing` smooth, whose cumulative-sum box-reparam
1138        // bounds `γ_j ≥ 0` are folded into `linear_constraints` above), the
1139        // minimizer is the solution of an inequality-constrained QP, not the
1140        // plain normal-equations solve. Taking this branch then returns the
1141        // unconstrained β, which generically violates the constraints and is
1142        // rejected by the REML startup KKT gate (`enforce_constraint_kkt`),
1143        // aborting the whole fit. Gating on `linear_constraints.is_none()`
1144        // routes every constrained Identity fit to the iterative loop below,
1145        // which builds a feasible initial point and solves the exact QP via
1146        // the active-set solver — mirroring the gate already enforced on the
1147        // GPU Gaussian-PLS path in `try_gaussian_pls_gpu`.
1148        //
1149        // Apply the Gaussian-Identity fixed-data cache only when every
1150        // precondition for the short-circuit's exact reuse holds: the family
1151        // really is Gaussian (z = y), there is no Firth bias-reduction term,
1152        // no coefficient lower bounds, and no linear inequality constraints
1153        // — anything that would change the right-hand side or the system
1154        // beyond the additive penalty would invalidate the cache.
1155        let cache_eligible = gaussian_fixed_cache.is_some()
1156            && likelihood.spec.is_gaussian_identity()
1157            && !config.firth_bias_reduction
1158            && penalty.coefficient_lower_bounds.is_none()
1159            && penalty.linear_constraints_original.is_none();
1160        let cache_for_solve = if cache_eligible {
1161            gaussian_fixed_cache
1162        } else {
1163            None
1164        };
1165        let (pls_result, _) = solve_penalized_least_squares_implicit(
1166            &x_original,
1167            transform_active.as_ref(),
1168            y,
1169            priorweights,
1170            offset,
1171            &penalty_active,
1172            &mut workspace,
1173            y,
1174            link_function,
1175            cache_for_solve,
1176        )?;
1177
1178        let beta_transformed = pls_result.beta;
1179        let penalized_hessian = pls_result.penalized_hessian;
1180        let edf = pls_result.edf;
1181        let baseridge = pls_result.ridge_used;
1182
1183        let priorweights_owned = priorweights.to_owned();
1184        // eta = offset + X Qs beta (composed, no materialization) unless a
1185        // design-moving ψ tensor cache explicitly says the surface rows are a
1186        // stale reference. In that lane the Gaussian objective and gradient are
1187        // fully determined by (G, r, y'Wy), so applying `x_original` would both
1188        // reintroduce per-trial row work and evaluate the wrong ψ.
1189        let qbeta = transform_active
1190            .as_ref()
1191            .map(|transform| transform.apply(beta_transformed.as_ref()))
1192            .unwrap_or_else(|| beta_transformed.as_ref().clone());
1193        let stale_row_cache = cache_for_solve.filter(|cache| cache.row_prediction_is_stale);
1194        let (final_eta, finalmu, finalz, gradient_data, deviance, log_likelihood, max_abs_eta) =
1195            if let Some(cache) = stale_row_cache {
1196                let final_eta = offset.to_owned();
1197                let finalmu = final_eta.clone();
1198                let finalz = y.to_owned();
1199                let mut grad_orig = cache.xtwx_orig.dot(&qbeta);
1200                grad_orig -= &cache.xtwy_orig;
1201                let gradient_data = transform_active
1202                    .as_ref()
1203                    .map(|transform| transform.apply_transpose(&grad_orig))
1204                    .unwrap_or(grad_orig);
1205                let weighted_rss = (cache.centered_weighted_y_sq
1206                    - 2.0 * qbeta.dot(&cache.xtwy_orig)
1207                    + qbeta.dot(&cache.xtwx_orig.dot(&qbeta)))
1208                .max(0.0);
1209                let phi = likelihood.scale.fixed_phi().unwrap_or(1.0);
1210                let deviance = if phi.is_finite() && phi > 0.0 {
1211                    weighted_rss / phi
1212                } else {
1213                    f64::NAN
1214                };
1215                let log_likelihood = calculate_loglikelihood(y, &finalmu, likelihood, priorweights);
1216                let max_abs_eta = inf_norm(finalmu.iter().copied());
1217                (
1218                    final_eta,
1219                    finalmu,
1220                    finalz,
1221                    gradient_data,
1222                    deviance,
1223                    log_likelihood,
1224                    max_abs_eta,
1225                )
1226            } else {
1227                let mut eta = offset.to_owned();
1228                eta += &x_original.apply(&qbeta);
1229                let final_eta = eta.clone();
1230                let finalmu = eta.clone();
1231                let finalz = y.to_owned();
1232
1233                let mut weighted_residual = finalmu.clone();
1234                weighted_residual -= &finalz;
1235                weighted_residual *= &priorweights_owned;
1236                // gradient = Qs^T X^T (w * residual) (composed)
1237                let xt_wr = x_original.apply_transpose(&weighted_residual);
1238                let gradient_data = transform_active
1239                    .as_ref()
1240                    .map(|transform| transform.apply_transpose(&xt_wr))
1241                    .unwrap_or(xt_wr);
1242                let deviance = calculate_deviance(y, &finalmu, likelihood, priorweights);
1243                let log_likelihood = calculate_loglikelihood(y, &finalmu, likelihood, priorweights);
1244                let max_abs_eta = inf_norm(finalmu.iter().copied());
1245                (
1246                    final_eta,
1247                    finalmu,
1248                    finalz,
1249                    gradient_data,
1250                    deviance,
1251                    log_likelihood,
1252                    max_abs_eta,
1253                )
1254            };
1255        let score_norm = array1_l2_norm(&gradient_data);
1256        let s_beta = penalty_active.shifted_gradient(beta_transformed.as_ref());
1257        let s_beta_norm = array1_l2_norm(&s_beta);
1258        let mut gradient = gradient_data;
1259        gradient += &s_beta;
1260        let mut penalty_term = penalty_active.shifted_quadratic(beta_transformed.as_ref());
1261        let ridge_used = baseridge;
1262        let stabilizedhessian = if ridge_used > 0.0 {
1263            penalized_hessian
1264                .addridge(ridge_used)
1265                .map_err(|e| EstimationError::InvalidInput(format!("ridge addition failed: {e}")))?
1266        } else {
1267            penalized_hessian.clone()
1268        };
1269        let mut ridge_grad_norm = 0.0;
1270        if ridge_used > 0.0 {
1271            let ridge_penalty =
1272                ridge_used * beta_transformed.as_ref().dot(beta_transformed.as_ref());
1273            penalty_term += ridge_penalty;
1274            gradient += &beta_transformed.as_ref().mapv(|v| ridge_used * v);
1275            ridge_grad_norm = ridge_used * array1_l2_norm(beta_transformed.as_ref());
1276        }
1277
1278        let gradient_norm = array1_l2_norm(&gradient);
1279        let working_state = WorkingState {
1280            eta: LinearPredictor::new(finalmu.clone()),
1281            gradient: gradient.clone(),
1282            hessian: penalized_hessian.clone(),
1283
1284            log_likelihood,
1285            deviance,
1286            penalty_term,
1287            firth: FirthDiagnostics::Inactive,
1288            ridge_used,
1289            hessian_curvature: HessianCurvatureKind::Fisher,
1290            gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1291        };
1292
1293        let zero_iter_penalized = deviance + penalty_term;
1294        let working_summary = WorkingModelPirlsResult {
1295            beta: beta_transformed.clone(),
1296            state: working_state,
1297            status: PirlsStatus::Converged,
1298            iterations: 1,
1299            lastgradient_norm: gradient_norm,
1300            last_deviance_change: 0.0,
1301            last_step_size: 1.0,
1302            last_step_halving: 0,
1303            max_abs_eta,
1304            constraint_kkt: linear_constraints.as_ref().map(|lin| {
1305                compute_constraint_kkt_diagnostics(beta_transformed.as_ref(), &gradient, lin)
1306            }),
1307            min_penalized_deviance: if zero_iter_penalized.is_finite() {
1308                zero_iter_penalized
1309            } else {
1310                f64::INFINITY
1311            },
1312            // Zero-iteration synthesis: no LM damping was exercised, so
1313            // hand the next solve the cold default.
1314            final_lm_lambda: 1e-6,
1315            // Zero-iteration synthesis: no LM gain ratio was measured.
1316            final_accept_rho: None,
1317            // Zero-iteration synthesis assembles the Hessian with prior
1318            // weights only; no observed-information re-evaluation has
1319            // happened. Label honestly as a Fisher-type surrogate so
1320            // outer Laplace consumers see the truth.
1321            exported_laplace_curvature: ExportedLaplaceCurvature::ExpectedInformationSurrogate,
1322        };
1323
1324        let (solve_c_array, solve_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
1325            computeworkingweight_derivatives_from_eta(
1326                &config.likelihood,
1327                &config.link_kind,
1328                &final_eta,
1329                priorweights_owned.view(),
1330            )?;
1331        let reparam_result = materialize_final_reparam_result()?;
1332        let qs_arc_final = Arc::new(reparam_result.qs.clone());
1333        let pirls_result = PirlsResult {
1334            likelihood: config.likelihood.clone(),
1335            beta_transformed,
1336            penalized_hessian_transformed: penalized_hessian,
1337            stabilizedhessian_transformed: stabilizedhessian,
1338            ridge_passport: RidgePassport::scaled_identity(
1339                ridge_used,
1340                RidgePolicy::explicit_stabilization_full(),
1341            ),
1342            ridge_used,
1343            deviance,
1344            edf,
1345            stable_penalty_term: penalty_term,
1346            firth: FirthDiagnostics::Inactive,
1347            finalweights: priorweights_owned.clone(),
1348            final_offset: offset.to_owned(),
1349            final_eta: final_eta.clone(),
1350            finalmu: finalmu.clone(),
1351            solveweights: priorweights_owned,
1352            solveworking_response: finalz.clone(),
1353            solvemu: finalmu.clone(),
1354            solve_dmu_deta,
1355            solve_d2mu_deta2,
1356            solve_d3mu_deta3,
1357            solve_c_array,
1358            solve_d_array,
1359            derivatives_unsupported: false,
1360            status: PirlsStatus::Converged,
1361            iteration: 1,
1362            max_abs_eta,
1363            lastgradient_norm: gradient_norm,
1364            gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1365            last_deviance_change: 0.0,
1366            last_step_halving: 0,
1367            hessian_curvature: HessianCurvatureKind::Fisher,
1368            exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
1369            final_lm_lambda: working_summary.final_lm_lambda,
1370            final_accept_rho: working_summary.final_accept_rho,
1371            constraint_kkt: working_summary.constraint_kkt.clone(),
1372            linear_constraints_transformed: linear_constraints.clone(),
1373            reparam_result,
1374            x_transformed: make_reparam_operator(&x_original, &qs_arc_final, use_sparse_native),
1375            coordinate_frame,
1376            used_device: false,
1377            cache_compacted: false,
1378            min_penalized_deviance: working_summary.min_penalized_deviance,
1379        };
1380
1381        return Ok((pirls_result, working_summary));
1382    }
1383
1384    let x_original_for_result = x_original.clone();
1385    let mut working_model = GamWorkingModel::new(
1386        None, // No pre-materialized x_transformed: use implicit Qs composition
1387        x_original.clone(),
1388        coordinate_frame,
1389        offset,
1390        y,
1391        priorweights,
1392        penalty_active.clone(),
1393        workspace,
1394        config.likelihood.clone(),
1395        config.link_kind.clone(),
1396        // Inner Firth/Jeffreys activation must agree with the caller-requested
1397        // mode. The REML *outer* analytic derivative assembly only carries the
1398        // Jeffreys score/curvature term when `firth_bias_reduction` is set
1399        // (`reml_robust_jeffreys_link` returns `None` otherwise), so arming the
1400        // inner penalty unconditionally would converge the inner mode to the
1401        // Firth-penalized stationary point while the outer H/u/IFT stayed
1402        // non-Firth — the two would then disagree by exactly the Jeffreys
1403        // contribution (broken τ-τ Hessian-vs-FD and stationarity-cancellation
1404        // identities, #825). Gate on `firth_bias_reduction` so inner and outer
1405        // are the same objective.
1406        config.firth_bias_reduction
1407            && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1408            && inverse_link_has_fisher_weight_jet(&config.link_kind),
1409        transform_active.clone(),
1410        quadctx,
1411        // #1111 / #1033 mechanism (c): frozen-W first-Fisher-step XᵀWX in the
1412        // original (conditioned x_fit) frame, served n-free on the first inner
1413        // iteration. Suppressed under Firth bias reduction, which shifts the
1414        // working response per iteration (the installer also gates Firth off).
1415        if config.firth_bias_reduction {
1416            None
1417        } else {
1418            glm_first_step_gram.cloned()
1419        },
1420    );
1421
1422    // Apply integrated (GHQ) likelihood if per-observation SE is provided.
1423    // This is used by the calibrator to coherently account for base prediction uncertainty.
1424    if let Some(se) = covariate_se {
1425        working_model = working_model.with_covariate_se(se.to_owned());
1426    }
1427
1428    let mut beta_guess_original = warm_start_beta
1429        .filter(|beta| beta.len() == penalty.p)
1430        .map(|beta| beta.to_owned())
1431        .unwrap_or_else(|| {
1432            Coefficients::new(default_beta_guess_external(
1433                penalty.p,
1434                link_function,
1435                y,
1436                priorweights,
1437                config.link_kind.mixture_state(),
1438                config.link_kind.sas_state(),
1439            ))
1440        });
1441    if let Some(lb) = penalty.coefficient_lower_bounds {
1442        project_coefficients_to_lower_bounds(&mut beta_guess_original.0, lb);
1443    }
1444    let initial_beta = transform_active
1445        .as_ref()
1446        .map(|transform| transform.apply_transpose(beta_guess_original.as_ref()))
1447        .unwrap_or_else(|| beta_guess_original.as_ref().clone());
1448    let initial_beta = if let Some(constraints) = linear_constraints.as_ref() {
1449        // Worst per-row *scaled* (geometric) slack of the current seed against the
1450        // constraint cone. Negative ⇒ the seed violates a row; ~0 ⇒ the seed sits
1451        // ON the boundary (for a homogeneous convex/concave second-difference
1452        // cone, `β = 0` — the unconstrained Gaussian seed — sits on EVERY row's
1453        // boundary, i.e. the cone vertex). Either way the seed must be pushed
1454        // strictly into the interior before P-IRLS starts.
1455        let mut min_scaled_slack = f64::INFINITY;
1456        for i in 0..constraints.a.nrows() {
1457            let norm = constraints.a.row(i).dot(&constraints.a.row(i)).sqrt();
1458            let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
1459            let slack = (constraints.a.row(i).dot(&initial_beta) - constraints.b[i]) * inv;
1460            min_scaled_slack = min_scaled_slack.min(slack);
1461        }
1462        // Push the seed to the nearest STRICTLY-INTERIOR feasible point whenever
1463        // any row is tight or violated. A seed on the cone boundary (most acutely
1464        // the vertex `β = 0`) hands the inner active-set QP an all-rows-active
1465        // working set, where it stalls on a degenerate, non-stationary face — so
1466        // the fit silently diverges (or aborts in release) between a cold and a
1467        // warm warm-start cache (#873). A strictly-interior seed makes the QP's
1468        // initial active set empty; it then adds only the genuinely binding rows
1469        // and converges to the certified constrained optimum regardless of cache
1470        // state. The projection keeps the data-driven curvature of `initial_beta`
1471        // and falls back to the min-norm feasible point only if it cannot certify
1472        // a strictly-interior solution.
1473        //
1474        // The min-norm fallback (`feasible_point_for_linear_constraints`) is only
1475        // used for a NON-homogeneous cone (`b ≠ 0`), where it returns a genuine
1476        // interior-of-the-offset-polyhedron point. For a HOMOGENEOUS shape cone
1477        // (`b ≈ 0` — the convex/concave second-difference rows) that function
1478        // returns the minimum-norm feasible point `β = 0`, which is the cone
1479        // *vertex*: the exact all-rows-tight degenerate seed #873 is about. Taking
1480        // it would silently reintroduce the #873 pathology whenever the strict
1481        // projection rarely fails to certify. So for a homogeneous cone we skip the
1482        // vertex fallback entirely and prefer the data-driven `initial_beta`: it
1483        // violates at most *some* rows (a lower-dimensional, non-degenerate face the
1484        // inner active-set QP can recover from), strictly better than the vertex
1485        // where *every* row is simultaneously tight.
1486        let cone_is_homogeneous = constraints.b.iter().all(|v| v.abs() <= 1e-14);
1487        if min_scaled_slack < active_set::interior_seed_margin() {
1488            let projected =
1489                active_set::project_point_strictly_into_feasible_cone(&initial_beta, constraints)
1490                    .or_else(|| {
1491                        if cone_is_homogeneous {
1492                            None
1493                        } else {
1494                            active_set::feasible_point_for_linear_constraints(
1495                                constraints,
1496                                initial_beta.len(),
1497                            )
1498                        }
1499                    });
1500            projected.unwrap_or(initial_beta)
1501        } else {
1502            initial_beta
1503        }
1504    } else {
1505        initial_beta
1506    };
1507    // Inner P-IRLS Firth activation. The inner penalized objective must match
1508    // the objective the REML outer derivatives are assembled against: the outer
1509    // path carries the Jeffreys/Firth score+curvature only when the caller set
1510    // `firth_bias_reduction` (`reml_robust_jeffreys_link` is `None` otherwise),
1511    // so the inner Firth term is armed iff the caller requested it AND the link
1512    // exposes a Fisher-weight jet (#825). Forcing it on unconditionally desynced
1513    // the Firth-penalized inner mode from the non-Firth outer assembly.
1514    let firth_active = config.firth_bias_reduction
1515        && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1516        && inverse_link_has_fisher_weight_jet(&config.link_kind);
1517    let base_max_step_halving = if firth_active { 60 } else { 30 };
1518    let options = WorkingModelPirlsOptions {
1519        // The Firth-penalized P-IRLS converges at the same iteration count as
1520        // the unpenalized fit — the Jeffreys term is a smooth, bounded addition
1521        // to a Newton system that is already well conditioned (the additional
1522        // per-iteration LM step-halving budget above absorbs the early-iteration
1523        // curvature change). Bumping the outer-iteration cap to mask a
1524        // mis-conditioned step would only hide non-convergence, so the cap stays
1525        // the caller's `max_iterations` and trips as a hard error if exceeded.
1526        max_iterations: config.max_iterations,
1527        convergence_tolerance: config.convergence_tolerance,
1528        adaptive_kkt_tolerance,
1529        // LM step-halving is a per-iteration damping retry budget; it is
1530        // independent of the total outer-iteration cap. Tying the two
1531        // together collapsed step halving to 3 under seed screening (where
1532        // max_iterations is intentionally capped low), turning recoverable
1533        // damping into spurious failures.
1534        max_step_halving: base_max_step_halving,
1535        min_step_size: if firth_active { 1e-12 } else { 1e-10 },
1536        firth_bias_reduction: firth_active,
1537        coefficient_lower_bounds: None,
1538        linear_constraints: linear_constraints.clone(),
1539        initial_lm_lambda: config.initial_lm_lambda,
1540        geodesic_acceleration: config.geodesic_acceleration,
1541        arrow_schur: config.arrow_schur.clone(),
1542    };
1543
1544    let mut iteration_logger = |info: &WorkingModelIterationInfo| {
1545        log::debug!(
1546            "[PIRLS] iter {:>3} | deviance {:.6e} | |grad| {:.3e} | step {:.3e} (halving {})",
1547            info.iteration,
1548            info.deviance,
1549            info.gradient_norm,
1550            info.step_size,
1551            info.step_halving
1552        );
1553    };
1554
1555    // Stage 3.3 GPU PIRLS-loop dispatch — see pirls_host_dispatch::try_pirls_loop_gpu.
1556    if let Some(result) = try_pirls_loop_gpu(
1557        config,
1558        &penalty_active,
1559        kronecker_runtime.is_none(),
1560        use_sparse_native,
1561        &linear_constraints,
1562        &x_original,
1563        &qs_arc,
1564        penalty.p,
1565        &x_original_for_result,
1566        || materialize_final_reparam_result(),
1567        y,
1568        priorweights,
1569        offset,
1570        &initial_beta,
1571        link_function,
1572        coordinate_frame,
1573    ) {
1574        return result;
1575    }
1576
1577    let mut working_summary = runworking_model_pirls(
1578        &mut working_model,
1579        Coefficients::new(initial_beta),
1580        &options,
1581        &mut iteration_logger,
1582    )?;
1583
1584    // ── Gamma dispersion: re-estimate the shape at the *converged* η (#678) ──
1585    //
1586    // The inner LM solve estimates the Gamma shape ν = 1/φ **once** from the
1587    // warm-start η and freezes it for the rest of the solve (see the
1588    // `gamma_shape_locked` doc on `GamWorkingModel`): holding ν fixed keeps the
1589    // product φ·λ — and hence the penalized argmin β̂ — a stationary LM target,
1590    // so the gain ratio compares one objective. That lock is correct *within* a
1591    // solve, but it pins ν to whatever η the solve started from. When the fit
1592    // cold-starts (the final dedicated fit at the converged ρ passes
1593    // `warm_start_beta = None`, and seed screening starts from a default guess),
1594    // that warm-start η has not yet captured the mean structure; the leftover
1595    // spread of μ inflates the Gamma deviance term `mean[y/μ − ln(y/μ) − 1]` and
1596    // biases ν **down** (φ up) by >2× whenever μ varies appreciably. The mean
1597    // surface still converges (β̂ is essentially scale-free here), but the frozen
1598    // ν that survives into `UnifiedFitResult::dispersion_phi()` — and from there
1599    // into every coefficient SE `Vb = H⁻¹·φ̂`, prediction interval, and
1600    // observation-noise interval — is the early, mean-spread-contaminated value.
1601    //
1602    // Fix: after the solve converges, re-estimate ν at the converged η. If it
1603    // moved, re-solve β (warm-started, ν held fixed at the refreshed value) and
1604    // repeat, driving the pair (β, ν) to their joint fixed point at the current
1605    // λ. At convergence the reported dispersion is the Gamma ML estimate at the
1606    // converged mean (mgcv's post-hoc Pearson/deviance scale), and the final
1607    // working state — `finalweights`, the penalized Hessian, the deviance, μ —
1608    // is rebuilt with that same ν, so `Vb = H⁻¹·φ̂` stays internally consistent.
1609    // Warm-started solves (every REML cost eval) already sit near the converged
1610    // η, so the first refresh check confirms ν and exits without a re-solve; the
1611    // added cost there is a single O(n) shape evaluation.
1612    if refine_dispersion_at_converged_eta
1613        && working_model.likelihood.scale.gamma_shape_is_estimated()
1614    {
1615        // A few passes suffice: the converged-η shape map is a strong
1616        // contraction (β̂ barely moves once the mean is captured), so cold
1617        // starts settle in 1–2 re-solves and warm starts in zero.
1618        const MAX_SHAPE_REFRESH: usize = 5;
1619        // Relative shape tolerance below which a re-solve cannot move any
1620        // reported quantity meaningfully (far under statistical resolution).
1621        const SHAPE_REFRESH_REL_TOL: f64 = 1e-4;
1622        for refresh_iter in 0..MAX_SHAPE_REFRESH {
1623            let refreshed_shape = super::estimate_gamma_shape_from_eta(
1624                y,
1625                working_summary.state.eta.as_ref(),
1626                priorweights,
1627            );
1628            let prior_shape = working_model.likelihood.gamma_shape().unwrap_or(1.0);
1629            let rel_change =
1630                (refreshed_shape - prior_shape).abs() / prior_shape.max(f64::MIN_POSITIVE);
1631            // Install the refreshed shape and hold it fixed for any re-solve so
1632            // the LM objective stays stationary (the lock is *re-armed*, not
1633            // released — the seed-from-warm-start branch in `update_with_curvature`
1634            // must not overwrite this deliberately chosen value). Because this
1635            // assignment evaluated the shape at the *current* converged η and no
1636            // re-solve follows it on the exit paths below, the reported shape
1637            // always equals `estimate_gamma_shape_from_eta(final_eta)` — the
1638            // self-consistency invariant the in-module Gamma unit test checks.
1639            working_model.likelihood = working_model
1640                .likelihood
1641                .clone()
1642                .with_gamma_shape(refreshed_shape);
1643            working_model.gamma_shape_locked = true;
1644            if rel_change <= SHAPE_REFRESH_REL_TOL {
1645                // Converged: the working-state buffers (weights, Hessian,
1646                // deviance) already reflect a shape within tolerance of
1647                // `refreshed_shape`, because the only way to reach here without
1648                // a re-solve is that the prior solve's shape already matched the
1649                // converged-η estimate. Nothing left to rebuild.
1650                break;
1651            }
1652            if refresh_iter + 1 == MAX_SHAPE_REFRESH {
1653                // Final allowed pass and the shape is still drifting (a
1654                // pathological non-contraction). Do NOT re-solve: re-solving
1655                // would advance `final_eta` past the η the just-installed shape
1656                // was evaluated at, breaking the stored-shape == estimate(final_eta)
1657                // invariant. Stopping here keeps the reported shape exactly the
1658                // ML estimate at the reported η; the residual weight/φ drift is
1659                // bounded by the last `rel_change` and never worse than the
1660                // pre-fix frozen-warm-start value.
1661                break;
1662            }
1663            // The shape moved: re-solve β at the corrected shape, warm-started
1664            // at the converged β, so the final working state is rebuilt with the
1665            // refreshed ν.
1666            working_summary = runworking_model_pirls(
1667                &mut working_model,
1668                working_summary.beta.clone(),
1669                &options,
1670                &mut iteration_logger,
1671            )?;
1672        }
1673    }
1674
1675    // ── Tweedie dispersion φ: re-estimate at the *converged* η (#771) ─────────
1676    //
1677    // Identical in spirit to the Gamma-shape refresh above: the inner LM solve
1678    // estimates φ **once** from the warm-start η and freezes it (the
1679    // `tweedie_phi_locked` lock), keeping the product φ·λ — and hence β̂ — a
1680    // stationary LM target. φ enters only the working weight `prior·μ^{2−p}/φ`
1681    // and not the working response, so (like the Gamma shape, and unlike the
1682    // Beta precision which couples through the digamma mean score) the mean
1683    // surface is essentially scale-free and β̂ barely moves when φ is corrected.
1684    // But the frozen warm-start φ is the value that survives into
1685    // `FitInference::dispersion` and the covariance `Vb = H⁻¹` (whose √φ scaling
1686    // lives in the weight); at a cold-started η ≈ 0 the Pearson residuals carry
1687    // the *marginal* spread of y, biasing the estimate. Re-estimating at the
1688    // converged η — re-solving β only if φ moved materially — drives (β, φ) to
1689    // their joint fixed point, so the reported φ is the converged-mean Pearson
1690    // estimate and the final weights/Hessian/SE are internally consistent with
1691    // it. Held OFF inside the REML λ search (the flag), φ is refreshed only at
1692    // the reported fit, so it cannot couple to the smoothing parameter.
1693    if refine_dispersion_at_converged_eta
1694        && working_model.likelihood.scale.tweedie_phi_is_estimated()
1695    {
1696        if let ResponseFamily::Tweedie { p } = working_model.likelihood.spec.response {
1697            // The converged-η Pearson map is a strong contraction (β̂ scale-free
1698            // here), so cold starts settle in 1–2 re-solves and warm starts in
1699            // zero.
1700            const MAX_PHI_REFRESH: usize = 5;
1701            // Relative φ tolerance below which a re-solve cannot move any reported
1702            // quantity meaningfully (far under statistical resolution).
1703            const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1704            for refresh_iter in 0..MAX_PHI_REFRESH {
1705                let refreshed_phi = super::estimate_tweedie_phi_from_eta(
1706                    y,
1707                    working_summary.state.eta.as_ref(),
1708                    priorweights,
1709                    p,
1710                );
1711                let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1712                let rel_change =
1713                    (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1714                // Install the refreshed φ (the scale metadata the working weight
1715                // reads via `fixed_phi()`) and re-arm the lock so a following
1716                // re-solve does not overwrite this converged-η value. Because the
1717                // exit paths below evaluate φ at the *current* η with no following
1718                // re-solve, the reported φ always equals
1719                // `estimate_tweedie_phi_from_eta(final_eta)`.
1720                working_model.likelihood = working_model
1721                    .likelihood
1722                    .clone()
1723                    .with_tweedie_phi(refreshed_phi);
1724                working_model.tweedie_phi_locked = true;
1725                if rel_change <= PHI_REFRESH_REL_TOL {
1726                    // Converged: the working state already reflects a φ within
1727                    // tolerance of `refreshed_phi`. Nothing left to rebuild.
1728                    break;
1729                }
1730                if refresh_iter + 1 == MAX_PHI_REFRESH {
1731                    // Final allowed pass and φ is still drifting. Do NOT re-solve:
1732                    // re-solving would advance η past the point φ was evaluated at,
1733                    // breaking the stored-φ == estimate(final_eta) invariant.
1734                    break;
1735                }
1736                // φ moved materially: re-solve β at the corrected φ, warm-started
1737                // at the converged β, so the final working state is rebuilt with
1738                // the refreshed φ.
1739                working_summary = runworking_model_pirls(
1740                    &mut working_model,
1741                    working_summary.beta.clone(),
1742                    &options,
1743                    &mut iteration_logger,
1744                )?;
1745            }
1746        }
1747    }
1748
1749    // ── Beta precision φ: re-estimate at the *converged* η and drive (β, φ) to
1750    //    their joint fixed point (#769) ──────────────────────────────────────
1751    //
1752    // Like the Gamma shape above, the inner LM solve estimates φ **once** from
1753    // the warm-start η and freezes it for the rest of the solve (the
1754    // `beta_phi_locked` doc on `GamWorkingModel`): holding φ fixed keeps the
1755    // penalized argmin β̂ a stationary LM target so the gain ratio compares one
1756    // objective. But that lock pins φ to whatever η the solve started from, and
1757    // for the final dedicated fit at the converged ρ the warm-start is the cold
1758    // default guess (η ≈ 0, μ ≈ 0.5 everywhere). At the null predictor the
1759    // Pearson residuals `(y−μ)²/(μ(1−μ))` capture the full *marginal* spread of
1760    // y rather than its *conditional* spread, so the moment estimator
1761    // `1+φ = Σw / Σ w·s` returns a precision far too small (≈3 when the truth is
1762    // ≈20 here).
1763    //
1764    // Crucially — and unlike the Gamma shape — φ does **not** factor out of the
1765    // Beta mean score. With the logit link the score for β is
1766    //     ∂ℓ/∂β = φ · Σᵢ xᵢ (y*ᵢ − μ*ᵢ),   y*ᵢ = logit(yᵢ),
1767    //     μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ),
1768    // so the root β̂ depends on φ through the digamma terms. A φ that is too
1769    // small shrinks every fitted coefficient toward zero. So this refresh is not
1770    // cosmetic (as it is for Gamma): the re-solve is what *recovers the mean*.
1771    //
1772    // Fix: after the cold solve converges, re-estimate φ at the converged η,
1773    // re-solve β at the corrected φ (warm-started), and repeat. This is the
1774    // betareg alternating mean-fit ↔ φ-estimate scheme; the moment estimator is
1775    // a strong contraction once the mean has any structure, so the pair settles
1776    // in a handful of passes. Held OFF inside the REML λ search (see the flag
1777    // doc), φ is refreshed only here at the reported fit, so it cannot couple to
1778    // the smoothing parameter and reward over-smoothing. As with Gamma, every
1779    // exit path installs φ evaluated at the *current* η with no following
1780    // re-solve, so the reported φ (which flows into `EstimatedBetaPhi`, the
1781    // embedded `Beta { phi }`, `dispersion`, and every SE) always equals
1782    // `estimate_beta_phi_from_eta(final_eta)`.
1783    if refine_dispersion_at_converged_eta && working_model.likelihood.scale.beta_phi_is_estimated()
1784    {
1785        // The mean moves between passes (φ feeds back through the digamma
1786        // score), so allow a few more passes than the scale-free Gamma case;
1787        // the contraction is fast and warm-started re-solves are cheap.
1788        const MAX_PHI_REFRESH: usize = 30;
1789        // Relative φ tolerance below which a re-solve cannot move β̂ — and hence
1790        // any reported quantity — by a statistically meaningful amount.
1791        const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1792        for refresh_iter in 0..MAX_PHI_REFRESH {
1793            let refreshed_phi = super::estimate_beta_phi_from_eta(
1794                y,
1795                working_summary.state.eta.as_ref(),
1796                priorweights,
1797            );
1798            let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1799            let rel_change = (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1800            // Install the refreshed φ (updates BOTH the `Beta { phi }` family
1801            // variant every weight/deviance expression reads and the
1802            // `EstimatedBetaPhi` scale metadata) and re-arm the lock so a
1803            // following re-solve's `update_with_curvature` does not overwrite
1804            // this deliberately chosen value with a fresh cold estimate.
1805            working_model.likelihood = working_model
1806                .likelihood
1807                .clone()
1808                .with_beta_phi(refreshed_phi);
1809            working_model.beta_phi_locked = true;
1810            if rel_change <= PHI_REFRESH_REL_TOL {
1811                // Converged: the just-installed φ matches (to tolerance) the φ
1812                // the current working state was solved at, so β̂, the weights,
1813                // the Hessian and the deviance are already self-consistent with
1814                // the reported φ. Nothing left to rebuild.
1815                break;
1816            }
1817            if refresh_iter + 1 == MAX_PHI_REFRESH {
1818                // Final allowed pass and φ is still drifting. Do NOT re-solve:
1819                // re-solving would advance η past the point the just-installed φ
1820                // was evaluated at, breaking the stored-φ == estimate(final_eta)
1821                // invariant. Stop here so the reported φ is exactly the moment
1822                // estimate at the reported η.
1823                break;
1824            }
1825            // φ moved materially: re-solve β at the corrected φ, warm-started at
1826            // the converged β, so the mean is refit under the better precision
1827            // and the final working state is rebuilt consistently.
1828            working_summary = runworking_model_pirls(
1829                &mut working_model,
1830                working_summary.beta.clone(),
1831                &options,
1832                &mut iteration_logger,
1833            )?;
1834        }
1835    }
1836
1837    // ── Negative-Binomial overdispersion θ: re-estimate at the *converged* η and
1838    //    drive (β, θ) to their joint fixed point (#802) ───────────────────────
1839    //
1840    // Identical in spirit to the Beta-precision refresh above. The inner LM solve
1841    // estimates θ **once** from the warm-start η and freezes it (the
1842    // `negbin_theta_locked` lock), keeping the penalized argmin β̂ a stationary LM
1843    // target. But that lock pins θ to whatever η the solve started from, and for
1844    // the final dedicated fit at the converged ρ the warm-start is the cold
1845    // default guess (η ≈ 0). At the null predictor the Pearson residuals carry
1846    // the *marginal* spread of y rather than its *conditional* spread, biasing
1847    // the moment seed — and the frozen θ is what survives into the working weight
1848    // `W = μθ/(θ+μ)`, the covariance `Vb = H⁻¹` (whose overdispersion scaling
1849    // lives in that weight, not a post-hoc multiply), and every reported SE /
1850    // interval / `generate` draw.
1851    //
1852    // Like the Beta precision — and unlike the scale-free Gamma shape / Tweedie φ
1853    // — θ enters the NB2 working *response*, not only the weight, so re-solving β
1854    // under the corrected θ is not cosmetic: it recovers the mean under the right
1855    // variance function. Re-estimating at the converged η, re-solving β
1856    // (warm-started), and repeating drives (β, θ) to their joint maximum-
1857    // likelihood fixed point. Held OFF inside the REML λ search (the flag), θ is
1858    // refreshed only here at the reported fit, so it cannot couple to the
1859    // smoothing parameter. Every exit path installs θ evaluated at the *current*
1860    // η with no following re-solve, so the reported θ (which flows into the
1861    // embedded `NegativeBinomial { theta }`, the `EstimatedNegBinTheta` scale
1862    // metadata, the predictive-interval variance, and every SE) always equals
1863    // `estimate_negbin_theta_from_eta(final_eta)`.
1864    if refine_dispersion_at_converged_eta
1865        && working_model.likelihood.scale.negbin_theta_is_estimated()
1866    {
1867        // θ feeds back through the working response, so allow a few more passes
1868        // than the scale-free Gamma case; the alternation is a strong contraction
1869        // and warm-started re-solves are cheap.
1870        const MAX_THETA_REFRESH: usize = 30;
1871        // Relative θ tolerance below which a re-solve cannot move β̂ — and hence
1872        // any reported quantity — by a statistically meaningful amount.
1873        const THETA_REFRESH_REL_TOL: f64 = 1e-4;
1874        for refresh_iter in 0..MAX_THETA_REFRESH {
1875            let refreshed_theta = super::estimate_negbin_theta_from_eta(
1876                y,
1877                working_summary.state.eta.as_ref(),
1878                priorweights,
1879            );
1880            let prior_theta = working_model.likelihood.negbin_theta().unwrap_or(1.0);
1881            let rel_change =
1882                (refreshed_theta - prior_theta).abs() / prior_theta.max(f64::MIN_POSITIVE);
1883            // Install the refreshed θ (updates BOTH the `NegativeBinomial { theta }`
1884            // family variant every weight/deviance expression reads and the
1885            // `EstimatedNegBinTheta` scale metadata) and re-arm the lock so a
1886            // following re-solve's `update_with_curvature` does not overwrite this
1887            // deliberately chosen value with a fresh cold estimate.
1888            working_model.likelihood = working_model
1889                .likelihood
1890                .clone()
1891                .with_negbin_theta(refreshed_theta);
1892            working_model.negbin_theta_locked = true;
1893            if rel_change <= THETA_REFRESH_REL_TOL {
1894                // Converged: the just-installed θ matches (to tolerance) the θ the
1895                // current working state was solved at, so β̂, the weights, the
1896                // Hessian and the deviance are already self-consistent with the
1897                // reported θ. Nothing left to rebuild.
1898                break;
1899            }
1900            if refresh_iter + 1 == MAX_THETA_REFRESH {
1901                // Final allowed pass and θ is still drifting. Do NOT re-solve:
1902                // re-solving would advance η past the point the just-installed θ
1903                // was evaluated at, breaking the stored-θ == estimate(final_eta)
1904                // invariant. Stop here so the reported θ is exactly the ML
1905                // estimate at the reported η.
1906                break;
1907            }
1908            // θ moved materially: re-solve β at the corrected θ, warm-started at
1909            // the converged β, so the mean is refit under the better variance
1910            // function and the final working state is rebuilt consistently.
1911            working_summary = runworking_model_pirls(
1912                &mut working_model,
1913                working_summary.beta.clone(),
1914                &options,
1915                &mut iteration_logger,
1916            )?;
1917        }
1918    }
1919
1920    // Extract workspace before consuming working_model so we can reuse
1921    // the pre-allocated buffers in calculate_edfwithworkspace_with_penalty.
1922    // into_final_state() drops the workspace field anyway (it uses `..` in
1923    // its destructure); we replace it with a zero-sized stub to satisfy the
1924    // borrow checker, then keep the real workspace alive for the EDF call.
1925    let mut saved_workspace = std::mem::replace(
1926        &mut working_model.workspace,
1927        PirlsWorkspace::new(0, 0, 0, 0),
1928    );
1929    let final_state = working_model.into_final_state();
1930    let GamModelFinalState {
1931        likelihood: final_likelihood,
1932        coordinate_frame,
1933        finalmu,
1934        finalweights,
1935        scoreweights,
1936        finalz,
1937        final_c,
1938        final_d,
1939        final_dmu_deta,
1940        final_d2mu_deta2,
1941        final_d3mu_deta3,
1942        penalty_term,
1943        ..
1944    } = final_state;
1945
1946    // Preserve the Hessian as-is (sparse or dense) — no densification.
1947    // P-IRLS already folded any stabilization ridge directly into the Hessian.
1948    // Keep that exact matrix so outer LAML derivatives stay consistent:
1949    // H_eff = X'W_H X + S_λ + ridge I (if ridge_used > 0).
1950    let penalized_hessian_transformed = working_summary.state.hessian.clone();
1951    let stabilizedhessian_transformed = penalized_hessian_transformed.clone();
1952    // Use the workspace-backed variant for the dense path to reuse the
1953    // `final_aug_matrix` allocation; the sparse path still allocates
1954    // internally because no pre-computed factor is available at this site.
1955    let mut edf = if let Some(dense_h) = penalized_hessian_transformed.as_dense() {
1956        calculate_edfwithworkspace_with_penalty(dense_h, &penalty_active, &mut saved_workspace)?
1957    } else {
1958        calculate_edf_with_penalty(&penalized_hessian_transformed, &penalty_active)?
1959    };
1960    if !edf.is_finite() || edf.is_nan() {
1961        let p = penalized_hessian_transformed.ncols() as f64;
1962        let r = penalty_active.rank() as f64;
1963        edf = (p - r).max(0.0);
1964    }
1965
1966    // Outer rescue: a fit that hit max-iterations may still be a usable
1967    // minimum if progress has effectively stopped (deviance plateaued or
1968    // step size collapsed to the floor) AND the projected gradient is in
1969    // the near-stationary band under the scale-invariant certificate.
1970    // Same logic for non-Firth and Firth paths; firth_active just gates
1971    // the second pass.
1972    let stalled_at_valid_minimum = |summary: &WorkingModelPirlsResult| -> bool {
1973        // Scale-equivariant deviance plateau band (issue #1127). The
1974        // `last_deviance_change` compared below and the deviance both scale as
1975        // `O(a²)` under a response rescaling `y → a·y` (the penalized normal
1976        // equations are linear in `y`, so `β → a·β` and the RSS-deviance
1977        // scales by `a²`). Keying the plateau band to the deviance's own
1978        // magnitude `+ |penalty|` makes the ratio `Δdev / dev_scale`
1979        // scale-invariant. The previous `.max(1.0)` absolute floor broke this:
1980        // for a micro-unit response (`a = 1e-6`) the deviance is `O(1e-12)`, so
1981        // the floor pinned the band at `1.0` — ~1e9× too loose — and this
1982        // max-iteration rescue declared `progress_stopped` at an over-smoothed
1983        // iterate, propagating an inflated `λ̂` to the outer REML loop. For a
1984        // well-scaled (`a ≳ 1`) or up-scaled (`a = 1e6`) objective the floor was
1985        // already a no-op, so those directions are byte-identical. A perfect
1986        // interpolating fit gives a `0` band, so the relative `Δdev` test cannot
1987        // fire spuriously and the scale-invariant `near_stationary_kkt`
1988        // certificate then governs acceptance.
1989        let dev_scale = summary.state.deviance.abs() + summary.state.penalty_term.abs();
1990        // Progress plateau uses the fixed solver tolerance; only the KKT band below adapts.
1991        let dev_tol = options.convergence_tolerance * dev_scale;
1992        let step_floor = options.min_step_size * 2.0;
1993        let progress_stopped =
1994            summary.last_deviance_change.abs() <= dev_tol || summary.last_step_size <= step_floor;
1995        let near_stationary = summary
1996            .state
1997            .near_stationary_kkt(summary.lastgradient_norm, effective_kkt_tolerance(&options));
1998        progress_stopped && near_stationary
1999    };
2000
2001    let mut status = working_summary.status;
2002    if status.is_failed_max_iterations() && stalled_at_valid_minimum(&working_summary) {
2003        status = PirlsStatus::StalledAtValidMinimum;
2004        working_summary.status = status;
2005    }
2006    if status.is_failed_max_iterations()
2007        && firth_active
2008        && stalled_at_valid_minimum(&working_summary)
2009    {
2010        // Firth-adjusted fits can stall; accept under the same dual-criterion
2011        // near-stationary band.
2012        status = PirlsStatus::StalledAtValidMinimum;
2013        working_summary.status = status;
2014    }
2015    let has_penalty = penalty_active.rank() > 0;
2016    let firth_active = options.firth_bias_reduction;
2017    if detect_logit_instability(
2018        link_function,
2019        &final_likelihood.spec.response,
2020        has_penalty,
2021        firth_active,
2022        &working_summary,
2023        &finalmu,
2024        &finalweights,
2025        y,
2026    ) {
2027        status = PirlsStatus::Unstable;
2028        working_summary.status = status;
2029    }
2030
2031    // Store a lazy ReparamOperator instead of materializing X·Qs.
2032    // Consumers that truly need dense access can call .to_dense() on demand.
2033    let reparam_result_final = materialize_final_reparam_result()?;
2034    let qs_arc_final = Arc::new(reparam_result_final.qs.clone());
2035    let x_transformed_final =
2036        make_reparam_operator(&x_original_for_result, &qs_arc_final, use_sparse_native);
2037
2038    let pirls_result = assemble_pirls_result(
2039        &working_summary,
2040        final_likelihood,
2041        offset,
2042        penalized_hessian_transformed,
2043        stabilizedhessian_transformed,
2044        edf,
2045        penalty_term,
2046        &finalmu,
2047        &finalweights,
2048        &scoreweights,
2049        &finalz,
2050        &final_c,
2051        &final_d,
2052        &final_dmu_deta,
2053        &final_d2mu_deta2,
2054        &final_d3mu_deta3,
2055        status,
2056        reparam_result_final,
2057        x_transformed_final,
2058        coordinate_frame,
2059        linear_constraints,
2060    );
2061
2062    Ok((pirls_result, working_summary))
2063}
2064
2065#[derive(Clone)]
2066pub struct PirlsConfig {
2067    pub likelihood: GlmLikelihoodSpec,
2068    pub link_kind: InverseLink,
2069    pub max_iterations: usize,
2070    pub convergence_tolerance: f64,
2071    pub firth_bias_reduction: bool,
2072    /// Optional warm-start hint for `WorkingModelPirlsOptions::initial_lm_lambda`.
2073    /// Forwarded directly when `fit_model_for_fixed_rho` builds its
2074    /// internal options. See the field doc on `WorkingModelPirlsOptions`
2075    /// for the seeding semantics.
2076    pub initial_lm_lambda: Option<f64>,
2077    /// Enable the Transtrum-Sethna geodesic-acceleration second-order
2078    /// correction on each accepted LM step. Forwarded to
2079    /// `WorkingModelPirlsOptions::geodesic_acceleration`; see that
2080    /// field's doc for the full semantics and cost model. Default
2081    /// `false`; opt-in until validated.
2082    pub geodesic_acceleration: bool,
2083    /// Optional arrow-Schur structured-inner-solve descriptor. When
2084    /// `Some`, forwarded to `WorkingModelPirlsOptions::arrow_schur` so
2085    /// each accepted LM step is solved by the per-observation
2086    /// arrow-Schur path
2087    /// ([`crate::arrow_schur::ArrowSchurSystem`]). When `None`
2088    /// (the default), the existing β-only path is used unchanged.
2089    ///
2090    /// See [`ArrowSchurInnerConfig`] for the closure contract.
2091    pub arrow_schur: Option<ArrowSchurInnerConfig>,
2092}
2093
2094impl PirlsConfig {
2095    #[inline]
2096    pub fn link_function(&self) -> LinkFunction {
2097        self.link_kind.link_function()
2098    }
2099}
2100
2101#[inline]
2102pub(super) fn max_symmetric_asymmetry(matrix: &Array2<f64>) -> f64 {
2103    let n = matrix.nrows().min(matrix.ncols());
2104    let mut max_asym = 0.0_f64;
2105    for i in 0..n {
2106        for j in 0..i {
2107            let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2108            if diff > max_asym {
2109                max_asym = diff;
2110            }
2111        }
2112    }
2113    max_asym
2114}
2115
2116#[inline]
2117pub(super) fn assert_symmetric_tol(matrix: &Array2<f64>, label: &str, tol: f64) {
2118    let max_asym = max_symmetric_asymmetry(matrix);
2119    assert!(
2120        max_asym <= tol,
2121        "{} asymmetry too large: {:.3e} (tol {:.3e})",
2122        label,
2123        max_asym,
2124        tol
2125    );
2126}
2127
2128/// Build a DesignMatrix wrapping a lazy ReparamOperator (or the original for sparse-native).
2129pub(crate) fn make_reparam_operator(
2130    x_original: &DesignMatrix,
2131    qs_arc: &Arc<Array2<f64>>,
2132    use_sparse_native: bool,
2133) -> DesignMatrix {
2134    if use_sparse_native {
2135        x_original.clone()
2136    } else {
2137        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
2138            ReparamOperator::new(x_original.clone(), Arc::clone(qs_arc)),
2139        )))
2140    }
2141}
2142
2143// solve_penalized_least_squares_implicit lives in pls_solver (imported above).
2144
2145pub(super) fn build_transformed_lower_bound_constraints(
2146    qs: &Array2<f64>,
2147    coefficient_lower_bounds: Option<&Array1<f64>>,
2148) -> Option<LinearInequalityConstraints> {
2149    let lb = coefficient_lower_bounds?;
2150    if lb.len() != qs.nrows() {
2151        return None;
2152    }
2153    let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2154    if activerows.is_empty() {
2155        return None;
2156    }
2157    let mut a = Array2::<f64>::zeros((activerows.len(), qs.ncols()));
2158    let mut b = Array1::<f64>::zeros(activerows.len());
2159    for (r, &idx) in activerows.iter().enumerate() {
2160        a.row_mut(r).assign(&qs.row(idx));
2161        b[r] = lb[idx];
2162    }
2163    Some(
2164        LinearInequalityConstraints::new(a, b)
2165            .expect("transformed lower-bound constraint shape invariant"),
2166    )
2167}
2168
2169pub(super) fn build_transformed_lower_bound_constraints_with_transform(
2170    transform: &WorkingReparamTransform,
2171    coefficient_lower_bounds: Option<&Array1<f64>>,
2172) -> Option<LinearInequalityConstraints> {
2173    let lb = coefficient_lower_bounds?;
2174    let p = match transform {
2175        WorkingReparamTransform::Dense(qs) => qs.nrows(),
2176        WorkingReparamTransform::Kronecker(kron) => kron.p,
2177    };
2178    if lb.len() != p {
2179        return None;
2180    }
2181    let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2182    if activerows.is_empty() {
2183        return None;
2184    }
2185    let mut a = Array2::<f64>::zeros((activerows.len(), p));
2186    let mut b = Array1::<f64>::zeros(activerows.len());
2187    for (r, &idx) in activerows.iter().enumerate() {
2188        let mut basis = Array1::<f64>::zeros(p);
2189        basis[idx] = 1.0;
2190        let row = transform.apply_transpose(&basis);
2191        a.row_mut(r).assign(&row);
2192        b[r] = lb[idx];
2193    }
2194    Some(
2195        LinearInequalityConstraints::new(a, b)
2196            .expect("transformed lower-bound constraint shape invariant"),
2197    )
2198}
2199
2200pub(super) fn build_transformed_linear_constraints(
2201    qs: &Array2<f64>,
2202    linear_constraints: Option<&LinearInequalityConstraints>,
2203) -> Option<LinearInequalityConstraints> {
2204    let lc = linear_constraints?;
2205    if lc.a.ncols() != qs.nrows() {
2206        return None;
2207    }
2208    Some(
2209        LinearInequalityConstraints::new(lc.a.dot(qs), lc.b.clone())
2210            .expect("transformed linear constraint shape invariant"),
2211    )
2212}
2213
2214pub(super) fn build_transformed_linear_constraints_with_transform(
2215    transform: &WorkingReparamTransform,
2216    linear_constraints: Option<&LinearInequalityConstraints>,
2217) -> Option<LinearInequalityConstraints> {
2218    let lc = linear_constraints?;
2219    let p = match transform {
2220        WorkingReparamTransform::Dense(qs) => qs.nrows(),
2221        WorkingReparamTransform::Kronecker(kron) => kron.p,
2222    };
2223    if lc.a.ncols() != p {
2224        return None;
2225    }
2226    let mut a = Array2::<f64>::zeros((lc.a.nrows(), p));
2227    for row in 0..lc.a.nrows() {
2228        let transformed = transform.apply_transpose(&lc.a.row(row).to_owned());
2229        a.row_mut(row).assign(&transformed);
2230    }
2231    Some(LinearInequalityConstraints { a, b: lc.b.clone() })
2232}
2233
2234pub(super) fn merge_linear_constraints(
2235    first: Option<LinearInequalityConstraints>,
2236    second: Option<LinearInequalityConstraints>,
2237) -> Option<LinearInequalityConstraints> {
2238    match (first, second) {
2239        (None, None) => None,
2240        (Some(c), None) | (None, Some(c)) => Some(c),
2241        (Some(c1), Some(c2)) => {
2242            if c1.a.ncols() != c2.a.ncols() {
2243                return None;
2244            }
2245            let rows = c1.a.nrows() + c2.a.nrows();
2246            let cols = c1.a.ncols();
2247            let mut a = Array2::<f64>::zeros((rows, cols));
2248            a.slice_mut(s![0..c1.a.nrows(), ..]).assign(&c1.a);
2249            a.slice_mut(s![c1.a.nrows()..rows, ..]).assign(&c2.a);
2250            let mut b = Array1::<f64>::zeros(rows);
2251            b.slice_mut(s![0..c1.b.len()]).assign(&c1.b);
2252            b.slice_mut(s![c1.b.len()..rows]).assign(&c2.b);
2253            Some(LinearInequalityConstraints { a, b })
2254        }
2255    }
2256}
2257
2258pub(super) fn sparse_from_denseview(x: ArrayView2<f64>) -> Option<DesignMatrix> {
2259    // Below this column count a dense factorization beats the sparse path even
2260    // at high sparsity, so skip the sparsity scan entirely for narrow designs.
2261    const DENSE_PREFERRED_MAX_COLS: usize = 32;
2262    // Sparse storage + sparse Cholesky only pays off below this density (nnz as
2263    // a fraction of all entries); denser matrices stay dense.
2264    const SPARSE_DENSITY_LIMIT: f64 = 0.20;
2265
2266    let nrows = x.nrows();
2267    let ncols = x.ncols();
2268    if nrows == 0 || ncols == 0 {
2269        return None;
2270    }
2271    // Narrow matrices are faster in dense form; avoid any sparsity scan overhead.
2272    if ncols <= DENSE_PREFERRED_MAX_COLS {
2273        return None;
2274    }
2275
2276    const ZERO_EPS: f64 = 1e-12;
2277    let total = nrows.saturating_mul(ncols);
2278    if total == 0 {
2279        return None;
2280    }
2281    // If a matrix exceeds this nnz count it is too dense for sparse path; bail early.
2282    let sparse_nnz_limit = ((total as f64) * SPARSE_DENSITY_LIMIT).floor() as usize;
2283    let mut nnz = 0usize;
2284    for &val in x.iter() {
2285        if val.abs() > ZERO_EPS {
2286            nnz += 1;
2287            if nnz > sparse_nnz_limit {
2288                return None;
2289            }
2290        }
2291    }
2292    let mut triplets = Vec::with_capacity(nnz);
2293    for (row_idx, row) in x.outer_iter().enumerate() {
2294        for (col_idx, &val) in row.iter().enumerate() {
2295            if val.abs() > ZERO_EPS {
2296                triplets.push(Triplet::new(row_idx, col_idx, val));
2297            }
2298        }
2299    }
2300    SparseColMat::try_new_from_triplets(nrows, ncols, &triplets)
2301        .ok()
2302        .map(DesignMatrix::from)
2303}
2304
2305#[cfg(test)]
2306mod tests {
2307    use super::{PirlsPenalty, build_diagonal_penalty_from_kronecker};
2308    use gam_terms::construction::KroneckerReparamResult;
2309    use ndarray::{Array1, Array2, array};
2310
2311    #[test]
2312    fn kronecker_diagonal_double_penalty_hits_only_joint_null_space() {
2313        let kron_result = KroneckerReparamResult {
2314            reparameterized_marginals: std::sync::Arc::new(Vec::new()),
2315            marginal_eigenvalues: std::sync::Arc::new(vec![array![0.0, 2.0], array![0.0, 3.0]]),
2316            marginal_qs: std::sync::Arc::new(Vec::new()),
2317            log_det: 0.0,
2318            det1: Array1::zeros(3),
2319            det2: Array2::zeros((3, 3)),
2320            penalty_shrinkage_ridge: 0.5,
2321            has_double_penalty: true,
2322            marginal_dims: vec![2usize, 2usize],
2323        };
2324        let penalty = build_diagonal_penalty_from_kronecker(&kron_result, &[5.0, 7.0, 11.0]);
2325
2326        let PirlsPenalty::Diagonal {
2327            diag,
2328            positive_indices,
2329            ..
2330        } = penalty
2331        else {
2332            panic!("expected diagonal Kronecker PIRLS penalty");
2333        };
2334        let expected = [11.0, 21.5, 10.5, 31.5];
2335        for (idx, expected_diag) in expected.iter().copied().enumerate() {
2336            assert!(
2337                (diag[idx] - expected_diag).abs() <= 1e-12,
2338                "diagonal {idx} got {}, expected {expected_diag}",
2339                diag[idx]
2340            );
2341        }
2342        assert_eq!(positive_indices, vec![0, 1, 2, 3]);
2343    }
2344}