Skip to main content

gam_solve/pirls/
loop_driver.rs

1//! Outer driver for a single fixed-ρ PIRLS fit.
2//!
3//! Owns:
4//! - `fit_model_for_fixed_rho` and `fit_model_for_fixed_rho_with_adaptive_kkt`
5//!   — build the working model, run the inner LM loop, assemble the final result.
6//! - `PirlsProblem`, `PenaltyConfig`, `PirlsConfig` — the configuration types.
7//! - Helper functions exclusive to the fixed-ρ fitting path: constraint
8//!   transformation, sparse-native decision, reparam materialisation, prior
9//!   shift assembly, initial-β guess, Gaussian short-circuit assembly, etc.
10//! - The two GPU dispatch blocks (Stage 3.3) that call into
11//!   `crate::gpu::pirls_dispatch_wire`.
12
13use super::{
14    // state re-exports
15    AdaptiveKktTolerance,
16    ExportedLaplaceCurvature,
17    FirthDiagnostics,
18    GamWorkingModel,
19    HessianCurvatureKind,
20    // penalty types
21    KroneckerQsTransform,
22    LinearInequalityConstraints,
23    PirlsCoordinateFrame,
24    PirlsLinearSolvePath,
25    PirlsPenalty,
26    PirlsResult,
27    PirlsStatus,
28    PirlsWorkspace,
29    SparsePirlsDecision,
30    WorkingModelIterationInfo,
31    WorkingModelPirlsOptions,
32    WorkingModelPirlsResult,
33    WorkingReparamTransform,
34    WorkingState,
35    // misc helpers
36    array1_l2_norm,
37    attach_penalty_shift,
38    // compute functions
39    calculate_deviance,
40    // edf helpers
41    calculate_edf_with_penalty,
42    calculate_edfwithworkspace_with_penalty,
43    calculate_loglikelihood_omitting_constants,
44    compute_constraint_kkt_diagnostics,
45    computeworkingweight_derivatives_from_eta,
46    inf_norm,
47    runworking_model_pirls,
48    should_use_sparse_native_pirls,
49    solve_penalized_least_squares_implicit,
50    standard_inverse_link_jet,
51};
52use super::{
53    ArrowSchurInnerConfig, GamModelFinalState, effective_kkt_tolerance,
54    project_coefficients_to_lower_bounds,
55};
56use gam_terms::construction::{KroneckerReparamResult, ReparamResult};
57use crate::estimate::EstimationError;
58use gam_linalg::faer_ndarray::fast_ab;
59use gam_linalg::matrix::{DesignMatrix, LinearOperator, ReparamOperator, SymmetricMatrix};
60use crate::mixture_link::inverse_link_has_fisher_weight_jet;
61use gam_math::probability::standard_normal_quantile;
62use crate::active_set;
63use crate::gpu::pirls_host_dispatch::{try_gaussian_pls_gpu, try_pirls_loop_gpu};
64use gam_problem::{
65    Coefficients, GlmLikelihoodSpec, InverseLink, LinearPredictor, LinkFunction,
66    LogSmoothingParamsView, MixtureLinkState, ResponseFamily, RidgePassport, RidgePolicy,
67    SasLinkState, StandardLink,
68};
69use faer::sparse::{SparseColMat, Triplet};
70use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
71use std::borrow::Cow;
72use std::sync::Arc;
73
74pub(super) fn default_beta_guess_external(
75    p: usize,
76    link_function: LinkFunction,
77    y: ArrayView1<f64>,
78    priorweights: ArrayView1<f64>,
79    mixture_link_state: Option<&MixtureLinkState>,
80    sas_link_state: Option<&SasLinkState>,
81) -> Array1<f64> {
82    let mut beta = Array1::<f64>::zeros(p);
83    let intercept_col = 0usize;
84    match link_function {
85        LinkFunction::Logit
86        | LinkFunction::Probit
87        | LinkFunction::CLogLog
88        | LinkFunction::Sas
89        | LinkFunction::BetaLogistic => {
90            let mut weighted_sum = 0.0;
91            let mut totalweight = 0.0;
92            for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
93                weighted_sum += wi * yi;
94                totalweight += wi;
95            }
96            if totalweight > 0.0 {
97                let prevalence =
98                    ((weighted_sum + 0.5) / (totalweight + 1.0)).clamp(1e-6, 1.0 - 1e-6);
99                beta[intercept_col] = match link_function {
100                    LinkFunction::Logit => (prevalence / (1.0 - prevalence)).ln(),
101                    LinkFunction::Probit => {
102                        standard_normal_quantile(prevalence).unwrap_or_else(|_| {
103                            // `prevalence` is clamped to (0, 1); this fallback is
104                            // only for defensive robustness under non-finite upstream inputs.
105                            (prevalence / (1.0 - prevalence)).ln()
106                        })
107                    }
108                    LinkFunction::CLogLog => (-(1.0 - prevalence).ln()).ln(),
109                    LinkFunction::Sas => solve_intercept_for_prevalence(
110                        link_function,
111                        prevalence,
112                        mixture_link_state,
113                        sas_link_state,
114                    )
115                    .unwrap_or_else(|| {
116                        standard_normal_quantile(prevalence)
117                            .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
118                    }),
119                    LinkFunction::BetaLogistic => solve_intercept_for_prevalence(
120                        link_function,
121                        prevalence,
122                        mixture_link_state,
123                        sas_link_state,
124                    )
125                    .unwrap_or_else(|| {
126                        standard_normal_quantile(prevalence)
127                            .unwrap_or_else(|_| (prevalence / (1.0 - prevalence)).ln())
128                    }),
129                    // Outer arm guard already filtered out Log/Identity; fall
130                    // back to the canonical logit transform for defensive safety
131                    // if these are ever reached unexpectedly.
132                    LinkFunction::Log | LinkFunction::Identity => {
133                        (prevalence / (1.0 - prevalence)).ln()
134                    }
135                };
136                if mixture_link_state.is_some() {
137                    beta[intercept_col] = solve_intercept_for_prevalence(
138                        link_function,
139                        prevalence,
140                        mixture_link_state,
141                        sas_link_state,
142                    )
143                    .unwrap_or(beta[intercept_col]);
144                }
145            }
146        }
147        LinkFunction::Identity => {
148            let mut weighted_sum = 0.0;
149            let mut totalweight = 0.0;
150            for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
151                weighted_sum += wi * yi;
152                totalweight += wi;
153            }
154            if totalweight > 0.0 {
155                beta[intercept_col] = weighted_sum / totalweight;
156            }
157        }
158        LinkFunction::Log => {
159            // For log link, intercept = ln(weighted mean of y)
160            let mut weighted_sum = 0.0;
161            let mut totalweight = 0.0;
162            for (&yi, &wi) in y.iter().zip(priorweights.iter()) {
163                weighted_sum += wi * yi;
164                totalweight += wi;
165            }
166            if totalweight > 0.0 {
167                let mean_y = (weighted_sum / totalweight).max(1e-10);
168                beta[intercept_col] = mean_y.ln();
169            }
170        }
171    }
172    beta
173}
174
175pub(super) fn solve_intercept_for_prevalence(
176    link_function: LinkFunction,
177    prevalence: f64,
178    mixture_link_state: Option<&MixtureLinkState>,
179    sas_link_state: Option<&SasLinkState>,
180) -> Option<f64> {
181    #[inline]
182    fn f_eta(
183        link_function: LinkFunction,
184        eta: f64,
185        prevalence: f64,
186        mixture_link_state: Option<&MixtureLinkState>,
187        sas_link_state: Option<&SasLinkState>,
188    ) -> f64 {
189        let inverse_link = if let Some(state) = mixture_link_state {
190            InverseLink::Mixture(state.clone())
191        } else if let Some(state) = sas_link_state {
192            match link_function {
193                LinkFunction::BetaLogistic => InverseLink::BetaLogistic(*state),
194                _ => InverseLink::Sas(*state),
195            }
196        } else {
197            // SAFETY: when `sas_link_state` is None, `solve_intercept_for_prevalence`
198            // is only invoked with the five legal `StandardLink` variants (the
199            // dispatch site at pirls.rs:4203 routes Sas/BetaLogistic into the
200            // Some branch above with state).
201            InverseLink::Standard(StandardLink::try_from(link_function).expect(
202                "state-bearing link reached state-less arm in solve_intercept_for_prevalence",
203            ))
204        };
205        standard_inverse_link_jet(&inverse_link, eta)
206            .map(|jet| jet.mu - prevalence)
207            .unwrap_or(f64::NAN)
208    }
209
210    let mut lo = -40.0;
211    let mut hi = 40.0;
212    let mut f_lo = f_eta(
213        link_function,
214        lo,
215        prevalence,
216        mixture_link_state,
217        sas_link_state,
218    );
219    let mut f_hi = f_eta(
220        link_function,
221        hi,
222        prevalence,
223        mixture_link_state,
224        sas_link_state,
225    );
226    if !(f_lo.is_finite() && f_hi.is_finite()) {
227        return None;
228    }
229    for _ in 0..8 {
230        if f_lo <= 0.0 && f_hi >= 0.0 {
231            break;
232        }
233        lo *= 2.0;
234        hi *= 2.0;
235        f_lo = f_eta(
236            link_function,
237            lo,
238            prevalence,
239            mixture_link_state,
240            sas_link_state,
241        );
242        f_hi = f_eta(
243            link_function,
244            hi,
245            prevalence,
246            mixture_link_state,
247            sas_link_state,
248        );
249        if !(f_lo.is_finite() && f_hi.is_finite()) {
250            return None;
251        }
252    }
253    if f_lo > 0.0 {
254        return Some(lo);
255    }
256    if f_hi < 0.0 {
257        return Some(hi);
258    }
259    for _ in 0..80 {
260        let mid = 0.5 * (lo + hi);
261        let f_mid = f_eta(
262            link_function,
263            mid,
264            prevalence,
265            mixture_link_state,
266            sas_link_state,
267        );
268        if !f_mid.is_finite() {
269            return None;
270        }
271        if f_mid > 0.0 {
272            hi = mid;
273        } else {
274            lo = mid;
275        }
276    }
277    Some(0.5 * (lo + hi))
278}
279
280pub(super) fn assemble_pirls_result(
281    working_summary: &WorkingModelPirlsResult,
282    likelihood: GlmLikelihoodSpec,
283    offset: ArrayView1<'_, f64>,
284    penalized_hessian_transformed: SymmetricMatrix,
285    stabilizedhessian_transformed: SymmetricMatrix,
286    edf: f64,
287    penalty_term: f64,
288    finalmu: &Array1<f64>,
289    finalweights: &Array1<f64>,
290    scoreweights: &Array1<f64>,
291    finalz: &Array1<f64>,
292    final_c: &Array1<f64>,
293    final_d: &Array1<f64>,
294    final_dmu_deta: &Array1<f64>,
295    final_d2mu_deta2: &Array1<f64>,
296    final_d3mu_deta3: &Array1<f64>,
297    status: PirlsStatus,
298    reparam_result: ReparamResult,
299    x_transformed: DesignMatrix,
300    coordinate_frame: PirlsCoordinateFrame,
301    linear_constraints_transformed: Option<LinearInequalityConstraints>,
302) -> PirlsResult {
303    let final_eta_arr = working_summary.state.eta.as_ref().clone();
304    PirlsResult {
305        likelihood,
306        beta_transformed: working_summary.beta.clone(),
307        penalized_hessian_transformed,
308        stabilizedhessian_transformed,
309        ridge_passport: RidgePassport::scaled_identity(
310            working_summary.state.ridge_used,
311            RidgePolicy::explicit_stabilization_full(),
312        ),
313        ridge_used: working_summary.state.ridge_used,
314        deviance: working_summary.state.deviance,
315        edf,
316        stable_penalty_term: penalty_term,
317        firth: working_summary.state.firth.clone(),
318        finalweights: finalweights.clone(),
319        final_offset: offset.to_owned(),
320        final_eta: final_eta_arr,
321        finalmu: finalmu.clone(),
322        solveweights: scoreweights.clone(),
323        solveworking_response: finalz.clone(),
324        solvemu: finalmu.clone(),
325        solve_dmu_deta: final_dmu_deta.clone(),
326        solve_d2mu_deta2: final_d2mu_deta2.clone(),
327        solve_d3mu_deta3: final_d3mu_deta3.clone(),
328        solve_c_array: final_c.clone(),
329        solve_d_array: final_d.clone(),
330        derivatives_unsupported: false,
331        status,
332        iteration: working_summary.iterations,
333        max_abs_eta: working_summary.max_abs_eta,
334        lastgradient_norm: working_summary.lastgradient_norm,
335        gradient_natural_scale: working_summary.state.gradient_natural_scale,
336        last_deviance_change: working_summary.last_deviance_change,
337        last_step_halving: working_summary.last_step_halving,
338        hessian_curvature: working_summary.state.hessian_curvature,
339        exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
340        final_lm_lambda: working_summary.final_lm_lambda,
341        final_accept_rho: working_summary.final_accept_rho,
342        constraint_kkt: working_summary.constraint_kkt.clone(),
343        linear_constraints_transformed,
344        reparam_result,
345        x_transformed,
346        coordinate_frame,
347        used_device: false,
348        cache_compacted: false,
349        min_penalized_deviance: working_summary.min_penalized_deviance,
350    }
351}
352
353pub(super) fn detect_logit_instability(
354    link: LinkFunction,
355    response: &ResponseFamily,
356    has_penalty: bool,
357    firth_active: bool,
358    summary: &WorkingModelPirlsResult,
359    finalmu: &Array1<f64>,
360    finalweights: &Array1<f64>,
361    y: ArrayView1<'_, f64>,
362) -> bool {
363    // Perfect / quasi-perfect separation is a *Bernoulli/Binomial* pathology.
364    // Every heuristic below is binary-response–specific: saturation toward
365    // μ ∈ {0, 1}, the `yᵢ > 0.5` order-separation split, and working-weight
366    // collapse only carry meaning when each `yᵢ` is a 0/1 outcome (or a
367    // proportion of Bernoulli trials). The Beta family also fits through the
368    // logit link, but its response is *continuous* on (0, 1): a perfectly
369    // healthy monotone mean (μ increasing in a covariate ⇒ rows with y > 0.5
370    // sit at higher η than rows with y ≤ 0.5) trivially satisfies the
371    // `order_separated` test, so gating this detector on the logit link alone
372    // misclassifies well-behaved Beta fits as separated and forces a spurious
373    // inner-solve retreat at every smoothing-parameter seed (issue #499).
374    // Gate strictly on the Binomial response so only binary GLMs are screened.
375    if !matches!(response, ResponseFamily::Binomial) || link != LinkFunction::Logit || firth_active
376    {
377        return false;
378    }
379
380    // Separation-detection policy thresholds. Each is a heuristic cut-off, not
381    // a math identity: they decide when a binary-logit fit has drifted into the
382    // perfect/quasi-perfect separation regime and the inner solve must retreat.
383    //
384    // `ORDER_SEPARATION_ETA_GAP`: a strictly positive η-gap between the lowest
385    //   η among y=1 rows and the highest among y=0 rows means the two classes
386    //   are linearly separable on the linear predictor.
387    // `EXTREME_ETA`: |η| this large drives μ to within machine-ε of {0,1}.
388    // `SATURATION_FRACTION` / `SEVERE_SATURATION_FRACTION`: share of fitted μ
389    //   pinned to the {0,1} boundary that flags (severe) saturation.
390    // `DEGENERATE_DEVIANCE_PER_SAMPLE` / `EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE`:
391    //   near-zero per-sample deviance means the model fits the data perfectly.
392    // `EXTREME_BETA_NORM`: coefficient norm blow-up characteristic of the MLE
393    //   escaping to infinity under separation.
394    // `WEIGHT_COLLAPSE_FRACTION`: share of working weights collapsed to ~0.
395    const ORDER_SEPARATION_ETA_GAP: f64 = 1e-3;
396    const EXTREME_ETA: f64 = 30.0;
397    const SATURATION_FRACTION: f64 = 0.98;
398    const SEVERE_SATURATION_FRACTION: f64 = 0.995;
399    const DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-3;
400    const EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE: f64 = 1e-6;
401    const EXTREME_BETA_NORM: f64 = 1e4;
402    const WEIGHT_COLLAPSE_FRACTION: f64 = 0.98;
403
404    let n = y.len() as f64;
405    if n == 0.0 {
406        return false;
407    }
408
409    let max_abs_eta = summary.max_abs_eta;
410    let sat_fraction = {
411        const SAT_EPS: f64 = 1e-3;
412        finalmu
413            .iter()
414            .filter(|&&m| m <= SAT_EPS || m >= 1.0 - SAT_EPS)
415            .count() as f64
416            / n
417    };
418
419    let weight_collapse_fraction = {
420        const WEIGHT_EPS: f64 = 1e-8;
421        finalweights
422            .iter()
423            .filter(|&&w| w <= WEIGHT_EPS || !w.is_finite())
424            .count() as f64
425            / n
426    };
427
428    let beta_norm = summary.beta.as_ref().dot(summary.beta.as_ref()).sqrt();
429    let dev_per_sample = summary.state.deviance / n;
430
431    let mut has_pos = false;
432    let mut has_neg = false;
433    let mut min_eta_pos = f64::INFINITY;
434    let mut max_eta_neg = f64::NEG_INFINITY;
435    for (eta_i, &yi) in summary.state.eta.iter().zip(y.iter()) {
436        if yi > 0.5 {
437            has_pos = true;
438            if *eta_i < min_eta_pos {
439                min_eta_pos = *eta_i;
440            }
441        } else {
442            has_neg = true;
443            if *eta_i > max_eta_neg {
444                max_eta_neg = *eta_i;
445            }
446        }
447    }
448    let order_separated =
449        has_pos && has_neg && (min_eta_pos - max_eta_neg) > ORDER_SEPARATION_ETA_GAP;
450
451    let classic_signals = max_abs_eta > EXTREME_ETA
452        || sat_fraction > SATURATION_FRACTION
453        || dev_per_sample < DEGENERATE_DEVIANCE_PER_SAMPLE
454        || beta_norm > EXTREME_BETA_NORM;
455
456    if !has_penalty {
457        return classic_signals || order_separated;
458    }
459
460    let severe_saturation = sat_fraction > SEVERE_SATURATION_FRACTION && max_abs_eta > EXTREME_ETA;
461    let weights_collapsed = weight_collapse_fraction > WEIGHT_COLLAPSE_FRACTION;
462    let dev_extremely_small = dev_per_sample < EXTREME_DEGENERATE_DEVIANCE_PER_SAMPLE;
463
464    order_separated || severe_saturation || weights_collapsed || dev_extremely_small
465}
466
467/// Stack λ-weighted penalty roots from canonical penalties into a single
468/// `total_rank × p` matrix for PIRLS. Each block-local root is embedded
469/// into the full column space on-the-fly.
470pub(super) fn stack_lambdaweighted_penalty_root_canonical(
471    penalties: &[gam_terms::construction::CanonicalPenalty],
472    lambdas: &[f64],
473    p: usize,
474) -> Array2<f64> {
475    let totalrows: usize = penalties.iter().map(|cp| cp.rank()).sum();
476    if totalrows == 0 {
477        return Array2::zeros((0, p));
478    }
479    let mut e = Array2::<f64>::zeros((totalrows, p));
480    let mut row_start = 0usize;
481    for (k, cp) in penalties.iter().enumerate() {
482        let rows = cp.rank();
483        if rows == 0 {
484            continue;
485        }
486        let scale = lambdas.get(k).copied().unwrap_or(0.0).max(0.0).sqrt();
487        if scale != 0.0 {
488            // Embed block-local root (rank × block_dim) into full width (rank × p).
489            let r = &cp.col_range;
490            for row in 0..rows {
491                for col in 0..cp.block_dim() {
492                    e[[row_start + row, r.start + col]] = scale * cp.root[[row, col]];
493                }
494            }
495        }
496        row_start += rows;
497    }
498    e
499}
500
501pub(super) fn build_sparse_native_reparam_result(
502    base: ReparamResult,
503    penalties: &[gam_terms::construction::CanonicalPenalty],
504    lambdas: &[f64],
505    p: usize,
506) -> ReparamResult {
507    // Map the engine penalty back into identity (original) coordinates. The
508    // engine returns `s_transformed = Qsᵀ S Qs` (and `e_transformed = E Qs`)
509    // with `S = S_λ + shrinkage·P_range` already folded in (so it matches the
510    // reported `log_det`/`det1`). With the sparse-native `qs = I` we need that
511    // SAME penalty expressed in original coordinates: `S_orig = Qs S_transformed
512    // Qsᵀ`. Rebuilding `S_orig` from the bare lambda-weighted canonical sum
513    // would DROP the shrinkage ridge and desync the inner penalized Hessian from
514    // the penalty log-determinant the REML criterion uses for this fit — the
515    // cross-backend λ-selection divergence (#1266 class). Round-tripping the
516    // engine penalty through `Qs` keeps the inner solve, EDF, and REML logdet on
517    // one penalty.
518    let qs = &base.qs;
519    let s_orig = if qs.nrows() == p && qs.ncols() == base.s_transformed.nrows() {
520        // S_orig = Qs · S_transformed · Qsᵀ
521        let qs_s = fast_ab(qs, &base.s_transformed);
522        qs_s.dot(&qs.t())
523    } else {
524        // Degenerate fallback (engine produced no transform): use the bare
525        // lambda-weighted sum. Shrinkage is zero in this branch by construction.
526        let mut s_original = Array2::<f64>::zeros((p, p));
527        for (k, cp) in penalties.iter().enumerate() {
528            let lambda_k = lambdas.get(k).copied().unwrap_or(0.0);
529            if lambda_k != 0.0 {
530                cp.accumulate_weighted(&mut s_original, lambda_k);
531            }
532        }
533        s_original
534    };
535    // E_orig = E_transformed · Qsᵀ  (so that E_origᵀ E_orig = S_orig and the EDF
536    // augmented system matches the inner Hessian).
537    let e_orig = if qs.nrows() == p && base.e_transformed.ncols() == qs.ncols() {
538        base.e_transformed.dot(&qs.t())
539    } else {
540        stack_lambdaweighted_penalty_root_canonical(penalties, lambdas, p)
541    };
542    let u_original = if base.u_truncated.nrows() == p {
543        fast_ab(&base.qs, &base.u_truncated)
544    } else {
545        Array2::<f64>::eye(p)
546    };
547    // In the sparse-native path, qs = I, so the penalties are already in the
548    // right coordinate frame. We keep them as-is in canonical_transformed.
549    let canonical_transformed: Vec<gam_terms::construction::CanonicalPenalty> = penalties.to_vec();
550    ReparamResult {
551        penalty_shrinkage_ridge: base.penalty_shrinkage_ridge,
552        s_transformed: s_orig,
553        log_det: base.log_det,
554        det1: base.det1,
555        qs: Array2::<f64>::eye(p),
556        canonical_transformed,
557        e_transformed: e_orig,
558        u_truncated: u_original,
559    }
560}
561
562pub(super) fn build_diagonal_penalty_from_kronecker(
563    kron_result: &KroneckerReparamResult,
564    lambdas: &[f64],
565) -> PirlsPenalty {
566    let d = kron_result.marginal_dims.len();
567    let p: usize = kron_result.marginal_dims.iter().copied().product();
568    let mut diag = Array1::<f64>::zeros(p);
569    let mut positive_indices = Vec::new();
570
571    const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
572    let mut multi_idx = vec![0usize; d];
573    let mut flat = 0usize;
574    loop {
575        let mut sigma = 0.0;
576        let mut structural_sigma = 0.0;
577        for k in 0..d {
578            let marginal_eigenvalue = kron_result.marginal_eigenvalues[k][multi_idx[k]];
579            structural_sigma += marginal_eigenvalue;
580            sigma += lambdas[k] * marginal_eigenvalue;
581        }
582        let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
583        if kron_result.has_double_penalty && lambdas.len() > d && joint_null {
584            sigma += lambdas[d];
585        }
586        if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
587            sigma += kron_result.penalty_shrinkage_ridge;
588        }
589        diag[flat] = sigma;
590        if sigma > 0.0 {
591            positive_indices.push(flat);
592        }
593        flat += 1;
594
595        let mut carry = true;
596        for dim in (0..d).rev() {
597            if carry {
598                multi_idx[dim] += 1;
599                if multi_idx[dim] < kron_result.marginal_dims[dim] {
600                    carry = false;
601                } else {
602                    multi_idx[dim] = 0;
603                }
604            }
605        }
606        if carry {
607            break;
608        }
609    }
610
611    PirlsPenalty::Diagonal {
612        diag,
613        positive_indices,
614        linear_shift: Array1::zeros(p),
615        constant_shift: 0.0,
616        prior_mean_target: Array1::zeros(p),
617    }
618}
619
620pub(super) fn canonical_prior_shift(
621    penalties: &[gam_terms::construction::CanonicalPenalty],
622    lambdas: &[f64],
623    p: usize,
624) -> (Array1<f64>, f64) {
625    let mut linear = Array1::<f64>::zeros(p);
626    let mut constant = 0.0;
627    for (idx, cp) in penalties.iter().enumerate() {
628        let Some(&lambda) = lambdas.get(idx) else {
629            continue;
630        };
631        if lambda == 0.0 {
632            continue;
633        }
634        linear += &cp.prior_linear_shift(lambda);
635        constant += cp.prior_constant_shift(lambda);
636    }
637    (linear, constant)
638}
639
640/// Aggregate prior-mean target across canonical penalty blocks: the sum of
641/// each block's `full_width_prior_mean()`. Used by the PIRLS solve sites
642/// that add a fixed stabilization ridge `δI` to the penalized Hessian — they
643/// must also add `δ · prior_mean_target` to the RHS to keep `β = μ` recovery
644/// exact when the data carries no information (X'WX = 0). Equivalent to
645/// `canonical_prior_shift` with all λ = 1 and dropping `S_k` from the linear
646/// piece (i.e., raw μ rather than `S_k μ`). Returned in the *original*
647/// coordinates; callers transform if needed.
648pub(super) fn canonical_prior_mean_aggregate(
649    penalties: &[gam_terms::construction::CanonicalPenalty],
650    p: usize,
651) -> Array1<f64> {
652    let mut mean = Array1::<f64>::zeros(p);
653    for cp in penalties {
654        mean += &cp.full_width_prior_mean();
655    }
656    mean
657}
658
659pub struct PirlsProblem<'a, X> {
660    pub x: X,
661    pub offset: ArrayView1<'a, f64>,
662    pub y: ArrayView1<'a, f64>,
663    pub priorweights: ArrayView1<'a, f64>,
664    pub covariate_se: Option<ArrayView1<'a, f64>>,
665    /// When set, the inner PLS solver reuses the precomputed `XᵀWX` and
666    /// `XᵀW(y − offset)` in *original* coordinates instead of streaming the
667    /// O(N·p²) GEMM and the O(N·p) matvec on every outer REML iteration.
668    ///
669    /// Valid only when the family is Gaussian + Identity link, prior weights
670    /// are constant across outer iterations (always true in the REML outer
671    /// loop), no Firth bias reduction, and no inequality / lower-bound
672    /// constraints (matching the existing Identity short-circuit at
673    /// `pirls.rs:6237`). The penalty `λ·S` is still added per-λ on top of
674    /// the cached `XᵀWX`.
675    pub gaussian_fixed_cache: Option<&'a GaussianFixedCache>,
676    /// Frozen-weight first-Fisher-step data-fit Gram `XᵀWX` for a GLM
677    /// design-moving ψ-trial (#1111 / #1033 mechanism (c)), in *original*
678    /// (conditioned `x_fit`) coordinates. When set, the iterative GLM P-IRLS
679    /// serves its FIRST Fisher-scoring iteration's `XᵀWX` from this matrix
680    /// instead of streaming the O(N·p²) weighted cross-product; every later
681    /// iteration restreams the true moving `W`, so the converged β̂ is
682    /// unchanged. Mutually distinct from `gaussian_fixed_cache` (which is the
683    /// Gaussian-identity converged-objective short-circuit); this is the GLM
684    /// first-step lane and never short-circuits the iteration count.
685    pub glm_first_step_gram: Option<&'a Array2<f64>>,
686}
687
688// GaussianFixedCache is defined in pls_solver.
689pub use super::pls_solver::GaussianFixedCache;
690
691pub struct PenaltyConfig<'a> {
692    /// Block-local canonical penalties with precomputed roots and spectral data.
693    /// This is the single canonical penalty representation — no full-width
694    /// `rank × p` roots are stored. When the reparameterization engine needs
695    /// full-width roots, they are derived on-the-fly from these block-local roots.
696    pub canonical_penalties: &'a [gam_terms::construction::CanonicalPenalty],
697    pub balanced_penalty_root: Option<&'a Array2<f64>>,
698    pub reparam_invariant: Option<&'a gam_terms::construction::ReparamInvariant>,
699    pub p: usize,
700    pub coefficient_lower_bounds: Option<&'a Array1<f64>>,
701    pub linear_constraints_original: Option<&'a LinearInequalityConstraints>,
702    /// Relative shrinkage floor for eigenvalues of the penalized block.
703    /// If `Some(epsilon)`, a rho-independent ridge of `epsilon * max_balanced_eigenvalue`
704    /// is added to prevent barely-penalized directions from causing pathological
705    /// non-Gaussianity in the posterior. Typical value: `1e-6`. `None` disables.
706    pub penalty_shrinkage_floor: Option<f64>,
707    /// When set, the penalties have Kronecker (tensor-product) structure.
708    /// The reparameterization engine will use factored Qs = U_1 ⊗ ... ⊗ U_d
709    /// instead of eigendecomposing the full p×p balanced penalty.
710    pub kronecker_factored: Option<&'a gam_terms::basis::KroneckerFactoredBasis>,
711}
712
713/// P-IRLS solver that follows mgcv's architecture exactly
714///
715/// This function implements the complete algorithm from mgcv's gam.fit3 function
716/// for fitting a GAM model with a fixed set of smoothing parameters:
717///
718/// - Perform stable reparameterization ONCE at the beginning (mgcv's gam.reparam)
719/// - Transform the design matrix into this stable basis
720/// - Extract a single penalty square root from the transformed penalty
721/// - Run the P-IRLS loop entirely in the transformed basis
722/// - Transform the coefficients back to the original basis only when returning
723/// - Reuse a cached balanced penalty root when available to avoid repeated eigendecompositions
724///
725/// This architecture ensures optimal numerical stability throughout the entire
726/// fitting process by working in a well-conditioned parameter space.
727pub fn fit_model_for_fixed_rho<'a, X: Into<DesignMatrix> + Clone>(
728    rho: LogSmoothingParamsView<'_>,
729    problem: PirlsProblem<'a, X>,
730    penalty: PenaltyConfig<'_>,
731    config: &PirlsConfig,
732    warm_start_beta: Option<&Coefficients>,
733) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
734    fit_model_for_fixed_rho_with_adaptive_kkt(
735        rho,
736        problem,
737        penalty,
738        config,
739        warm_start_beta,
740        None,
741        false,
742    )
743}
744
745/// `refine_dispersion_at_converged_eta`: when `true`, after the inner P-IRLS
746/// solve converges, re-estimate the family's estimated dispersion nuisance — the
747/// Gamma shape ν = 1/φ or the Beta precision φ — at the *converged* linear
748/// predictor and iterate the (β, dispersion) pair to its joint fixed point at the
749/// current λ (see the in-body comments at each refresh loop). This is ON only for
750/// the single final, reported fit at the REML-selected λ (#678 for Gamma, #769
751/// for Beta). It is deliberately OFF for every REML cost / sigma-point evaluation:
752/// re-profiling the dispersion against each trial λ's converged residuals would
753/// couple the scale to the smoothing parameter (a flat over-smoothed μ inflates
754/// the deviance ⇒ a smaller effective precision ⇒ a smaller `deviance/(2φ)` REML
755/// term), perversely rewarding over-smoothing and biasing λ selection. mgcv
756/// likewise estimates the scale at the converged fit, not inside the λ search.
757///
758/// The Gamma and Beta cases differ in what the re-solve buys. For Gamma the shape
759/// is a pure nuisance — β̂ is essentially scale-free — so the re-solve only keeps
760/// the reported dispersion and SEs self-consistent. For Beta the precision φ
761/// enters the *mean* score through the digamma terms
762/// `μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ)`, so a φ measured at the cold null predictor
763/// (μ ≈ 0.5) attenuates every slope toward zero; here the fixed point is
764/// load-bearing — it is what recovers the correct mean coefficients (the betareg
765/// alternating mean-fit ↔ φ-estimate scheme).
766pub(crate) fn fit_model_for_fixed_rho_with_adaptive_kkt<'a, X: Into<DesignMatrix> + Clone>(
767    rho: LogSmoothingParamsView<'_>,
768    problem: PirlsProblem<'a, X>,
769    penalty: PenaltyConfig<'_>,
770    config: &PirlsConfig,
771    warm_start_beta: Option<&Coefficients>,
772    adaptive_kkt_tolerance: Option<AdaptiveKktTolerance>,
773    refine_dispersion_at_converged_eta: bool,
774) -> Result<(PirlsResult, WorkingModelPirlsResult), EstimationError> {
775    let PirlsProblem {
776        x,
777        offset,
778        y,
779        priorweights,
780        covariate_se,
781        gaussian_fixed_cache,
782        glm_first_step_gram,
783    } = problem;
784    let quadctx = crate::quadrature::QuadratureContext::new();
785    // gam#1379 — finite-ceiling λ = exp(ρ). When the outer REML / spatial-κ
786    // optimizer drives a redundant penalty direction's log-λ past ~709 (it does
787    // so deterministically on 1-D `matern(x)` / `bs="gp"` data whose kernel
788    // already controls the smoothness an operator block also penalizes, so REML
789    // wants λ → ∞), `exp(ρ)` overflows to `+∞`. A literal `+∞` λ then poisons
790    // every downstream consumer that forms `λ · S`: the range-penalty block
791    // assembled as `Σ λ_k S_k` hits `∞ · 0 = NaN` and the eigensolve aborts, and
792    // the final fit-result validation rejects the non-finite stored λ outright.
793    // `exp(709.78) ≈ 1.8e308` is already the largest finite f64; capping log-λ at
794    // a value whose `exp` stays finite pins the over-penalized direction exactly
795    // as hard as `+∞` would for every finite-arithmetic consumer (the penalized
796    // block is numerically a hard constraint at λ this large) while keeping
797    // `λ · 0 = 0`. Ordinary finite λ are untouched, so non-degenerate fits and
798    // their recorded λ̂ are bit-identical. `ln(1e300) ≈ 690.78` keeps this in lock
799    // step with the post-exp λ ceiling (`1e300`) used by the reparam range-block
800    // assembly and the stored fit result, so a fully-smoothed direction carries
801    // the SAME finite λ everywhere it is consumed.
802    const LOG_LAMBDA_CEILING: f64 = 690.0;
803    let lambdas = rho.mapv(|r| {
804        if r.is_nan() {
805            r
806        } else {
807            r.min(LOG_LAMBDA_CEILING).exp()
808        }
809    });
810    let lambdas_slice = lambdas.as_slice_memory_order().ok_or_else(|| {
811        EstimationError::InvalidInput("non-contiguous lambda storage".to_string())
812    })?;
813
814    let likelihood = &config.likelihood;
815    let link_function = config.link_function();
816
817    use gam_terms::construction::{
818        EngineDims, create_balanced_penalty_root_from_canonical,
819        stable_reparameterization_engine_canonical,
820    };
821
822    let eb_cow: Cow<'_, Array2<f64>> = if let Some(precomputed) = penalty.balanced_penalty_root {
823        Cow::Borrowed(precomputed)
824    } else {
825        Cow::Owned(create_balanced_penalty_root_from_canonical(
826            penalty.canonical_penalties,
827            penalty.p,
828        )?)
829    };
830    let eb: &Array2<f64> = eb_cow.as_ref();
831
832    // Build a cheap weighted penalty sum for the sparse-native decision
833    // WITHOUT running the expensive eigendecomposition engine.
834    // The full reparameterization is deferred until we know which path we need.
835    let cheap_s_lambda: Option<Array2<f64>> = if penalty.kronecker_factored.is_none() {
836        let mut s = Array2::<f64>::zeros((penalty.p, penalty.p));
837        for (k, cp) in penalty.canonical_penalties.iter().enumerate() {
838            let lam = lambdas_slice.get(k).copied().unwrap_or(0.0);
839            if lam != 0.0 {
840                cp.accumulate_weighted(&mut s, lam);
841            }
842        }
843        Some(s)
844    } else {
845        None
846    };
847    let kronecker_runtime = if let Some(kron) = penalty.kronecker_factored {
848        // The marginal eigensystems and reparameterized marginals depend only on
849        // the fixed marginal designs/penalties, not on λ = exp(ρ). Memoize them
850        // once per fit so each outer REML iterate reuses the eigendecomposition
851        // instead of recomputing `eigh()` + `B_k·U_k` every call; only the cheap
852        // λ-grid logdet/derivative sweep is redone here. Bit-identical to the
853        // unmemoized engine.
854        let invariant = kron.invariant_structure()?;
855        let kron_result = gam_terms::construction::kronecker_reparameterization_engine_with_invariant(
856            invariant.as_ref(),
857            &kron.marginal_dims,
858            lambdas_slice,
859            kron.has_double_penalty,
860            penalty.penalty_shrinkage_floor,
861        )?;
862        let transform = Arc::new(KroneckerQsTransform::new(&kron_result));
863        let penalty_diag = build_diagonal_penalty_from_kronecker(&kron_result, lambdas_slice);
864        Some((kron_result, transform, penalty_diag))
865    } else {
866        None
867    };
868    // Constraint transformation is deferred until after the sparse-native
869    // decision, because the dense reparameterization engine (which provides Qs)
870    // is now run lazily.  Kronecker constraints can be built eagerly since
871    // the Kronecker transform is already available.
872    let kronecker_constraints = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
873        let tb = build_transformed_lower_bound_constraints_with_transform(
874            &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
875            penalty.coefficient_lower_bounds,
876        );
877        let tl = build_transformed_linear_constraints_with_transform(
878            &WorkingReparamTransform::Kronecker(Arc::clone(transform)),
879            penalty.linear_constraints_original,
880        );
881        Some(merge_linear_constraints(tb, tl))
882    } else {
883        None
884    };
885
886    let x_original: DesignMatrix = x.into();
887    // Auto-detect sparse structure in dense designs so the sparse-native path
888    // can engage for structurally sparse models that happen to be stored dense.
889    let x_original = {
890        let auto_sparse = x_original
891            .as_dense()
892            .and_then(|dense| sparse_from_denseview(dense.view()));
893        auto_sparse.unwrap_or(x_original)
894    };
895    let ebrows = eb.nrows();
896    let erows = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
897        penalty_diag.rank()
898    } else {
899        // Compute penalty root rank cheaply from canonical penalties.
900        penalty
901            .canonical_penalties
902            .iter()
903            .map(|cp| cp.rank())
904            .sum::<usize>()
905    };
906    let mut workspace = PirlsWorkspace::new(x_original.nrows(), x_original.ncols(), ebrows, erows);
907    let solver_decision = if let Some((_, _, _)) = kronecker_runtime.as_ref() {
908        SparsePirlsDecision {
909            path: PirlsLinearSolvePath::DenseTransformed,
910            reason: "kronecker_runtime",
911            p: x_original.ncols(),
912            nnz_x: 0,
913            nnz_xtwx_symbolic: None,
914            nnz_s_lambda: 0,
915            nnz_h_est: None,
916            density_h_est: None,
917        }
918    } else {
919        should_use_sparse_native_pirls(
920            &mut workspace,
921            &x_original,
922            cheap_s_lambda
923                .as_ref()
924                .expect("cheap_s_lambda should be present outside Kronecker path"),
925            penalty.coefficient_lower_bounds,
926            penalty.linear_constraints_original,
927        )
928    };
929    solver_decision.log_once();
930
931    let use_sparse_native = matches!(solver_decision.path, PirlsLinearSolvePath::SparseNative);
932
933    // Run the eigendecomposition engine for the dense-transformed path. The
934    // sparse-native path also needs it, but only to obtain a penalty that is
935    // *consistent with the REML penalty log-determinant it reports* — see the
936    // sparse-native `reparam` below. The dense path keeps `qs ≠ I`; the
937    // sparse-native path discards `qs` (identity coords) and reuses only the
938    // shrinkage-folded `s_transformed`/`e_transformed`.
939    let dense_reparam_result = if !use_sparse_native && penalty.kronecker_factored.is_none() {
940        Some(stable_reparameterization_engine_canonical(
941            penalty.canonical_penalties,
942            lambdas_slice,
943            EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
944            penalty.reparam_invariant,
945            penalty.penalty_shrinkage_floor,
946        )?)
947    } else {
948        None
949    };
950    // Sparse-native reparam result, in identity (original) coordinates with the
951    // penalty shrinkage floor folded in. This MUST drive the inner penalized
952    // solve too: when `penalty_shrinkage_floor` is active (default `Some(1e-6)`)
953    // the dense engine adds `shrinkage·P_range` to every penalized range
954    // direction of `S_λ` and rebuilds `s_transformed = EᵀE` from the floored
955    // roots, so `base.log_det` (the REML penalty pseudo-logdet) is the
956    // determinant of `S_λ + shrinkage·P_range`, NOT of the bare `S_λ`. Building
957    // the inner Hessian from an UN-shrunk `S_λ` (the previous behaviour, via the
958    // `cheap_s_lambda` row-sum) while reporting the shrunk `log_det` made the
959    // sparse-native REML surface internally inconsistent — the penalty-logdet
960    // term and the inner H / EDF / β̂ lived on different penalties — which biased
961    // λ-selection relative to the dense and Kronecker backends for the SAME
962    // model (the #1266 cross-backend divergence class). Reusing the engine's
963    // shrinkage-folded penalty here makes all three backends solve the same
964    // penalized objective.
965    let sparse_native_reparam = if use_sparse_native && penalty.kronecker_factored.is_none() {
966        let base = stable_reparameterization_engine_canonical(
967            penalty.canonical_penalties,
968            lambdas_slice,
969            EngineDims::new(penalty.p, penalty.canonical_penalties.len()),
970            penalty.reparam_invariant,
971            penalty.penalty_shrinkage_floor,
972        )?;
973        Some(build_sparse_native_reparam_result(
974            base,
975            penalty.canonical_penalties,
976            lambdas_slice,
977            penalty.p,
978        ))
979    } else {
980        None
981    };
982    let qs_arc = dense_reparam_result
983        .as_ref()
984        .map(|reparam_result| Arc::new(reparam_result.qs.clone()));
985    let transform_active = if let Some((_, transform, _)) = kronecker_runtime.as_ref() {
986        Some(WorkingReparamTransform::Kronecker(Arc::clone(transform)))
987    } else if use_sparse_native {
988        None
989    } else {
990        Some(WorkingReparamTransform::Dense(Arc::clone(
991            qs_arc
992                .as_ref()
993                .expect("dense Qs should exist for non-Kronecker transformed path"),
994        )))
995    };
996    let mut penalty_active = if let Some((_, _, penalty_diag)) = kronecker_runtime.as_ref() {
997        penalty_diag.clone()
998    } else if use_sparse_native {
999        // Sparse-native inner penalty in original (identity) coordinates. Use
1000        // the shrinkage-folded `s_transformed`/`e_transformed` from
1001        // `sparse_native_reparam` so the inner penalized Hessian
1002        // `H = XᵀWX + S` matches the penalty whose log-determinant the REML
1003        // criterion reports for this fit (`base.log_det`). Falling back to the
1004        // bare lambda-weighted sum here (the prior behaviour) omitted the
1005        // `penalty_shrinkage_floor` ridge and desynced the inner solve from the
1006        // REML logdet, biasing λ-selection vs the dense/Kronecker backends.
1007        let sparse_reparam = sparse_native_reparam
1008            .as_ref()
1009            .expect("sparse_native_reparam should be present for sparse-native path");
1010        PirlsPenalty::Dense {
1011            s_transformed: sparse_reparam.s_transformed.clone(),
1012            e_transformed: sparse_reparam.e_transformed.clone(),
1013            linear_shift: Array1::zeros(penalty.p),
1014            constant_shift: 0.0,
1015            prior_mean_target: Array1::zeros(penalty.p),
1016        }
1017    } else {
1018        let dense = dense_reparam_result
1019            .as_ref()
1020            .expect("dense reparam result should be present outside Kronecker path");
1021        PirlsPenalty::Dense {
1022            s_transformed: dense.s_transformed.clone(),
1023            e_transformed: dense.e_transformed.clone(),
1024            linear_shift: Array1::zeros(penalty.p),
1025            constant_shift: 0.0,
1026            prior_mean_target: Array1::zeros(penalty.p),
1027        }
1028    };
1029    let (shift_original, shift_constant) =
1030        canonical_prior_shift(penalty.canonical_penalties, lambdas_slice, penalty.p);
1031    let shift_active = transform_active
1032        .as_ref()
1033        .map(|transform| transform.apply_transpose(&shift_original))
1034        .unwrap_or(shift_original);
1035    let prior_mean_original =
1036        canonical_prior_mean_aggregate(penalty.canonical_penalties, penalty.p);
1037    let prior_mean_active = transform_active
1038        .as_ref()
1039        .map(|transform| transform.apply_transpose(&prior_mean_original))
1040        .unwrap_or(prior_mean_original);
1041    attach_penalty_shift(
1042        &mut penalty_active,
1043        shift_active,
1044        shift_constant,
1045        prior_mean_active,
1046    );
1047    // Build transformed constraints now that dense_reparam_result is available.
1048    let linear_constraints = if let Some(kc) = kronecker_constraints {
1049        kc
1050    } else if let Some(reparam) = dense_reparam_result.as_ref() {
1051        let tb = build_transformed_lower_bound_constraints(
1052            &reparam.qs,
1053            penalty.coefficient_lower_bounds,
1054        );
1055        let tl =
1056            build_transformed_linear_constraints(&reparam.qs, penalty.linear_constraints_original);
1057        merge_linear_constraints(tb, tl)
1058    } else {
1059        // Sparse-native without dense reparam: constraints stay in original
1060        // coordinates (identity Qs).  Use an identity matrix of appropriate size.
1061        let p = penalty.p;
1062        let qs_identity = Array2::<f64>::eye(p);
1063        let tb = build_transformed_lower_bound_constraints(
1064            &qs_identity,
1065            penalty.coefficient_lower_bounds,
1066        );
1067        let tl =
1068            build_transformed_linear_constraints(&qs_identity, penalty.linear_constraints_original);
1069        merge_linear_constraints(tb, tl)
1070    };
1071
1072    let coordinate_frame = if use_sparse_native {
1073        PirlsCoordinateFrame::OriginalSparseNative
1074    } else {
1075        PirlsCoordinateFrame::TransformedQs
1076    };
1077    let materialize_final_reparam_result = || -> Result<ReparamResult, EstimationError> {
1078        if let Some((kron_result, _, _)) = kronecker_runtime.as_ref() {
1079            let rs_list: Vec<Array2<f64>> = penalty
1080                .canonical_penalties
1081                .iter()
1082                .map(|cp| cp.full_width_root())
1083                .collect();
1084            kron_result.materialize_dense_artifact_result(&rs_list, lambdas_slice, penalty.p)
1085        } else if use_sparse_native {
1086            // Sparse-native path: reuse the engine result already computed for
1087            // `penalty_active` (with the shrinkage floor folded in and mapped to
1088            // identity coordinates). This is both correct — the REML
1089            // log-determinant now matches the penalty the inner solve used — and
1090            // cheaper, since the eigendecomposition is no longer run twice.
1091            Ok(sparse_native_reparam
1092                .as_ref()
1093                .expect("sparse_native_reparam should be present for sparse-native path")
1094                .clone())
1095        } else {
1096            Ok(dense_reparam_result
1097                .as_ref()
1098                .expect("dense reparam result should be present outside Kronecker path")
1099                .clone())
1100        }
1101    };
1102
1103    // Stage 3.3-GI: GPU exact PLS dispatch — see pirls_host_dispatch::try_gaussian_pls_gpu.
1104    if let Some(result) = try_gaussian_pls_gpu(
1105        link_function,
1106        config,
1107        penalty.coefficient_lower_bounds,
1108        penalty.linear_constraints_original,
1109        gaussian_fixed_cache,
1110        &penalty_active,
1111        &qs_arc,
1112        &x_original,
1113        use_sparse_native,
1114        penalty.p,
1115        || materialize_final_reparam_result(),
1116        y,
1117        priorweights,
1118        offset,
1119        coordinate_frame,
1120        &linear_constraints,
1121    ) {
1122        return result;
1123    }
1124
1125    if matches!(link_function, LinkFunction::Identity) && linear_constraints.is_none() {
1126        // Gaussian-Identity zero-iteration exact solve. The unconstrained
1127        // penalized least-squares system is linear, so for an identity link a
1128        // single solve is the exact minimizer and no PIRLS iteration is needed.
1129        //
1130        // This shortcut is only valid in the *unconstrained* convex program.
1131        // When shape/box/linear inequality constraints are present (e.g. a
1132        // `shape=monotone_increasing` smooth, whose cumulative-sum box-reparam
1133        // bounds `γ_j ≥ 0` are folded into `linear_constraints` above), the
1134        // minimizer is the solution of an inequality-constrained QP, not the
1135        // plain normal-equations solve. Taking this branch then returns the
1136        // unconstrained β, which generically violates the constraints and is
1137        // rejected by the REML startup KKT gate (`enforce_constraint_kkt`),
1138        // aborting the whole fit. Gating on `linear_constraints.is_none()`
1139        // routes every constrained Identity fit to the iterative loop below,
1140        // which builds a feasible initial point and solves the exact QP via
1141        // the active-set solver — mirroring the gate already enforced on the
1142        // GPU Gaussian-PLS path in `try_gaussian_pls_gpu`.
1143        //
1144        // Apply the Gaussian-Identity fixed-data cache only when every
1145        // precondition for the short-circuit's exact reuse holds: the family
1146        // really is Gaussian (z = y), there is no Firth bias-reduction term,
1147        // no coefficient lower bounds, and no linear inequality constraints
1148        // — anything that would change the right-hand side or the system
1149        // beyond the additive penalty would invalidate the cache.
1150        let cache_eligible = gaussian_fixed_cache.is_some()
1151            && likelihood.spec.is_gaussian_identity()
1152            && !config.firth_bias_reduction
1153            && penalty.coefficient_lower_bounds.is_none()
1154            && penalty.linear_constraints_original.is_none();
1155        let cache_for_solve = if cache_eligible {
1156            gaussian_fixed_cache
1157        } else {
1158            None
1159        };
1160        let (pls_result, _) = solve_penalized_least_squares_implicit(
1161            &x_original,
1162            transform_active.as_ref(),
1163            y,
1164            priorweights,
1165            offset,
1166            &penalty_active,
1167            &mut workspace,
1168            y,
1169            link_function,
1170            cache_for_solve,
1171        )?;
1172
1173        let beta_transformed = pls_result.beta;
1174        let penalized_hessian = pls_result.penalized_hessian;
1175        let edf = pls_result.edf;
1176        let baseridge = pls_result.ridge_used;
1177
1178        let priorweights_owned = priorweights.to_owned();
1179        // eta = offset + X Qs beta (composed, no materialization) unless a
1180        // design-moving ψ tensor cache explicitly says the surface rows are a
1181        // stale reference. In that lane the Gaussian objective and gradient are
1182        // fully determined by (G, r, y'Wy), so applying `x_original` would both
1183        // reintroduce per-trial row work and evaluate the wrong ψ.
1184        let qbeta = transform_active
1185            .as_ref()
1186            .map(|transform| transform.apply(beta_transformed.as_ref()))
1187            .unwrap_or_else(|| beta_transformed.as_ref().clone());
1188        let stale_row_cache = cache_for_solve.filter(|cache| cache.row_prediction_is_stale);
1189        let (final_eta, finalmu, finalz, gradient_data, deviance, log_likelihood, max_abs_eta) =
1190            if let Some(cache) = stale_row_cache {
1191                let final_eta = offset.to_owned();
1192                let finalmu = final_eta.clone();
1193                let finalz = y.to_owned();
1194                let mut grad_orig = cache.xtwx_orig.dot(&qbeta);
1195                grad_orig -= &cache.xtwy_orig;
1196                let gradient_data = transform_active
1197                    .as_ref()
1198                    .map(|transform| transform.apply_transpose(&grad_orig))
1199                    .unwrap_or(grad_orig);
1200                let weighted_rss = (cache.centered_weighted_y_sq
1201                    - 2.0 * qbeta.dot(&cache.xtwy_orig)
1202                    + qbeta.dot(&cache.xtwx_orig.dot(&qbeta)))
1203                .max(0.0);
1204                let phi = likelihood.scale.fixed_phi().unwrap_or(1.0);
1205                let deviance = if phi.is_finite() && phi > 0.0 {
1206                    weighted_rss / phi
1207                } else {
1208                    f64::NAN
1209                };
1210                let log_likelihood = -0.5 * deviance;
1211                let max_abs_eta = inf_norm(finalmu.iter().copied());
1212                (
1213                    final_eta,
1214                    finalmu,
1215                    finalz,
1216                    gradient_data,
1217                    deviance,
1218                    log_likelihood,
1219                    max_abs_eta,
1220                )
1221            } else {
1222                let mut eta = offset.to_owned();
1223                eta += &x_original.apply(&qbeta);
1224                let final_eta = eta.clone();
1225                let finalmu = eta.clone();
1226                let finalz = y.to_owned();
1227
1228                let mut weighted_residual = finalmu.clone();
1229                weighted_residual -= &finalz;
1230                weighted_residual *= &priorweights_owned;
1231                // gradient = Qs^T X^T (w * residual) (composed)
1232                let xt_wr = x_original.apply_transpose(&weighted_residual);
1233                let gradient_data = transform_active
1234                    .as_ref()
1235                    .map(|transform| transform.apply_transpose(&xt_wr))
1236                    .unwrap_or(xt_wr);
1237                let deviance = calculate_deviance(y, &finalmu, likelihood, priorweights);
1238                let log_likelihood = calculate_loglikelihood_omitting_constants(
1239                    y,
1240                    &finalmu,
1241                    likelihood,
1242                    priorweights,
1243                );
1244                let max_abs_eta = inf_norm(finalmu.iter().copied());
1245                (
1246                    final_eta,
1247                    finalmu,
1248                    finalz,
1249                    gradient_data,
1250                    deviance,
1251                    log_likelihood,
1252                    max_abs_eta,
1253                )
1254            };
1255        let score_norm = array1_l2_norm(&gradient_data);
1256        let s_beta = penalty_active.shifted_gradient(beta_transformed.as_ref());
1257        let s_beta_norm = array1_l2_norm(&s_beta);
1258        let mut gradient = gradient_data;
1259        gradient += &s_beta;
1260        let mut penalty_term = penalty_active.shifted_quadratic(beta_transformed.as_ref());
1261        let ridge_used = baseridge;
1262        let stabilizedhessian = if ridge_used > 0.0 {
1263            penalized_hessian
1264                .addridge(ridge_used)
1265                .map_err(|e| EstimationError::InvalidInput(format!("ridge addition failed: {e}")))?
1266        } else {
1267            penalized_hessian.clone()
1268        };
1269        let mut ridge_grad_norm = 0.0;
1270        if ridge_used > 0.0 {
1271            let ridge_penalty =
1272                ridge_used * beta_transformed.as_ref().dot(beta_transformed.as_ref());
1273            penalty_term += ridge_penalty;
1274            gradient += &beta_transformed.as_ref().mapv(|v| ridge_used * v);
1275            ridge_grad_norm = ridge_used * array1_l2_norm(beta_transformed.as_ref());
1276        }
1277
1278        let gradient_norm = array1_l2_norm(&gradient);
1279        let working_state = WorkingState {
1280            eta: LinearPredictor::new(finalmu.clone()),
1281            gradient: gradient.clone(),
1282            hessian: penalized_hessian.clone(),
1283
1284            log_likelihood,
1285            deviance,
1286            penalty_term,
1287            firth: FirthDiagnostics::Inactive,
1288            ridge_used,
1289            hessian_curvature: HessianCurvatureKind::Fisher,
1290            gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1291        };
1292
1293        let zero_iter_penalized = deviance + penalty_term;
1294        let working_summary = WorkingModelPirlsResult {
1295            beta: beta_transformed.clone(),
1296            state: working_state,
1297            status: PirlsStatus::Converged,
1298            iterations: 1,
1299            lastgradient_norm: gradient_norm,
1300            last_deviance_change: 0.0,
1301            last_step_size: 1.0,
1302            last_step_halving: 0,
1303            max_abs_eta,
1304            constraint_kkt: linear_constraints.as_ref().map(|lin| {
1305                compute_constraint_kkt_diagnostics(beta_transformed.as_ref(), &gradient, lin)
1306            }),
1307            min_penalized_deviance: if zero_iter_penalized.is_finite() {
1308                zero_iter_penalized
1309            } else {
1310                f64::INFINITY
1311            },
1312            // Zero-iteration synthesis: no LM damping was exercised, so
1313            // hand the next solve the cold default.
1314            final_lm_lambda: 1e-6,
1315            // Zero-iteration synthesis: no LM gain ratio was measured.
1316            final_accept_rho: None,
1317            // Zero-iteration synthesis assembles the Hessian with prior
1318            // weights only; no observed-information re-evaluation has
1319            // happened. Label honestly as a Fisher-type surrogate so
1320            // outer Laplace consumers see the truth.
1321            exported_laplace_curvature: ExportedLaplaceCurvature::ExpectedInformationSurrogate,
1322        };
1323
1324        let (solve_c_array, solve_d_array, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
1325            computeworkingweight_derivatives_from_eta(
1326                &config.likelihood,
1327                &config.link_kind,
1328                &final_eta,
1329                priorweights_owned.view(),
1330            )?;
1331        let reparam_result = materialize_final_reparam_result()?;
1332        let qs_arc_final = Arc::new(reparam_result.qs.clone());
1333        let pirls_result = PirlsResult {
1334            likelihood: config.likelihood.clone(),
1335            beta_transformed,
1336            penalized_hessian_transformed: penalized_hessian,
1337            stabilizedhessian_transformed: stabilizedhessian,
1338            ridge_passport: RidgePassport::scaled_identity(
1339                ridge_used,
1340                RidgePolicy::explicit_stabilization_full(),
1341            ),
1342            ridge_used,
1343            deviance,
1344            edf,
1345            stable_penalty_term: penalty_term,
1346            firth: FirthDiagnostics::Inactive,
1347            finalweights: priorweights_owned.clone(),
1348            final_offset: offset.to_owned(),
1349            final_eta: final_eta.clone(),
1350            finalmu: finalmu.clone(),
1351            solveweights: priorweights_owned,
1352            solveworking_response: finalz.clone(),
1353            solvemu: finalmu.clone(),
1354            solve_dmu_deta,
1355            solve_d2mu_deta2,
1356            solve_d3mu_deta3,
1357            solve_c_array,
1358            solve_d_array,
1359            derivatives_unsupported: false,
1360            status: PirlsStatus::Converged,
1361            iteration: 1,
1362            max_abs_eta,
1363            lastgradient_norm: gradient_norm,
1364            gradient_natural_scale: score_norm + s_beta_norm + ridge_grad_norm,
1365            last_deviance_change: 0.0,
1366            last_step_halving: 0,
1367            hessian_curvature: HessianCurvatureKind::Fisher,
1368            exported_laplace_curvature: working_summary.exported_laplace_curvature.clone(),
1369            final_lm_lambda: working_summary.final_lm_lambda,
1370            final_accept_rho: working_summary.final_accept_rho,
1371            constraint_kkt: working_summary.constraint_kkt.clone(),
1372            linear_constraints_transformed: linear_constraints.clone(),
1373            reparam_result,
1374            x_transformed: make_reparam_operator(&x_original, &qs_arc_final, use_sparse_native),
1375            coordinate_frame,
1376            used_device: false,
1377            cache_compacted: false,
1378            min_penalized_deviance: working_summary.min_penalized_deviance,
1379        };
1380
1381        return Ok((pirls_result, working_summary));
1382    }
1383
1384    let x_original_for_result = x_original.clone();
1385    let mut working_model = GamWorkingModel::new(
1386        None, // No pre-materialized x_transformed: use implicit Qs composition
1387        x_original.clone(),
1388        coordinate_frame,
1389        offset,
1390        y,
1391        priorweights,
1392        penalty_active.clone(),
1393        workspace,
1394        config.likelihood.clone(),
1395        config.link_kind.clone(),
1396        // Inner Firth/Jeffreys activation must agree with the caller-requested
1397        // mode. The REML *outer* analytic derivative assembly only carries the
1398        // Jeffreys score/curvature term when `firth_bias_reduction` is set
1399        // (`reml_robust_jeffreys_link` returns `None` otherwise), so arming the
1400        // inner penalty unconditionally would converge the inner mode to the
1401        // Firth-penalized stationary point while the outer H/u/IFT stayed
1402        // non-Firth — the two would then disagree by exactly the Jeffreys
1403        // contribution (broken τ-τ Hessian-vs-FD and stationarity-cancellation
1404        // identities, #825). Gate on `firth_bias_reduction` so inner and outer
1405        // are the same objective.
1406        config.firth_bias_reduction
1407            && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1408            && inverse_link_has_fisher_weight_jet(&config.link_kind),
1409        transform_active.clone(),
1410        quadctx,
1411        // #1111 / #1033 mechanism (c): frozen-W first-Fisher-step XᵀWX in the
1412        // original (conditioned x_fit) frame, served n-free on the first inner
1413        // iteration. Suppressed under Firth bias reduction, which shifts the
1414        // working response per iteration (the installer also gates Firth off).
1415        if config.firth_bias_reduction {
1416            None
1417        } else {
1418            glm_first_step_gram.cloned()
1419        },
1420    );
1421
1422    // Apply integrated (GHQ) likelihood if per-observation SE is provided.
1423    // This is used by the calibrator to coherently account for base prediction uncertainty.
1424    if let Some(se) = covariate_se {
1425        working_model = working_model.with_covariate_se(se.to_owned());
1426    }
1427
1428    let mut beta_guess_original = warm_start_beta
1429        .filter(|beta| beta.len() == penalty.p)
1430        .map(|beta| beta.to_owned())
1431        .unwrap_or_else(|| {
1432            Coefficients::new(default_beta_guess_external(
1433                penalty.p,
1434                link_function,
1435                y,
1436                priorweights,
1437                config.link_kind.mixture_state(),
1438                config.link_kind.sas_state(),
1439            ))
1440        });
1441    if let Some(lb) = penalty.coefficient_lower_bounds {
1442        project_coefficients_to_lower_bounds(&mut beta_guess_original.0, lb);
1443    }
1444    let initial_beta = transform_active
1445        .as_ref()
1446        .map(|transform| transform.apply_transpose(beta_guess_original.as_ref()))
1447        .unwrap_or_else(|| beta_guess_original.as_ref().clone());
1448    let initial_beta = if let Some(constraints) = linear_constraints.as_ref() {
1449        // Worst per-row *scaled* (geometric) slack of the current seed against the
1450        // constraint cone. Negative ⇒ the seed violates a row; ~0 ⇒ the seed sits
1451        // ON the boundary (for a homogeneous convex/concave second-difference
1452        // cone, `β = 0` — the unconstrained Gaussian seed — sits on EVERY row's
1453        // boundary, i.e. the cone vertex). Either way the seed must be pushed
1454        // strictly into the interior before P-IRLS starts.
1455        let mut min_scaled_slack = f64::INFINITY;
1456        for i in 0..constraints.a.nrows() {
1457            let norm = constraints.a.row(i).dot(&constraints.a.row(i)).sqrt();
1458            let inv = if norm > 0.0 { 1.0 / norm } else { 0.0 };
1459            let slack = (constraints.a.row(i).dot(&initial_beta) - constraints.b[i]) * inv;
1460            min_scaled_slack = min_scaled_slack.min(slack);
1461        }
1462        // Push the seed to the nearest STRICTLY-INTERIOR feasible point whenever
1463        // any row is tight or violated. A seed on the cone boundary (most acutely
1464        // the vertex `β = 0`) hands the inner active-set QP an all-rows-active
1465        // working set, where it stalls on a degenerate, non-stationary face — so
1466        // the fit silently diverges (or aborts in release) between a cold and a
1467        // warm warm-start cache (#873). A strictly-interior seed makes the QP's
1468        // initial active set empty; it then adds only the genuinely binding rows
1469        // and converges to the certified constrained optimum regardless of cache
1470        // state. The projection keeps the data-driven curvature of `initial_beta`
1471        // and falls back to the min-norm feasible point only if it cannot certify
1472        // a strictly-interior solution.
1473        //
1474        // The min-norm fallback (`feasible_point_for_linear_constraints`) is only
1475        // used for a NON-homogeneous cone (`b ≠ 0`), where it returns a genuine
1476        // interior-of-the-offset-polyhedron point. For a HOMOGENEOUS shape cone
1477        // (`b ≈ 0` — the convex/concave second-difference rows) that function
1478        // returns the minimum-norm feasible point `β = 0`, which is the cone
1479        // *vertex*: the exact all-rows-tight degenerate seed #873 is about. Taking
1480        // it would silently reintroduce the #873 pathology whenever the strict
1481        // projection rarely fails to certify. So for a homogeneous cone we skip the
1482        // vertex fallback entirely and prefer the data-driven `initial_beta`: it
1483        // violates at most *some* rows (a lower-dimensional, non-degenerate face the
1484        // inner active-set QP can recover from), strictly better than the vertex
1485        // where *every* row is simultaneously tight.
1486        let cone_is_homogeneous = constraints.b.iter().all(|v| v.abs() <= 1e-14);
1487        if min_scaled_slack < active_set::interior_seed_margin() {
1488            let projected =
1489                active_set::project_point_strictly_into_feasible_cone(&initial_beta, constraints)
1490                    .or_else(|| {
1491                        if cone_is_homogeneous {
1492                            None
1493                        } else {
1494                            active_set::feasible_point_for_linear_constraints(
1495                                constraints,
1496                                initial_beta.len(),
1497                            )
1498                        }
1499                    });
1500            projected.unwrap_or(initial_beta)
1501        } else {
1502            initial_beta
1503        }
1504    } else {
1505        initial_beta
1506    };
1507    // Inner P-IRLS Firth activation. The inner penalized objective must match
1508    // the objective the REML outer derivatives are assembled against: the outer
1509    // path carries the Jeffreys/Firth score+curvature only when the caller set
1510    // `firth_bias_reduction` (`reml_robust_jeffreys_link` is `None` otherwise),
1511    // so the inner Firth term is armed iff the caller requested it AND the link
1512    // exposes a Fisher-weight jet (#825). Forcing it on unconditionally desynced
1513    // the Firth-penalized inner mode from the non-Firth outer assembly.
1514    let firth_active = config.firth_bias_reduction
1515        && matches!(config.likelihood.spec.response, ResponseFamily::Binomial)
1516        && inverse_link_has_fisher_weight_jet(&config.link_kind);
1517    let base_max_step_halving = if firth_active { 60 } else { 30 };
1518    let options = WorkingModelPirlsOptions {
1519        // The Firth-penalized P-IRLS converges at the same iteration count as
1520        // the unpenalized fit — the Jeffreys term is a smooth, bounded addition
1521        // to a Newton system that is already well conditioned (the additional
1522        // per-iteration LM step-halving budget above absorbs the early-iteration
1523        // curvature change). Bumping the outer-iteration cap to mask a
1524        // mis-conditioned step would only hide non-convergence, so the cap stays
1525        // the caller's `max_iterations` and trips as a hard error if exceeded.
1526        max_iterations: config.max_iterations,
1527        convergence_tolerance: config.convergence_tolerance,
1528        adaptive_kkt_tolerance,
1529        // LM step-halving is a per-iteration damping retry budget; it is
1530        // independent of the total outer-iteration cap. Tying the two
1531        // together collapsed step halving to 3 under seed screening (where
1532        // max_iterations is intentionally capped low), turning recoverable
1533        // damping into spurious failures.
1534        max_step_halving: base_max_step_halving,
1535        min_step_size: if firth_active { 1e-12 } else { 1e-10 },
1536        firth_bias_reduction: firth_active,
1537        coefficient_lower_bounds: None,
1538        linear_constraints: linear_constraints.clone(),
1539        initial_lm_lambda: config.initial_lm_lambda,
1540        geodesic_acceleration: config.geodesic_acceleration,
1541        arrow_schur: config.arrow_schur.clone(),
1542    };
1543
1544    let mut iteration_logger = |info: &WorkingModelIterationInfo| {
1545        log::debug!(
1546            "[PIRLS] iter {:>3} | deviance {:.6e} | |grad| {:.3e} | step {:.3e} (halving {})",
1547            info.iteration,
1548            info.deviance,
1549            info.gradient_norm,
1550            info.step_size,
1551            info.step_halving
1552        );
1553    };
1554
1555    // Stage 3.3 GPU PIRLS-loop dispatch — see pirls_host_dispatch::try_pirls_loop_gpu.
1556    if let Some(result) = try_pirls_loop_gpu(
1557        config,
1558        &penalty_active,
1559        kronecker_runtime.is_none(),
1560        use_sparse_native,
1561        &linear_constraints,
1562        &x_original,
1563        &qs_arc,
1564        penalty.p,
1565        &x_original_for_result,
1566        || materialize_final_reparam_result(),
1567        y,
1568        priorweights,
1569        offset,
1570        &initial_beta,
1571        link_function,
1572        coordinate_frame,
1573    ) {
1574        return result;
1575    }
1576
1577    let mut working_summary = runworking_model_pirls(
1578        &mut working_model,
1579        Coefficients::new(initial_beta),
1580        &options,
1581        &mut iteration_logger,
1582    )?;
1583
1584    // ── Gamma dispersion: re-estimate the shape at the *converged* η (#678) ──
1585    //
1586    // The inner LM solve estimates the Gamma shape ν = 1/φ **once** from the
1587    // warm-start η and freezes it for the rest of the solve (see the
1588    // `gamma_shape_locked` doc on `GamWorkingModel`): holding ν fixed keeps the
1589    // product φ·λ — and hence the penalized argmin β̂ — a stationary LM target,
1590    // so the gain ratio compares one objective. That lock is correct *within* a
1591    // solve, but it pins ν to whatever η the solve started from. When the fit
1592    // cold-starts (the final dedicated fit at the converged ρ passes
1593    // `warm_start_beta = None`, and seed screening starts from a default guess),
1594    // that warm-start η has not yet captured the mean structure; the leftover
1595    // spread of μ inflates the Gamma deviance term `mean[y/μ − ln(y/μ) − 1]` and
1596    // biases ν **down** (φ up) by >2× whenever μ varies appreciably. The mean
1597    // surface still converges (β̂ is essentially scale-free here), but the frozen
1598    // ν that survives into `UnifiedFitResult::dispersion_phi()` — and from there
1599    // into every coefficient SE `Vb = H⁻¹·φ̂`, prediction interval, and
1600    // observation-noise interval — is the early, mean-spread-contaminated value.
1601    //
1602    // Fix: after the solve converges, re-estimate ν at the converged η. If it
1603    // moved, re-solve β (warm-started, ν held fixed at the refreshed value) and
1604    // repeat, driving the pair (β, ν) to their joint fixed point at the current
1605    // λ. At convergence the reported dispersion is the Gamma ML estimate at the
1606    // converged mean (mgcv's post-hoc Pearson/deviance scale), and the final
1607    // working state — `finalweights`, the penalized Hessian, the deviance, μ —
1608    // is rebuilt with that same ν, so `Vb = H⁻¹·φ̂` stays internally consistent.
1609    // Warm-started solves (every REML cost eval) already sit near the converged
1610    // η, so the first refresh check confirms ν and exits without a re-solve; the
1611    // added cost there is a single O(n) shape evaluation.
1612    if refine_dispersion_at_converged_eta
1613        && working_model.likelihood.scale.gamma_shape_is_estimated()
1614    {
1615        // A few passes suffice: the converged-η shape map is a strong
1616        // contraction (β̂ barely moves once the mean is captured), so cold
1617        // starts settle in 1–2 re-solves and warm starts in zero.
1618        const MAX_SHAPE_REFRESH: usize = 5;
1619        // Relative shape tolerance below which a re-solve cannot move any
1620        // reported quantity meaningfully (far under statistical resolution).
1621        const SHAPE_REFRESH_REL_TOL: f64 = 1e-4;
1622        for refresh_iter in 0..MAX_SHAPE_REFRESH {
1623            let refreshed_shape = super::estimate_gamma_shape_from_eta(
1624                y,
1625                working_summary.state.eta.as_ref(),
1626                priorweights,
1627            );
1628            let prior_shape = working_model.likelihood.gamma_shape().unwrap_or(1.0);
1629            let rel_change =
1630                (refreshed_shape - prior_shape).abs() / prior_shape.max(f64::MIN_POSITIVE);
1631            // Install the refreshed shape and hold it fixed for any re-solve so
1632            // the LM objective stays stationary (the lock is *re-armed*, not
1633            // released — the seed-from-warm-start branch in `update_with_curvature`
1634            // must not overwrite this deliberately chosen value). Because this
1635            // assignment evaluated the shape at the *current* converged η and no
1636            // re-solve follows it on the exit paths below, the reported shape
1637            // always equals `estimate_gamma_shape_from_eta(final_eta)` — the
1638            // self-consistency invariant the in-module Gamma unit test checks.
1639            working_model.likelihood = working_model
1640                .likelihood
1641                .clone()
1642                .with_gamma_shape(refreshed_shape);
1643            working_model.gamma_shape_locked = true;
1644            if rel_change <= SHAPE_REFRESH_REL_TOL {
1645                // Converged: the working-state buffers (weights, Hessian,
1646                // deviance) already reflect a shape within tolerance of
1647                // `refreshed_shape`, because the only way to reach here without
1648                // a re-solve is that the prior solve's shape already matched the
1649                // converged-η estimate. Nothing left to rebuild.
1650                break;
1651            }
1652            if refresh_iter + 1 == MAX_SHAPE_REFRESH {
1653                // Final allowed pass and the shape is still drifting (a
1654                // pathological non-contraction). Do NOT re-solve: re-solving
1655                // would advance `final_eta` past the η the just-installed shape
1656                // was evaluated at, breaking the stored-shape == estimate(final_eta)
1657                // invariant. Stopping here keeps the reported shape exactly the
1658                // ML estimate at the reported η; the residual weight/φ drift is
1659                // bounded by the last `rel_change` and never worse than the
1660                // pre-fix frozen-warm-start value.
1661                break;
1662            }
1663            // The shape moved: re-solve β at the corrected shape, warm-started
1664            // at the converged β, so the final working state is rebuilt with the
1665            // refreshed ν.
1666            working_summary = runworking_model_pirls(
1667                &mut working_model,
1668                working_summary.beta.clone(),
1669                &options,
1670                &mut iteration_logger,
1671            )?;
1672        }
1673    }
1674
1675    // ── Tweedie dispersion φ: re-estimate at the *converged* η (#771) ─────────
1676    //
1677    // Identical in spirit to the Gamma-shape refresh above: the inner LM solve
1678    // estimates φ **once** from the warm-start η and freezes it (the
1679    // `tweedie_phi_locked` lock), keeping the product φ·λ — and hence β̂ — a
1680    // stationary LM target. φ enters only the working weight `prior·μ^{2−p}/φ`
1681    // and not the working response, so (like the Gamma shape, and unlike the
1682    // Beta precision which couples through the digamma mean score) the mean
1683    // surface is essentially scale-free and β̂ barely moves when φ is corrected.
1684    // But the frozen warm-start φ is the value that survives into
1685    // `FitInference::dispersion` and the covariance `Vb = H⁻¹` (whose √φ scaling
1686    // lives in the weight); at a cold-started η ≈ 0 the Pearson residuals carry
1687    // the *marginal* spread of y, biasing the estimate. Re-estimating at the
1688    // converged η — re-solving β only if φ moved materially — drives (β, φ) to
1689    // their joint fixed point, so the reported φ is the converged-mean Pearson
1690    // estimate and the final weights/Hessian/SE are internally consistent with
1691    // it. Held OFF inside the REML λ search (the flag), φ is refreshed only at
1692    // the reported fit, so it cannot couple to the smoothing parameter.
1693    if refine_dispersion_at_converged_eta
1694        && working_model.likelihood.scale.tweedie_phi_is_estimated()
1695    {
1696        if let ResponseFamily::Tweedie { p } = working_model.likelihood.spec.response {
1697            // The converged-η Pearson map is a strong contraction (β̂ scale-free
1698            // here), so cold starts settle in 1–2 re-solves and warm starts in
1699            // zero.
1700            const MAX_PHI_REFRESH: usize = 5;
1701            // Relative φ tolerance below which a re-solve cannot move any reported
1702            // quantity meaningfully (far under statistical resolution).
1703            const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1704            for refresh_iter in 0..MAX_PHI_REFRESH {
1705                let refreshed_phi = super::estimate_tweedie_phi_from_eta(
1706                    y,
1707                    working_summary.state.eta.as_ref(),
1708                    priorweights,
1709                    p,
1710                );
1711                let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1712                let rel_change =
1713                    (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1714                // Install the refreshed φ (the scale metadata the working weight
1715                // reads via `fixed_phi()`) and re-arm the lock so a following
1716                // re-solve does not overwrite this converged-η value. Because the
1717                // exit paths below evaluate φ at the *current* η with no following
1718                // re-solve, the reported φ always equals
1719                // `estimate_tweedie_phi_from_eta(final_eta)`.
1720                working_model.likelihood = working_model
1721                    .likelihood
1722                    .clone()
1723                    .with_tweedie_phi(refreshed_phi);
1724                working_model.tweedie_phi_locked = true;
1725                if rel_change <= PHI_REFRESH_REL_TOL {
1726                    // Converged: the working state already reflects a φ within
1727                    // tolerance of `refreshed_phi`. Nothing left to rebuild.
1728                    break;
1729                }
1730                if refresh_iter + 1 == MAX_PHI_REFRESH {
1731                    // Final allowed pass and φ is still drifting. Do NOT re-solve:
1732                    // re-solving would advance η past the point φ was evaluated at,
1733                    // breaking the stored-φ == estimate(final_eta) invariant.
1734                    break;
1735                }
1736                // φ moved materially: re-solve β at the corrected φ, warm-started
1737                // at the converged β, so the final working state is rebuilt with
1738                // the refreshed φ.
1739                working_summary = runworking_model_pirls(
1740                    &mut working_model,
1741                    working_summary.beta.clone(),
1742                    &options,
1743                    &mut iteration_logger,
1744                )?;
1745            }
1746        }
1747    }
1748
1749    // ── Beta precision φ: re-estimate at the *converged* η and drive (β, φ) to
1750    //    their joint fixed point (#769) ──────────────────────────────────────
1751    //
1752    // Like the Gamma shape above, the inner LM solve estimates φ **once** from
1753    // the warm-start η and freezes it for the rest of the solve (the
1754    // `beta_phi_locked` doc on `GamWorkingModel`): holding φ fixed keeps the
1755    // penalized argmin β̂ a stationary LM target so the gain ratio compares one
1756    // objective. But that lock pins φ to whatever η the solve started from, and
1757    // for the final dedicated fit at the converged ρ the warm-start is the cold
1758    // default guess (η ≈ 0, μ ≈ 0.5 everywhere). At the null predictor the
1759    // Pearson residuals `(y−μ)²/(μ(1−μ))` capture the full *marginal* spread of
1760    // y rather than its *conditional* spread, so the moment estimator
1761    // `1+φ = Σw / Σ w·s` returns a precision far too small (≈3 when the truth is
1762    // ≈20 here).
1763    //
1764    // Crucially — and unlike the Gamma shape — φ does **not** factor out of the
1765    // Beta mean score. With the logit link the score for β is
1766    //     ∂ℓ/∂β = φ · Σᵢ xᵢ (y*ᵢ − μ*ᵢ),   y*ᵢ = logit(yᵢ),
1767    //     μ*ᵢ = ψ(μᵢφ) − ψ((1−μᵢ)φ),
1768    // so the root β̂ depends on φ through the digamma terms. A φ that is too
1769    // small shrinks every fitted coefficient toward zero. So this refresh is not
1770    // cosmetic (as it is for Gamma): the re-solve is what *recovers the mean*.
1771    //
1772    // Fix: after the cold solve converges, re-estimate φ at the converged η,
1773    // re-solve β at the corrected φ (warm-started), and repeat. This is the
1774    // betareg alternating mean-fit ↔ φ-estimate scheme; the moment estimator is
1775    // a strong contraction once the mean has any structure, so the pair settles
1776    // in a handful of passes. Held OFF inside the REML λ search (see the flag
1777    // doc), φ is refreshed only here at the reported fit, so it cannot couple to
1778    // the smoothing parameter and reward over-smoothing. As with Gamma, every
1779    // exit path installs φ evaluated at the *current* η with no following
1780    // re-solve, so the reported φ (which flows into `EstimatedBetaPhi`, the
1781    // embedded `Beta { phi }`, `dispersion`, and every SE) always equals
1782    // `estimate_beta_phi_from_eta(final_eta)`.
1783    if refine_dispersion_at_converged_eta && working_model.likelihood.scale.beta_phi_is_estimated()
1784    {
1785        // The mean moves between passes (φ feeds back through the digamma
1786        // score), so allow a few more passes than the scale-free Gamma case;
1787        // the contraction is fast and warm-started re-solves are cheap.
1788        const MAX_PHI_REFRESH: usize = 30;
1789        // Relative φ tolerance below which a re-solve cannot move β̂ — and hence
1790        // any reported quantity — by a statistically meaningful amount.
1791        const PHI_REFRESH_REL_TOL: f64 = 1e-4;
1792        for refresh_iter in 0..MAX_PHI_REFRESH {
1793            let refreshed_phi = super::estimate_beta_phi_from_eta(
1794                y,
1795                working_summary.state.eta.as_ref(),
1796                priorweights,
1797            );
1798            let prior_phi = working_model.likelihood.fixed_phi().unwrap_or(1.0);
1799            let rel_change = (refreshed_phi - prior_phi).abs() / prior_phi.max(f64::MIN_POSITIVE);
1800            // Install the refreshed φ (updates BOTH the `Beta { phi }` family
1801            // variant every weight/deviance expression reads and the
1802            // `EstimatedBetaPhi` scale metadata) and re-arm the lock so a
1803            // following re-solve's `update_with_curvature` does not overwrite
1804            // this deliberately chosen value with a fresh cold estimate.
1805            working_model.likelihood = working_model
1806                .likelihood
1807                .clone()
1808                .with_beta_phi(refreshed_phi);
1809            working_model.beta_phi_locked = true;
1810            if rel_change <= PHI_REFRESH_REL_TOL {
1811                // Converged: the just-installed φ matches (to tolerance) the φ
1812                // the current working state was solved at, so β̂, the weights,
1813                // the Hessian and the deviance are already self-consistent with
1814                // the reported φ. Nothing left to rebuild.
1815                break;
1816            }
1817            if refresh_iter + 1 == MAX_PHI_REFRESH {
1818                // Final allowed pass and φ is still drifting. Do NOT re-solve:
1819                // re-solving would advance η past the point the just-installed φ
1820                // was evaluated at, breaking the stored-φ == estimate(final_eta)
1821                // invariant. Stop here so the reported φ is exactly the moment
1822                // estimate at the reported η.
1823                break;
1824            }
1825            // φ moved materially: re-solve β at the corrected φ, warm-started at
1826            // the converged β, so the mean is refit under the better precision
1827            // and the final working state is rebuilt consistently.
1828            working_summary = runworking_model_pirls(
1829                &mut working_model,
1830                working_summary.beta.clone(),
1831                &options,
1832                &mut iteration_logger,
1833            )?;
1834        }
1835    }
1836
1837    // ── Negative-Binomial overdispersion θ: re-estimate at the *converged* η and
1838    //    drive (β, θ) to their joint fixed point (#802) ───────────────────────
1839    //
1840    // Identical in spirit to the Beta-precision refresh above. The inner LM solve
1841    // estimates θ **once** from the warm-start η and freezes it (the
1842    // `negbin_theta_locked` lock), keeping the penalized argmin β̂ a stationary LM
1843    // target. But that lock pins θ to whatever η the solve started from, and for
1844    // the final dedicated fit at the converged ρ the warm-start is the cold
1845    // default guess (η ≈ 0). At the null predictor the Pearson residuals carry
1846    // the *marginal* spread of y rather than its *conditional* spread, biasing
1847    // the moment seed — and the frozen θ is what survives into the working weight
1848    // `W = μθ/(θ+μ)`, the covariance `Vb = H⁻¹` (whose overdispersion scaling
1849    // lives in that weight, not a post-hoc multiply), and every reported SE /
1850    // interval / `generate` draw.
1851    //
1852    // Like the Beta precision — and unlike the scale-free Gamma shape / Tweedie φ
1853    // — θ enters the NB2 working *response*, not only the weight, so re-solving β
1854    // under the corrected θ is not cosmetic: it recovers the mean under the right
1855    // variance function. Re-estimating at the converged η, re-solving β
1856    // (warm-started), and repeating drives (β, θ) to their joint maximum-
1857    // likelihood fixed point. Held OFF inside the REML λ search (the flag), θ is
1858    // refreshed only here at the reported fit, so it cannot couple to the
1859    // smoothing parameter. Every exit path installs θ evaluated at the *current*
1860    // η with no following re-solve, so the reported θ (which flows into the
1861    // embedded `NegativeBinomial { theta }`, the `EstimatedNegBinTheta` scale
1862    // metadata, the predictive-interval variance, and every SE) always equals
1863    // `estimate_negbin_theta_from_eta(final_eta)`.
1864    if refine_dispersion_at_converged_eta
1865        && working_model.likelihood.scale.negbin_theta_is_estimated()
1866    {
1867        // θ feeds back through the working response, so allow a few more passes
1868        // than the scale-free Gamma case; the alternation is a strong contraction
1869        // and warm-started re-solves are cheap.
1870        const MAX_THETA_REFRESH: usize = 30;
1871        // Relative θ tolerance below which a re-solve cannot move β̂ — and hence
1872        // any reported quantity — by a statistically meaningful amount.
1873        const THETA_REFRESH_REL_TOL: f64 = 1e-4;
1874        for refresh_iter in 0..MAX_THETA_REFRESH {
1875            let refreshed_theta = super::estimate_negbin_theta_from_eta(
1876                y,
1877                working_summary.state.eta.as_ref(),
1878                priorweights,
1879            );
1880            let prior_theta = working_model.likelihood.negbin_theta().unwrap_or(1.0);
1881            let rel_change =
1882                (refreshed_theta - prior_theta).abs() / prior_theta.max(f64::MIN_POSITIVE);
1883            // Install the refreshed θ (updates BOTH the `NegativeBinomial { theta }`
1884            // family variant every weight/deviance expression reads and the
1885            // `EstimatedNegBinTheta` scale metadata) and re-arm the lock so a
1886            // following re-solve's `update_with_curvature` does not overwrite this
1887            // deliberately chosen value with a fresh cold estimate.
1888            working_model.likelihood = working_model
1889                .likelihood
1890                .clone()
1891                .with_negbin_theta(refreshed_theta);
1892            working_model.negbin_theta_locked = true;
1893            if rel_change <= THETA_REFRESH_REL_TOL {
1894                // Converged: the just-installed θ matches (to tolerance) the θ the
1895                // current working state was solved at, so β̂, the weights, the
1896                // Hessian and the deviance are already self-consistent with the
1897                // reported θ. Nothing left to rebuild.
1898                break;
1899            }
1900            if refresh_iter + 1 == MAX_THETA_REFRESH {
1901                // Final allowed pass and θ is still drifting. Do NOT re-solve:
1902                // re-solving would advance η past the point the just-installed θ
1903                // was evaluated at, breaking the stored-θ == estimate(final_eta)
1904                // invariant. Stop here so the reported θ is exactly the ML
1905                // estimate at the reported η.
1906                break;
1907            }
1908            // θ moved materially: re-solve β at the corrected θ, warm-started at
1909            // the converged β, so the mean is refit under the better variance
1910            // function and the final working state is rebuilt consistently.
1911            working_summary = runworking_model_pirls(
1912                &mut working_model,
1913                working_summary.beta.clone(),
1914                &options,
1915                &mut iteration_logger,
1916            )?;
1917        }
1918    }
1919
1920    // Extract workspace before consuming working_model so we can reuse
1921    // the pre-allocated buffers in calculate_edfwithworkspace_with_penalty.
1922    // into_final_state() drops the workspace field anyway (it uses `..` in
1923    // its destructure); we replace it with a zero-sized stub to satisfy the
1924    // borrow checker, then keep the real workspace alive for the EDF call.
1925    let mut saved_workspace = std::mem::replace(
1926        &mut working_model.workspace,
1927        PirlsWorkspace::new(0, 0, 0, 0),
1928    );
1929    let final_state = working_model.into_final_state();
1930    let GamModelFinalState {
1931        likelihood: final_likelihood,
1932        coordinate_frame,
1933        finalmu,
1934        finalweights,
1935        scoreweights,
1936        finalz,
1937        final_c,
1938        final_d,
1939        final_dmu_deta,
1940        final_d2mu_deta2,
1941        final_d3mu_deta3,
1942        penalty_term,
1943        ..
1944    } = final_state;
1945
1946    // Preserve the Hessian as-is (sparse or dense) — no densification.
1947    // P-IRLS already folded any stabilization ridge directly into the Hessian.
1948    // Keep that exact matrix so outer LAML derivatives stay consistent:
1949    // H_eff = X'W_H X + S_λ + ridge I (if ridge_used > 0).
1950    let penalized_hessian_transformed = working_summary.state.hessian.clone();
1951    let stabilizedhessian_transformed = penalized_hessian_transformed.clone();
1952    // Use the workspace-backed variant for the dense path to reuse the
1953    // `final_aug_matrix` allocation; the sparse path still allocates
1954    // internally because no pre-computed factor is available at this site.
1955    let mut edf = if let Some(dense_h) = penalized_hessian_transformed.as_dense() {
1956        calculate_edfwithworkspace_with_penalty(dense_h, &penalty_active, &mut saved_workspace)?
1957    } else {
1958        calculate_edf_with_penalty(&penalized_hessian_transformed, &penalty_active)?
1959    };
1960    if !edf.is_finite() || edf.is_nan() {
1961        let p = penalized_hessian_transformed.ncols() as f64;
1962        let r = penalty_active.rank() as f64;
1963        edf = (p - r).max(0.0);
1964    }
1965
1966    // Outer rescue: a fit that hit max-iterations may still be a usable
1967    // minimum if progress has effectively stopped (deviance plateaued or
1968    // step size collapsed to the floor) AND the projected gradient is in
1969    // the near-stationary band under the scale-invariant certificate.
1970    // Same logic for non-Firth and Firth paths; firth_active just gates
1971    // the second pass.
1972    let stalled_at_valid_minimum = |summary: &WorkingModelPirlsResult| -> bool {
1973        // Scale-equivariant deviance plateau band (issue #1127). The
1974        // `last_deviance_change` compared below and the deviance both scale as
1975        // `O(a²)` under a response rescaling `y → a·y` (the penalized normal
1976        // equations are linear in `y`, so `β → a·β` and the RSS-deviance
1977        // scales by `a²`). Keying the plateau band to the deviance's own
1978        // magnitude `+ |penalty|` makes the ratio `Δdev / dev_scale`
1979        // scale-invariant. The previous `.max(1.0)` absolute floor broke this:
1980        // for a micro-unit response (`a = 1e-6`) the deviance is `O(1e-12)`, so
1981        // the floor pinned the band at `1.0` — ~1e9× too loose — and this
1982        // max-iteration rescue declared `progress_stopped` at an over-smoothed
1983        // iterate, propagating an inflated `λ̂` to the outer REML loop. For a
1984        // well-scaled (`a ≳ 1`) or up-scaled (`a = 1e6`) objective the floor was
1985        // already a no-op, so those directions are byte-identical. A perfect
1986        // interpolating fit gives a `0` band, so the relative `Δdev` test cannot
1987        // fire spuriously and the scale-invariant `near_stationary_kkt`
1988        // certificate then governs acceptance.
1989        let dev_scale = summary.state.deviance.abs() + summary.state.penalty_term.abs();
1990        // Progress plateau uses the fixed solver tolerance; only the KKT band below adapts.
1991        let dev_tol = options.convergence_tolerance * dev_scale;
1992        let step_floor = options.min_step_size * 2.0;
1993        let progress_stopped =
1994            summary.last_deviance_change.abs() <= dev_tol || summary.last_step_size <= step_floor;
1995        let near_stationary = summary
1996            .state
1997            .near_stationary_kkt(summary.lastgradient_norm, effective_kkt_tolerance(&options));
1998        progress_stopped && near_stationary
1999    };
2000
2001    let mut status = working_summary.status;
2002    if status.is_failed_max_iterations() && stalled_at_valid_minimum(&working_summary) {
2003        status = PirlsStatus::StalledAtValidMinimum;
2004        working_summary.status = status;
2005    }
2006    if status.is_failed_max_iterations()
2007        && firth_active
2008        && stalled_at_valid_minimum(&working_summary)
2009    {
2010        // Firth-adjusted fits can stall; accept under the same dual-criterion
2011        // near-stationary band.
2012        status = PirlsStatus::StalledAtValidMinimum;
2013        working_summary.status = status;
2014    }
2015    let has_penalty = penalty_active.rank() > 0;
2016    let firth_active = options.firth_bias_reduction;
2017    if detect_logit_instability(
2018        link_function,
2019        &final_likelihood.spec.response,
2020        has_penalty,
2021        firth_active,
2022        &working_summary,
2023        &finalmu,
2024        &finalweights,
2025        y,
2026    ) {
2027        status = PirlsStatus::Unstable;
2028        working_summary.status = status;
2029    }
2030
2031    // Store a lazy ReparamOperator instead of materializing X·Qs.
2032    // Consumers that truly need dense access can call .to_dense() on demand.
2033    let reparam_result_final = materialize_final_reparam_result()?;
2034    let qs_arc_final = Arc::new(reparam_result_final.qs.clone());
2035    let x_transformed_final =
2036        make_reparam_operator(&x_original_for_result, &qs_arc_final, use_sparse_native);
2037
2038    let pirls_result = assemble_pirls_result(
2039        &working_summary,
2040        final_likelihood,
2041        offset,
2042        penalized_hessian_transformed,
2043        stabilizedhessian_transformed,
2044        edf,
2045        penalty_term,
2046        &finalmu,
2047        &finalweights,
2048        &scoreweights,
2049        &finalz,
2050        &final_c,
2051        &final_d,
2052        &final_dmu_deta,
2053        &final_d2mu_deta2,
2054        &final_d3mu_deta3,
2055        status,
2056        reparam_result_final,
2057        x_transformed_final,
2058        coordinate_frame,
2059        linear_constraints,
2060    );
2061
2062    Ok((pirls_result, working_summary))
2063}
2064
2065#[derive(Clone)]
2066pub struct PirlsConfig {
2067    pub likelihood: GlmLikelihoodSpec,
2068    pub link_kind: InverseLink,
2069    pub max_iterations: usize,
2070    pub convergence_tolerance: f64,
2071    pub firth_bias_reduction: bool,
2072    /// Optional warm-start hint for `WorkingModelPirlsOptions::initial_lm_lambda`.
2073    /// Forwarded directly when `fit_model_for_fixed_rho` builds its
2074    /// internal options. See the field doc on `WorkingModelPirlsOptions`
2075    /// for the seeding semantics.
2076    pub initial_lm_lambda: Option<f64>,
2077    /// Enable the Transtrum-Sethna geodesic-acceleration second-order
2078    /// correction on each accepted LM step. Forwarded to
2079    /// `WorkingModelPirlsOptions::geodesic_acceleration`; see that
2080    /// field's doc for the full semantics and cost model. Default
2081    /// `false`; opt-in until validated.
2082    pub geodesic_acceleration: bool,
2083    /// Optional arrow-Schur structured-inner-solve descriptor. When
2084    /// `Some`, forwarded to `WorkingModelPirlsOptions::arrow_schur` so
2085    /// each accepted LM step is solved by the per-observation
2086    /// arrow-Schur path
2087    /// ([`crate::arrow_schur::ArrowSchurSystem`]). When `None`
2088    /// (the default), the existing β-only path is used unchanged.
2089    ///
2090    /// See [`ArrowSchurInnerConfig`] for the closure contract.
2091    pub arrow_schur: Option<ArrowSchurInnerConfig>,
2092}
2093
2094impl PirlsConfig {
2095    #[inline]
2096    pub fn link_function(&self) -> LinkFunction {
2097        self.link_kind.link_function()
2098    }
2099}
2100
2101#[inline]
2102pub(super) fn max_symmetric_asymmetry(matrix: &Array2<f64>) -> f64 {
2103    let n = matrix.nrows().min(matrix.ncols());
2104    let mut max_asym = 0.0_f64;
2105    for i in 0..n {
2106        for j in 0..i {
2107            let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2108            if diff > max_asym {
2109                max_asym = diff;
2110            }
2111        }
2112    }
2113    max_asym
2114}
2115
2116#[inline]
2117pub(super) fn assert_symmetric_tol(matrix: &Array2<f64>, label: &str, tol: f64) {
2118    let max_asym = max_symmetric_asymmetry(matrix);
2119    assert!(
2120        max_asym <= tol,
2121        "{} asymmetry too large: {:.3e} (tol {:.3e})",
2122        label,
2123        max_asym,
2124        tol
2125    );
2126}
2127
2128/// Build a DesignMatrix wrapping a lazy ReparamOperator (or the original for sparse-native).
2129pub(crate) fn make_reparam_operator(
2130    x_original: &DesignMatrix,
2131    qs_arc: &Arc<Array2<f64>>,
2132    use_sparse_native: bool,
2133) -> DesignMatrix {
2134    if use_sparse_native {
2135        x_original.clone()
2136    } else {
2137        DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(
2138            ReparamOperator::new(x_original.clone(), Arc::clone(qs_arc)),
2139        )))
2140    }
2141}
2142
2143// solve_penalized_least_squares_implicit lives in pls_solver (imported above).
2144
2145pub(super) fn build_transformed_lower_bound_constraints(
2146    qs: &Array2<f64>,
2147    coefficient_lower_bounds: Option<&Array1<f64>>,
2148) -> Option<LinearInequalityConstraints> {
2149    let lb = coefficient_lower_bounds?;
2150    if lb.len() != qs.nrows() {
2151        return None;
2152    }
2153    let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2154    if activerows.is_empty() {
2155        return None;
2156    }
2157    let mut a = Array2::<f64>::zeros((activerows.len(), qs.ncols()));
2158    let mut b = Array1::<f64>::zeros(activerows.len());
2159    for (r, &idx) in activerows.iter().enumerate() {
2160        a.row_mut(r).assign(&qs.row(idx));
2161        b[r] = lb[idx];
2162    }
2163    Some(
2164        LinearInequalityConstraints::new(a, b)
2165            .expect("transformed lower-bound constraint shape invariant"),
2166    )
2167}
2168
2169pub(super) fn build_transformed_lower_bound_constraints_with_transform(
2170    transform: &WorkingReparamTransform,
2171    coefficient_lower_bounds: Option<&Array1<f64>>,
2172) -> Option<LinearInequalityConstraints> {
2173    let lb = coefficient_lower_bounds?;
2174    let p = match transform {
2175        WorkingReparamTransform::Dense(qs) => qs.nrows(),
2176        WorkingReparamTransform::Kronecker(kron) => kron.p,
2177    };
2178    if lb.len() != p {
2179        return None;
2180    }
2181    let activerows: Vec<usize> = (0..lb.len()).filter(|&i| lb[i].is_finite()).collect();
2182    if activerows.is_empty() {
2183        return None;
2184    }
2185    let mut a = Array2::<f64>::zeros((activerows.len(), p));
2186    let mut b = Array1::<f64>::zeros(activerows.len());
2187    for (r, &idx) in activerows.iter().enumerate() {
2188        let mut basis = Array1::<f64>::zeros(p);
2189        basis[idx] = 1.0;
2190        let row = transform.apply_transpose(&basis);
2191        a.row_mut(r).assign(&row);
2192        b[r] = lb[idx];
2193    }
2194    Some(
2195        LinearInequalityConstraints::new(a, b)
2196            .expect("transformed lower-bound constraint shape invariant"),
2197    )
2198}
2199
2200pub(super) fn build_transformed_linear_constraints(
2201    qs: &Array2<f64>,
2202    linear_constraints: Option<&LinearInequalityConstraints>,
2203) -> Option<LinearInequalityConstraints> {
2204    let lc = linear_constraints?;
2205    if lc.a.ncols() != qs.nrows() {
2206        return None;
2207    }
2208    Some(
2209        LinearInequalityConstraints::new(lc.a.dot(qs), lc.b.clone())
2210            .expect("transformed linear constraint shape invariant"),
2211    )
2212}
2213
2214pub(super) fn build_transformed_linear_constraints_with_transform(
2215    transform: &WorkingReparamTransform,
2216    linear_constraints: Option<&LinearInequalityConstraints>,
2217) -> Option<LinearInequalityConstraints> {
2218    let lc = linear_constraints?;
2219    let p = match transform {
2220        WorkingReparamTransform::Dense(qs) => qs.nrows(),
2221        WorkingReparamTransform::Kronecker(kron) => kron.p,
2222    };
2223    if lc.a.ncols() != p {
2224        return None;
2225    }
2226    let mut a = Array2::<f64>::zeros((lc.a.nrows(), p));
2227    for row in 0..lc.a.nrows() {
2228        let transformed = transform.apply_transpose(&lc.a.row(row).to_owned());
2229        a.row_mut(row).assign(&transformed);
2230    }
2231    Some(LinearInequalityConstraints { a, b: lc.b.clone() })
2232}
2233
2234pub(super) fn merge_linear_constraints(
2235    first: Option<LinearInequalityConstraints>,
2236    second: Option<LinearInequalityConstraints>,
2237) -> Option<LinearInequalityConstraints> {
2238    match (first, second) {
2239        (None, None) => None,
2240        (Some(c), None) | (None, Some(c)) => Some(c),
2241        (Some(c1), Some(c2)) => {
2242            if c1.a.ncols() != c2.a.ncols() {
2243                return None;
2244            }
2245            let rows = c1.a.nrows() + c2.a.nrows();
2246            let cols = c1.a.ncols();
2247            let mut a = Array2::<f64>::zeros((rows, cols));
2248            a.slice_mut(s![0..c1.a.nrows(), ..]).assign(&c1.a);
2249            a.slice_mut(s![c1.a.nrows()..rows, ..]).assign(&c2.a);
2250            let mut b = Array1::<f64>::zeros(rows);
2251            b.slice_mut(s![0..c1.b.len()]).assign(&c1.b);
2252            b.slice_mut(s![c1.b.len()..rows]).assign(&c2.b);
2253            Some(LinearInequalityConstraints { a, b })
2254        }
2255    }
2256}
2257
2258pub(super) fn sparse_from_denseview(x: ArrayView2<f64>) -> Option<DesignMatrix> {
2259    // Below this column count a dense factorization beats the sparse path even
2260    // at high sparsity, so skip the sparsity scan entirely for narrow designs.
2261    const DENSE_PREFERRED_MAX_COLS: usize = 32;
2262    // Sparse storage + sparse Cholesky only pays off below this density (nnz as
2263    // a fraction of all entries); denser matrices stay dense.
2264    const SPARSE_DENSITY_LIMIT: f64 = 0.20;
2265
2266    let nrows = x.nrows();
2267    let ncols = x.ncols();
2268    if nrows == 0 || ncols == 0 {
2269        return None;
2270    }
2271    // Narrow matrices are faster in dense form; avoid any sparsity scan overhead.
2272    if ncols <= DENSE_PREFERRED_MAX_COLS {
2273        return None;
2274    }
2275
2276    const ZERO_EPS: f64 = 1e-12;
2277    let total = nrows.saturating_mul(ncols);
2278    if total == 0 {
2279        return None;
2280    }
2281    // If a matrix exceeds this nnz count it is too dense for sparse path; bail early.
2282    let sparse_nnz_limit = ((total as f64) * SPARSE_DENSITY_LIMIT).floor() as usize;
2283    let mut nnz = 0usize;
2284    for &val in x.iter() {
2285        if val.abs() > ZERO_EPS {
2286            nnz += 1;
2287            if nnz > sparse_nnz_limit {
2288                return None;
2289            }
2290        }
2291    }
2292    let mut triplets = Vec::with_capacity(nnz);
2293    for (row_idx, row) in x.outer_iter().enumerate() {
2294        for (col_idx, &val) in row.iter().enumerate() {
2295            if val.abs() > ZERO_EPS {
2296                triplets.push(Triplet::new(row_idx, col_idx, val));
2297            }
2298        }
2299    }
2300    SparseColMat::try_new_from_triplets(nrows, ncols, &triplets)
2301        .ok()
2302        .map(DesignMatrix::from)
2303}
2304
2305#[cfg(test)]
2306mod tests {
2307    use super::{PirlsPenalty, build_diagonal_penalty_from_kronecker};
2308    use gam_terms::construction::KroneckerReparamResult;
2309    use ndarray::{Array1, Array2, array};
2310
2311    #[test]
2312    fn kronecker_diagonal_double_penalty_hits_only_joint_null_space() {
2313        let kron_result = KroneckerReparamResult {
2314            reparameterized_marginals: std::sync::Arc::new(Vec::new()),
2315            marginal_eigenvalues: std::sync::Arc::new(vec![array![0.0, 2.0], array![0.0, 3.0]]),
2316            marginal_qs: std::sync::Arc::new(Vec::new()),
2317            log_det: 0.0,
2318            det1: Array1::zeros(3),
2319            det2: Array2::zeros((3, 3)),
2320            penalty_shrinkage_ridge: 0.5,
2321            has_double_penalty: true,
2322            marginal_dims: vec![2usize, 2usize],
2323        };
2324        let penalty = build_diagonal_penalty_from_kronecker(&kron_result, &[5.0, 7.0, 11.0]);
2325
2326        let PirlsPenalty::Diagonal {
2327            diag,
2328            positive_indices,
2329            ..
2330        } = penalty
2331        else {
2332            panic!("expected diagonal Kronecker PIRLS penalty");
2333        };
2334        let expected = [11.0, 21.5, 10.5, 31.5];
2335        for (idx, expected_diag) in expected.iter().copied().enumerate() {
2336            assert!(
2337                (diag[idx] - expected_diag).abs() <= 1e-12,
2338                "diagonal {idx} got {}, expected {expected_diag}",
2339                diag[idx]
2340            );
2341        }
2342        assert_eq!(positive_indices, vec![0, 1, 2, 3]);
2343    }
2344}