Skip to main content

gam_solve/inference/
alo.rs

1use crate::estimate::EstimationError;
2use crate::estimate::{FitGeometry, UnifiedFitResult};
3use crate::pirls;
4use faer::Mat as FaerMat;
5use faer::linalg::matmul::matmul;
6use faer::prelude::ReborrowMut;
7use faer::{Accum, Par};
8use gam_linalg::faer_ndarray::{FaerArrayView, FaerCholesky};
9use gam_linalg::matrix::{PsdWeightsView, SignedWeightsView};
10use gam_linalg::utils::StableSolver;
11use gam_problem::LinkFunction;
12use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder, s};
13use std::fmt;
14
15/// Typed error variants for the ALO (approximate leave-one-out) diagnostics
16/// module.
17///
18/// Public entry points continue to return `Result<_, EstimationError>`; this
19/// enum is materialized at leaf sites and converted at the boundary via
20/// `From<AloError> for EstimationError` so error text remains byte-identical
21/// to the previous `EstimationError::InvalidInput(format!(...))` /
22/// `ModelIsIllConditioned { ... }` output.
23#[derive(Debug, Clone)]
24pub enum AloError {
25    /// Caller-supplied configuration is structurally invalid: dimension
26    /// mismatch, non-finite inputs that are not weights/response, missing
27    /// PIRLS / geometry artifacts, or out-of-range scalar parameters.
28    InvalidInput { reason: String },
29    /// IRLS weights or working response contain a non-finite entry, or the
30    /// working response itself is invalid.
31    WeightInvalid { reason: String },
32    /// The dense design matrix required for ALO could not be materialized
33    /// from the underlying PIRLS artifact (e.g. sparse-only export).
34    DesignDegenerate { reason: String },
35    /// The penalized Hessian factorization failed, or downstream diagnostics
36    /// produced NaN values that indicate the influence matrix is unusable.
37    InfluenceMatrixFailed { condition_number: f64 },
38    /// Per-observation ALO computation produced a non-finite value (variance,
39    /// denominator, or corrected η̃) at convergence.
40    LooComputationFailed { reason: String },
41}
42
43impl fmt::Display for AloError {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            AloError::InvalidInput { reason }
47            | AloError::WeightInvalid { reason }
48            | AloError::DesignDegenerate { reason }
49            | AloError::LooComputationFailed { reason } => f.write_str(reason),
50            AloError::InfluenceMatrixFailed { condition_number } => {
51                write!(
52                    f,
53                    "ALO influence matrix failed (condition number {condition_number:.3e})"
54                )
55            }
56        }
57    }
58}
59
60impl std::error::Error for AloError {}
61
62impl From<AloError> for EstimationError {
63    fn from(err: AloError) -> EstimationError {
64        match err {
65            AloError::InvalidInput { reason }
66            | AloError::WeightInvalid { reason }
67            | AloError::DesignDegenerate { reason }
68            | AloError::LooComputationFailed { reason } => EstimationError::InvalidInput(reason),
69            AloError::InfluenceMatrixFailed { condition_number } => {
70                EstimationError::ModelIsIllConditioned { condition_number }
71            }
72        }
73    }
74}
75
76impl From<AloError> for String {
77    fn from(err: AloError) -> String {
78        err.to_string()
79    }
80}
81
82/// Approximate leave-one-out diagnostics derived from a fitted model.
83#[derive(Debug, Clone)]
84pub struct AloDiagnostics {
85    pub eta_tilde: Array1<f64>,
86    /// Bayesian/conditional standard error on eta:
87    /// sqrt(phi * x_i^T H^{-1} x_i).
88    pub se_bayes: Array1<f64>,
89    /// Frequentist sandwich-style standard error on eta:
90    /// sqrt(phi * x_i^T H^{-1} X^T W X H^{-1} x_i).
91    pub se_sandwich: Array1<f64>,
92    pub pred_identity: Array1<f64>,
93    pub leverage: Array1<f64>,
94    pub fisherweights: Array1<f64>,
95}
96
97#[inline]
98fn alo_eta_updatewith_offset(
99    eta_hat: f64,
100    z: f64,
101    offset: f64,
102    x_hinv_x: f64,
103    score_weight: f64,
104    denom: f64,
105) -> f64 {
106    // PIRLS working-response algebra is centered on offset, so the scalar
107    // score uses (eta - offset) - (z - offset).
108    let eta_centered = eta_hat - offset;
109    let z_centered = z - offset;
110    let score = score_weight * (eta_centered - z_centered);
111    offset + eta_centered + x_hinv_x * score / denom
112}
113
114/// Per-row score and curvature of the penalized NLL contribution as functions
115/// of the row's linear predictor `eta`.
116///
117/// Returns `(ℓ_i'(eta), ℓ_i''(eta))` where `ℓ_i` is the (dispersion-scaled)
118/// negative log-likelihood of observation `i` viewed as a univariate function
119/// of `eta_i = x_i^T β`. This is the local family geometry that the ALO
120/// frozen-curvature fixed point [`alo_eta_exact_frozen_curvature`] iterates to
121/// convergence; supplying it upgrades the single-Newton-step ALO correction to
122/// the exact leave-`i`-out predictor under a frozen penalized Hessian.
123pub type AloScalarScoreCurvature<'a> = dyn Fn(usize, f64) -> (f64, f64) + Sync + 'a;
124
125/// Maximum scalar Newton iterations for the exact frozen-curvature ALO fixed
126/// point. The map `r(η) = η − η̂ − a_ii ℓ_i'(η)` is one-dimensional and
127/// strongly contractive for the well-leveraged majority of points, so this
128/// caps the rare high-leverage / near-separation rows where convergence is
129/// slow without ever exceeding O(1) work per observation.
130const ALO_EXACT_SCALAR_MAX_ITERS: usize = 64;
131
132/// Absolute convergence tolerance on the scalar residual `r(η)` for the exact
133/// frozen-curvature ALO fixed point. Well below the `1e-2` predictive bar the
134/// LOO comparison asserts, so the refinement is not the limiting error term.
135const ALO_EXACT_SCALAR_TOL: f64 = 1e-12;
136
137/// Solve the frozen-curvature ALO leave-`i`-out fixed point exactly.
138///
139/// The leave-`i`-out optimum differs from the full fit only through the removed
140/// observation, whose gradient/Hessian depend on `β` solely via the scalar
141/// `η_i = x_i^T β`. Freezing the penalized Hessian `H` at its converged value
142/// reduces the exact leave-`i`-out condition to the scalar equation
143///
144///   η = η̂_i + a_ii · ℓ_i'(η),     a_ii = x_i^T H^{-1} x_i,
145///
146/// where `ℓ_i'(η)` is the row's NLL score (so that `∇F = ℓ_i'(η_i) x_i` at the
147/// leave-`i`-out point). The single-Newton-step ALO is exactly the first
148/// iterate of Newton's method on `r(η) = η − η̂_i − a_ii ℓ_i'(η)` started at
149/// `η̂_i`; iterating to convergence captures the change in the held-out point's
150/// likelihood curvature (the dominant first-order error on small-`n`, curved
151/// likelihoods such as binomial logistic regression near separation).
152///
153/// `score_curvature(eta)` returns `(ℓ_i'(eta), ℓ_i''(eta))`. The returned value
154/// is the corrected linear predictor `η̃_i`. Failure to reach the residual
155/// tolerance is reported to the caller; no one-step approximation is substituted
156/// for a failed exact solve.
157#[derive(Debug, Clone, Copy, PartialEq)]
158enum AloExactScalarError {
159    NonFiniteScoreCurvature {
160        eta: f64,
161        ell_prime: f64,
162        ell_double: f64,
163    },
164    DegenerateJacobian {
165        eta: f64,
166        jacobian: f64,
167    },
168    NonFiniteStep {
169        eta: f64,
170        residual: f64,
171        jacobian: f64,
172        next: f64,
173    },
174    MaxIterations {
175        iterations: usize,
176        residual: f64,
177        eta: f64,
178    },
179}
180
181impl fmt::Display for AloExactScalarError {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        match *self {
184            AloExactScalarError::NonFiniteScoreCurvature {
185                eta,
186                ell_prime,
187                ell_double,
188            } => write!(
189                f,
190                "non-finite score/curvature at eta={eta:.6e}: ell_prime={ell_prime:.6e}, ell_double={ell_double:.6e}"
191            ),
192            AloExactScalarError::DegenerateJacobian { eta, jacobian } => write!(
193                f,
194                "degenerate Newton Jacobian at eta={eta:.6e}: jacobian={jacobian:.6e}, min={ALO_DENOMINATOR_MIN:.1e}"
195            ),
196            AloExactScalarError::NonFiniteStep {
197                eta,
198                residual,
199                jacobian,
200                next,
201            } => write!(
202                f,
203                "non-finite Newton step from eta={eta:.6e}: residual={residual:.6e}, jacobian={jacobian:.6e}, next={next:.6e}"
204            ),
205            AloExactScalarError::MaxIterations {
206                iterations,
207                residual,
208                eta,
209            } => write!(
210                f,
211                "did not converge within {iterations} iterations: residual={residual:.6e}, eta={eta:.6e}, tol={ALO_EXACT_SCALAR_TOL:.1e}"
212            ),
213        }
214    }
215}
216
217/// Maximum number of step halvings in the backtracking line search that
218/// globalizes the scalar Newton iteration. `2^{-40}` shrinks a unit step well
219/// below `ALO_EXACT_SCALAR_TOL` relative to any η of practical magnitude, so a
220/// row that cannot make progress within this budget is genuinely stalled rather
221/// than merely under-damped.
222const ALO_EXACT_SCALAR_BACKTRACKS: usize = 40;
223
224#[inline]
225fn alo_eta_exact_frozen_curvature(
226    eta_hat: f64,
227    a_ii: f64,
228    score_curvature: &dyn Fn(f64) -> (f64, f64),
229) -> Result<f64, AloExactScalarError> {
230    // Residual of the leave-i-out fixed point η = η̂ + a_ii ℓ'(η):
231    //   r(η) = η − η̂ − a_ii ℓ'(η),     r'(η) = 1 − a_ii ℓ''(η) = jac.
232    // For an exponential-family NLL score ℓ'(η) = c_i(μ(η) − y) on a non-linear
233    // (e.g. log) link the curvature ℓ''(η) = c_i μ'(η) grows without bound, so
234    // r(η) is concave with an interior maximum where the weighted leverage
235    // a_ii ℓ'' passes 1 (jac = 0): the leave-i-out root that limits to η̂ as
236    // a_ii → 0 sits on the jac > 0 branch anchored at η̂, while beyond the
237    // maximum r turns over and diverges as μ(η) explodes.
238    //
239    // Two safeguards make the scalar solve globally convergent to that root:
240    //
241    //   1. Anchor the iteration at η̂ itself, not at the classical one-step ALO
242    //      predictor. At η̂ the weighted leverage a_ii ℓ''(η̂) < 1, so jac ≈ 1
243    //      and we start strictly inside the correct basin; the brute-force
244    //      n-fold reference solves the identical fixed point anchored at η̂.
245    //      Seeding at the one-step predictor instead can land a high-leverage
246    //      row *past* the interior maximum on the runaway branch, from which no
247    //      Newton iteration returns (Poisson/log row 198: η ≈ 6.3, r ≈ −577).
248    //
249    //   2. Backtrack on the merit ½r(η)². The Newton direction d = −r/jac
250    //      satisfies (½r²)'·d = r·jac·(−r/jac) = −r² < 0 for any finite nonzero
251    //      jac, so halving the step until |r| strictly decreases never leaves
252    //      the basin even if a full step would overshoot the maximum.
253    let residual_and_jac = |eta: f64| -> Result<(f64, f64), AloExactScalarError> {
254        let (ell_prime, ell_double) = score_curvature(eta);
255        if !ell_prime.is_finite() || !ell_double.is_finite() {
256            return Err(AloExactScalarError::NonFiniteScoreCurvature {
257                eta,
258                ell_prime,
259                ell_double,
260            });
261        }
262        Ok((eta - eta_hat - a_ii * ell_prime, 1.0 - a_ii * ell_double))
263    };
264
265    let mut eta = eta_hat;
266    let (mut residual, mut jac) = residual_and_jac(eta)?;
267    for _ in 0..ALO_EXACT_SCALAR_MAX_ITERS {
268        if residual.abs() <= ALO_EXACT_SCALAR_TOL {
269            return Ok(eta);
270        }
271        if jac.abs() <= ALO_DENOMINATOR_MIN || !jac.is_finite() {
272            return Err(AloExactScalarError::DegenerateJacobian { eta, jacobian: jac });
273        }
274        let step = residual / jac;
275        if !step.is_finite() {
276            return Err(AloExactScalarError::NonFiniteStep {
277                eta,
278                residual,
279                jacobian: jac,
280                next: eta - step,
281            });
282        }
283        // Backtracking line search: take the longest damped Newton step
284        // 2^{-k} that strictly reduces the merit |r|. A non-finite trial
285        // (score/curvature evaluated in the runaway branch) is treated as no
286        // improvement and rejected, so the search retreats toward η̂.
287        let mut t = 1.0;
288        let mut advanced = false;
289        for _ in 0..ALO_EXACT_SCALAR_BACKTRACKS {
290            let trial = eta - t * step;
291            if let Ok((r_trial, j_trial)) = residual_and_jac(trial) {
292                if r_trial.abs() < residual.abs() {
293                    eta = trial;
294                    residual = r_trial;
295                    jac = j_trial;
296                    advanced = true;
297                    break;
298                }
299            }
300            t *= 0.5;
301        }
302        if !advanced {
303            break;
304        }
305    }
306    Err(AloExactScalarError::MaxIterations {
307        iterations: ALO_EXACT_SCALAR_MAX_ITERS,
308        residual,
309        eta,
310    })
311}
312
313#[inline]
314fn bayesvar_eta(phi: f64, x_hinv_x: f64) -> f64 {
315    phi * x_hinv_x
316}
317
318#[inline]
319fn sandwichvar_eta_from_meat(phi: f64, meat_quad: f64) -> f64 {
320    phi * meat_quad
321}
322
323#[inline]
324fn variance_negative_tolerance(scale: f64) -> f64 {
325    // Tight relative tolerance for cancellation from x'H^{-1}x - ||E t||^2 - ridge||t||^2.
326    1e-12 * scale.abs().max(1.0)
327}
328
329const LEVERAGE_HIGH_THRESHOLD: f64 = 0.99;
330const LEVERAGE_VERY_HIGH_THRESHOLD: f64 = 0.999;
331const LEVERAGE_RATE_THRESHOLDS: [f64; 3] = [0.90, 0.95, 0.99];
332const LEVERAGE_PERCENTILES: [f64; 3] = [0.50, 0.95, 0.99];
333const ALO_DENOMINATOR_MIN: f64 = 1e-12;
334const MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES: usize = 256 * 1024 * 1024;
335
336/// Number of observation columns solved per blocked right-hand-side batch in the
337/// scalar-leverage path. Sizes the reusable `(p, .)` and `(e_rank, .)` scratch
338/// buffers so the dense multi-RHS solve stays BLAS-3 (good cache reuse) without
339/// materializing all `n` columns at once. The final batch is the remainder.
340const ALO_RHS_BLOCK_COLS: usize = 8192;
341
342/// Relative tolerance for accepting the input penalised Hessian `H` as
343/// symmetric. We require `|H_ij − H_ji| ≤ HESSIAN_SYMMETRY_REL_TOL ·
344/// max(|H_ij|, |H_ji|, 1)`. `1e-8` matches the loosest tolerance any
345/// upstream symmetrisation pass leaves on the matrix and is tight enough
346/// that a genuinely asymmetric Hessian (a real bug) is caught.
347const HESSIAN_SYMMETRY_REL_TOL: f64 = 1e-8;
348
349/// Diagonal ridge added to the local block precision when its LU pivot is
350/// below [`LU_PIVOT_SINGULAR_TOL`]. Matches the legacy `eps = 1e-6`
351/// regularisation in the prior `det_small < 1e-12` branch — bumping the
352/// determinant of `I − W A` (or `I − A W`) safely off zero without
353/// perturbing well-conditioned blocks.
354const ALO_LOCAL_BLOCK_RIDGE: f64 = 1e-6;
355
356/// Pivot magnitude below which [`lu_factor_in_place`] reports the block
357/// `I − W A` as singular and triggers the ridge-regularised refactor.
358/// Equivalent to the original `det_small < 1e-12` test on the unfactored
359/// determinant.
360const LU_PIVOT_SINGULAR_TOL: f64 = 1e-12;
361
362#[inline]
363fn percentile_index(sample_size: usize, quantile: f64) -> usize {
364    if sample_size <= 1 {
365        return 0;
366    }
367    let max_index = sample_size - 1;
368    ((quantile * max_index as f64).round() as usize).min(max_index)
369}
370
371#[inline]
372fn percentile_from_sorted(sorted: &[f64], quantile: f64) -> f64 {
373    if sorted.is_empty() {
374        0.0
375    } else {
376        sorted[percentile_index(sorted.len(), quantile)]
377    }
378}
379
380#[inline]
381fn multiblock_col_offsets(block_designs: &[Array2<f64>]) -> Vec<usize> {
382    let mut offsets = Vec::with_capacity(block_designs.len());
383    let mut off = 0usize;
384    for design in block_designs {
385        offsets.push(off);
386        off += design.ncols();
387    }
388    offsets
389}
390
391#[inline]
392fn multiblock_alo_parallel_leverage_chunk_size(
393    p_tot: usize,
394    n_blocks: usize,
395    n_obs: usize,
396    max_workers: usize,
397) -> usize {
398    if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
399        return 1;
400    }
401
402    // Each parallel leverage chunk owns q_storage for all block RHS products
403    // (B * p_tot * chunk_len) plus one transposed design chunk across all
404    // blocks (p_tot * chunk_len).  Divide the global scratch budget by the
405    // maximum number of chunks Rayon can execute concurrently so total live
406    // per-chunk scratch remains bounded.
407    let workers = max_workers.max(1);
408    let per_worker_budget = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / workers).max(1);
409    let elem_count_per_obs = p_tot.saturating_mul(n_blocks.saturating_add(1)).max(1);
410    let bytes_per_obs = elem_count_per_obs
411        .saturating_mul(std::mem::size_of::<f64>())
412        .max(1);
413    let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
414    budget_obs.min(n_obs)
415}
416
417fn compute_alo_diagnostics_from_pirls_impl(
418    base: &pirls::PirlsResult,
419    y: ArrayView1<f64>,
420    link: LinkFunction,
421) -> Result<AloDiagnostics, EstimationError> {
422    compute_alo_diagnostics_from_pirls_inner(base, y, link).map_err(EstimationError::from)
423}
424
425/// True when the fitted GLM uses a *curved* canonical link, so that the row NLL
426/// score and curvature satisfy `ℓ_i'(η) = c_i(μ(η)−y_i)` and `ℓ_i''(η) = c_i μ'(η)`
427/// with a single per-row scale `c_i = (prior weight)/φ`. This is the exact
428/// condition under which the frozen-curvature ALO scalar fixed point matches
429/// the leave-`i`-out refit; only these families enable the exact refinement.
430///
431/// Gaussian identity is canonical too, but its per-row curvature is *constant*
432/// (`μ'(η) ≡ 1`), so the classical Sherman–Morrison one-step ALO is already the
433/// exact frozen-Hessian leave-`i`-out solution. Routing it through the scalar
434/// Newton closure would only add an O(n) nonlinear solve to diagnostics and
435/// quality sweeps without changing the answer, so it is excluded here and falls
436/// back to the (exact, for this family) one-step formula.
437fn alo_link_needs_exact_curvature_refinement(likelihood: &gam_problem::GlmLikelihoodSpec) -> bool {
438    use gam_problem::ResponseFamily;
439    matches!(
440        (&likelihood.spec.response, likelihood.link_function()),
441        (ResponseFamily::Binomial, LinkFunction::Logit)
442            | (ResponseFamily::Poisson, LinkFunction::Log)
443    )
444}
445
446fn compute_alo_diagnostics_from_pirls_inner(
447    base: &pirls::PirlsResult,
448    y: ArrayView1<f64>,
449    link: LinkFunction,
450) -> Result<AloDiagnostics, AloError> {
451    let x_dense_arc = base
452        .x_transformed
453        .try_to_dense_arc("ALO diagnostics require dense transformed design")
454        .map_err(|reason| AloError::DesignDegenerate { reason })?;
455    let x_dense = x_dense_arc.as_ref();
456    let n = x_dense.nrows();
457
458    // Compute dispersion parameter.
459    let phi = match link {
460        LinkFunction::Log => 1.0,
461        LinkFunction::Logit
462        | LinkFunction::Probit
463        | LinkFunction::CLogLog
464        | LinkFunction::LogLog
465        | LinkFunction::Cauchit
466        | LinkFunction::Sas
467        | LinkFunction::BetaLogistic => 1.0,
468        LinkFunction::Identity => {
469            use rayon::iter::{IntoParallelIterator, ParallelIterator};
470            let rss: f64 = (0..n)
471                .into_par_iter()
472                .map(|i| {
473                    let r = y[i] - base.finalmu[i];
474                    base.finalweights[i] * r * r
475                })
476                .sum();
477            // Effective sample size for dispersion (#584): a zero prior weight
478            // makes w_i·r_i² = 0, so the row is already excluded from the RSS
479            // numerator and must be excluded from the denominator too. Count only
480            // positive-weight rows, exactly as the main optimizer path does
481            // (optimizer.rs ~1567); using the raw row count over a zero-excluding
482            // numerator biases φ̂ low and shrinks every ALO SE.
483            let n_pos = (0..n).filter(|&i| base.finalweights[i] > 0.0).count();
484            let dof = (n_pos as f64) - base.edf;
485            let denom = dof.max(1.0);
486            rss / denom
487        }
488    };
489
490    let e = &base.reparam_result.e_transformed;
491    let ridge = base.ridge_passport.laplacehessianridge().max(0.0);
492
493    // ALO needs the exact penalized Hessian materialized densely for chunked
494    // column solves via StableSolver.  The PIRLS export path validates the
495    // matrix instead of falling back to a numerical Hessian approximation.
496    let h_dense_for_alo = base
497        .dense_stabilizedhessian_transformed(
498            "ALO diagnostics require exact dense stabilized penalized Hessian",
499        )
500        .map_err(|e| match e {
501            EstimationError::InvalidInput(reason) => AloError::InvalidInput { reason },
502            other => AloError::InvalidInput {
503                reason: format!("{other:?}"),
504            },
505        })?;
506
507    // Exact frozen-curvature ALO refinement for canonical-link GLMs.
508    //
509    // For a canonical link the row NLL score and curvature are
510    //   ℓ_i'(η)  = c_i · (μ(η) − y_i),     ℓ_i''(η) = c_i · μ'(η),
511    // with c_i = (prior weight)/φ recovered from the converged geometry as
512    // c_i = W_H[i] / μ'(η̂_i) (since W_H[i] = c_i μ'(η̂_i) at convergence).
513    // Supplying this evaluator lets `compute_alo_from_input_inner` solve the
514    // leave-i-out scalar fixed point η = η̂_i + a_ii ℓ_i'(η) exactly instead of
515    // taking a single Newton step, removing the first-order linearization error
516    // that dominates on small-n, strongly curved likelihoods (binomial logit).
517    //
518    // Restricted to canonical links because only there does the observed
519    // curvature carried by the frozen Hessian (W_H) coincide with c_i μ'(η) for
520    // every trial η; non-canonical links retain the classical one-step ALO.
521    // Per-row scale c_i = W_H[i]/μ'(η̂_i). Rows whose μ'(η̂_i) is negligible
522    // (saturated / near-separation) get c_i = NaN, which makes the exact solver
523    // reject that row explicitly rather than substituting the classical one-step
524    // ALO.
525    let canonical_scale: Option<Array1<f64>> =
526        if alo_link_needs_exact_curvature_refinement(&base.likelihood) {
527            let mut c = Array1::<f64>::zeros(n);
528            for i in 0..n {
529                let dmu = base.solve_dmu_deta[i];
530                let w_h = base.finalweights[i];
531                c[i] = if dmu.abs() <= ALO_DENOMINATOR_MIN || !dmu.is_finite() || !w_h.is_finite() {
532                    f64::NAN
533                } else {
534                    w_h / dmu
535                };
536            }
537            Some(c)
538        } else {
539            None
540        };
541
542    let inv_link_for_closure = base.likelihood.spec.link.clone();
543    let score_curvature_closure = canonical_scale.as_ref().map(|scale| {
544        move |i: usize, eta: f64| -> (f64, f64) {
545            let (mu, dmu) = crate::mixture_link::inverse_link_mu_d1_for_inverse_link(
546                &inv_link_for_closure,
547                eta,
548            )
549            .unwrap_or((f64::NAN, f64::NAN));
550            let c_i = scale[i];
551            (c_i * (mu - y[i]), c_i * dmu)
552        }
553    });
554    let score_curvature_ref: Option<&AloScalarScoreCurvature> = score_curvature_closure
555        .as_ref()
556        .map(|f| f as &AloScalarScoreCurvature);
557
558    // Build model-agnostic AloInput from PIRLS geometry, then delegate.
559    let input = AloInput {
560        design: x_dense,
561        penalized_hessian: &h_dense_for_alo,
562        hessian_weights: base.final_weights_signed(),
563        score_weights: base.solve_weights_psd(),
564        working_response: &base.solveworking_response,
565        eta: &base.final_eta,
566        offset: &base.final_offset,
567        link,
568        phi,
569        penalty_root: if e.nrows() > 0 { Some(e) } else { None },
570        ridge,
571        score_curvature: score_curvature_ref,
572    };
573
574    let result = compute_alo_from_input_inner(&input)?;
575
576    // PIRLS-specific post-hoc leverage diagnostics logging.
577    log_leverage_diagnostics(&result.leverage, phi);
578
579    // Final NaN guard with detailed error reporting.
580    let has_nan_pred = result.eta_tilde.iter().any(|&x| x.is_nan());
581    let has_nan_se_bayes = result.se_bayes.iter().any(|&x| x.is_nan());
582    let has_nan_se_sandwich = result.se_sandwich.iter().any(|&x| x.is_nan());
583    let has_nan_leverage = result.leverage.iter().any(|&x| x.is_nan());
584
585    if has_nan_pred || has_nan_se_bayes || has_nan_se_sandwich || has_nan_leverage {
586        log::error!("[GAM ALO] NaN values found in ALO diagnostics:");
587        log::error!(
588            "[GAM ALO] eta_tilde: {} NaN values",
589            result.eta_tilde.iter().filter(|&&x| x.is_nan()).count()
590        );
591        log::error!(
592            "[GAM ALO] se_bayes: {} NaN values",
593            result.se_bayes.iter().filter(|&&x| x.is_nan()).count()
594        );
595        log::error!(
596            "[GAM ALO] se_sandwich: {} NaN values",
597            result.se_sandwich.iter().filter(|&&x| x.is_nan()).count()
598        );
599        log::error!(
600            "[GAM ALO] leverage: {} NaN values",
601            result.leverage.iter().filter(|&&x| x.is_nan()).count()
602        );
603        return Err(AloError::InfluenceMatrixFailed {
604            condition_number: f64::INFINITY,
605        });
606    }
607
608    Ok(result)
609}
610
611/// Log detailed leverage percentile diagnostics for a completed ALO computation.
612fn log_leverage_diagnostics(leverage: &Array1<f64>, phi: f64) {
613    let n = leverage.len();
614    if n == 0 {
615        return;
616    }
617
618    let mut invalid_count = 0usize;
619    let mut high_leverage_count = 0usize;
620    let mut threshold_counts = [0usize; LEVERAGE_RATE_THRESHOLDS.len()];
621    let mut finite_leverage = Vec::with_capacity(n);
622
623    for (obs, &ai) in leverage.iter().enumerate() {
624        if ai.is_finite() {
625            finite_leverage.push(ai);
626        }
627
628        if !(0.0..=1.0).contains(&ai) || !ai.is_finite() {
629            invalid_count += 1;
630            log::warn!("[GAM ALO] invalid leverage at i={}, a_ii={:.6e}", obs, ai);
631        } else if ai > LEVERAGE_HIGH_THRESHOLD {
632            high_leverage_count += 1;
633            if ai > LEVERAGE_VERY_HIGH_THRESHOLD {
634                log::warn!("[GAM ALO] very high leverage at i={}, a_ii={:.6e}", obs, ai);
635            }
636        }
637
638        for (idx, threshold) in LEVERAGE_RATE_THRESHOLDS.iter().enumerate() {
639            if ai > *threshold {
640                threshold_counts[idx] += 1;
641            }
642        }
643    }
644
645    if invalid_count > 0 || high_leverage_count > 0 {
646        log::warn!(
647            "[GAM ALO] leverage diagnostics: {} invalid values, {} high values (>0.99)",
648            invalid_count,
649            high_leverage_count
650        );
651    }
652
653    finite_leverage.sort_by(f64::total_cmp);
654
655    let finite_n = finite_leverage.len();
656    let a_mean = if finite_n > 0 {
657        finite_leverage.iter().copied().sum::<f64>() / finite_n as f64
658    } else {
659        0.0
660    };
661    let a_median = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[0]);
662    let a_p95 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[1]);
663    let a_p99 = percentile_from_sorted(&finite_leverage, LEVERAGE_PERCENTILES[2]);
664    let a_max = finite_leverage.last().copied().unwrap_or(0.0);
665
666    // Routine per-ALO leverage summary: a diagnostic snapshot, not an
667    // anomaly. Emitted at `info!` so it is visible when the host raises
668    // verbosity (CLI `-v`; `gamfit.set_log_level("info")`) but silent at the
669    // default `Warn` level (genuine anomalies — invalid / very
670    // high leverage — are logged at `warn!` above and stay visible). This
671    // line fires once per ALO computation, which recurs across the outer
672    // smoothing loop, so at `warn!` it was a dominant source of stderr noise
673    // on perfectly healthy fits (#1689).
674    log::info!(
675        "[GAM ALO] leverage: n={}, mean={:.3e}, median={:.3e}, p95={:.3e}, p99={:.3e}, max={:.3e}",
676        n,
677        a_mean,
678        a_median,
679        a_p95,
680        a_p99,
681        a_max
682    );
683    log::info!(
684        "[GAM ALO] high-leverage: a>0.90: {:.2}%, a>0.95: {:.2}%, a>0.99: {:.2}%, dispersion phi={:.3e}",
685        100.0 * (threshold_counts[0] as f64) / n as f64,
686        100.0 * (threshold_counts[1] as f64) / n as f64,
687        100.0 * (threshold_counts[2] as f64) / n as f64,
688        phi
689    );
690}
691
692/// Model-agnostic input for ALO diagnostics.
693///
694/// Any model with a design matrix, penalized Hessian, and IRLS geometry can
695/// compute ALO leverages and leave-one-out predictions. This decouples ALO
696/// from the single-block PIRLS solver and enables diagnostics for GAMLSS,
697/// survival, and joint models.
698pub struct AloInput<'a> {
699    /// Dense design matrix X (n × p).
700    pub design: &'a Array2<f64>,
701    /// Penalized Hessian H = X'WX + S(λ) at convergence (p × p).
702    pub penalized_hessian: &'a Array2<f64>,
703    /// Hessian-side IRLS weights W_H at convergence (n). Sign-honest: for
704    /// non-canonical links the observed-information diagonal can have negative
705    /// entries, so the typed [`SignedWeightsView`] is the contract here. PSD
706    /// callers needing to promote (e.g. the canonical-link case where the
707    /// caller has discharged W_H ≥ 0 algebraically) can route through
708    /// `SignedWeightsView::as_psd()` at the consumer.
709    pub hessian_weights: SignedWeightsView<'a>,
710    /// Score-side IRLS weights W_S paired with `working_response` (n).
711    /// PSD-by-construction: the score-side Fisher weights `h'²/(φ V(μ)) ≥ 0`.
712    pub score_weights: PsdWeightsView<'a>,
713    /// IRLS working response at convergence (n).
714    pub working_response: &'a Array1<f64>,
715    /// Fitted linear predictor η̂ (n).
716    pub eta: &'a Array1<f64>,
717    /// Offset vector (n). Pass zeros if no offset.
718    pub offset: &'a Array1<f64>,
719    /// Link function (for phi determination).
720    pub link: LinkFunction,
721    /// Dispersion parameter φ. For non-Gaussian families this is 1.0.
722    pub phi: f64,
723    /// Optional penalty square root E with E^T E = S(λ) (rank × p) for sandwich SE.
724    /// When `None`, sandwich SE is set equal to Bayesian SE.
725    pub penalty_root: Option<&'a Array2<f64>>,
726    /// Ridge added to the Hessian for logdet surface.
727    pub ridge: f64,
728    /// Optional per-row score/curvature evaluator `(i, η) → (ℓ_i'(η), ℓ_i''(η))`.
729    ///
730    /// When supplied, the leave-`i`-out predictor is obtained by solving the
731    /// frozen-curvature scalar fixed point `η = η̂_i + a_ii ℓ_i'(η)` to
732    /// convergence (see [`alo_eta_exact_frozen_curvature`]) instead of taking a
733    /// single Newton step. This eliminates the first-order linearization error
734    /// that the one-step ALO incurs on small-`n`, strongly curved likelihoods
735    /// (e.g. binomial logistic regression). Non-convergence or invalid scalar
736    /// Newton geometry is returned as an ALO error. When `None`, the classical
737    /// single-Newton-step ALO formula is used. The evaluator must be consistent
738    /// with `hessian_weights` at convergence: `ℓ_i''(η̂_i) = W_H[i]` and
739    /// `ℓ_i'(η̂_i) = W_S[i]·((η̂_i−o_i) − (z_i−o_i))`.
740    pub score_curvature: Option<&'a AloScalarScoreCurvature<'a>>,
741}
742
743impl<'a> AloInput<'a> {
744    /// Build an `AloInput` from `FitGeometry` and associated vectors.
745    pub fn from_geometry(
746        geom: &'a FitGeometry,
747        design: &'a Array2<f64>,
748        eta: &'a Array1<f64>,
749        offset: &'a Array1<f64>,
750        link: LinkFunction,
751        phi: f64,
752    ) -> Self {
753        // FitGeometry stores one working-weight vector, so this constructor is
754        // exact only when the score- and Hessian-side IRLS weights coincide
755        // (canonical-link case where Fisher == Observed). In that path the
756        // diagonal is the Fisher weight `h'²/(φ V(μ)) ≥ 0`, so the PSD
757        // obligation is discharged algebraically without a runtime scan;
758        // `as_signed()` re-views the same buffer for the Hessian-side slot.
759        let psd_w = PsdWeightsView::from_view_unchecked(geom.working_weights.view());
760        Self {
761            design,
762            penalized_hessian: &geom.penalized_hessian,
763            hessian_weights: psd_w.as_signed(),
764            score_weights: psd_w,
765            working_response: &geom.working_response,
766            eta,
767            offset,
768            link,
769            phi,
770            penalty_root: None,
771            ridge: 0.0,
772            score_curvature: None,
773        }
774    }
775
776    /// Build an `AloInput` from a `FitGeometry`'s penalized Hessian plus
777    /// externally supplied working weights / working response.
778    ///
779    /// The row-sized IRLS working vectors are *derived* quantities: at
780    /// convergence they are deterministic functions of the linear predictor
781    /// `η̂ = Xβ̂`, the response `y`, and the family (`w_i = h'(η̂_i)²/(φ V(μ̂_i))·
782    /// prior_i`, `z_i = η̂_i + (y_i−μ̂_i)/h'(η̂_i)`). A size-compacted saved model
783    /// keeps the p×p `penalized_hessian` (n-independent) but drops those n-sized
784    /// vectors; a post-fit consumer such as `gam diagnose` reconstructs them from
785    /// the saved `β` by replaying the same PIRLS working-state update the fit
786    /// used, then feeds them here. This preserves the size win of dropping the
787    /// working vectors from persistence while still serving the exact geometry
788    /// ALO path (no refit, exact saved Hessian).
789    ///
790    /// Same canonical (Fisher == Observed) contract as [`from_geometry`]: the
791    /// supplied `working_weights` are the score-side Fisher weights and are
792    /// re-viewed for the Hessian-side slot via `as_signed()`.
793    ///
794    /// [`from_geometry`]: AloInput::from_geometry
795    pub fn from_geometry_with_working_state(
796        geom: &'a FitGeometry,
797        design: &'a Array2<f64>,
798        eta: &'a Array1<f64>,
799        offset: &'a Array1<f64>,
800        link: LinkFunction,
801        phi: f64,
802        working_weights: &'a Array1<f64>,
803        working_response: &'a Array1<f64>,
804    ) -> Self {
805        let psd_w = PsdWeightsView::from_view_unchecked(working_weights.view());
806        Self {
807            design,
808            penalized_hessian: &geom.penalized_hessian,
809            hessian_weights: psd_w.as_signed(),
810            score_weights: psd_w,
811            working_response,
812            eta,
813            offset,
814            link,
815            phi,
816            penalty_root: None,
817            ridge: 0.0,
818            score_curvature: None,
819        }
820    }
821}
822
823/// Compute ALO diagnostics from model-agnostic inputs.
824///
825/// This is the generalized entry point that works for any model type.
826/// For standard single-block GAMs, prefer `compute_alo_diagnostics_from_fit`
827/// which automatically extracts the PIRLS geometry (including sandwich SE).
828pub fn compute_alo_from_input(input: &AloInput) -> Result<AloDiagnostics, EstimationError> {
829    compute_alo_from_input_inner(input).map_err(EstimationError::from)
830}
831
832fn compute_alo_from_input_inner(input: &AloInput) -> Result<AloDiagnostics, AloError> {
833    let x_dense = input.design;
834    let n = x_dense.nrows();
835    let p = x_dense.ncols();
836    // Bind the underlying ArrayView1 once so the loop body can index and
837    // borrow as before; the sign-character contract lives in the
838    // `AloInput` field types, not in this local binding.
839    let w_h = input.hessian_weights.view();
840    let w_s = input.score_weights.view();
841
842    validate_alo_solve_setup(input, n, p)?;
843
844    let factor = StableSolver::new("alo penalized hessian")
845        .factorize(input.penalized_hessian)
846        .map_err(|_| AloError::InfluenceMatrixFailed {
847            condition_number: f64::INFINITY,
848        })?;
849
850    let xt = x_dense.t();
851    let phi = input.phi;
852
853    let mut aii = Array1::<f64>::zeros(n);
854    let mut x_hinv_x_diag = Array1::<f64>::zeros(n);
855    let mut se_bayes = Array1::<f64>::zeros(n);
856    let mut se_sandwich = Array1::<f64>::zeros(n);
857
858    let block_cols = ALO_RHS_BLOCK_COLS;
859    // Allocate the RHS scratch in column-major (Fortran) order so its column
860    // slices are contiguous and align with faer's column-major solve output.
861    // This removes redundant `xrow = x_dense.row(obs)` indirection inside the
862    // per-observation loop: rhs_chunk_buf already holds X^T at the right cols.
863    let mut rhs_chunk_buf = Array2::<f64>::zeros((p, block_cols).f());
864    // Reusable faer column-major buffer for X*S, where S = H^{-1}X_i for the
865    // current RHS chunk.  The sandwich SE must use the same frozen-curvature
866    // meat as the exact LOO reference, `X' W X`, directly; reconstructing it as
867    // `H - S_penalty - ridge*I` is brittle because the exported stabilized
868    // Hessian may include curvature/stabilization details that are not exactly
869    // represented by the penalty root plus public ridge scalar.
870    let mut xs_chunk_storage = FaerMat::<f64>::zeros(n, block_cols);
871    let x_dense_view = FaerArrayView::new(x_dense);
872
873    for chunk_start in (0..n).step_by(block_cols) {
874        let chunk_end = (chunk_start + block_cols).min(n);
875        let width = chunk_end - chunk_start;
876
877        rhs_chunk_buf
878            .slice_mut(s![.., ..width])
879            .assign(&xt.slice(s![.., chunk_start..chunk_end]));
880
881        let rhs_chunkview = rhs_chunk_buf.slice(s![.., ..width]);
882        let rhs_chunk = FaerArrayView::new(&rhs_chunkview);
883        // s_chunk is owned column-major faer storage; its column slices are
884        // contiguous and can be read directly via `col_as_slice` — no need to
885        // materialize a parallel ndarray copy.
886        let s_chunk = factor.solve(rhs_chunk.as_ref());
887
888        let mut xs_target = xs_chunk_storage.as_mut().subcols_mut(0, width);
889        matmul(
890            xs_target.rb_mut(),
891            Accum::Replace,
892            x_dense_view.as_ref(),
893            s_chunk.as_ref(),
894            1.0,
895            Par::Seq,
896        );
897
898        let rhs_view = rhs_chunk_buf.slice(s![.., ..width]);
899
900        for local_col in 0..width {
901            let obs = chunk_start + local_col;
902            // rhs is column-major Fortran ndarray; faer Mat columns are
903            // contiguous by construction. Both accesses borrow the existing
904            // storage directly — no per-column copy.
905            let rhs_col = rhs_view.column(local_col);
906            let rhs_slice = rhs_col.as_slice().expect("column-major col contiguous");
907            let s_slice = s_chunk.col_as_slice(local_col);
908
909            let mut x_hinv_x = 0.0f64;
910            // Fused dot product over the current solve column.
911            for k in 0..p {
912                let sval = s_slice[k];
913                let xval = rhs_slice[k];
914                x_hinv_x = sval.mul_add(xval, x_hinv_x);
915            }
916            let ai = w_h[obs].max(0.0) * x_hinv_x;
917            aii[obs] = ai;
918            x_hinv_x_diag[obs] = x_hinv_x;
919
920            let var_bayes = bayesvar_eta(phi, x_hinv_x);
921            let xs_slice = xs_chunk_storage.col_as_slice(local_col);
922            let mut meat_quad = 0.0f64;
923            for row in 0..n {
924                let xs = xs_slice[row];
925                // Sandwich meat is the SCORE covariance Xᵀ diag(W_S) X (Fisher,
926                // PSD by construction), not the observed-information Hessian
927                // weight W_H: the estimator is Var = H⁻¹·Cov(score)·H⁻¹ with the
928                // bread H = Xᵀ W_H X + S. For non-canonical links W_H ≠ W_S (and
929                // W_H can be negative), so using W_H here gives a wrong — even
930                // negative — sandwich SE. See `AloInput::score_weights`.
931                meat_quad += w_s[row] * xs * xs;
932            }
933            let var_sandwich = sandwichvar_eta_from_meat(phi, meat_quad);
934
935            if !var_bayes.is_finite() || !var_sandwich.is_finite() {
936                return Err(AloError::LooComputationFailed {
937                    reason: format!(
938                        "ALO variance is not finite at row {obs}: bayes={var_bayes:.6e}, sandwich={var_sandwich:.6e}"
939                    ),
940                });
941            }
942            let bayes_tol = variance_negative_tolerance(phi * x_hinv_x.abs());
943            if var_bayes < -bayes_tol {
944                return Err(AloError::LooComputationFailed {
945                    reason: format!(
946                        "ALO Bayesian variance is materially negative at row {obs}: var={var_bayes:.6e}, tol={bayes_tol:.6e}"
947                    ),
948                });
949            }
950            let sandwich_scale = phi * meat_quad.abs().max(x_hinv_x.abs());
951            let sandwich_tol = variance_negative_tolerance(sandwich_scale);
952            if var_sandwich < -sandwich_tol {
953                return Err(AloError::LooComputationFailed {
954                    reason: format!(
955                        "ALO sandwich variance is materially negative at row {obs}: var={var_sandwich:.6e}, tol={sandwich_tol:.6e}"
956                    ),
957                });
958            }
959
960            se_bayes[obs] = var_bayes.max(0.0).sqrt();
961            se_sandwich[obs] = var_sandwich.max(0.0).sqrt();
962        }
963    }
964
965    let eta_hat = input.eta;
966    let z = input.working_response;
967    let offset = input.offset;
968
969    use rayon::prelude::*;
970    let eta_tilde_vec: Vec<f64> = (0..n)
971        .into_par_iter()
972        .map(|i| {
973            let denom_raw = 1.0 - aii[i];
974            if denom_raw <= ALO_DENOMINATOR_MIN || !denom_raw.is_finite() {
975                return Err(AloError::LooComputationFailed {
976                    reason: format!(
977                        "ALO denominator is too small at row {i}: a_ii={:.6e}, 1-a_ii={:.6e}, min={:.1e}",
978                        aii[i], denom_raw, ALO_DENOMINATOR_MIN
979                    ),
980                });
981            }
982            let one_step = alo_eta_updatewith_offset(
983                eta_hat[i],
984                z[i],
985                offset[i],
986                x_hinv_x_diag[i],
987                w_s[i],
988                denom_raw,
989            );
990            // When the family score/curvature evaluator is supplied, solve the
991            // exact frozen-curvature leave-i-out fixed point (anchored at η̂_i,
992            // the basin that limits to the in-sample fit) instead of taking the
993            // single Newton step. a_ii here is the unweighted influence
994            // x_i^T H^{-1} x_i (= x_hinv_x_diag[i]); the per-row curvature
995            // W_H[i] = ℓ_i''(η̂_i) is folded into the scalar fixed point via
996            // score_curvature. Non-canonical links fall back to `one_step`.
997            let v = if let Some(score_curvature) = input.score_curvature {
998                alo_eta_exact_frozen_curvature(
999                    eta_hat[i],
1000                    x_hinv_x_diag[i],
1001                    &|eta| score_curvature(i, eta),
1002                )
1003                .map_err(|err| AloError::LooComputationFailed {
1004                    reason: format!(
1005                        "ALO exact frozen-curvature solve failed at row {i}: {err}"
1006                    ),
1007                })?
1008            } else {
1009                one_step
1010            };
1011            if !v.is_finite() {
1012                return Err(AloError::LooComputationFailed {
1013                    reason: format!("ALO eta_tilde is not finite at row {i}: eta_tilde={v}"),
1014                });
1015            }
1016            Ok(v)
1017        })
1018        .collect::<Result<_, _>>()?;
1019    let eta_tilde = Array1::from(eta_tilde_vec);
1020
1021    Ok(AloDiagnostics {
1022        eta_tilde,
1023        se_bayes,
1024        se_sandwich,
1025        pred_identity: eta_hat.clone(),
1026        leverage: aii,
1027        fisherweights: w_h.to_owned(),
1028    })
1029}
1030
1031fn validate_alo_solve_setup(input: &AloInput, n: usize, p: usize) -> Result<(), AloError> {
1032    let h = input.penalized_hessian;
1033    if h.nrows() != p || h.ncols() != p {
1034        return Err(AloError::InvalidInput {
1035            reason: format!(
1036                "ALO diagnostics require a dense exact penalized Hessian with shape {p}x{p}; got {}x{}",
1037                h.nrows(),
1038                h.ncols()
1039            ),
1040        });
1041    }
1042    if h.iter().any(|v| !v.is_finite()) {
1043        return Err(AloError::InvalidInput {
1044            reason: "ALO diagnostics require a finite dense exact penalized Hessian".to_string(),
1045        });
1046    }
1047    for i in 0..p {
1048        for j in 0..i {
1049            let a = h[[i, j]];
1050            let b = h[[j, i]];
1051            let scale = a.abs().max(b.abs()).max(1.0);
1052            if (a - b).abs() > HESSIAN_SYMMETRY_REL_TOL * scale {
1053                return Err(AloError::InvalidInput {
1054                    reason: format!(
1055                        "ALO diagnostics require a symmetric dense exact penalized Hessian; entries ({i},{j}) and ({j},{i}) differ by {:.3e}",
1056                        (a - b).abs()
1057                    ),
1058                });
1059            }
1060        }
1061    }
1062
1063    let vector_lengths = [
1064        ("hessian_weights", input.hessian_weights.len()),
1065        ("score_weights", input.score_weights.len()),
1066        ("working_response", input.working_response.len()),
1067        ("eta", input.eta.len()),
1068        ("offset", input.offset.len()),
1069    ];
1070    for (name, len) in vector_lengths {
1071        if len != n {
1072            return Err(AloError::InvalidInput {
1073                reason: format!("ALO diagnostics require {name} length {n}; got {len}"),
1074            });
1075        }
1076    }
1077    if input.hessian_weights.view().iter().any(|v| !v.is_finite()) {
1078        return Err(AloError::WeightInvalid {
1079            reason: "ALO diagnostics require finite Hessian-side weights".to_string(),
1080        });
1081    }
1082    if input.score_weights.view().iter().any(|v| !v.is_finite()) {
1083        return Err(AloError::WeightInvalid {
1084            reason: "ALO diagnostics require finite score-side weights".to_string(),
1085        });
1086    }
1087    if input.working_response.iter().any(|v| !v.is_finite()) {
1088        return Err(AloError::WeightInvalid {
1089            reason: "ALO diagnostics require finite working responses".to_string(),
1090        });
1091    }
1092    if input.eta.iter().any(|v| !v.is_finite()) || input.offset.iter().any(|v| !v.is_finite()) {
1093        return Err(AloError::InvalidInput {
1094            reason: "ALO diagnostics require finite linear predictors and offsets".to_string(),
1095        });
1096    }
1097    if !input.phi.is_finite() || input.phi <= 0.0 {
1098        return Err(AloError::InvalidInput {
1099            reason: format!(
1100                "ALO diagnostics require positive finite dispersion phi; got {}",
1101                input.phi
1102            ),
1103        });
1104    }
1105    if !input.ridge.is_finite() || input.ridge < 0.0 {
1106        return Err(AloError::InvalidInput {
1107            reason: format!(
1108                "ALO diagnostics require a finite non-negative Hessian ridge; got {}",
1109                input.ridge
1110            ),
1111        });
1112    }
1113    if let Some(e) = input.penalty_root {
1114        if e.ncols() != p {
1115            return Err(AloError::InvalidInput {
1116                reason: format!(
1117                    "ALO diagnostics require penalty root to have {p} columns; got {}",
1118                    e.ncols()
1119                ),
1120            });
1121        }
1122        if e.iter().any(|v| !v.is_finite()) {
1123            return Err(AloError::InvalidInput {
1124                reason: "ALO diagnostics require finite penalty-root entries".to_string(),
1125            });
1126        }
1127    }
1128    Ok(())
1129}
1130
1131/// Compute ALO diagnostics (eta_tilde, SE, leverage) from a fitted GAM result.
1132pub fn compute_alo_diagnostics_from_fit(
1133    fit: &UnifiedFitResult,
1134    y: ArrayView1<f64>,
1135    link: LinkFunction,
1136) -> Result<AloDiagnostics, EstimationError> {
1137    let pirls = fit
1138        .artifacts
1139        .pirls
1140        .as_ref()
1141        .ok_or_else(|| AloError::InvalidInput {
1142            reason:
1143                "ALO diagnostics require a PIRLS-backed fit; this fit does not expose PIRLS geometry"
1144                    .to_string(),
1145        })
1146        .map_err(EstimationError::from)?;
1147    compute_alo_diagnostics_from_pirls_impl(pirls, y, link)
1148}
1149
1150/// Compute ALO diagnostics from a `UnifiedFitResult`.
1151///
1152/// Extracts `FitGeometry` from `unified.geometry`, builds an `AloInput`
1153/// via `from_geometry`, and delegates to `compute_alo_from_input`.
1154/// This avoids requiring a full `UnifiedFitResult` with PIRLS artifacts.
1155pub fn compute_alo_diagnostics_from_unified(
1156    unified: &UnifiedFitResult,
1157    design: &Array2<f64>,
1158    eta: &Array1<f64>,
1159    offset: &Array1<f64>,
1160    link: LinkFunction,
1161    phi: f64,
1162) -> Result<AloDiagnostics, EstimationError> {
1163    let geom = unified
1164        .geometry
1165        .as_ref()
1166        .ok_or_else(|| AloError::InvalidInput {
1167            reason: "UnifiedFitResult does not contain working-set geometry; \
1168             ALO diagnostics require geometry at convergence"
1169                .to_string(),
1170        })
1171        .map_err(EstimationError::from)?;
1172    let input = AloInput::from_geometry(geom, design, eta, offset, link, phi);
1173    compute_alo_from_input(&input)
1174}
1175
1176/// Compute ALO diagnostics from a PIRLS result for lower-level callers.
1177pub fn compute_alo_diagnostics_from_pirls(
1178    base: &pirls::PirlsResult,
1179    y: ArrayView1<f64>,
1180    link: LinkFunction,
1181) -> Result<AloDiagnostics, EstimationError> {
1182    compute_alo_diagnostics_from_pirls_impl(base, y, link)
1183}
1184
1185/// Exact (one-step) case-deletion influence from a converged PIRLS fit, via
1186/// the one `FitSensitivity` operator (#935).
1187///
1188/// This is the diagnostic the sensitivity operator's `case_deletion` channel
1189/// was built to expose but had no production entry point for: per-observation
1190/// dfbetas `β̂ − β̂₍ᵢ₎`, hat-value leverage `h_ii = w_i x_iᵀ H⁻¹ x_i`, and
1191/// Cook's distance. It is the same factored inverse the REML gradient (IFT),
1192/// ALO, and the Riesz debias already contract — built once at the optimum,
1193/// asked in the leave-one-out direction — so no call site can disagree about
1194/// which `H⁻¹` is meant (the bug class #935 dismantles).
1195///
1196/// The penalized Hessian, design, working weights `w_i = W_H[i]` and working
1197/// residual `z_i − η̂_i` are read straight from the converged geometry — the
1198/// same PIRLS state [`compute_alo_diagnostics_from_pirls`] consumes — so the
1199/// IRLS reduction `scale = w_i r_i / (1 − h_ii)` is exact for the Gaussian
1200/// identity link and the one-step Newton deletion for canonical-link GLMs.
1201/// Returns `None` (rather than emitting `∞`) for any observation whose
1202/// leverage is one, or if the dense Hessian / design is unavailable.
1203pub fn compute_case_deletion_from_pirls(
1204    base: &pirls::PirlsResult,
1205    y: ArrayView1<f64>,
1206    link: LinkFunction,
1207) -> Result<Option<crate::sensitivity::CaseDeletionInfluence>, EstimationError> {
1208    let x_dense_arc = base
1209        .x_transformed
1210        .try_to_dense_arc("case-deletion diagnostics require dense transformed design")
1211        .map_err(|reason| EstimationError::InvalidInput(reason))?;
1212    let x_dense = x_dense_arc.as_ref();
1213    let n = x_dense.nrows();
1214    let p = x_dense.ncols();
1215    if n == 0 || p == 0 {
1216        return Ok(None);
1217    }
1218
1219    // Dispersion φ matches the ALO entry point: estimated RSS/(n−edf) for the
1220    // Gaussian identity link, fixed at 1 for the single-parameter families.
1221    let phi = match link {
1222        LinkFunction::Identity => {
1223            use rayon::iter::{IntoParallelIterator, ParallelIterator};
1224            let rss: f64 = (0..n)
1225                .into_par_iter()
1226                .map(|i| {
1227                    let r = y[i] - base.finalmu[i];
1228                    base.finalweights[i] * r * r
1229                })
1230                .sum();
1231            let dof = (n as f64) - base.edf;
1232            rss / dof.max(1.0)
1233        }
1234        _ => 1.0,
1235    };
1236    if !(phi.is_finite() && phi > 0.0) {
1237        return Ok(None);
1238    }
1239
1240    // The same dense stabilized penalized Hessian ALO materializes; the one
1241    // factored inverse every sensitivity channel shares.
1242    let h_dense = base
1243        .dense_stabilizedhessian_transformed(
1244            "case-deletion diagnostics require exact dense stabilized penalized Hessian",
1245        )
1246        .map_err(|e| match e {
1247            EstimationError::InvalidInput(reason) => EstimationError::InvalidInput(reason),
1248            other => EstimationError::InvalidInput(format!("{other:?}")),
1249        })?;
1250
1251    let factor = match h_dense.cholesky(faer::Side::Lower) {
1252        Ok(f) => f,
1253        // A non-SPD stabilized Hessian means the optimum is rank-deficient in a
1254        // way the dense Cholesky case-deletion path cannot invert; decline
1255        // rather than fabricate an influence diagnostic.
1256        Err(_) => return Ok(None),
1257    };
1258
1259    // Working weights and working residual straight from the IRLS reduction:
1260    // w_i = W_H[i] and r_i = z_i − η̂_i, so w_i r_i is the working score the
1261    // closed-form deletion `scale = w_i r_i / (1 − h_ii)` consumes.
1262    let working_weights = base.finalweights.clone();
1263    let working_residual = &base.solveworking_response - &base.final_eta;
1264
1265    let sensitivity = crate::sensitivity::FitSensitivity::from_faer_cholesky(&factor, p);
1266    Ok(sensitivity.case_deletion(
1267        x_dense,
1268        working_weights.view(),
1269        working_residual.view(),
1270        phi,
1271    ))
1272}
1273
1274// Multi-block ALO for multi-predictor models (GAMLSS, survival, joint)
1275
1276/// Diagnostics returned by multi-block ALO.
1277#[derive(Debug, Clone)]
1278pub struct MultiBlockAloDiagnostics {
1279    /// Corrected linear predictors η̃^{(-i)} for each observation.
1280    /// Outer length = n_obs, inner length = n_blocks (B).
1281    pub eta_tilde: Vec<Array1<f64>>,
1282    /// Per-observation leverage tr(H_ii) where H_ii is the B×B hat-matrix block.
1283    pub leverage: Array1<f64>,
1284    /// Per-observation ALO variance diagonals: for each observation i,
1285    /// Var(Δη_i) ≈ A_i (I - W_i A_i)⁻¹ W_i (I - A_i W_i)⁻¹ A_iᵀ.
1286    /// Outer length = n_obs, inner length = n_blocks (B) containing the
1287    /// diagonal entries of the variance matrix.
1288    pub alo_variance: Vec<Array1<f64>>,
1289    /// Cook-type ALO influence: D_i = Δη_iᵀ W_i Δη_i.
1290    /// Length = n_obs.
1291    pub cook_distance: Array1<f64>,
1292}
1293
1294/// Model-agnostic input for multi-predictor ALO diagnostics.
1295///
1296/// Generalises [`AloInput`] to models with B > 1 linear predictors per
1297/// observation (e.g. location-scale GAMLSS with B=2, or survival models
1298/// with time-dependent predictors).
1299///
1300/// # Mathematical setup
1301///
1302/// For observation i the per-observation Jacobian is a B × p_tot block matrix
1303/// X_i whose b-th row is the i-th row of `block_designs[b]`.  The joint
1304/// hat-matrix block is
1305///
1306///   H_ii = X_i H⁻¹ X_iᵀ W_i     (B × B)
1307///
1308/// where H = Σ_i X_iᵀ W_i X_i + S is the total penalized Hessian and W_i
1309/// is the B × B per-observation weight matrix (negative Hessian of the
1310/// log-likelihood w.r.t. the B predictors at observation i).
1311///
1312/// The ALO leave-one-out correction is
1313///
1314///   Δη_i^ALO = A_i (I_B − W_i A_i)⁻¹ s_i
1315///
1316/// where A_i = X_i H⁻¹ X_iᵀ (the B×B per-observation influence matrix),
1317/// W_i is the B×B per-observation NLL Hessian, and
1318/// s_i = ∇_{η_i} NLL_i(η̂_i) is the B-dimensional score vector.
1319/// This is algebraically equivalent to (I_B − H_ii)⁻¹ H_ii W_i⁻¹ s_i
1320/// but does NOT require W_i⁻¹, which is critical when W_i is singular
1321/// (e.g. at boundary observations in survival models).
1322/// For B = 1 this reduces to the classical scalar ALO formula.
1323pub struct MultiBlockAloInput<'a> {
1324    /// Number of observations.
1325    pub n_obs: usize,
1326    /// Number of predictors per observation (B).
1327    pub n_blocks: usize,
1328    /// B design matrices, each n_obs × p_b.  The total parameter count is
1329    /// p_tot = Σ_b p_b.
1330    pub block_designs: &'a [Array2<f64>],
1331    /// Inverse of the penalized Hessian, H⁻¹ (p_tot × p_tot).
1332    pub penalized_hessian_inv: &'a Array2<f64>,
1333    /// Per-observation weight matrices W_i (B × B).  Length = n_obs.
1334    pub block_weights: Vec<Array2<f64>>,
1335    /// Per-observation score vectors s_i = ∇_{η_i} NLL_i.  Length = n_obs,
1336    /// each entry is B-dimensional.
1337    pub scores: Vec<Array1<f64>>,
1338    /// Fitted linear predictor vectors η̂_i.  Length = n_obs, each entry is
1339    /// B-dimensional.
1340    pub eta_hat: Vec<Array1<f64>>,
1341}
1342
1343/// Compute multi-block ALO diagnostics: corrected η̃ and leverages.
1344///
1345/// # Optimisation note
1346///
1347/// The dominant cost is forming X_i H⁻¹ X_iᵀ for every observation.
1348/// Rather than forming the B × p_tot row-block X_i and multiplying naïvely,
1349/// we precompute for each block b the matrix
1350///
1351///   Q_b = H⁻¹ X_bᵀ      (p_tot × n)
1352///
1353/// Then the (a, b) entry of the B × B matrix X_i H⁻¹ X_iᵀ is simply
1354///
1355///   (X_i H⁻¹ X_iᵀ)_{a,b} = x_{a,i}ᵀ Q_b[:,i]
1356///                           = Σ_k  X_a[i,k] · Q_b[k,i]
1357///
1358/// where x_{a,i} is the i-th row of block-design a.  This turns the per-
1359/// observation work from O(B · p_tot²) into O(B² · p_tot), and the
1360/// precomputation is O(B · p_tot² · n) total via a single blocked solve.
1361pub fn compute_multiblock_alo(
1362    input: &MultiBlockAloInput,
1363) -> Result<MultiBlockAloDiagnostics, EstimationError> {
1364    compute_multiblock_alo_inner(input).map_err(EstimationError::from)
1365}
1366
1367fn compute_multiblock_alo_inner(
1368    input: &MultiBlockAloInput,
1369) -> Result<MultiBlockAloDiagnostics, AloError> {
1370    use rayon::prelude::*;
1371
1372    let n = input.n_obs;
1373    let b = input.n_blocks;
1374    let p_tot = input.penalized_hessian_inv.nrows();
1375
1376    // --- Validate dimensions ---
1377    if input.block_designs.len() != b {
1378        return Err(AloError::InvalidInput {
1379            reason: format!(
1380                "MultiBlockAloInput: expected {} block designs, got {}",
1381                b,
1382                input.block_designs.len()
1383            ),
1384        });
1385    }
1386
1387    // Verify total column count matches p_tot.
1388    let col_sum: usize = input.block_designs.iter().map(|d| d.ncols()).sum();
1389    if col_sum != p_tot {
1390        return Err(AloError::InvalidInput {
1391            reason: format!(
1392                "MultiBlockAloInput: total design columns ({}) != penalized_hessian_inv size ({})",
1393                col_sum, p_tot
1394            ),
1395        });
1396    }
1397
1398    let col_offsets = multiblock_col_offsets(input.block_designs);
1399    let (chunk_size, max_concurrent_chunks) = multiblock_alo_parallel_plan(p_tot, b, n);
1400    let chunk_starts: Vec<usize> = (0..n).step_by(chunk_size).collect();
1401
1402    // Each Rayon worker owns its small B×B/B-vector scratch buffers via
1403    // `map_init`, avoiding cross-thread mutation and avoiding per-observation
1404    // allocations.  The much larger Q panels are bounded by the parallel chunk
1405    // size and by wave-level concurrency, so at most roughly one global memory
1406    // budget worth of p_total × chunk_len panels can be live across workers.
1407    let mut chunk_results: Vec<Result<MultiBlockAloChunkDiagnostics, AloError>> =
1408        Vec::with_capacity(chunk_starts.len());
1409    for chunk_wave in chunk_starts.chunks(max_concurrent_chunks) {
1410        let mut wave_results: Vec<Result<MultiBlockAloChunkDiagnostics, AloError>> = chunk_wave
1411            .par_iter()
1412            .map_init(
1413                || MultiBlockAloScratch::new(b),
1414                |scratch, &chunk_start| {
1415                    let chunk_end = (chunk_start + chunk_size).min(n);
1416                    compute_multiblock_alo_chunk(
1417                        input,
1418                        &col_offsets,
1419                        chunk_start,
1420                        chunk_end,
1421                        scratch,
1422                    )
1423                },
1424            )
1425            .collect();
1426        chunk_results.append(&mut wave_results);
1427    }
1428
1429    let mut eta_tilde = Vec::with_capacity(n);
1430    let mut leverage = Array1::<f64>::zeros(n);
1431    let mut alo_variance = Vec::with_capacity(n);
1432    let mut cook_distance = Array1::<f64>::zeros(n);
1433
1434    let mut chunks = Vec::with_capacity(chunk_results.len());
1435    for result in chunk_results {
1436        chunks.push(result?);
1437    }
1438    chunks.sort_unstable_by_key(|chunk| chunk.chunk_start);
1439
1440    for chunk in chunks {
1441        let chunk_start = chunk.chunk_start;
1442        eta_tilde.extend(chunk.eta_tilde);
1443        alo_variance.extend(chunk.alo_variance);
1444        for (local_i, lev) in chunk.leverage.into_iter().enumerate() {
1445            leverage[chunk_start + local_i] = lev;
1446        }
1447        for (local_i, cook) in chunk.cook_distance.into_iter().enumerate() {
1448            cook_distance[chunk_start + local_i] = cook;
1449        }
1450    }
1451
1452    Ok(MultiBlockAloDiagnostics {
1453        eta_tilde,
1454        leverage,
1455        alo_variance,
1456        cook_distance,
1457    })
1458}
1459
1460#[inline]
1461fn multiblock_alo_parallel_plan(p_tot: usize, n_blocks: usize, n_obs: usize) -> (usize, usize) {
1462    if p_tot == 0 || n_blocks == 0 || n_obs == 0 {
1463        return (1, 1);
1464    }
1465    let bytes_per_obs = (p_tot * n_blocks * std::mem::size_of::<f64>()).max(1);
1466    let workers = rayon::current_num_threads().max(1);
1467    let max_concurrent_chunks = (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / bytes_per_obs)
1468        .max(1)
1469        .min(workers);
1470    let per_worker_budget =
1471        (MULTIBLOCK_ALO_MEMORY_BUDGET_BYTES / max_concurrent_chunks).max(bytes_per_obs);
1472    let budget_obs = (per_worker_budget / bytes_per_obs).max(1);
1473    (budget_obs.min(n_obs), max_concurrent_chunks)
1474}
1475
1476struct MultiBlockAloScratch {
1477    a_i: Vec<f64>,
1478    wa: Vec<f64>,
1479    aw: Vec<f64>,
1480    imwa: Vec<f64>,
1481    imaw: Vec<f64>,
1482    perm_imwa: Vec<usize>,
1483    perm_imaw: Vec<usize>,
1484    delta_eta: Vec<f64>,
1485    rhs_buf: Vec<f64>,
1486    w_u: Vec<f64>,
1487    var_diag_buf: Vec<f64>,
1488    w_flat: Vec<f64>,
1489    lu_scratch: Vec<f64>,
1490}
1491
1492impl MultiBlockAloScratch {
1493    fn new(b: usize) -> Self {
1494        let bb_sz = b * b;
1495        Self {
1496            a_i: vec![0.0f64; bb_sz],
1497            wa: vec![0.0f64; bb_sz],
1498            aw: vec![0.0f64; bb_sz],
1499            imwa: vec![0.0f64; bb_sz],
1500            imaw: vec![0.0f64; bb_sz],
1501            perm_imwa: vec![0usize; b],
1502            perm_imaw: vec![0usize; b],
1503            delta_eta: vec![0.0f64; b],
1504            rhs_buf: vec![0.0f64; b],
1505            w_u: vec![0.0f64; b],
1506            var_diag_buf: vec![0.0f64; b],
1507            w_flat: vec![0.0f64; bb_sz],
1508            lu_scratch: vec![0.0f64; b],
1509        }
1510    }
1511}
1512
1513struct MultiBlockAloChunkDiagnostics {
1514    chunk_start: usize,
1515    eta_tilde: Vec<Array1<f64>>,
1516    leverage: Vec<f64>,
1517    alo_variance: Vec<Array1<f64>>,
1518    cook_distance: Vec<f64>,
1519}
1520
1521fn compute_multiblock_alo_chunk(
1522    input: &MultiBlockAloInput,
1523    col_offsets: &[usize],
1524    chunk_start: usize,
1525    chunk_end: usize,
1526    scratch: &mut MultiBlockAloScratch,
1527) -> Result<MultiBlockAloChunkDiagnostics, AloError> {
1528    let b = input.n_blocks;
1529    let chunk_len = chunk_end - chunk_start;
1530
1531    let mut q_blocks = Vec::with_capacity(b);
1532    for blk in 0..b {
1533        let x_chunk_t = input.block_designs[blk]
1534            .slice(s![chunk_start..chunk_end, ..])
1535            .t()
1536            .to_owned();
1537        let off_b = col_offsets[blk];
1538        let h_slice = input
1539            .penalized_hessian_inv
1540            .slice(s![.., off_b..off_b + x_chunk_t.nrows()])
1541            .to_owned();
1542        q_blocks.push(h_slice.dot(&x_chunk_t));
1543    }
1544
1545    let mut eta_tilde = Vec::with_capacity(chunk_len);
1546    let mut leverage = vec![0.0f64; chunk_len];
1547    let mut alo_variance = Vec::with_capacity(chunk_len);
1548    let mut cook_distance = vec![0.0f64; chunk_len];
1549
1550    for local_i in 0..chunk_len {
1551        let i = chunk_start + local_i;
1552        let w_i = &input.block_weights[i];
1553
1554        // Flatten W_i once per observation (row-major).
1555        for r in 0..b {
1556            for c in 0..b {
1557                scratch.w_flat[r * b + c] = w_i[(r, c)];
1558            }
1559        }
1560
1561        // --- Assemble A_i = X_i H⁻¹ X_iᵀ  (B × B), row-major flat. ---
1562        for a in 0..b {
1563            let x_a = &input.block_designs[a];
1564            let p_a = x_a.ncols();
1565            let off_a = col_offsets[a];
1566            let xa_row = x_a.row(i);
1567            for bb in 0..b {
1568                let q_bb = &q_blocks[bb];
1569                let mut dot = 0.0f64;
1570                for k in 0..p_a {
1571                    dot += xa_row[k] * q_bb[(off_a + k, local_i)];
1572                }
1573                scratch.a_i[a * b + bb] = dot;
1574            }
1575        }
1576
1577        // WA = W_i · A_i (row-major).
1578        mat_mul_flat(&scratch.w_flat, &scratch.a_i, &mut scratch.wa, b);
1579        // AW = A_i · W_i (row-major).
1580        mat_mul_flat(&scratch.a_i, &scratch.w_flat, &mut scratch.aw, b);
1581
1582        // Trace of H_ii = A_i W_i (= AW): leverage[i].
1583        // (Original code wrote H_ii = A · W — the same operator we already have in `aw`.)
1584        let mut tr = 0.0f64;
1585        for d in 0..b {
1586            tr += scratch.aw[d * b + d];
1587        }
1588        leverage[local_i] = tr;
1589
1590        // Build (I - W A) and (I - A W) into imwa/imaw.
1591        for r in 0..b {
1592            for c in 0..b {
1593                let idx = r * b + c;
1594                let id = if r == c { 1.0 } else { 0.0 };
1595                scratch.imwa[idx] = id - scratch.wa[idx];
1596                scratch.imaw[idx] = id - scratch.aw[idx];
1597            }
1598        }
1599
1600        // Factor in place with partial pivoting; ridge on the diagonal if singular.
1601        // Equivalence with original: original computed det via det_small, regularized
1602        // by adding eps=1e-6 to the diagonal when |det| < 1e-12, then re-factored on
1603        // the regularized matrix. Here we factor directly; if any pivot is below the
1604        // singular threshold we add the ridge once and re-factor — same numerical path.
1605        if !lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b) {
1606            for r in 0..b {
1607                for c in 0..b {
1608                    let idx = r * b + c;
1609                    let id = if r == c { 1.0 } else { 0.0 };
1610                    scratch.imwa[idx] = id - scratch.wa[idx];
1611                }
1612            }
1613            for d in 0..b {
1614                scratch.imwa[d * b + d] += ALO_LOCAL_BLOCK_RIDGE;
1615            }
1616            let refactored = lu_factor_in_place(&mut scratch.imwa, &mut scratch.perm_imwa, b);
1617            assert!(
1618                refactored,
1619                "ALO local block remained singular after ridge regularization"
1620            );
1621        }
1622        if !lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b) {
1623            for r in 0..b {
1624                for c in 0..b {
1625                    let idx = r * b + c;
1626                    let id = if r == c { 1.0 } else { 0.0 };
1627                    scratch.imaw[idx] = id - scratch.aw[idx];
1628                }
1629            }
1630            for d in 0..b {
1631                scratch.imaw[d * b + d] += ALO_LOCAL_BLOCK_RIDGE;
1632            }
1633            let refactored = lu_factor_in_place(&mut scratch.imaw, &mut scratch.perm_imaw, b);
1634            assert!(
1635                refactored,
1636                "ALO local variance block remained singular after ridge regularization"
1637            );
1638        }
1639
1640        // v_i = (I - W A)⁻¹ s_i  -- solve into rhs_buf.
1641        let s_i = &input.scores[i];
1642        for k in 0..b {
1643            scratch.rhs_buf[k] = s_i[k];
1644        }
1645        lu_solve_in_place(
1646            &scratch.imwa,
1647            &scratch.perm_imwa,
1648            &mut scratch.rhs_buf,
1649            &mut scratch.lu_scratch,
1650            b,
1651        );
1652        // delta_eta = A_i · v_i
1653        for r in 0..b {
1654            let mut acc = 0.0f64;
1655            let row_off = r * b;
1656            for k in 0..b {
1657                acc += scratch.a_i[row_off + k] * scratch.rhs_buf[k];
1658            }
1659            scratch.delta_eta[r] = acc;
1660        }
1661
1662        let eta_i = &input.eta_hat[i];
1663        let mut corrected = Array1::<f64>::zeros(b);
1664        for d in 0..b {
1665            corrected[d] = eta_i[d] + scratch.delta_eta[d];
1666        }
1667        eta_tilde.push(corrected);
1668
1669        // Cook's distance: δη^T W δη.
1670        let mut cook = 0.0f64;
1671        for r in 0..b {
1672            let mut w_delta_r = 0.0f64;
1673            let row_off = r * b;
1674            for k in 0..b {
1675                w_delta_r += scratch.w_flat[row_off + k] * scratch.delta_eta[k];
1676            }
1677            cook += scratch.delta_eta[r] * w_delta_r;
1678        }
1679        cook_distance[local_i] = cook;
1680
1681        // var_diag[d] = a_d^T (I-WA)⁻¹ W (I-AW)⁻¹ a_d
1682        // where a_d is the d-th row of A_i.
1683        // Reuses already-factored imwa and imaw (one LU factorization each, reused
1684        // across all B right-hand sides — major saving over the original which redid
1685        // both LU decompositions B times per observation).
1686        for d in 0..b {
1687            let row_off = d * b;
1688            // u_d = (I - A W)⁻¹ a_d
1689            for k in 0..b {
1690                scratch.rhs_buf[k] = scratch.a_i[row_off + k];
1691            }
1692            lu_solve_in_place(
1693                &scratch.imaw,
1694                &scratch.perm_imaw,
1695                &mut scratch.rhs_buf,
1696                &mut scratch.lu_scratch,
1697                b,
1698            );
1699            // w_u = W u_d
1700            for r in 0..b {
1701                let mut acc = 0.0f64;
1702                let wr = r * b;
1703                for k in 0..b {
1704                    acc += scratch.w_flat[wr + k] * scratch.rhs_buf[k];
1705                }
1706                scratch.w_u[r] = acc;
1707            }
1708            // t_d = (I - W A)⁻¹ w_u  (back-solve in place using w_u as RHS).
1709            lu_solve_in_place(
1710                &scratch.imwa,
1711                &scratch.perm_imwa,
1712                &mut scratch.w_u,
1713                &mut scratch.lu_scratch,
1714                b,
1715            );
1716            // v_dd = a_d^T t_d
1717            let mut v_dd = 0.0f64;
1718            for k in 0..b {
1719                v_dd += scratch.a_i[row_off + k] * scratch.w_u[k];
1720            }
1721            scratch.var_diag_buf[d] = v_dd.max(0.0);
1722        }
1723        let mut var_diag = Array1::<f64>::zeros(b);
1724        for d in 0..b {
1725            var_diag[d] = scratch.var_diag_buf[d];
1726        }
1727        alo_variance.push(var_diag);
1728    }
1729
1730    Ok(MultiBlockAloChunkDiagnostics {
1731        chunk_start,
1732        eta_tilde,
1733        leverage,
1734        alo_variance,
1735        cook_distance,
1736    })
1737}
1738
1739/// B × B row-major matmul: out = a · b.
1740#[inline]
1741fn mat_mul_flat(a: &[f64], b_mat: &[f64], out: &mut [f64], b: usize) {
1742    for r in 0..b {
1743        let ar = r * b;
1744        let or = r * b;
1745        for c in 0..b {
1746            let mut acc = 0.0f64;
1747            for k in 0..b {
1748                acc += a[ar + k] * b_mat[k * b + c];
1749            }
1750            out[or + c] = acc;
1751        }
1752    }
1753}
1754
1755/// LU-decompose a B × B row-major matrix in place with partial pivoting and
1756/// physical row swaps. Returns false if any pivot |a_kk| < 1e-12 (singular).
1757/// On success, `m` holds L (strict lower, unit diag implicit) and U (upper, diag
1758/// included); `perm[k]` records the original-row index that ended up in physical
1759/// row k after pivoting. Pivot threshold matches the original `det_small < 1e-12`
1760/// path so the regularization branch fires under equivalent conditions.
1761fn lu_factor_in_place(m: &mut [f64], perm: &mut [usize], b: usize) -> bool {
1762    for i in 0..b {
1763        perm[i] = i;
1764    }
1765    for col in 0..b {
1766        // Partial pivot on column `col` over physical rows `[col..b]`.
1767        let mut max_val = m[col * b + col].abs();
1768        let mut max_idx = col;
1769        for row in (col + 1)..b {
1770            let v = m[row * b + col].abs();
1771            if v > max_val {
1772                max_val = v;
1773                max_idx = row;
1774            }
1775        }
1776        if max_val < LU_PIVOT_SINGULAR_TOL {
1777            return false;
1778        }
1779        if max_idx != col {
1780            // Physically swap rows `col` and `max_idx` (full row, all columns).
1781            for k in 0..b {
1782                m.swap(col * b + k, max_idx * b + k);
1783            }
1784            perm.swap(col, max_idx);
1785        }
1786        let pivot = m[col * b + col];
1787        for row in (col + 1)..b {
1788            let factor = m[row * b + col] / pivot;
1789            m[row * b + col] = factor; // store L below diag
1790            for k in (col + 1)..b {
1791                let upd = factor * m[col * b + k];
1792                m[row * b + k] -= upd;
1793            }
1794        }
1795    }
1796    true
1797}
1798
1799/// Solve L U x = P rhs using a previously factored matrix (LU in `m`, perm).
1800/// Writes the solution back into `rhs`. `scratch` must have length ≥ b.
1801fn lu_solve_in_place(m: &[f64], perm: &[usize], rhs: &mut [f64], scratch: &mut [f64], b: usize) {
1802    // Forward substitution Ly = P rhs (L is unit-diag, strict lower of m).
1803    let y = &mut scratch[..b];
1804    for row in 0..b {
1805        let mut s = rhs[perm[row]];
1806        for k in 0..row {
1807            s -= m[row * b + k] * y[k];
1808        }
1809        y[row] = s;
1810    }
1811    // Back substitution U x = y.  Write into rhs[].
1812    for row in (0..b).rev() {
1813        let mut s = y[row];
1814        for k in (row + 1)..b {
1815            s -= m[row * b + k] * rhs[k];
1816        }
1817        rhs[row] = s / m[row * b + row];
1818    }
1819}
1820
1821/// Compute only per-observation leverages tr(H_ii) for multi-predictor models.
1822///
1823/// This is cheaper than the full ALO correction when only EDF or leverage
1824/// diagnostics are needed (no scores or W⁻¹ computation required).
1825///
1826/// Returns an n-length array of leverages.  The total model EDF is the sum
1827/// of all leverages.
1828pub fn compute_multiblock_alo_leverages(
1829    n_obs: usize,
1830    n_blocks: usize,
1831    block_designs: &[Array2<f64>],
1832    penalized_hessian_inv: &Array2<f64>,
1833    block_weights: &[Array2<f64>],
1834) -> Result<Array1<f64>, EstimationError> {
1835    use rayon::prelude::*;
1836
1837    let n = n_obs;
1838    let b = n_blocks;
1839    let p_tot = penalized_hessian_inv.nrows();
1840
1841    let col_offsets = multiblock_col_offsets(block_designs);
1842    let max_workers = rayon::current_num_threads();
1843    let chunk_size = multiblock_alo_parallel_leverage_chunk_size(p_tot, b, n, max_workers);
1844
1845    let mut leverage = Array1::<f64>::zeros(n);
1846
1847    // Per-block H_inv stripe scratch (p_tot × p_blk) is read-only once built
1848    // and shared by the parallel chunks.  Only per-chunk q/XT/B×B scratch is
1849    // replicated across Rayon workers.
1850    let block_widths: Vec<usize> = block_designs.iter().map(|d| d.ncols()).collect();
1851    let mut h_stripes: Vec<FaerMat<f64>> = block_widths
1852        .iter()
1853        .map(|&p_blk| FaerMat::<f64>::zeros(p_tot, p_blk))
1854        .collect();
1855    // Populate the H_inv stripes once: each block reads a constant column slab
1856    // out of `penalized_hessian_inv` and copies it into a column-major faer Mat.
1857    for blk in 0..b {
1858        let off_b = col_offsets[blk];
1859        let p_blk = block_widths[blk];
1860        let stripe = &mut h_stripes[blk];
1861        for c in 0..p_blk {
1862            for r in 0..p_tot {
1863                stripe[(r, c)] = penalized_hessian_inv[(r, off_b + c)];
1864            }
1865        }
1866    }
1867
1868    leverage
1869        .as_slice_mut()
1870        .expect("newly allocated Array1 is contiguous")
1871        .par_chunks_mut(chunk_size)
1872        .enumerate()
1873        .for_each(|(chunk_idx, leverage_chunk)| {
1874            let chunk_start = chunk_idx * chunk_size;
1875            let chunk_len = leverage_chunk.len();
1876            let chunk_end = chunk_start + chunk_len;
1877
1878            // Chunk-local scratch: B×B flat row-major buffers for A_i, W_i
1879            // and AW = A·W.  Each worker writes only its `leverage_chunk`, so
1880            // output writes are disjoint and require no synchronization.
1881            let bb_sz = b * b;
1882            let mut a_i = vec![0.0f64; bb_sz];
1883            let mut aw = vec![0.0f64; bb_sz];
1884            let mut w_flat = vec![0.0f64; bb_sz];
1885
1886            // Column-major faer storage for q_blocks: q_k has shape
1887            // (p_tot, chunk_len) with contiguous columns, so
1888            // `col_as_slice(local_i)` is a direct stripe.
1889            let mut q_storage: Vec<FaerMat<f64>> = block_widths
1890                .iter()
1891                .map(|_| FaerMat::<f64>::zeros(p_tot, chunk_len))
1892                .collect();
1893
1894            // Per-block X^T scratch in column-major faer storage
1895            // (p_blk × chunk_len), owned by this chunk to keep the matmul input
1896            // contiguous without sharing mutable scratch across threads.
1897            let mut xt_storage: Vec<FaerMat<f64>> = block_widths
1898                .iter()
1899                .map(|&p_blk| FaerMat::<f64>::zeros(p_blk, chunk_len))
1900                .collect();
1901
1902            // Build q_blocks[blk] = H_inv[:, off..off+p_blk] · X_blk[chunk, :]^T
1903            // entirely in column-major faer storage so subsequent column reads
1904            // are contiguous f64 stripes — replaces the per-chunk `to_owned()`
1905            // ndarray slicing + row-major `dot()` from the original.
1906            for blk in 0..b {
1907                let p_blk = block_widths[blk];
1908
1909                let x_chunk = block_designs[blk].slice(s![chunk_start..chunk_end, ..]);
1910                let xt = &mut xt_storage[blk];
1911                for local_i in 0..chunk_len {
1912                    let row = x_chunk.row(local_i);
1913                    for j in 0..p_blk {
1914                        xt[(j, local_i)] = row[j];
1915                    }
1916                }
1917
1918                matmul(
1919                    q_storage[blk].as_mut(),
1920                    Accum::Replace,
1921                    h_stripes[blk].as_ref(),
1922                    xt_storage[blk].as_ref(),
1923                    1.0,
1924                    Par::Seq,
1925                );
1926            }
1927
1928            for local_i in 0..chunk_len {
1929                let i = chunk_start + local_i;
1930                let w_i = &block_weights[i];
1931
1932                // Flatten W_i once per observation (row-major).
1933                for r in 0..b {
1934                    for c in 0..b {
1935                        w_flat[r * b + c] = w_i[(r, c)];
1936                    }
1937                }
1938
1939                // Assemble A_i[a, k] = X_a[i, :] · q_k[off_a:off_a+p_a, local_i].
1940                // For each k, read its column once (contiguous f64 stripe), then
1941                // for each a take the matching offset slab.
1942                for r in 0..bb_sz {
1943                    a_i[r] = 0.0;
1944                }
1945                for k in 0..b {
1946                    let q_k = &q_storage[k];
1947                    let q_col = q_k.col_as_slice(local_i);
1948                    for a in 0..b {
1949                        let p_a = block_widths[a];
1950                        let off_a = col_offsets[a];
1951                        let xa_row = block_designs[a].row(i);
1952                        let mut dot = 0.0f64;
1953                        for j in 0..p_a {
1954                            dot = xa_row[j].mul_add(q_col[off_a + j], dot);
1955                        }
1956                        a_i[a * b + k] = dot;
1957                    }
1958                }
1959
1960                // AW = A_i · W_i (B×B), then leverage = trace(AW) = sum_{a,k} A[a,k]·W[k,a].
1961                mat_mul_flat(&a_i, &w_flat, &mut aw, b);
1962                let mut tr = 0.0f64;
1963                for d in 0..b {
1964                    tr += aw[d * b + d];
1965                }
1966                leverage_chunk[local_i] = tr;
1967            }
1968        });
1969
1970    Ok(leverage)
1971}
1972
1973// (Allocation-free, factor-once-reuse-many B×B LU helpers live next to the
1974// multi-block ALO callsite — see `lu_factor_in_place` and `lu_solve_in_place`.)
1975
1976#[cfg(test)]
1977mod tests {
1978    use super::{
1979        ALO_EXACT_SCALAR_MAX_ITERS, AloExactScalarError, AloInput, alo_eta_exact_frozen_curvature,
1980        alo_eta_updatewith_offset, bayesvar_eta, compute_alo_from_input_inner,
1981        percentile_from_sorted, percentile_index, sandwichvar_eta_from_meat,
1982    };
1983    use gam_linalg::matrix::{PsdWeightsView, SignedWeightsView};
1984    use gam_problem::LinkFunction;
1985
1986    #[test]
1987    fn alo_offset_update_matches_centered_algebra() {
1988        let eta_hat = 11.0;
1989        let z = 13.0;
1990        let offset = 10.0;
1991        let x_hinv_x = 0.2;
1992        let hessian_weight = 1.0;
1993        let score_weight = 1.0;
1994        // centered: eta~=off + ((eta-off)-a(z-off))/(1-a) when W_S = W_H.
1995        let leverage = hessian_weight * x_hinv_x;
1996        let expected = offset + ((eta_hat - offset) - leverage * (z - offset)) / (1.0 - leverage);
1997        let got =
1998            alo_eta_updatewith_offset(eta_hat, z, offset, x_hinv_x, score_weight, 1.0 - leverage);
1999        assert!((got - expected).abs() < 1e-12);
2000    }
2001
2002    #[test]
2003    fn alo_offset_update_reduces_to_classicwhen_offsetzero() {
2004        let eta_hat = 1.25;
2005        let z = -0.5;
2006        let x_hinv_x = 0.35;
2007        let hessian_weight = 1.0;
2008        let score_weight = 1.0;
2009        let leverage = hessian_weight * x_hinv_x;
2010        let expected = (eta_hat - leverage * z) / (1.0 - leverage);
2011        let got =
2012            alo_eta_updatewith_offset(eta_hat, z, 0.0, x_hinv_x, score_weight, 1.0 - leverage);
2013        assert!((got - expected).abs() < 1e-12);
2014    }
2015
2016    #[test]
2017    fn alo_offset_update_uses_distinct_score_and_hessian_weights() {
2018        let eta_hat = 1.7;
2019        let z = 0.4;
2020        let offset = -0.2;
2021        let x_hinv_x = 0.15;
2022        let hessian_weight = 3.0;
2023        let score_weight = 5.0;
2024        let expected = offset
2025            + (eta_hat - offset)
2026            + x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset))
2027                / (1.0 - hessian_weight * x_hinv_x);
2028        let got = alo_eta_updatewith_offset(
2029            eta_hat,
2030            z,
2031            offset,
2032            x_hinv_x,
2033            score_weight,
2034            1.0 - hessian_weight * x_hinv_x,
2035        );
2036        assert!((got - expected).abs() < 1e-12);
2037    }
2038
2039    #[test]
2040    fn alo_offset_update_handles_zero_hessian_weight() {
2041        let eta_hat = 0.8;
2042        let z = -0.3;
2043        let offset = 0.1;
2044        let x_hinv_x = 0.4;
2045        let hessian_weight = 0.0;
2046        let score_weight = 2.5;
2047        let expected = offset
2048            + (eta_hat - offset)
2049            + x_hinv_x * score_weight * ((eta_hat - offset) - (z - offset));
2050        let got = alo_eta_updatewith_offset(
2051            eta_hat,
2052            z,
2053            offset,
2054            x_hinv_x,
2055            score_weight,
2056            1.0 - hessian_weight * x_hinv_x,
2057        );
2058        assert!((got - expected).abs() < 1e-12);
2059    }
2060
2061    #[test]
2062    fn alo_exact_frozen_curvature_converges_to_fixed_point() {
2063        let eta_hat = 1.0;
2064        let a_ii = 0.4;
2065        let got = alo_eta_exact_frozen_curvature(eta_hat, a_ii, &|eta| (0.5 * (eta - 2.0), 0.5))
2066            .expect("linear scalar fixed point should converge in one Newton step");
2067        assert!((got - 0.75).abs() < 1e-12);
2068    }
2069
2070    #[test]
2071    fn alo_exact_frozen_curvature_reports_nonconvergence() {
2072        let err = alo_eta_exact_frozen_curvature(0.0, 1.0, &|eta| (eta + 1.0, 0.0))
2073            .expect_err("constant residual should exhaust the scalar iteration budget");
2074        let AloExactScalarError::MaxIterations { iterations, .. } = err else {
2075            panic!("constant residual must report MaxIterations, got {err:?}");
2076        };
2077        assert_eq!(
2078            iterations, ALO_EXACT_SCALAR_MAX_ITERS,
2079            "non-convergence must report the full scalar iteration budget"
2080        );
2081    }
2082
2083    #[test]
2084    fn alo_input_reports_exact_scalar_nonconvergence_with_row_context() {
2085        let design = Array2::from_elem((1, 1), 1.0);
2086        let penalized_hessian = Array2::from_elem((1, 1), 1.0);
2087        let hessian_weights = Array1::from_vec(vec![0.0]);
2088        let score_weights = Array1::from_vec(vec![0.0]);
2089        let working_response = Array1::from_vec(vec![0.0]);
2090        let eta = Array1::from_vec(vec![0.0]);
2091        let offset = Array1::from_vec(vec![0.0]);
2092        let score_curvature = |_: usize, eta: f64| (eta + 1.0, 0.0);
2093        let input = AloInput {
2094            design: &design,
2095            penalized_hessian: &penalized_hessian,
2096            hessian_weights: SignedWeightsView::from_array(&hessian_weights),
2097            score_weights: PsdWeightsView::try_from_array(&score_weights).expect("psd weights"),
2098            working_response: &working_response,
2099            eta: &eta,
2100            offset: &offset,
2101            link: LinkFunction::Logit,
2102            phi: 1.0,
2103            penalty_root: None,
2104            ridge: 0.0,
2105            score_curvature: Some(&score_curvature),
2106        };
2107
2108        let err =
2109            compute_alo_from_input_inner(&input).expect_err("non-converged exact ALO must error");
2110        let msg = err.to_string();
2111        assert!(
2112            msg.contains("ALO exact frozen-curvature solve failed at row 0"),
2113            "missing row context in exact ALO error: {msg}"
2114        );
2115        assert!(
2116            msg.contains("did not converge within"),
2117            "missing non-convergence cause in exact ALO error: {msg}"
2118        );
2119    }
2120
2121    #[test]
2122    fn gaussian_unpenalized_direct_sandwich_equals_bayes() {
2123        // In a Gaussian linear model with H = X'WX, direct meat
2124        // x_i'H^{-1}X'WXH^{-1}x_i equals x_i'H^{-1}x_i.
2125        let phi = 2.5;
2126        let x_hinv_x = 0.3;
2127        let vb = bayesvar_eta(phi, x_hinv_x);
2128        let vs = sandwichvar_eta_from_meat(phi, x_hinv_x);
2129        assert!((vb - vs).abs() < 1e-12);
2130    }
2131
2132    #[test]
2133    fn sandwich_from_direct_meat_scales_by_phi() {
2134        let phi = 1.7;
2135        let meat_quad = 0.358;
2136        let got = sandwichvar_eta_from_meat(phi, meat_quad);
2137        let expected = phi * meat_quad;
2138        assert!((got - expected).abs() < 1e-12);
2139    }
2140
2141    #[test]
2142    fn sandwich_meat_uses_score_weights_not_hessian_weights_noncanonical() {
2143        // Regression for the sandwich-SE "meat" weight bug: the meat must be the
2144        // SCORE covariance Xᵀ diag(W_S) X (Fisher, PSD), NOT the observed-info
2145        // Hessian weight W_H (signed). This fixture mimics a non-canonical link
2146        // (W_H ≠ W_S) with mixed-sign observed curvature.
2147        //
2148        // Single column (p = 1) makes H a scalar, so the sandwich variance is
2149        // closed form: with H = Σ W_H·x² + s0 (> 0 after the penalty), the meat
2150        // for obs is x_obs²·H⁻²·Σ_row W_S·x_row², and se = sqrt(φ·meat).
2151        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 1.0, 2.0, 1.0]).unwrap();
2152        // Mixed-sign observed-information weights; the negative rows carry the
2153        // larger design values so Σ W_H·x² is NEGATIVE (see assert below).
2154        let w_h_vec = Array1::from_vec(vec![1.0, -1.0, 1.0, -1.0, 0.5]);
2155        // Score/Fisher weights are strictly positive (PSD by construction).
2156        let w_s_vec = Array1::from_vec(vec![1.0, 0.8, 1.2, 0.6, 0.9]);
2157        let phi = 1.3;
2158
2159        let n = x.nrows();
2160        let sum_wh_x2: f64 = (0..n).map(|i| w_h_vec[i] * x[[i, 0]] * x[[i, 0]]).sum();
2161        let sum_ws_x2: f64 = (0..n).map(|i| w_s_vec[i] * x[[i, 0]] * x[[i, 0]]).sum();
2162        // The whole point: Σ W_H·x² < 0 < Σ W_S·x². With W_H the meat is negative
2163        // and the "materially negative sandwich variance" guard would trip
2164        // (spurious LooComputationFailed); with W_S it is a valid PSD meat.
2165        assert!(sum_wh_x2 < 0.0, "fixture must exercise a negative W_H meat");
2166        assert!(sum_ws_x2 > 0.0);
2167
2168        // Penalize enough that the penalized Hessian is PD despite Σ W_H·x² < 0.
2169        let s0 = 8.0_f64;
2170        let h = s0 + sum_wh_x2; // = 2.5
2171        assert!(h > 0.0, "penalized Hessian must stay PD");
2172        let penalized_hessian = Array2::from_elem((1, 1), h);
2173
2174        // Pre-fix arithmetic check: the OLD W_H meat would be materially negative
2175        // for the larger-x rows, so the old code returned LooComputationFailed.
2176        let old_meat_obs1 = x[[1, 0]] * x[[1, 0]] / (h * h) * sum_wh_x2;
2177        assert!(
2178            phi * old_meat_obs1 < -super::variance_negative_tolerance(phi * old_meat_obs1.abs()),
2179            "the pre-fix W_H meat must be materially negative (guard would trip)"
2180        );
2181
2182        let working_response = Array1::from_vec(vec![0.3, -0.2, 0.5, 0.1, -0.4]);
2183        let eta = Array1::from_vec(vec![0.2, 0.1, 0.4, -0.1, 0.05]);
2184        let offset = Array1::zeros(n);
2185        let input = AloInput {
2186            design: &x,
2187            penalized_hessian: &penalized_hessian,
2188            hessian_weights: SignedWeightsView::from_array(&w_h_vec),
2189            score_weights: PsdWeightsView::try_from_array(&w_s_vec).expect("psd weights"),
2190            working_response: &working_response,
2191            eta: &eta,
2192            offset: &offset,
2193            link: LinkFunction::Probit,
2194            phi,
2195            penalty_root: None,
2196            ridge: 0.0,
2197            score_curvature: None,
2198        };
2199
2200        // The fix must let this succeed (no spurious negative-meat failure)...
2201        let diag = compute_alo_from_input_inner(&input)
2202            .expect("fixed sandwich meat (W_S) must not trip the negative-variance guard");
2203
2204        // ...and match the closed-form W_S reference for every row.
2205        for obs in 0..n {
2206            let expected =
2207                (phi * x[[obs, 0]] * x[[obs, 0]] / (h * h) * sum_ws_x2).sqrt();
2208            assert!(
2209                (diag.se_sandwich[obs] - expected).abs() <= 1e-10 * expected.max(1.0),
2210                "row {obs}: se_sandwich={} expected={expected}",
2211                diag.se_sandwich[obs]
2212            );
2213        }
2214    }
2215
2216    #[test]
2217    fn percentile_index_matches_expected_rounding() {
2218        assert_eq!(percentile_index(0, 0.95), 0);
2219        assert_eq!(percentile_index(1, 0.95), 0);
2220        assert_eq!(percentile_index(10, 0.50), 5);
2221        assert_eq!(percentile_index(10, 0.95), 9);
2222    }
2223
2224    #[test]
2225    fn percentile_from_sorted_returns_order_statistic() {
2226        let values = [1.0, 2.0, 3.0, 4.0, 5.0];
2227        assert_eq!(percentile_from_sorted(&values, 0.50), 3.0);
2228        assert_eq!(percentile_from_sorted(&values, 0.95), 5.0);
2229        assert_eq!(percentile_from_sorted(&[], 0.95), 0.0);
2230    }
2231
2232    // --- Multi-block ALO tests ---
2233
2234    use super::{MultiBlockAloInput, compute_multiblock_alo, compute_multiblock_alo_leverages};
2235    use ndarray::{Array1, Array2};
2236
2237    #[test]
2238    fn multiblock_b1_matches_scalar_leverage() {
2239        // With B=1 the multi-block formula should reduce to the scalar case.
2240        // H_ii = x_i^T H^{-1} x_i * w_i  (scalar).
2241        let n = 3;
2242        let p = 2;
2243        let x = Array2::from_shape_vec((n, p), vec![1.0, 0.5, 0.8, -0.3, 0.2, 1.1]).unwrap();
2244        // H = X'WX + I (simple regularisation).
2245        let w = [1.0, 2.0, 0.5];
2246        let mut h = Array2::<f64>::eye(p);
2247        for i in 0..n {
2248            for r in 0..p {
2249                for c in 0..p {
2250                    h[(r, c)] += w[i] * x[(i, r)] * x[(i, c)];
2251                }
2252            }
2253        }
2254        // Invert H (2x2).
2255        let det = h[(0, 0)] * h[(1, 1)] - h[(0, 1)] * h[(1, 0)];
2256        let mut h_inv = Array2::<f64>::zeros((p, p));
2257        h_inv[(0, 0)] = h[(1, 1)] / det;
2258        h_inv[(1, 1)] = h[(0, 0)] / det;
2259        h_inv[(0, 1)] = -h[(0, 1)] / det;
2260        h_inv[(1, 0)] = -h[(1, 0)] / det;
2261
2262        // Scalar leverages: a_ii = w_i * x_i^T H^{-1} x_i
2263        let mut scalar_lev = vec![0.0f64; n];
2264        for i in 0..n {
2265            let mut xhx = 0.0;
2266            for r in 0..p {
2267                for c in 0..p {
2268                    xhx += x[(i, r)] * h_inv[(r, c)] * x[(i, c)];
2269                }
2270            }
2271            scalar_lev[i] = w[i] * xhx;
2272        }
2273
2274        // Multi-block with B=1.
2275        let block_designs = vec![x.clone()];
2276        let block_weights: Vec<Array2<f64>> =
2277            w.iter().map(|&wi| Array2::from_elem((1, 1), wi)).collect();
2278        let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.1])).collect();
2279        let eta_hat: Vec<Array1<f64>> = (0..n).map(|i| Array1::from_vec(vec![i as f64])).collect();
2280
2281        let input = MultiBlockAloInput {
2282            n_obs: n,
2283            n_blocks: 1,
2284            block_designs: &block_designs,
2285            penalized_hessian_inv: &h_inv,
2286            block_weights,
2287            scores,
2288            eta_hat,
2289        };
2290
2291        let result = compute_multiblock_alo(&input).unwrap();
2292        for i in 0..n {
2293            assert!(
2294                (result.leverage[i] - scalar_lev[i]).abs() < 1e-10,
2295                "leverage mismatch at i={}: got {}, expected {}",
2296                i,
2297                result.leverage[i],
2298                scalar_lev[i]
2299            );
2300        }
2301    }
2302
2303    #[test]
2304    fn multiblock_leverage_only_matches_full() {
2305        // Verify that compute_multiblock_alo_leverages returns the same
2306        // leverages as compute_multiblock_alo.
2307        let n = 4;
2308        let p1 = 2;
2309        let p2 = 3;
2310        let x1 = Array2::from_shape_fn((n, p1), |(i, j)| (i + j + 1) as f64 * 0.3);
2311        let x2 = Array2::from_shape_fn((n, p2), |(i, j)| (i * 2 + j) as f64 * 0.2 - 0.1);
2312        let p_tot = p1 + p2;
2313        let h_inv = Array2::<f64>::eye(p_tot); // Simple identity for test.
2314        let block_weights: Vec<Array2<f64>> = (0..n)
2315            .map(|i| {
2316                let v = (i + 1) as f64;
2317                Array2::from_shape_vec((2, 2), vec![v, 0.1, 0.1, v * 0.5]).unwrap()
2318            })
2319            .collect();
2320        let scores: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
2321        let eta_hat: Vec<Array1<f64>> = (0..n).map(|_| Array1::from_vec(vec![0.0, 0.0])).collect();
2322        let block_designs = vec![x1.clone(), x2.clone()];
2323
2324        let input = MultiBlockAloInput {
2325            n_obs: n,
2326            n_blocks: 2,
2327            block_designs: &block_designs,
2328            penalized_hessian_inv: &h_inv,
2329            block_weights: block_weights.clone(),
2330            scores,
2331            eta_hat,
2332        };
2333        let full = compute_multiblock_alo(&input).unwrap();
2334        let lev_only =
2335            compute_multiblock_alo_leverages(n, 2, &block_designs, &h_inv, &block_weights).unwrap();
2336
2337        for i in 0..n {
2338            assert!(
2339                (full.leverage[i] - lev_only[i]).abs() < 1e-12,
2340                "leverage mismatch at i={}: full={}, lev_only={}",
2341                i,
2342                full.leverage[i],
2343                lev_only[i]
2344            );
2345        }
2346    }
2347
2348    #[test]
2349    fn multiblock_singular_weight_still_corrects() {
2350        // When W_i = 0 (singular), the W_i⁻¹-free formula still works:
2351        // (I - W_i A_i)⁻¹ = I, so Δη = A_i s_i.
2352        // A_i = x H⁻¹ xᵀ = 1.0² + 0.5² = 1.25 (scalar, B=1).
2353        let n = 1;
2354        let p = 2;
2355        let x = Array2::from_shape_vec((1, p), vec![1.0, 0.5]).unwrap();
2356        let h_inv = Array2::eye(p);
2357        let block_designs = vec![x.clone()];
2358        let block_weights = vec![Array2::from_elem((1, 1), 0.0)]; // singular
2359        let scores = vec![Array1::from_vec(vec![1.0])];
2360        let eta_hat = vec![Array1::from_vec(vec![std::f64::consts::PI])];
2361
2362        let input = MultiBlockAloInput {
2363            n_obs: n,
2364            n_blocks: 1,
2365            block_designs: &block_designs,
2366            penalized_hessian_inv: &h_inv,
2367            block_weights,
2368            scores,
2369            eta_hat,
2370        };
2371        let result = compute_multiblock_alo(&input).unwrap();
2372        // Δη = A_i * s_i = 1.25 * 1.0 = 1.25
2373        let expected = std::f64::consts::PI + 1.25;
2374        assert!(
2375            (result.eta_tilde[0][0] - expected).abs() < 1e-12,
2376            "expected {}, got {}",
2377            expected,
2378            result.eta_tilde[0][0]
2379        );
2380        // Cook's distance should be 0 since W_i = 0.
2381        assert!(result.cook_distance[0].abs() < 1e-14);
2382        // ALO variance should be 0 since W_i = 0.
2383        assert!(result.alo_variance[0][0].abs() < 1e-14);
2384    }
2385
2386    #[test]
2387    fn multiblock_cook_and_variance_basic() {
2388        // B=1 with known values: verify Cook's distance and variance.
2389        let n = 1;
2390        let x = Array2::from_elem((1, 1), 1.0);
2391        // H⁻¹ = [[0.5]]
2392        let h_inv = Array2::from_elem((1, 1), 0.5);
2393        let block_designs = vec![x.clone()];
2394        let w_val = 2.0;
2395        let s_val = 0.4;
2396        let block_weights = vec![Array2::from_elem((1, 1), w_val)];
2397        let scores = vec![Array1::from_vec(vec![s_val])];
2398        let eta_hat = vec![Array1::from_vec(vec![1.0])];
2399
2400        let input = MultiBlockAloInput {
2401            n_obs: n,
2402            n_blocks: 1,
2403            block_designs: &block_designs,
2404            penalized_hessian_inv: &h_inv,
2405            block_weights,
2406            scores,
2407            eta_hat,
2408        };
2409        let result = compute_multiblock_alo(&input).unwrap();
2410
2411        // A_i = x H⁻¹ xᵀ = 1 * 0.5 * 1 = 0.5
2412        // (I - W A)⁻¹ = 1 / (1 - 2.0 * 0.5) = 1/0 => regularised
2413        // Actually 1 - w*a = 1 - 1.0 = 0.0, so det < 1e-12 => regularised with eps=1e-6
2414        // (I - W A + eps) = 1e-6, so v = s / 1e-6 = 4e5
2415        // delta_eta = A * v = 0.5 * 4e5 = 2e5
2416        // This is the regularised case; just check it doesn't panic and returns finite values.
2417        assert!(result.eta_tilde[0][0].is_finite());
2418        assert!(result.cook_distance[0].is_finite());
2419        assert!(result.alo_variance[0][0].is_finite());
2420    }
2421}