Skip to main content

gam/linalg/
utils.rs

1use crate::construction::calculate_condition_number;
2use crate::estimate::EstimationError;
3use crate::faer_ndarray::{
4    FaerArrayView, FaerLinalgError, array2_to_matmut, factorize_symmetricwith_fallback,
5};
6use crate::faer_ndarray::{FaerCholesky, FaerEigh};
7use crate::matrix::symmetrize_in_place;
8use faer::Side;
9use ndarray::{
10    Array1, Array2, Array3, ArrayBase, ArrayView1, ArrayView2, ArrayView3, Data, Dimension, Zip, s,
11};
12
13/// SplitMix64: deterministic 64-bit hash / streaming RNG step.
14///
15/// Canonical home for the implementation that previously lived as eight
16/// module-local copies (gpu/kernels/hutchpp, terms/analytic_penalties,
17/// solver/evidence, solver/reml/unified, inference/sample, inference/hmc,
18/// families/cubic_cell_kernel, families/marginal_slope_shared). All call
19/// sites used identical constants; this is the streaming form. For the
20/// pure-hash flavour (single `u64 -> u64` with no externally retained
21/// state) use [`splitmix64_hash`].
22#[inline]
23pub(crate) const fn splitmix64(state: &mut u64) -> u64 {
24    *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
25    let mut z = *state;
26    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
27    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
28    z ^ (z >> 31)
29}
30
31/// Pure-hash flavour of [`splitmix64`]: takes a single `u64` seed and
32/// returns a mixed value without persisting state. Equivalent to
33/// `{ let mut s = x; splitmix64(&mut s) }`.
34#[inline]
35pub(crate) const fn splitmix64_hash(x: u64) -> u64 {
36    let mut state = x;
37    splitmix64(&mut state)
38}
39
40/// Vertically concatenate 1D blocks into a single contiguous vector.
41///
42/// Blocks are copied in order into a freshly allocated `Array1` whose length
43/// is the sum of the block lengths. Canonical home for the implementation that
44/// previously lived as identical module-local copies in
45/// `families/latent_survival.rs` and `families/survival_location_scale.rs`,
46/// where it stacks per-segment offset vectors (entry / exit / derivative) into
47/// one design offset.
48pub(crate) fn stack_offsets(blocks: &[&Array1<f64>]) -> Array1<f64> {
49    let total: usize = blocks.iter().map(|block| block.len()).sum();
50    let mut out = Array1::<f64>::zeros(total);
51    let mut row = 0usize;
52    for block in blocks {
53        let end = row + block.len();
54        out.slice_mut(ndarray::s![row..end]).assign(block);
55        row = end;
56    }
57    out
58}
59
60/// Numerically stable softplus `log(1 + exp(x))`.
61///
62/// Uses the identity `softplus(x) = max(x, 0) + log1p(exp(-|x|))`, which
63/// avoids both `exp` overflow for large positive `x` and `log(1)` cancellation
64/// for large negative `x`. Previously duplicated as `stable_softplus` in
65/// `terms/smooth.rs` and `families/gamlss.rs`.
66#[inline]
67pub(crate) fn stable_softplus(x: f64) -> f64 {
68    if x > 0.0 {
69        x + (-x).exp().ln_1p()
70    } else {
71        x.exp().ln_1p()
72    }
73}
74
75/// Numerically stable logistic `σ(x) = 1 / (1 + exp(-x))`.
76///
77/// Splits on the sign of `x` to keep both `exp` arguments non-positive and
78/// avoid overflow:
79///   σ(x) = 1 / (1 + exp(-x))   for x ≥ 0,
80///   σ(x) = exp(x) / (1 + exp(x))   for x < 0.
81///
82/// Canonical home for the routine previously duplicated as `logistic` in
83/// `terms/analytic_penalties.rs`, `sigmoid_stable` in `inference/hmc.rs`, and
84/// `sigmoid_scalar` in `terms/sae_manifold.rs` — all three were bit-identical.
85#[inline]
86pub(crate) fn stable_logistic(x: f64) -> f64 {
87    if x >= 0.0 {
88        1.0 / (1.0 + (-x).exp())
89    } else {
90        let ex = x.exp();
91        ex / (1.0 + ex)
92    }
93}
94
95/// Generic finiteness check for any `f64` ndarray view (1-D, 2-D, etc.).
96#[inline]
97pub(crate) fn array_is_finite<S, D>(values: &ArrayBase<S, D>) -> bool
98where
99    S: Data<Elem = f64>,
100    D: Dimension,
101{
102    values.iter().all(|v| v.is_finite())
103}
104
105/// Infinity norm of an `f64` iterator: `max |x|`. Centralises the
106/// `iter().fold(0.0, |a, b| a.max(b.abs()))` idiom that appeared in
107/// multiple call sites across `solver/pirls.rs`, `inference/predict_input.rs`,
108/// and `terms/construction.rs`. Returns `0.0` for an empty iterator.
109#[inline]
110pub(crate) fn inf_norm<I: IntoIterator<Item = f64>>(values: I) -> f64 {
111    values.into_iter().fold(0.0_f64, |acc, x| acc.max(x.abs()))
112}
113
114const HESSIAN_CONDITION_TARGET: f64 = 1e10;
115const MAX_FACTORIZATION_ATTEMPTS: usize = 4;
116const MAX_SOLVE_RETRIES: usize = 8;
117
118#[derive(Default, Clone, Copy)]
119pub(crate) struct KahanSum {
120    sum: f64,
121    c: f64,
122}
123
124impl KahanSum {
125    #[inline]
126    pub(crate) fn add(&mut self, value: f64) {
127        let y = value - self.c;
128        let t = self.sum + y;
129        self.c = (t - self.sum) - y;
130        self.sum = t;
131    }
132
133    #[inline]
134    pub(crate) fn sum(self) -> f64 {
135        self.sum
136    }
137}
138
139/// Compute `matrix^{-1}` with a stabilization ridge added solely to make
140/// the Cholesky factorization succeed.
141///
142/// **Stabilization semantics:** the ridge applied here is a
143/// [`StabilizationKind::NumericalPerturbation`](crate::types::StabilizationKind)
144/// — it does NOT change the model, the objective, the gradient, or
145/// anything serialized. The returned matrix is treated by callers as
146/// `(matrix)^{-1}`, not `(matrix + δ I)^{-1}`. Callers that genuinely
147/// need a model-level prior must build that prior into `matrix` *before*
148/// calling this function and pass through a `RidgePassport` /
149/// `StabilizationLedger::explicit_prior` so the same δ is also accounted
150/// for in objective, logdet, and saved state.
151pub(crate) fn matrix_inversewith_regularization(
152    matrix: &Array2<f64>,
153    label: &str,
154) -> Option<Array2<f64>> {
155    StableSolver::new(label).inversewith_regularization(matrix)
156}
157
158pub(crate) struct StableSolver<'a> {
159    label: &'a str,
160}
161
162impl<'a> StableSolver<'a> {
163    pub(crate) fn new(label: &'a str) -> Self {
164        Self { label }
165    }
166
167    pub(crate) fn factorize(
168        &self,
169        matrix: &Array2<f64>,
170    ) -> Result<crate::faer_ndarray::FaerSymmetricFactor, FaerLinalgError> {
171        let view = FaerArrayView::new(matrix);
172        factorize_symmetricwith_fallback(view.as_ref(), Side::Lower)
173    }
174
175    /// Generic factorize accepting any 2-D ndarray storage (owned or view).
176    /// Useful for hot loops that solve a contiguous subblock of a hoisted
177    /// workspace buffer without reallocating an owned `Array2`.
178    pub(crate) fn factorize_any<S>(
179        &self,
180        matrix: &ArrayBase<S, ndarray::Ix2>,
181    ) -> Result<crate::faer_ndarray::FaerSymmetricFactor, FaerLinalgError>
182    where
183        S: Data<Elem = f64>,
184    {
185        let view = FaerArrayView::new(matrix);
186        factorize_symmetricwith_fallback(view.as_ref(), Side::Lower)
187    }
188
189    pub(crate) fn inversewith_regularization(&self, matrix: &Array2<f64>) -> Option<Array2<f64>> {
190        let p = matrix.nrows();
191        if p == 0 || matrix.ncols() != p {
192            return None;
193        }
194
195        let mut planner = RidgePlanner::new(matrix);
196        let (factor, _, regularized) = self.factorize_with_ridge_plan(matrix, &mut planner)?;
197        let mut inv = Array2::<f64>::eye(p);
198        let mut invview = array2_to_matmut(&mut inv);
199        factor.solve_in_place(invview.as_mut());
200
201        if !inv.iter().all(|v| v.is_finite()) {
202            log::warn!("Non-finite inverse produced for {}", self.label);
203            return None;
204        }
205
206        // Numerical solves can leave tiny asymmetry; enforce symmetry explicitly.
207        for i in 0..p {
208            for j in (i + 1)..p {
209                let avg = 0.5 * (inv[[i, j]] + inv[[j, i]]);
210                inv[[i, j]] = avg;
211                inv[[j, i]] = avg;
212            }
213        }
214        assert_eq!(regularized.nrows(), p);
215        Some(inv)
216    }
217
218    pub(crate) fn solvevectorwithridge_retries(
219        &self,
220        matrix: &Array2<f64>,
221        rhs: &Array1<f64>,
222        baseridge: f64,
223    ) -> Option<Array1<f64>> {
224        let p = matrix.nrows();
225        if matrix.ncols() != p || rhs.len() != p {
226            return None;
227        }
228
229        // Scale the ridge by the matrix's diagonal magnitude so it is
230        // *rank-revealing* rather than absolute. A fixed `baseridge = 1e-10`
231        // is meaningless for a Hessian whose largest diagonal is `O(1e8)`
232        // (relative perturbation `1e-18` — well below f64 round-off) and
233        // simultaneously over-regularises a diagonal of `O(1e-5)`. Anchoring
234        // the ridge to `max_abs_diag(H)` makes the relative regularisation
235        // strength independent of how the family scales its likelihood, so
236        // null directions (eigenvalues < ridge) get treated consistently
237        // across blocks. Without this, the joint-Newton solver returns
238        // proposals with `|prop|∞ ≈ |g|/σ_min(H) = O(1e5–1e12)` because the
239        // absolute `1e-10` ridge cannot reach the smallest eigenvalue of an
240        // O(1e-5)-scale block while the largest block has `σ_max = 1e8`.
241        let diag_scale = max_abs_diag(matrix);
242        for retry in 0..MAX_SOLVE_RETRIES {
243            let ridge = if baseridge > 0.0 {
244                baseridge * diag_scale * 10f64.powi(retry as i32)
245            } else {
246                0.0
247            };
248            let h = addridge(matrix, ridge);
249            let factor = match self.factorize(&h) {
250                Ok(f) => f,
251                Err(_) => continue,
252            };
253            let mut out = rhs.clone();
254            let mut out_mat = crate::faer_ndarray::array1_to_col_matmut(&mut out);
255            factor.solve_in_place(out_mat.as_mut());
256            if out.iter().all(|v| v.is_finite()) {
257                return Some(out);
258            }
259        }
260        None
261    }
262
263    /// Solve `matrix · δ = rhs` with a rank-revealing fallback for the
264    /// case where `matrix` has a near-null subspace aligned with `rhs`.
265    ///
266    /// First attempts the regularised Cholesky path
267    /// (`solvevectorwithridge_retries`). If the produced δ satisfies the
268    /// linear equation well (`‖matrix·δ − rhs‖∞ / (1 + ‖rhs‖∞) < rel_tol`),
269    /// returns it. Otherwise the matrix has a real null subspace and the
270    /// Tikhonov-regularised Newton step leaves a residual of magnitude
271    /// ≈ ‖rhs_null‖ — the joint-Newton convergence test then fails
272    /// (`linearized_rel ≈ 1`) and the seed is rejected.
273    ///
274    /// In that case we fall back to the truncated-eigendecomposition
275    /// pseudoinverse:
276    ///
277    ///     δ = Σ_k (uₖᵀ rhs / λₖ) · uₖ      for k with |λₖ| > cutoff
278    ///
279    /// where `(λₖ, uₖ)` are the eigenpairs of `matrix` (assumed symmetric).
280    /// Components in `null(matrix)` (i.e. |λₖ| ≤ cutoff) are *excluded* from
281    /// the sum. This is the unique minimum-norm least-squares solution to
282    /// `matrix · δ ≈ rhs`. For components of `rhs` in `range(matrix)`, δ
283    /// solves the equation exactly; for components in `null(matrix)`, δ has
284    /// zero contribution (no spurious huge step) and the joint-Newton's
285    /// constrained-stationary certificate sees a *correctly small*
286    /// projected residual.
287    ///
288    /// The cutoff is `rank_tol × max(|λ|)`, the standard rank-revealing
289    /// threshold. For p ≲ a few hundred (joint Newton at biobank scale
290    /// has p = 33) the eigendecomposition is sub-millisecond and saves
291    /// the entire outer optimisation from rejecting ill-conditioned ρ.
292    pub(crate) fn solve_with_pseudoinverse_fallback(
293        &self,
294        matrix: &Array2<f64>,
295        rhs: &Array1<f64>,
296        baseridge: f64,
297        rel_tol: f64,
298        rank_tol: f64,
299    ) -> Option<Array1<f64>> {
300        use crate::faer_ndarray::FaerEigh;
301        use faer::Side;
302
303        let p = matrix.nrows();
304        if matrix.ncols() != p || rhs.len() != p {
305            return None;
306        }
307
308        // First try the regularised Cholesky path.
309        let delta = self.solvevectorwithridge_retries(matrix, rhs, baseridge)?;
310
311        // Compute the linear residual ‖matrix·δ − rhs‖∞ / (1 + ‖rhs‖∞)
312        // — the same quantity the joint-Newton convergence test reads off as
313        // `linearized_next_kkt_inf` / (1 + `old_kkt_inf`).
314        let matrix_delta = matrix.dot(&delta);
315        let residual_inf = matrix_delta
316            .iter()
317            .zip(rhs.iter())
318            .map(|(h, r)| (h - r).abs())
319            .fold(0.0_f64, f64::max);
320        let rhs_inf = rhs.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
321        let rel = residual_inf / (1.0 + rhs_inf);
322
323        if rel.is_finite() && rel < rel_tol {
324            return Some(delta);
325        }
326
327        // Rank-deficient. Use truncated eigendecomposition pseudoinverse.
328        let (eigvals, eigvecs) = matrix.eigh(Side::Lower).ok()?;
329        let max_abs_eig = eigvals.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
330        if !max_abs_eig.is_finite() || max_abs_eig <= 0.0 {
331            return Some(delta);
332        }
333        let cutoff = rank_tol * max_abs_eig;
334
335        let mut pseudo = Array1::<f64>::zeros(p);
336        let mut excluded = 0usize;
337        for k in 0..p {
338            let lam = eigvals[k];
339            if !lam.is_finite() || lam.abs() <= cutoff {
340                excluded += 1;
341                continue;
342            }
343            let u_k = eigvecs.column(k);
344            let proj = u_k.iter().zip(rhs.iter()).map(|(u, r)| u * r).sum::<f64>();
345            let scale = proj / lam;
346            for i in 0..p {
347                pseudo[i] += scale * u_k[i];
348            }
349        }
350
351        if !pseudo.iter().all(|v| v.is_finite()) {
352            return Some(delta);
353        }
354
355        log::debug!(
356            "[{}] pseudoinverse fallback engaged: rel = {:.3e} > rel_tol = {:.3e}, \
357             excluded {} of {} eigenvalues below cutoff = {:.3e} × max |λ| = {:.3e}",
358            self.label,
359            rel,
360            rel_tol,
361            excluded,
362            p,
363            rank_tol,
364            max_abs_eig,
365        );
366
367        Some(pseudo)
368    }
369
370    fn factorize_with_ridge_plan(
371        &self,
372        matrix: &Array2<f64>,
373        planner: &mut RidgePlanner,
374    ) -> Option<(crate::faer_ndarray::FaerSymmetricFactor, f64, Array2<f64>)> {
375        loop {
376            let ridge = planner.ridge();
377            let h_eff = addridge(matrix, ridge);
378            if let Ok(factor) = self.factorize(&h_eff) {
379                return Some((factor, ridge, h_eff));
380            }
381            if planner.attempts() >= MAX_FACTORIZATION_ATTEMPTS {
382                log::warn!(
383                    "Failed to factorize {} after ridge {:.3e}",
384                    self.label,
385                    ridge
386                );
387                return None;
388            }
389            planner.bumpwith_matrix(matrix);
390        }
391    }
392}
393
394pub(crate) fn max_abs_diag(matrix: &Array2<f64>) -> f64 {
395    matrix
396        .diag()
397        .iter()
398        .copied()
399        .map(f64::abs)
400        .fold(0.0, f64::max)
401        .max(1.0)
402}
403
404pub(crate) fn row_mismatch_message(
405    y_len: usize,
406    w_len: usize,
407    x_rows: usize,
408    offset_len: usize,
409) -> Option<String> {
410    if y_len == w_len && y_len == x_rows && y_len == offset_len {
411        None
412    } else {
413        Some(format!(
414            "Row mismatch: y={}, w={}, X.rows={}, offset={}",
415            y_len, w_len, x_rows, offset_len
416        ))
417    }
418}
419
420pub(crate) fn predict_gam_dimension_mismatch_message(
421    x_rows: usize,
422    x_cols: usize,
423    beta_len: usize,
424    offset_len: usize,
425) -> Option<String> {
426    if x_cols != beta_len {
427        return Some(format!(
428            "predict_gam dimension mismatch: X has {} columns but beta has length {}",
429            x_cols, beta_len
430        ));
431    }
432    if x_rows != offset_len {
433        return Some(format!(
434            "predict_gam dimension mismatch: X has {} rows but offset has length {}",
435            x_rows, offset_len
436        ));
437    }
438    None::<String>
439}
440
441pub(crate) fn add_relative_diag_ridge(matrix: &mut Array2<f64>, scale: f64, floor: f64) -> f64 {
442    let ridge = scale
443        * matrix
444            .diag()
445            .iter()
446            .map(|&value| value.abs())
447            .fold(0.0, f64::max)
448            .max(floor);
449    for idx in 0..matrix.nrows() {
450        matrix[[idx, idx]] += ridge;
451    }
452    ridge
453}
454
455pub(crate) fn boundary_hit_indices(
456    values: ArrayView1<'_, f64>,
457    bound: f64,
458    tolerance: f64,
459) -> (Vec<usize>, Vec<usize>) {
460    let at_lower = values
461        .iter()
462        .enumerate()
463        .filter_map(|(idx, &value)| (value <= -bound + tolerance).then_some(idx))
464        .collect();
465    let at_upper = values
466        .iter()
467        .enumerate()
468        .filter_map(|(idx, &value)| (value >= bound - tolerance).then_some(idx))
469        .collect();
470    (at_lower, at_upper)
471}
472
473/// SPD-only spectrum condition number: λ_max / λ_min on the principal
474/// (positive-eigenvalue) spectrum.
475///
476/// **Invariant:** caller must have already established the matrix is
477/// positive definite. For indefinite matrices λ_min may be negative or
478/// zero and the ratio max/min becomes meaningless (it can be negative or
479/// infinite even when the matrix is well-scaled). When the spectrum sign is
480/// unknown, inspect inertia directly via [`symmetric_extremes`].
481pub(crate) fn symmetric_spectrum_condition_number(matrix: &Array2<f64>) -> f64 {
482    matrix
483        .eigh(Side::Lower)
484        .ok()
485        .map(|(evals, _)| {
486            let min = evals
487                .iter()
488                .fold(f64::INFINITY, |acc, &value| acc.min(value));
489            let max = evals
490                .iter()
491                .fold(f64::NEG_INFINITY, |acc, &value| acc.max(value));
492            max / min.max(1e-12)
493        })
494        .unwrap_or(f64::NAN)
495}
496
497/// Estimate min/max eigenvalues of a symmetric matrix via a short
498/// `eigh` call. Used by the inertia-aware stabilization rule below.
499/// Returns `None` if the eigensolver fails.
500pub(crate) fn symmetric_extremes(matrix: &Array2<f64>) -> Option<(f64, f64)> {
501    let (evals, _) = matrix.eigh(Side::Lower).ok()?;
502    let mut min = f64::INFINITY;
503    let mut max = f64::NEG_INFINITY;
504    for &v in evals.iter() {
505        if v < min {
506            min = v;
507        }
508        if v > max {
509            max = v;
510        }
511    }
512    if min.is_finite() && max.is_finite() {
513        Some((min, max))
514    } else {
515        None
516    }
517}
518
519/// Enforce exact symmetry on a square matrix by averaging off-diagonal pairs.
520/// Canonical in-place symmetrizer for a dense square `ndarray` matrix.
521///
522/// Replaces each off-diagonal pair `(i, j)`/`(j, i)` with their arithmetic
523/// mean, leaving the diagonal untouched. This is the single source of truth
524/// for the ndarray symmetrize operation; the `gam-pyffi` crate routes its
525/// former local `symmetrize_in_place` through this function. The faer-typed
526/// equivalent (`symmetrize_faer_matrix_in_place` in `terms/construction.rs`)
527/// is kept separate because it operates on `faer::Mat`.
528pub fn enforce_symmetry(matrix: &mut Array2<f64>) {
529    let n = matrix.nrows();
530    assert_eq!(n, matrix.ncols());
531    for i in 0..n {
532        for j in i + 1..n {
533            let avg = 0.5 * (matrix[[i, j]] + matrix[[j, i]]);
534            matrix[[i, j]] = avg;
535            matrix[[j, i]] = avg;
536        }
537    }
538}
539
540pub(crate) fn addridge(matrix: &Array2<f64>, ridge: f64) -> Array2<f64> {
541    if ridge <= 0.0 {
542        return matrix.clone();
543    }
544    let mut regularized = matrix.clone();
545    let n = regularized.nrows();
546    for i in 0..n {
547        regularized[[i, i]] += ridge;
548    }
549    regularized
550}
551
552pub(crate) fn boundary_hit_step_fraction(
553    slack: f64,
554    directional_slack_change: f64,
555    current_step_limit: f64,
556) -> Option<f64> {
557    if !slack.is_finite()
558        || !directional_slack_change.is_finite()
559        || !current_step_limit.is_finite()
560        || current_step_limit <= 0.0
561    {
562        return None;
563    }
564
565    let scale = slack
566        .abs()
567        .max(directional_slack_change.abs())
568        .max(current_step_limit.abs())
569        .max(1.0);
570    let directional_tol = (64.0 * f64::EPSILON * scale).max(1e-14);
571    if directional_slack_change >= -directional_tol {
572        return None;
573    }
574
575    let step = (slack / -directional_slack_change).max(0.0);
576    if step.is_finite() && step < current_step_limit {
577        return Some(step);
578    }
579    None
580}
581
582#[derive(Debug, Clone, Copy, PartialEq)]
583pub struct PcgSolveInfo {
584    pub iterations: usize,
585    pub converged: bool,
586    pub relative_residual_norm: f64,
587    pub initial_residual_norm: f64,
588    pub final_residual_norm: f64,
589    pub residual_reduction: f64,
590    pub condition_estimate: Option<f64>,
591}
592
593#[derive(Debug, Clone)]
594struct PcgDiagnostics {
595    residuals: Vec<f64>,
596    alpha: Vec<f64>,
597    beta: Vec<f64>,
598}
599
600impl PcgDiagnostics {
601    fn new(initial_residual_norm: f64) -> Self {
602        Self {
603            residuals: vec![initial_residual_norm],
604            alpha: Vec::new(),
605            beta: Vec::new(),
606        }
607    }
608
609    fn push_iteration(&mut self, alpha: f64, beta: Option<f64>, residual_norm: f64) {
610        self.alpha.push(alpha);
611        if let Some(beta) = beta {
612            self.beta.push(beta);
613        }
614        self.residuals.push(residual_norm);
615    }
616
617    fn condition_estimate(&self) -> Option<f64> {
618        // Build the CG Lanczos tridiagonal for the preconditioned operator.
619        // For SPD CG, T has diagonal 1/a_i + b_{i-1}/a_{i-1} and off-diagonal
620        // sqrt(b_i)/a_i. Its eigenvalues are the Ritz estimates of the
621        // preconditioned operator's spectrum; cond ≈ λ_max(T) / λ_min(T).
622        //
623        // Previous code substituted Gershgorin disc bounds for the Ritz
624        // values. Those bounds are guaranteed *enclosures*, not estimates:
625        // they are systematically pessimistic and frequently produce a
626        // negative lower bound even for SPD T, which then collapsed the
627        // condition estimate to `None` and lost the diagnostic. With k ≤ 256
628        // a direct symmetric eigensolve is microseconds and yields the
629        // genuine Ritz values.
630        let k = self.alpha.len();
631        if k == 0 || k > 256 {
632            return None;
633        }
634        let mut t = ndarray::Array2::<f64>::zeros((k, k));
635        for i in 0..k {
636            let alpha_i = self.alpha[i];
637            if !alpha_i.is_finite() || alpha_i <= 0.0 {
638                return None;
639            }
640            let mut diag = 1.0 / alpha_i;
641            if i > 0 {
642                let beta_prev = self.beta.get(i - 1).copied()?;
643                if !beta_prev.is_finite() || beta_prev < 0.0 {
644                    return None;
645                }
646                diag += beta_prev / self.alpha[i - 1];
647            }
648            t[[i, i]] = diag;
649            if i + 1 < k {
650                let beta_i = self.beta.get(i).copied().unwrap_or(0.0);
651                if !beta_i.is_finite() || beta_i < 0.0 {
652                    return None;
653                }
654                let off = beta_i.sqrt() / alpha_i;
655                t[[i, i + 1]] = off;
656                t[[i + 1, i]] = off;
657            }
658        }
659        let (evals, _) = t.eigh(Side::Lower).ok()?;
660        let mut lower = f64::INFINITY;
661        let mut upper = f64::NEG_INFINITY;
662        for &v in evals.iter() {
663            if !v.is_finite() {
664                return None;
665            }
666            if v < lower {
667                lower = v;
668            }
669            if v > upper {
670                upper = v;
671            }
672        }
673        if lower > 0.0 && upper > 0.0 {
674            Some(upper / lower)
675        } else {
676            None
677        }
678    }
679
680    fn info(
681        &self,
682        iterations: usize,
683        converged: bool,
684        rhs_norm: f64,
685        final_residual_norm: f64,
686    ) -> PcgSolveInfo {
687        let initial = self.residuals.first().copied().unwrap_or(rhs_norm);
688        PcgSolveInfo {
689            iterations,
690            converged,
691            relative_residual_norm: final_residual_norm / rhs_norm.max(1.0),
692            initial_residual_norm: initial,
693            final_residual_norm,
694            residual_reduction: if initial > 0.0 {
695                final_residual_norm / initial
696            } else {
697                0.0
698            },
699            condition_estimate: self.condition_estimate(),
700        }
701    }
702}
703
704pub fn solve_spd_pcg_with_info<F>(
705    apply: F,
706    rhs: &Array1<f64>,
707    preconditioner_diag: &Array1<f64>,
708    rel_tol: f64,
709    max_iter: usize,
710) -> Option<(Array1<f64>, PcgSolveInfo)>
711where
712    F: Fn(&Array1<f64>) -> Array1<f64>,
713{
714    let p = rhs.len();
715    if p == 0 || preconditioner_diag.len() != p || max_iter == 0 {
716        return None;
717    }
718    let rhs_norm = rhs.dot(rhs).sqrt();
719    if !rhs_norm.is_finite() {
720        return None;
721    }
722    if rhs_norm == 0.0 {
723        return Some((
724            Array1::<f64>::zeros(p),
725            PcgSolveInfo {
726                iterations: 0,
727                converged: true,
728                relative_residual_norm: 0.0,
729                initial_residual_norm: 0.0,
730                final_residual_norm: 0.0,
731                residual_reduction: 0.0,
732                condition_estimate: None,
733            },
734        ));
735    }
736
737    let tol = rel_tol.max(1e-12) * rhs_norm.max(1.0);
738    let mut x = Array1::<f64>::zeros(p);
739    let mut r = rhs.clone();
740    let mut diagnostics = PcgDiagnostics::new(rhs_norm);
741
742    // Precompute reciprocal preconditioner once. Each PCG iteration applies
743    // M^{-1} via a single elementwise multiply (z = inv_m * r).
744    //
745    // SPD-PCG requires a strictly positive preconditioner (M ≻ 0). A
746    // non-positive diagonal entry is a contract violation by the caller —
747    // either the matrix is not actually SPD, or it has a structural zero.
748    // Silently `abs()`-ing the value (the historical behavior) hides this
749    // and produces a "solution" that does not minimize the SPD energy.
750    // Instead, fall through to `None` so the caller routes to a
751    // direct-factorization or indefinite Krylov path. We still tolerate
752    // very small positive values via a 1e-12 floor for numerical noise.
753    let mut inv_m = Array1::<f64>::zeros(p);
754    let mut bad_diag = false;
755    for (slot, &m) in inv_m.iter_mut().zip(preconditioner_diag.iter()) {
756        if !m.is_finite() || m < 0.0 {
757            // Negative or non-finite preconditioner diagonal violates the
758            // SPD-PCG contract (M ≻ 0). Hard error rather than silent
759            // `abs()`: caller should route to a direct factorization or
760            // indefinite Krylov path. Exactly-zero entries are treated as
761            // numerical noise and floored to 1e-12.
762            bad_diag = true;
763            break;
764        }
765        *slot = 1.0 / m.max(1e-12);
766    }
767    if bad_diag {
768        log::warn!(
769            "SPD PCG rejected: preconditioner diagonal contained a negative or \
770             non-finite entry; caller should route to a direct factorization \
771             or indefinite Krylov path."
772        );
773        return None;
774    }
775
776    let mut z = Array1::<f64>::zeros(p);
777    Zip::from(&mut z)
778        .and(&r)
779        .and(&inv_m)
780        .par_for_each(|zi, &ri, &im| {
781            *zi = ri * im;
782        });
783    let mut p_dir = z.clone();
784    let mut rz_old = r.dot(&z);
785    if !rz_old.is_finite() || rz_old <= 0.0 {
786        return None;
787    }
788
789    for iter in 0..max_iter {
790        let ap = apply(&p_dir);
791        if ap.len() != p {
792            return None;
793        }
794        let denom = p_dir.dot(&ap);
795        if !denom.is_finite() || denom <= 0.0 {
796            return None;
797        }
798        let alpha = rz_old / denom;
799        if !alpha.is_finite() {
800            return None;
801        }
802        x.scaled_add(alpha, &p_dir);
803        r.scaled_add(-alpha, &ap);
804        if (iter + 1) % 32 == 0 {
805            // Periodic residual refresh: r <- rhs - A x. Done in-place via
806            // assign + scaled_add to avoid the prior fresh-allocation pattern
807            // (`r = rhs - &ax;`) inside the hot loop.
808            let ax = apply(&x);
809            if ax.len() != p {
810                return None;
811            }
812            r.assign(rhs);
813            r.scaled_add(-1.0, &ax);
814        }
815        let r_norm = r.dot(&r).sqrt();
816        if r_norm.is_finite() && r_norm <= tol {
817            diagnostics.push_iteration(alpha, None, r_norm);
818            return x
819                .iter()
820                .all(|v| v.is_finite())
821                .then_some((x, diagnostics.info(iter + 1, true, rhs_norm, r_norm)));
822        }
823        Zip::from(&mut z)
824            .and(&r)
825            .and(&inv_m)
826            .par_for_each(|zi, &ri, &im| {
827                *zi = ri * im;
828            });
829        let rz_new = r.dot(&z);
830        if !rz_new.is_finite() || rz_new <= 0.0 {
831            return None;
832        }
833        let beta = rz_new / rz_old;
834        if !beta.is_finite() {
835            return None;
836        }
837        diagnostics.push_iteration(alpha, Some(beta), r_norm);
838        // p <- z + beta * p (fused, SIMD-friendly via ndarray::Zip; parallel
839        // over coefficient dimension at biobank-scale p).
840        Zip::from(&mut p_dir).and(&z).par_for_each(|pi, &zi| {
841            *pi = zi + beta * *pi;
842        });
843        rz_old = rz_new;
844    }
845    None
846}
847
848pub fn solve_spd_pcg<F>(
849    apply: F,
850    rhs: &Array1<f64>,
851    preconditioner_diag: &Array1<f64>,
852    rel_tol: f64,
853    max_iter: usize,
854) -> Option<Array1<f64>>
855where
856    F: Fn(&Array1<f64>) -> Array1<f64>,
857{
858    solve_spd_pcg_with_info(apply, rhs, preconditioner_diag, rel_tol, max_iter)
859        .map(|(solution, _)| solution)
860}
861
862/// Write-into variant of `solve_spd_pcg_with_info` that takes an apply closure
863/// of the form `Fn(&Array1<f64>, &mut Array1<f64>)` so the matvec can write into
864/// a caller-owned buffer. This eliminates the per-iteration `Array1::<f64>`
865/// allocation for the matvec result that the legacy closure-returning variant
866/// forces. See commit 83369abb for the analogous penalty-vector elimination.
867pub fn solve_spd_pcg_with_info_into<F>(
868    apply: F,
869    rhs: &Array1<f64>,
870    preconditioner_diag: &Array1<f64>,
871    rel_tol: f64,
872    max_iter: usize,
873) -> Option<(Array1<f64>, PcgSolveInfo)>
874where
875    F: Fn(&Array1<f64>, &mut Array1<f64>),
876{
877    let p = rhs.len();
878    if p == 0 || preconditioner_diag.len() != p || max_iter == 0 {
879        return None;
880    }
881    let rhs_norm = rhs.dot(rhs).sqrt();
882    if !rhs_norm.is_finite() {
883        return None;
884    }
885    if rhs_norm == 0.0 {
886        return Some((
887            Array1::<f64>::zeros(p),
888            PcgSolveInfo {
889                iterations: 0,
890                converged: true,
891                relative_residual_norm: 0.0,
892                initial_residual_norm: 0.0,
893                final_residual_norm: 0.0,
894                residual_reduction: 0.0,
895                condition_estimate: None,
896            },
897        ));
898    }
899
900    let tol = rel_tol.max(1e-12) * rhs_norm.max(1.0);
901    let mut x = Array1::<f64>::zeros(p);
902    let mut r = rhs.clone();
903    let mut diagnostics = PcgDiagnostics::new(rhs_norm);
904
905    if preconditioner_diag
906        .iter()
907        .any(|&m| !m.is_finite() || m <= 0.0)
908    {
909        return None;
910    }
911    let mut inv_m = Array1::<f64>::zeros(p);
912    Zip::from(&mut inv_m)
913        .and(preconditioner_diag)
914        .par_for_each(|inv, &m| {
915            *inv = 1.0 / m.max(1e-12);
916        });
917
918    let mut z = Array1::<f64>::zeros(p);
919    Zip::from(&mut z)
920        .and(&r)
921        .and(&inv_m)
922        .par_for_each(|zi, &ri, &im| {
923            *zi = ri * im;
924        });
925    let mut p_dir = z.clone();
926    let mut rz_old = r.dot(&z);
927    if !rz_old.is_finite() || rz_old <= 0.0 {
928        return None;
929    }
930
931    // Reusable matvec scratch (filled by `apply`).
932    let mut ap = Array1::<f64>::zeros(p);
933
934    for iter in 0..max_iter {
935        apply(&p_dir, &mut ap);
936        if ap.len() != p {
937            return None;
938        }
939        let denom = p_dir.dot(&ap);
940        if !denom.is_finite() || denom <= 0.0 {
941            return None;
942        }
943        let alpha = rz_old / denom;
944        if !alpha.is_finite() {
945            return None;
946        }
947        x.scaled_add(alpha, &p_dir);
948        r.scaled_add(-alpha, &ap);
949        if (iter + 1) % 32 == 0 {
950            // Periodic residual refresh: r <- rhs - A x. Reuse `ap` as scratch
951            // for A x to avoid an extra allocation.
952            apply(&x, &mut ap);
953            if ap.len() != p {
954                return None;
955            }
956            r.assign(rhs);
957            r.scaled_add(-1.0, &ap);
958        }
959        let r_norm = r.dot(&r).sqrt();
960        if r_norm.is_finite() && r_norm <= tol {
961            diagnostics.push_iteration(alpha, None, r_norm);
962            return x
963                .iter()
964                .all(|v| v.is_finite())
965                .then_some((x, diagnostics.info(iter + 1, true, rhs_norm, r_norm)));
966        }
967        Zip::from(&mut z)
968            .and(&r)
969            .and(&inv_m)
970            .par_for_each(|zi, &ri, &im| {
971                *zi = ri * im;
972            });
973        let rz_new = r.dot(&z);
974        if !rz_new.is_finite() || rz_new <= 0.0 {
975            return None;
976        }
977        let beta = rz_new / rz_old;
978        if !beta.is_finite() {
979            return None;
980        }
981        diagnostics.push_iteration(alpha, Some(beta), r_norm);
982        Zip::from(&mut p_dir).and(&z).par_for_each(|pi, &zi| {
983            *pi = zi + beta * *pi;
984        });
985        rz_old = rz_new;
986    }
987    None
988}
989
990#[derive(Clone)]
991pub(crate) struct RidgePlanner {
992    cond_estimate: Option<f64>,
993    ridge: f64,
994    attempts: usize,
995    scale: f64,
996}
997
998impl RidgePlanner {
999    pub(crate) fn new(matrix: &Array2<f64>) -> Self {
1000        let scale = max_abs_diag(matrix);
1001        let min_step = scale * 1e-10;
1002        // Most Hessians factorize on the first attempt. Avoid an eager exact
1003        // condition-number decomposition here and only pay for spectral
1004        // diagnostics after an actual factorization failure.
1005        //
1006        // RidgePlanner is *strictly* a numerical-perturbation device: the
1007        // perturbation is applied so a Cholesky factorization succeeds for
1008        // an inverse / linear solve, and the matrix the caller hands back
1009        // to the rest of the system is the unperturbed one.
1010        Self {
1011            cond_estimate: None,
1012            ridge: min_step,
1013            attempts: 0,
1014            scale,
1015        }
1016    }
1017
1018    pub(crate) fn ridge(&self) -> f64 {
1019        self.ridge
1020    }
1021
1022    #[inline]
1023    fn estimate_conditionwithridge(&self, matrix: &Array2<f64>, ridge: f64) -> Option<f64> {
1024        let regularized = if ridge > 0.0 {
1025            addridge(matrix, ridge)
1026        } else {
1027            matrix.clone()
1028        };
1029        calculate_condition_number(&regularized)
1030            .ok()
1031            .filter(|c| c.is_finite() && *c > 0.0)
1032    }
1033
1034    pub(crate) fn bumpwith_matrix(&mut self, matrix: &Array2<f64>) {
1035        self.attempts += 1;
1036        let min_step = self.scale * 1e-10;
1037        let base = self.ridge.max(min_step);
1038
1039        // Primary rule: inertia-target. Estimate λ_min(H) on the unperturbed
1040        // matrix; pick δ so that λ_min(H + δ I) ≥ τ for an SPD floor τ tied
1041        // to the matrix scale. This is a defensible "make it positive
1042        // definite by exactly the amount needed" rule, in contrast with
1043        // condition-number sqrt heuristics that happen to land in the same
1044        // ballpark only by coincidence.
1045        let spd_floor = self.scale * 1e-8;
1046        let mut next_ridge = if let Some((lam_min, _lam_max)) = symmetric_extremes(matrix) {
1047            // δ = max(min_step, τ - λ_min). Multiply by a small safety
1048            // factor (1.5×) on the deficit so a single eigensolver round-off
1049            // does not leave us a hair below τ on the first retry.
1050            let deficit = (spd_floor - lam_min).max(0.0);
1051            let proposal = (1.5 * deficit).max(base * 1.5).max(min_step);
1052            // Cap escalation per attempt so we don't shoot past what's
1053            // needed when λ_min is wildly negative; the surrounding loop
1054            // will re-bump up to MAX_FACTORIZATION_ATTEMPTS times.
1055            proposal.min(base * 10.0)
1056        } else {
1057            f64::NAN
1058        };
1059
1060        // Fallback rule: condition-number heuristic. Used only when the
1061        // eigensolver itself failed (rare, usually means a non-finite
1062        // matrix or extreme scaling).
1063        if !next_ridge.is_finite() {
1064            let cond_now = self.estimate_conditionwithridge(matrix, base);
1065            self.cond_estimate = cond_now;
1066            next_ridge = if let Some(cond) = cond_now {
1067                let ratio = cond / HESSIAN_CONDITION_TARGET;
1068                let mut multiplier = if ratio > 1.0 {
1069                    ratio.sqrt().clamp(1.5, 10.0)
1070                } else {
1071                    (2.0 + self.attempts as f64).clamp(3.0, 10.0)
1072                };
1073                let mut proposal = base * multiplier;
1074                if let Some(cond_next) = self.estimate_conditionwithridge(matrix, proposal)
1075                    && cond_next > cond * 0.9
1076                    && ratio > 1.0
1077                {
1078                    multiplier = (multiplier * 1.8).clamp(2.0, 10.0);
1079                    proposal = base * multiplier;
1080                }
1081                proposal.max(min_step)
1082            } else if self.ridge <= 0.0 {
1083                min_step
1084            } else {
1085                (base * 10.0).max(min_step)
1086            };
1087        }
1088
1089        if !next_ridge.is_finite() || next_ridge <= 0.0 {
1090            next_ridge = self.scale;
1091        }
1092
1093        self.ridge = next_ridge;
1094    }
1095
1096    pub(crate) fn attempts(&self) -> usize {
1097        self.attempts
1098    }
1099}
1100
1101/// Weighted ridge (penalized least-squares) solve for a multi-output Gaussian
1102/// response. Forms the weighted normal equations `XᵀWX (+ λ·penalty) β = XᵀWY`
1103/// (row weights `W = diag(weights)`), factorizes the symmetric system via the
1104/// Cholesky-with-fallback path, solves for the coefficients `(p, d)`, and
1105/// returns `(coefficients, fitted = Xβ)`. Single source of truth shared by the
1106/// `gaussian_weighted_ridge` FFI shim and any core consumer.
1107pub fn gaussian_weighted_ridge(
1108    x: ArrayView2<'_, f64>,
1109    y: ArrayView2<'_, f64>,
1110    penalty: ArrayView2<'_, f64>,
1111    weights: ArrayView1<'_, f64>,
1112    ridge_lambda: f64,
1113) -> Result<(Array2<f64>, Array2<f64>), String> {
1114    let n = x.nrows();
1115    let p = x.ncols();
1116    if n == 0 || p == 0 {
1117        return Err("X cannot be empty".to_string());
1118    }
1119    if y.nrows() != n {
1120        return Err(format!(
1121            "X/Y row mismatch: X has {n} rows but Y has {} rows",
1122            y.nrows()
1123        ));
1124    }
1125    if y.ncols() == 0 {
1126        return Err("Y must have at least one column".to_string());
1127    }
1128    if weights.len() != n {
1129        return Err(format!(
1130            "weights length mismatch: expected {n}, got {}",
1131            weights.len()
1132        ));
1133    }
1134    if penalty.nrows() != p || penalty.ncols() != p {
1135        return Err(format!(
1136            "penalty shape mismatch: expected {p}x{p}, got {}x{}",
1137            penalty.nrows(),
1138            penalty.ncols()
1139        ));
1140    }
1141    if !ridge_lambda.is_finite() || ridge_lambda < 0.0 {
1142        return Err(format!(
1143            "ridge_lambda must be finite and non-negative; got {ridge_lambda}"
1144        ));
1145    }
1146    if x.iter()
1147        .chain(y.iter())
1148        .chain(penalty.iter())
1149        .chain(weights.iter())
1150        .any(|value| !value.is_finite())
1151    {
1152        return Err("weighted ridge inputs must be finite".to_string());
1153    }
1154    if weights.iter().any(|value| *value < 0.0) {
1155        return Err("weights must be non-negative likelihood row weights".to_string());
1156    }
1157
1158    let mut wx = x.to_owned();
1159    let mut wy = y.to_owned();
1160    for i in 0..n {
1161        let wi = weights[i];
1162        wx.row_mut(i).iter_mut().for_each(|value| *value *= wi);
1163        wy.row_mut(i).iter_mut().for_each(|value| *value *= wi);
1164    }
1165    let mut system = x.t().dot(&wx);
1166    if ridge_lambda > 0.0 {
1167        system += &(penalty.to_owned() * ridge_lambda);
1168    }
1169    let rhs = x.t().dot(&wy);
1170    let factor =
1171        factorize_symmetricwith_fallback(FaerArrayView::new(&system).as_ref(), Side::Lower)
1172            .map_err(|err| format!("weighted ridge factorization failed: {err}"))?;
1173    let mut coefficients = rhs;
1174    let mut coefficients_view = array2_to_matmut(&mut coefficients);
1175    factor.solve_in_place(coefficients_view.as_mut());
1176    if coefficients.iter().any(|value| !value.is_finite()) {
1177        return Err("weighted ridge solve produced non-finite coefficients".to_string());
1178    }
1179    let fitted = x.dot(&coefficients);
1180    Ok((coefficients, fitted))
1181}
1182
1183/// Batched [`gaussian_weighted_ridge`]: solve one independent weighted-ridge fit
1184/// per leading-axis slice of the padded `(K, N_max, p)` design / `(K, N_max, d)`
1185/// response, honoring optional per-batch active `row_counts`. Runs the
1186/// per-batch solves in parallel and scatters results back into dense
1187/// `(K, p, d)` coefficients and `(K, N_max, d)` fitted arrays (padding rows
1188/// left zero).
1189pub fn gaussian_weighted_ridge_batch(
1190    x: ArrayView3<'_, f64>,
1191    y: ArrayView3<'_, f64>,
1192    penalty: ArrayView2<'_, f64>,
1193    weights: ArrayView2<'_, f64>,
1194    ridge_lambda: f64,
1195    row_counts: Option<ArrayView1<'_, usize>>,
1196) -> Result<(Array3<f64>, Array3<f64>), String> {
1197    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1198
1199    let (batch, n_max, p) = x.dim();
1200    let (y_batch, y_n_max, d) = y.dim();
1201    if batch == 0 || n_max == 0 || p == 0 {
1202        return Err("batched X must have non-empty K, N, and coefficient dimensions".to_string());
1203    }
1204    if y_batch != batch || y_n_max != n_max {
1205        return Err(format!(
1206            "batched X/Y shape mismatch: X is ({batch}, {n_max}, {p}) but Y is ({y_batch}, {y_n_max}, {d})"
1207        ));
1208    }
1209    if d == 0 {
1210        return Err("batched Y must have at least one output column".to_string());
1211    }
1212    if weights.nrows() != batch || weights.ncols() != n_max {
1213        return Err(format!(
1214            "batched weights shape mismatch: expected ({batch}, {n_max}), got ({}, {})",
1215            weights.nrows(),
1216            weights.ncols()
1217        ));
1218    }
1219    if penalty.nrows() != p || penalty.ncols() != p {
1220        return Err(format!(
1221            "penalty shape mismatch: expected {p}x{p}, got {}x{}",
1222            penalty.nrows(),
1223            penalty.ncols()
1224        ));
1225    }
1226    if !ridge_lambda.is_finite() || ridge_lambda < 0.0 {
1227        return Err(format!(
1228            "ridge_lambda must be finite and non-negative; got {ridge_lambda}"
1229        ));
1230    }
1231    if x.iter()
1232        .chain(y.iter())
1233        .chain(penalty.iter())
1234        .chain(weights.iter())
1235        .any(|value| !value.is_finite())
1236    {
1237        return Err("batched weighted ridge inputs must be finite".to_string());
1238    }
1239    if weights.iter().any(|value| *value < 0.0) {
1240        return Err("batched weights must be non-negative likelihood row weights".to_string());
1241    }
1242
1243    let active_rows: Vec<usize> = match row_counts {
1244        Some(counts) => {
1245            if counts.len() != batch {
1246                return Err(format!(
1247                    "row_counts length mismatch: expected {batch}, got {}",
1248                    counts.len()
1249                ));
1250            }
1251            counts.to_vec()
1252        }
1253        None => vec![n_max; batch],
1254    };
1255    for (b, &n_rows) in active_rows.iter().enumerate() {
1256        if n_rows > n_max {
1257            return Err(format!(
1258                "row_counts[{b}]={n_rows} exceeds padded row count {n_max}"
1259            ));
1260        }
1261    }
1262
1263    let results: Vec<Result<(usize, Array2<f64>, Array2<f64>), String>> = (0..batch)
1264        .into_par_iter()
1265        .map(|b| {
1266            let n_rows = active_rows[b];
1267            if n_rows == 0 {
1268                return Ok((
1269                    b,
1270                    Array2::<f64>::zeros((p, d)),
1271                    Array2::<f64>::zeros((0, d)),
1272                ));
1273            }
1274            gaussian_weighted_ridge(
1275                x.slice(s![b, 0..n_rows, ..]),
1276                y.slice(s![b, 0..n_rows, ..]),
1277                penalty,
1278                weights.slice(s![b, 0..n_rows]),
1279                ridge_lambda,
1280            )
1281            .map(|(coefficients, fitted)| (b, coefficients, fitted))
1282            .map_err(|err| format!("batched weighted ridge fit {b} failed: {err}"))
1283        })
1284        .collect();
1285
1286    let mut coefficients = Array3::<f64>::zeros((batch, p, d));
1287    let mut fitted = Array3::<f64>::zeros((batch, n_max, d));
1288    for result in results {
1289        let (b, fit_coefficients, fit_fitted) = result?;
1290        coefficients
1291            .slice_mut(s![b, .., ..])
1292            .assign(&fit_coefficients);
1293        let n_rows = fit_fitted.nrows();
1294        if n_rows > 0 {
1295            fitted.slice_mut(s![b, 0..n_rows, ..]).assign(&fit_fitted);
1296        }
1297    }
1298    Ok((coefficients, fitted))
1299}
1300
1301/// Rank and Moore–Penrose pseudoinverse of a symmetric PSD penalty matrix via
1302/// its eigendecomposition, keeping eigenpairs whose eigenvalue exceeds a
1303/// relative tolerance. Returns `(rank, pinv)`.
1304pub fn block_penalty_rank_and_pinv(
1305    penalty: &Array2<f64>,
1306) -> Result<(usize, Array2<f64>), EstimationError> {
1307    let (eigs, vecs) = penalty.to_owned().eigh(Side::Lower).map_err(|_| {
1308        EstimationError::ModelIsIllConditioned {
1309            condition_number: f64::INFINITY,
1310        }
1311    })?;
1312    let max_abs = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
1313    let tol = (1.0e-10 * max_abs).max(1.0e-14);
1314    let mut rank = 0_usize;
1315    let mut scaled = Array2::<f64>::zeros(vecs.dim());
1316    for col in 0..eigs.len() {
1317        if eigs[col] > tol {
1318            rank += 1;
1319            for row in 0..vecs.nrows() {
1320                scaled[[row, col]] = vecs[[row, col]] / eigs[col];
1321            }
1322        }
1323    }
1324    Ok((rank, scaled.dot(&vecs.t())))
1325}
1326
1327/// Invert a symmetric positive-definite matrix, escalating a relative diagonal
1328/// ridge until the Cholesky factorization succeeds (robust SPD inverse).
1329pub fn invert_spd_with_ridge(
1330    matrix: &Array2<f64>,
1331    ridge_rel: f64,
1332) -> Result<Array2<f64>, EstimationError> {
1333    let n = matrix.nrows();
1334    let eye = Array2::<f64>::eye(n);
1335    let scale = (0..n).map(|i| matrix[[i, i]].abs()).fold(1.0_f64, f64::max);
1336    let ridges = [0.0, ridge_rel, 1.0e-10, 1.0e-8, 1.0e-6, 1.0e-4];
1337    for rel in ridges {
1338        let mut candidate = matrix.clone();
1339        if rel > 0.0 {
1340            for i in 0..n {
1341                candidate[[i, i]] += rel * scale;
1342            }
1343        }
1344        if let Ok(chol) = candidate.cholesky(Side::Lower) {
1345            return Ok(chol.solve_mat(&eye));
1346        }
1347    }
1348    Err(EstimationError::ModelIsIllConditioned {
1349        condition_number: f64::INFINITY,
1350    })
1351}
1352
1353/// Solve a symmetric (possibly indefinite/ill-conditioned) linear system via
1354/// eigendecomposition with a spectral floor: eigenvalues below the floor are
1355/// clamped (preserving sign) before inversion, stabilizing the solve.
1356pub fn solve_symmetric_vector_with_floor(
1357    matrix: &Array2<f64>,
1358    rhs: &Array1<f64>,
1359    ridge_rel: f64,
1360) -> Result<Array1<f64>, EstimationError> {
1361    let n = matrix.nrows();
1362    let mut sym = matrix.clone();
1363    symmetrize_in_place(&mut sym);
1364    let (eigs, vecs) =
1365        sym.eigh(Side::Lower)
1366            .map_err(|_| EstimationError::ModelIsIllConditioned {
1367                condition_number: f64::INFINITY,
1368            })?;
1369    let max_eig = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
1370    let floor = (ridge_rel * max_eig.max(1.0)).max(1.0e-12);
1371    let projected = vecs.t().dot(rhs);
1372    let mut scaled = Array1::<f64>::zeros(n);
1373    for i in 0..n {
1374        let denom = if eigs[i].abs() >= floor {
1375            eigs[i]
1376        } else if eigs[i].is_sign_negative() {
1377            -floor
1378        } else {
1379            floor
1380        };
1381        scaled[i] = projected[i] / denom;
1382    }
1383    let out = vecs.dot(&scaled);
1384    if out.iter().all(|value| value.is_finite()) {
1385        Ok(out)
1386    } else {
1387        Err(EstimationError::ModelIsIllConditioned {
1388            condition_number: f64::INFINITY,
1389        })
1390    }
1391}
1392
1393/// Solve a symmetric dense block system `H x = rhs` (single right-hand side)
1394/// via the Cholesky-with-fallback factorization, returning the solution vector.
1395/// `context` labels errors.
1396pub fn solve_dense_block_system(
1397    hessian: &Array2<f64>,
1398    rhs: &Array1<f64>,
1399    context: &str,
1400) -> Result<Array1<f64>, String> {
1401    let mut rhs2 = Array2::<f64>::zeros((rhs.len(), 1));
1402    for i in 0..rhs.len() {
1403        rhs2[[i, 0]] = rhs[i];
1404    }
1405    let factor =
1406        factorize_symmetricwith_fallback(FaerArrayView::new(hessian).as_ref(), Side::Lower)
1407            .map_err(|err| format!("{context} factorization failed: {err}"))?;
1408    {
1409        let mut rhs_view = array2_to_matmut(&mut rhs2);
1410        factor.solve_in_place(rhs_view.as_mut());
1411    }
1412    let mut out = Array1::<f64>::zeros(rhs.len());
1413    for i in 0..rhs.len() {
1414        out[i] = rhs2[[i, 0]];
1415    }
1416    if out.iter().any(|v| !v.is_finite()) {
1417        return Err(format!("{context} solve produced non-finite coefficients"));
1418    }
1419    Ok(out)
1420}
1421
1422#[cfg(test)]
1423mod ridge_tests {
1424    use super::{gaussian_weighted_ridge, gaussian_weighted_ridge_batch};
1425    use ndarray::{Array2, Array3, ArrayView2, array, s};
1426
1427    fn assert_close(lhs: ArrayView2<'_, f64>, rhs: ArrayView2<'_, f64>, tol: f64) {
1428        assert_eq!(lhs.dim(), rhs.dim());
1429        for ((i, j), value) in lhs.indexed_iter() {
1430            let diff = (*value - rhs[[i, j]]).abs();
1431            assert!(
1432                diff <= tol,
1433                "matrix mismatch at ({i}, {j}): lhs={}, rhs={}, diff={diff}",
1434                value,
1435                rhs[[i, j]]
1436            );
1437        }
1438    }
1439
1440    #[test]
1441    fn weighted_ridge_batch_matches_single_fit_on_active_rows() {
1442        let x = Array3::from_shape_vec(
1443            (2, 3, 2),
1444            vec![1.0, 0.0, 1.0, 1.0, 0.5, 1.0, 2.0, 1.0, 0.0, 1.0, 9.0, 9.0],
1445        )
1446        .unwrap();
1447        let y = Array3::from_shape_vec((2, 3, 1), vec![1.0, 2.0, 1.5, 2.5, -0.5, 99.0]).unwrap();
1448        let weights = array![[1.0, 0.5, 2.0], [1.0, 3.0, 0.0]];
1449        let penalty = Array2::eye(2);
1450        let row_counts = array![3_usize, 2_usize];
1451
1452        let (coefficients, fitted) = gaussian_weighted_ridge_batch(
1453            x.view(),
1454            y.view(),
1455            penalty.view(),
1456            weights.view(),
1457            0.25,
1458            Some(row_counts.view()),
1459        )
1460        .unwrap();
1461
1462        for b in 0..2 {
1463            let n = row_counts[b];
1464            let (expected_coefficients, expected_fitted) = gaussian_weighted_ridge(
1465                x.slice(s![b, 0..n, ..]),
1466                y.slice(s![b, 0..n, ..]),
1467                penalty.view(),
1468                weights.slice(s![b, 0..n]),
1469                0.25,
1470            )
1471            .unwrap();
1472            assert_close(
1473                coefficients.slice(s![b, .., ..]),
1474                expected_coefficients.view(),
1475                1.0e-10,
1476            );
1477            assert_close(
1478                fitted.slice(s![b, 0..n, ..]),
1479                expected_fitted.view(),
1480                1.0e-10,
1481            );
1482        }
1483        assert_eq!(fitted[[1, 2, 0]], 0.0);
1484    }
1485}
1486
1487#[cfg(test)]
1488mod tests {
1489    use super::{
1490        boundary_hit_step_fraction, solve_spd_pcg, solve_spd_pcg_with_info,
1491        solve_spd_pcg_with_info_into,
1492    };
1493    use ndarray::{Array1, array};
1494
1495    #[test]
1496    fn boundary_hit_step_fraction_ignores_near_tangential_direction() {
1497        let step = boundary_hit_step_fraction(1.0, -1e-16, 1.0);
1498        assert_eq!(step, None);
1499    }
1500
1501    #[test]
1502    fn boundary_hit_step_fraction_returns_first_finite_hit() {
1503        let step = boundary_hit_step_fraction(0.25, -0.5, 1.0);
1504        assert_eq!(step, Some(0.5));
1505    }
1506
1507    #[test]
1508    fn boundary_hit_step_fraction_rejects_non_finite_candidate() {
1509        let step = boundary_hit_step_fraction(1.0, f64::NEG_INFINITY, 1.0);
1510        assert_eq!(step, None);
1511    }
1512
1513    #[test]
1514    fn solve_spd_pcg_matches_reference_solution() {
1515        let h = array![[4.0, 1.0], [1.0, 3.0]];
1516        let b = array![1.0, 2.0];
1517        let m = Array1::from_vec(vec![4.0, 3.0]);
1518        let x = solve_spd_pcg(|v| h.dot(v), &b, &m, 1e-10, 20).expect("pcg solve");
1519        assert!((x[0] - 0.0909090909).abs() < 1e-8);
1520        assert!((x[1] - 0.6363636363).abs() < 1e-8);
1521    }
1522
1523    #[test]
1524    fn solve_spd_pcg_rejects_zero_iteration_budget() {
1525        let h = array![[4.0, 1.0], [1.0, 3.0]];
1526        let b = array![1.0, 2.0];
1527        let m = Array1::from_vec(vec![4.0, 3.0]);
1528        assert!(solve_spd_pcg_with_info(|v| h.dot(v), &b, &m, 1e-10, 0).is_none());
1529        assert!(solve_spd_pcg(|v| h.dot(v), &b, &m, 1e-10, 0).is_none());
1530    }
1531
1532    #[test]
1533    fn matrix_free_qp_beta_matches_dense_reference_with_diagnostics() {
1534        // Small synthetic stand-in for the FLEX marginal-slope joint system:
1535        // a coupled SPD Hessian plus a penalty/ridge Jacobi preconditioner. The
1536        // matrix-free solve must return the same beta as the dense reference,
1537        // while surfacing bounded iteration/residual diagnostics for cycle-0
1538        // triage.
1539        let h = array![
1540            [12.0, 2.0, 0.5, 0.0],
1541            [2.0, 9.0, 1.25, 0.25],
1542            [0.5, 1.25, 7.0, 1.5],
1543            [0.0, 0.25, 1.5, 5.0],
1544        ];
1545        let rhs = array![1.0, -0.5, 2.0, 0.75];
1546        let precond = h.diag().to_owned();
1547        let factor = super::StableSolver::new("synthetic dense reference")
1548            .factorize(&h)
1549            .expect("dense SPD reference");
1550        let mut dense = rhs.clone();
1551        let mut dense_view = crate::faer_ndarray::array1_to_col_matmut(&mut dense);
1552        factor.solve_in_place(dense_view.as_mut());
1553        let (pcg, info) = solve_spd_pcg_with_info_into(
1554            |v, out| {
1555                let prod = h.dot(v);
1556                out.assign(&prod);
1557            },
1558            &rhs,
1559            &precond,
1560            1e-12,
1561            4 * rhs.len(),
1562        )
1563        .expect("matrix-free pcg");
1564
1565        assert!(info.converged);
1566        assert!(info.iterations <= 4 * rhs.len());
1567        assert!(info.final_residual_norm < info.initial_residual_norm);
1568        assert!(info.residual_reduction < 1e-10);
1569        assert!(info.condition_estimate.is_some());
1570        for (reference, actual) in dense.iter().zip(pcg.iter()) {
1571            assert!(
1572                (reference - actual).abs() < 1e-10,
1573                "dense={reference} pcg={actual}"
1574            );
1575        }
1576    }
1577
1578    #[test]
1579    fn solve_spd_pcg_with_info_into_rejects_zero_iteration_budget() {
1580        let h = array![[4.0, 1.0], [1.0, 3.0]];
1581        let b = array![1.0, 2.0];
1582        let m = Array1::from_vec(vec![4.0, 3.0]);
1583        assert!(
1584            solve_spd_pcg_with_info_into(
1585                |v, out| {
1586                    let prod = h.dot(v);
1587                    out.assign(&prod);
1588                },
1589                &b,
1590                &m,
1591                1e-10,
1592                0,
1593            )
1594            .is_none()
1595        );
1596    }
1597}