Skip to main content

gam_terms/
construction.rs

1use crate::EstimationError;
2use crate::basis::analyze_penalty_block;
3use crate::smooth::PenaltyStructureHint;
4use faer::linalg::matmul::matmul;
5use faer::{Accum, Mat, MatRef, Par, Side};
6use gam_linalg::faer_ndarray::{FaerEigh, FaerLinalgError, FaerSvd};
7use gam_linalg::matrix::symmetrize_in_place;
8use gam_linalg::utils::KahanSum;
9use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut2, Axis, s};
10use rayon::iter::{
11    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
12};
13use std::collections::{BTreeMap, HashSet};
14use std::ops::Range;
15use std::sync::Arc;
16
17/// Relative "numerically-PSD" floor of the symmetric eigensolver: an eigenvalue
18/// smaller than this fraction of the spectrum scale is roundoff, not genuine
19/// curvature (~`sqrt(machine ε)`; #1619). This single floor governs BOTH which
20/// eigenvalues are snapped to zero (declared null) in `classify_eigenvalues_strict`
21/// AND how much per-mode relative energy the transformed penalty root may retain
22/// in the resulting null block (`subspace_split_is_consistent`). Keeping the two
23/// tied to one constant makes the null-space definition and the leakage-consistency
24/// guard provably mutually consistent, rather than the guard demanding a precision
25/// the classifier never promised.
26const REL_PSD_FLOOR: f64 = 1.0e-8;
27
28#[derive(Clone)]
29pub enum PenaltyRepresentation {
30    Dense(Array2<f64>),
31    Banded {
32        bands: Vec<Array1<f64>>,
33        offsets: Vec<i32>,
34    },
35    Kronecker {
36        /// Full penalty-block Kronecker product `left ⊗ right`: each entry of
37        /// `left` scales an entire copy of `right` in the dense expansion.
38        ///
39        /// This is distinct from chunked kernel design assembly, where center
40        /// rows are kernel-evaluation arguments rather than matrix factors.
41        left: Array2<f64>,
42        right: Array2<f64>,
43    },
44}
45
46impl PenaltyRepresentation {
47    /// Side length of the square penalty block this representation expands to.
48    pub fn block_dimension(&self) -> usize {
49        match self {
50            PenaltyRepresentation::Dense(matrix) => matrix.nrows(),
51            PenaltyRepresentation::Banded { bands, offsets } => {
52                let mut dim = 0usize;
53                for (band, &offset) in bands.iter().zip(offsets.iter()) {
54                    let len = band.len();
55                    let extent = if offset >= 0 {
56                        len + offset as usize
57                    } else {
58                        len + (-offset) as usize
59                    };
60                    dim = dim.max(extent);
61                }
62                dim
63            }
64            PenaltyRepresentation::Kronecker { left, right } => left.nrows() * right.nrows(),
65        }
66    }
67
68    /// Materialize this representation (Dense / Banded / Kronecker) into a
69    /// single dense symmetric penalty block.
70    pub fn to_block_dense(&self) -> Array2<f64> {
71        match self {
72            PenaltyRepresentation::Dense(matrix) => matrix.clone(),
73            PenaltyRepresentation::Banded { bands, offsets } => {
74                let dim = self.block_dimension();
75                let mut dense = Array2::zeros((dim, dim));
76                let positive_offsets: HashSet<usize> = offsets
77                    .iter()
78                    .filter_map(|&off| (off >= 0).then_some(off as usize))
79                    .collect();
80                for (band, &offset) in bands.iter().zip(offsets.iter()) {
81                    let off = offset.unsigned_abs() as usize;
82                    if offset < 0 && positive_offsets.contains(&off) {
83                        continue;
84                    }
85                    for (idx, &value) in band.iter().enumerate() {
86                        let (i, j) = if offset >= 0 {
87                            (idx, idx + off)
88                        } else {
89                            (idx + off, idx)
90                        };
91                        if i >= dim || j >= dim {
92                            continue;
93                        }
94                        dense[[i, j]] = value;
95                        dense[[j, i]] = value;
96                    }
97                }
98                dense
99            }
100            PenaltyRepresentation::Kronecker { left, right } => {
101                let (lrows, l_cols) = left.dim();
102                let (rrows, r_cols) = right.dim();
103                let mut result = Array2::zeros((lrows * rrows, l_cols * r_cols));
104                for i in 0..lrows {
105                    for j in 0..l_cols {
106                        let scale = left[(i, j)];
107                        if scale == 0.0 {
108                            continue;
109                        }
110                        let mut block = result.slice_mut(s![
111                            i * rrows..(i + 1) * rrows,
112                            j * r_cols..(j + 1) * r_cols
113                        ]);
114                        block.assign(&(right * scale));
115                    }
116                }
117                result
118            }
119        }
120    }
121}
122
123#[derive(Clone)]
124pub struct PenaltyMatrix {
125    pub col_range: Range<usize>,
126    pub representation: PenaltyRepresentation,
127}
128
129impl PenaltyMatrix {
130    fn accumulate_into(&self, mut dest: ArrayViewMut2<'_, f64>, weight: f64) {
131        if weight == 0.0 {
132            return;
133        }
134        match &self.representation {
135            PenaltyRepresentation::Dense(block) => {
136                dest.scaled_add(weight, block);
137            }
138            PenaltyRepresentation::Banded { bands, offsets } => {
139                let positive_offsets: HashSet<usize> = offsets
140                    .iter()
141                    .filter_map(|&off| (off >= 0).then_some(off as usize))
142                    .collect();
143                for (band, &offset) in bands.iter().zip(offsets.iter()) {
144                    let off = offset.unsigned_abs() as usize;
145                    if offset < 0 && positive_offsets.contains(&off) {
146                        continue;
147                    }
148                    for (idx, &value) in band.iter().enumerate() {
149                        let (i, j) = if offset >= 0 {
150                            (idx, idx + off)
151                        } else {
152                            (idx + off, idx)
153                        };
154                        let Some(entry_ij) = dest.get_mut((i, j)) else {
155                            continue;
156                        };
157                        *entry_ij += weight * value;
158                        if i != j
159                            && let Some(entry_ji) = dest.get_mut((j, i))
160                        {
161                            *entry_ji += weight * value;
162                        }
163                    }
164                }
165            }
166            PenaltyRepresentation::Kronecker { left, right } => {
167                let (lrows, l_cols) = left.dim();
168                let (rrows, r_cols) = right.dim();
169                for i in 0..lrows {
170                    for j in 0..l_cols {
171                        let scale = left[(i, j)] * weight;
172                        if scale == 0.0 {
173                            continue;
174                        }
175                        let mut block = dest.slice_mut(s![
176                            i * rrows..(i + 1) * rrows,
177                            j * r_cols..(j + 1) * r_cols
178                        ]);
179                        block.scaled_add(scale, right);
180                    }
181                }
182            }
183        }
184    }
185
186    pub fn to_dense(&self, total_dim: usize) -> Array2<f64> {
187        let mut dense = Array2::<f64>::zeros((total_dim, total_dim));
188        self.accumulate_into(
189            dense.slice_mut(s![self.col_range.clone(), self.col_range.clone()]),
190            1.0,
191        );
192        dense
193    }
194}
195
196pub(crate) fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
197    let (rows, cols) = array.dim();
198    Mat::from_fn(rows, cols, |i, j| array[[i, j]])
199}
200
201pub(crate) fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
202    let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
203    for i in 0..mat.nrows() {
204        for j in 0..mat.ncols() {
205            out[[i, j]] = mat[(i, j)];
206        }
207    }
208    out
209}
210
211fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
212    let (rows, cols) = matrix.shape();
213    let mut maxval = 0.0_f64;
214    for i in 0..rows {
215        for j in 0..cols {
216            let val = matrix[(i, j)];
217            if val.is_finite() {
218                maxval = maxval.max(val.abs());
219            }
220        }
221    }
222    maxval
223}
224
225fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
226    let (rows, cols) = matrix.as_ref().shape();
227    assert_eq!(rows, cols, "Matrix must be square for sanitization");
228
229    let mut sanitized = matrix.clone();
230
231    for i in 0..rows {
232        let diag = sanitized[(i, i)];
233        if !diag.is_finite() {
234            sanitized[(i, i)] = 0.0;
235        }
236        for j in (i + 1)..cols {
237            let mut upper = sanitized[(i, j)];
238            let mut lower = sanitized[(j, i)];
239            if !upper.is_finite() {
240                upper = 0.0;
241            }
242            if !lower.is_finite() {
243                lower = 0.0;
244            }
245            let avg = 0.5 * (upper + lower);
246            sanitized[(i, j)] = avg;
247            sanitized[(j, i)] = avg;
248        }
249    }
250
251    let scale = mat_max_abs_element(sanitized.as_ref());
252    let tiny = (scale * 1e-14).max(1e-30);
253    for i in 0..rows {
254        for j in 0..cols {
255            let val = sanitized[(i, j)];
256            if !val.is_finite() {
257                sanitized[(i, j)] = 0.0;
258            } else if val.abs() < tiny {
259                sanitized[(i, j)] = 0.0;
260            }
261        }
262    }
263
264    sanitized
265}
266
267fn penalty_from_root_faer(root: &Mat<f64>) -> Mat<f64> {
268    let cols = root.ncols();
269    let mut full = Mat::<f64>::zeros(cols, cols);
270    let root_ref = root.as_ref();
271    let root_t = root_ref.transpose();
272    matmul(
273        full.as_mut(),
274        Accum::Replace,
275        root_t,
276        root_ref,
277        1.0,
278        Par::Seq,
279    );
280    sanitize_symmetric_faer(&full)
281}
282
283fn symmetrize_faer_matrix_in_place(matrix: &mut Mat<f64>) {
284    let n = matrix.nrows().min(matrix.ncols());
285    for i in 0..n {
286        for j in 0..i {
287            let avg = 0.5 * (matrix[(i, j)] + matrix[(j, i)]);
288            matrix[(i, j)] = avg;
289            matrix[(j, i)] = avg;
290        }
291    }
292}
293
294fn orthogonal_similarity_transform_faer(
295    matrix: &Mat<f64>,
296    block_dim: usize,
297    orthogonal: &Mat<f64>,
298) -> Mat<f64> {
299    let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
300    let cols = orthogonal.ncols();
301    let mut temp = Mat::<f64>::zeros(block_dim, cols);
302    matmul(
303        temp.as_mut(),
304        Accum::Replace,
305        matrix_block,
306        orthogonal.as_ref(),
307        1.0,
308        Par::Seq,
309    );
310    let mut rotated = Mat::<f64>::zeros(cols, cols);
311    matmul(
312        rotated.as_mut(),
313        Accum::Replace,
314        orthogonal.transpose(),
315        temp.as_ref(),
316        1.0,
317        Par::Seq,
318    );
319    symmetrize_faer_matrix_in_place(&mut rotated);
320    rotated
321}
322
323fn trace_penalty_in_orthogonal_basis(
324    matrix: &Mat<f64>,
325    block_dim: usize,
326    orthogonal: &Mat<f64>,
327    rotated_eigenvalues: &[f64],
328    delta: f64,
329) -> f64 {
330    let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
331    let cols = orthogonal.ncols();
332    assert!(rotated_eigenvalues.len() >= cols);
333    let mut projected = Mat::<f64>::zeros(block_dim, cols);
334    matmul(
335        projected.as_mut(),
336        Accum::Replace,
337        matrix_block,
338        orthogonal.as_ref(),
339        1.0,
340        Par::Seq,
341    );
342    let mut trace = KahanSum::default();
343    for l in 0..cols {
344        let mut diag_ll = KahanSum::default();
345        for i in 0..block_dim {
346            diag_ll.add(orthogonal[(i, l)] * projected[(i, l)]);
347        }
348        trace.add(diag_ll.sum() / (rotated_eigenvalues[l] + delta));
349    }
350    trace.sum()
351}
352
353pub fn trace_reduced_penalty_covariance(
354    reduced_penalty: &Array2<f64>,
355    covariance_basis: &Array2<f64>,
356) -> f64 {
357    assert_eq!(
358        reduced_penalty.dim(),
359        covariance_basis.dim(),
360        "trace_reduced_penalty_covariance dimension mismatch"
361    );
362    let r = covariance_basis.nrows();
363    let mut trace = KahanSum::default();
364    for i in 0..r {
365        for j in 0..r {
366            trace.add(covariance_basis[[i, j]] * reduced_penalty[[j, i]]);
367        }
368    }
369    trace.sum()
370}
371
372pub fn trace_penalty_covariance_in_orthogonal_basis(
373    matrix: &Array2<f64>,
374    orthogonal: &Array2<f64>,
375    covariance_basis: &Array2<f64>,
376) -> f64 {
377    let reduced = gam_linalg::faer_ndarray::fast_ab(
378        &gam_linalg::faer_ndarray::fast_atb(orthogonal, matrix),
379        orthogonal,
380    );
381    trace_reduced_penalty_covariance(&reduced, covariance_basis)
382}
383
384/// Strict spectral classifier used as a final guard on penalty eigendecompositions.
385///
386/// Penalty matrices fed to the GAM solver are required to be PSD by construction.
387/// This routine snaps roundoff-zero eigenvalues to exact zero, accepts strictly
388/// positive eigenvalues, and rejects materially-indefinite or non-finite spectra
389/// with a hard error rather than silently rewriting them. The previous behaviour
390/// (mass-zeroing negative or non-finite eigenvalues) hid construction bugs and
391/// changed the optimisation objective downstream.
392///
393/// The acceptance tolerance is the larger of a machine-ε floor
394/// (`C_EPS_P_FACTOR * eps_machine * p * scale`, with `C_EPS_P_FACTOR = 64`
395/// absorbing the rounding accumulated in a symmetric eigendecomposition of a
396/// moderate-dimension matrix) and a relative "numerically PSD" floor
397/// `REL_PSD_FLOOR * scale`. The latter dominates for large high-rank penalties
398/// assembled / reparameterized at extreme λ, where roundoff produces
399/// ~1e-11-relative negative eigenvalues that are PSD to any reasonable precision
400/// yet exceeded the bare ~12×ε machine floor and spuriously failed the inner
401/// P-IRLS solve (#1619). Genuine indefiniteness is O(1) relative and is still
402/// rejected far above either floor.
403fn classify_eigenvalues_strict(
404    eigenvalues: &mut [f64],
405    context: &str,
406) -> Result<(), EstimationError> {
407    const C_EPS_P_FACTOR: f64 = 64.0;
408    // `REL_PSD_FLOOR` (module-level): the relative threshold below which a
409    // (possibly slightly negative) eigenvalue is roundoff and is snapped to zero
410    // rather than rejected. Shared with the subspace-leakage guard so the null
411    // definition and the leakage tolerance stay mutually consistent.
412    let p = eigenvalues.len();
413
414    let mut scale = 0.0_f64;
415    for (idx, &val) in eigenvalues.iter().enumerate() {
416        if !val.is_finite() {
417            return Err(EstimationError::PenaltySpectrumNonFinite {
418                context: context.to_string(),
419                index: idx,
420                value: val,
421            });
422        }
423        scale = scale.max(val.abs());
424    }
425
426    // p * eps captures the rounding floor of a symmetric eigendecomposition of a
427    // p-dimensional matrix; multiplying by `scale` lifts the floor to the actual
428    // magnitude of the spectrum. For large high-rank penalties assembled at
429    // extreme λ this machine floor (~12×ε relative) is tighter than the roundoff
430    // actually produced, so we take the larger of it and a relative numerically-PSD
431    // floor `REL_PSD_FLOOR * scale` (#1619).
432    let machine_floor = C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale;
433    let tolerance = machine_floor
434        .max(REL_PSD_FLOOR * scale)
435        .max(f64::MIN_POSITIVE);
436
437    for (idx, val) in eigenvalues.iter_mut().enumerate() {
438        if val.abs() <= tolerance {
439            *val = 0.0;
440        } else if *val < 0.0 {
441            return Err(EstimationError::PenaltySpectrumIndefinite {
442                context: context.to_string(),
443                index: idx,
444                value: *val,
445                tolerance,
446                scale,
447            });
448        }
449    }
450    Ok(())
451}
452
453fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
454    matrix: &M,
455    context: &str,
456    validate_input: Validate,
457    sanitize: Sanitize,
458    mut eig_call: EigCall,
459    map_error: MapErr,
460) -> Result<(Vec<f64>, V), EstimationError>
461where
462    Validate: Fn(&M, &str) -> Result<(), EstimationError>,
463    Sanitize: Fn(&M) -> M,
464    EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
465    MapErr: Fn(E, &str) -> EstimationError,
466{
467    validate_input(matrix, context)?;
468
469    // The sanitize step only enforces exact symmetry by averaging M and M^T and
470    // zeros sub-eps noise; it never adds a diagonal ridge. Adding ridge changes
471    // the matrix being decomposed, which silently changes the optimisation
472    // objective downstream. If eigh genuinely fails on a finite symmetric input,
473    // surface the error instead of mutating the spectrum.
474    let candidate = sanitize(matrix);
475    match eig_call(&candidate) {
476        Ok((mut eigenvalues, eigenvectors)) => {
477            classify_eigenvalues_strict(&mut eigenvalues, context)?;
478            Ok((eigenvalues, eigenvectors))
479        }
480        Err(err) => Err(map_error(err, context)),
481    }
482}
483
484pub(crate) fn robust_eigh_faer(
485    matrix: &Mat<f64>,
486    side: Side,
487    context: &str,
488) -> Result<(Vec<f64>, Mat<f64>), EstimationError> {
489    robust_eighwith_policy(
490        matrix,
491        context,
492        |mat, ctx| {
493            let (rows, cols) = mat.as_ref().shape();
494            for i in 0..rows {
495                for j in 0..cols {
496                    let val = mat[(i, j)];
497                    if !val.is_finite() {
498                        let max_abs = mat_max_abs_element(mat.as_ref());
499                        crate::bail_invalid_estim!(
500                            "{} contains non-finite entries (max finite magnitude {:.3e})",
501                            ctx,
502                            max_abs
503                        );
504                    }
505                }
506            }
507            Ok(())
508        },
509        sanitize_symmetric_faer,
510        |candidate| {
511            let eig = candidate.as_ref().self_adjoint_eigen(side)?;
512            let diag = eig.S();
513            let mut eigenvalues = Vec::with_capacity(diag.dim());
514            for idx in 0..diag.dim() {
515                eigenvalues.push(diag[idx]);
516            }
517
518            let vectors_ref = eig.U();
519            let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
520            for i in 0..vectors_ref.nrows() {
521                for j in 0..vectors_ref.ncols() {
522                    eigenvectors[(i, j)] = vectors_ref[(i, j)];
523                }
524            }
525            Ok((eigenvalues, eigenvectors))
526        },
527        |err, _ctx| {
528            EstimationError::EigendecompositionFailed(FaerLinalgError::SelfAdjointEigen(err))
529        },
530    )
531}
532
533fn robust_eigh(
534    matrix: &Array2<f64>,
535    side: Side,
536    context: &str,
537) -> Result<(Array1<f64>, Array2<f64>), EstimationError> {
538    let matrix_faer = array_to_faer(matrix);
539    let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
540    Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
541}
542
543pub(crate) fn kronecker_marginal_eigensystems(
544    marginal_penalties: &[Array2<f64>],
545    context: &str,
546) -> Result<Vec<(Array1<f64>, Array2<f64>)>, EstimationError> {
547    let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
548    for (k, penalty) in marginal_penalties.iter().enumerate() {
549        eigensystems.push(robust_eigh(
550            penalty,
551            Side::Lower,
552            &format!("{context} marginal {k}"),
553        )?);
554    }
555    Ok(eigensystems)
556}
557
558#[derive(Debug, Clone, Copy)]
559struct SubspaceLeakageMetrics {
560    max_abs_sq: f64,
561    max_rel_sq: f64,
562    worst_penalty: usize,
563    max_cross_gram_abs: f64,
564}
565
566fn assess_subspace_leakage(
567    qs: &Mat<f64>,
568    rs_transformed: &[Mat<f64>],
569    structural_rank: usize,
570    p: usize,
571) -> SubspaceLeakageMetrics {
572    let mut max_abs_sq = 0.0_f64;
573    let mut max_rel_sq = 0.0_f64;
574    let mut worst_penalty = 0usize;
575
576    for (k, rs) in rs_transformed.iter().enumerate() {
577        let rows = rs.nrows();
578        let cols = rs.ncols().min(p);
579        let null_start = structural_rank.min(cols);
580        let mut abs_sq = 0.0_f64;
581        let mut total_sq = 0.0_f64;
582        for i in 0..rows {
583            for j in 0..cols {
584                let v = rs[(i, j)];
585                let vv = v * v;
586                total_sq += vv;
587                if j >= null_start {
588                    abs_sq += vv;
589                }
590            }
591        }
592        let rel_sq = if total_sq > 0.0 {
593            abs_sq / total_sq
594        } else {
595            0.0
596        };
597        if rel_sq > max_rel_sq {
598            max_rel_sq = rel_sq;
599            worst_penalty = k;
600        }
601        max_abs_sq = max_abs_sq.max(abs_sq);
602    }
603
604    let mut max_cross_gram_abs = 0.0_f64;
605    let null_count = p.saturating_sub(structural_rank);
606    if structural_rank > 0 && null_count > 0 {
607        for i in 0..structural_rank {
608            for j in 0..null_count {
609                let qn_col = structural_rank + j;
610                let mut dot = 0.0_f64;
611                for r in 0..p {
612                    dot += qs[(r, i)] * qs[(r, qn_col)];
613                }
614                max_cross_gram_abs = max_cross_gram_abs.max(dot.abs());
615            }
616        }
617    }
618
619    SubspaceLeakageMetrics {
620        max_abs_sq,
621        max_rel_sq,
622        worst_penalty,
623        max_cross_gram_abs,
624    }
625}
626
627/// True when the penalized/null subspace split is numerically self-consistent.
628///
629/// The split has two independent invariants:
630///
631/// 1. **Orthogonality** — `Qs = [Q_p | Q_n]` must be orthonormal, so the range
632///    and null blocks share no direction (`max |Qp'Qn| ≤ orth_tol`). This is the
633///    structural correctness guarantee and is checked at machine precision.
634///
635/// 2. **Bounded root leakage** — the transformed penalty root must have
636///    negligible energy on the null columns. The admissible relative-energy
637///    leakage is DERIVED from `REL_PSD_FLOOR`, the same numerically-PSD floor
638///    that `classify_eigenvalues_strict` uses to decide which modes are null:
639///    a mode the classifier is entitled to call null can, by the symmetric
640///    eigensolver's own eigenvector accuracy on a small-gap spectrum, retain up
641///    to `REL_PSD_FLOOR` of the penalty root's relative energy in that direction.
642///    Summed over the at-most-`p` near-threshold null modes the relative leakage
643///    cannot exceed `p · REL_PSD_FLOOR` without signalling a genuine
644///    (non-numerical) inconsistency, so that is the derived tolerance — matching
645///    the `p`-scaling `classify_eigenvalues_strict` already applies to its
646///    machine floor. Demanding a leakage tighter than the very floor that
647///    defined the null space is self-contradictory: it rejects well-posed smooth
648///    manifold / Duchon / sphere penalties whose Laplace-Beltrami spectrum decays
649///    through the classification threshold with no clean rank gap (#1802), even
650///    though the downstream penalty (`E'E`) is rebuilt with an EXACTLY clean null
651///    block regardless. The absolute floor keeps a vanishing-scale penalty from
652///    tripping the relative test on pure roundoff.
653fn subspace_split_is_consistent(leakage: &SubspaceLeakageMetrics, p: usize) -> bool {
654    let leakage_rel_tol = (p.max(1) as f64) * REL_PSD_FLOOR;
655    let leakage_abs_tol = 1e-12;
656    let orth_tol = 1e-10;
657    let root_leaks =
658        leakage.max_rel_sq > leakage_rel_tol && leakage.max_abs_sq > leakage_abs_tol;
659    let split_nonorthogonal = leakage.max_cross_gram_abs > orth_tol;
660    !(root_leaks || split_nonorthogonal)
661}
662
663fn compose_qs_from_split(q_pen: &Mat<f64>, q_null: &Mat<f64>, p: usize) -> Mat<f64> {
664    let rank = q_pen.ncols();
665    let null_count = q_null.ncols();
666    let mut qs = Mat::<f64>::zeros(p, p);
667    for i in 0..p {
668        for j in 0..rank {
669            qs[(i, j)] = q_pen[(i, j)];
670        }
671        for j in 0..null_count {
672            qs[(i, rank + j)] = q_null[(i, j)];
673        }
674    }
675    qs
676}
677
678/// Computes the Kronecker product A ⊗ B for penalty matrix construction.
679/// This is used to create tensor product penalties that enforce smoothness
680/// in multiple dimensions for interaction terms.
681pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
682    let (arows, a_cols) = a.dim();
683    let (brows, b_cols) = b.dim();
684    if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
685        return Array2::zeros((arows * brows, a_cols * b_cols));
686    }
687    let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
688
689    result
690        .axis_chunks_iter_mut(Axis(0), brows)
691        .into_par_iter()
692        .enumerate()
693        .for_each(|(i, mut row_block)| {
694            let arow = a.row(i);
695            let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
696            for (j, mut block) in col_chunks.into_iter().enumerate() {
697                let aval = arow[j];
698                if aval == 0.0 {
699                    continue;
700                }
701                for (dest, &src) in block.iter_mut().zip(b.iter()) {
702                    *dest = aval * src;
703                }
704            }
705        });
706
707    result
708}
709
710/// Result of the stable reparameterization algorithm from Wood (2011) Appendix B
711#[derive(Clone)]
712pub struct ReparamResult {
713    /// Penalty matrix in TRANSFORMED coefficient coordinates.
714    ///
715    /// This must be compatible with `beta_transformed` and `X_transformed = X * Qs`.
716    pub s_transformed: Array2<f64>,
717    /// Log-determinant of the penalty matrix (stable computation)
718    pub log_det: f64,
719    /// First derivatives of log-determinant w.r.t. log-smoothing parameters
720    pub det1: Array1<f64>,
721    /// Orthogonal transformation matrix Qs
722    pub qs: Array2<f64>,
723    /// Canonical penalties in the TRANSFORMED coordinate frame.
724    /// The single source of truth for penalty roots in the transformed frame.
725    /// Downstream consumers use these for block-local `PenaltyCoordinate`
726    /// construction, TK correction, and ext-coord paths.
727    pub canonical_transformed: Vec<CanonicalPenalty>,
728    /// Lambda-dependent penalty square root in TRANSFORMED coordinates (rank x p matrix).
729    /// This is used for applying the actual penalty in the least squares solve.
730    pub e_transformed: Array2<f64>,
731    /// Truncated eigenvectors (p × m where m = p - structural_rank).
732    ///
733    /// Coordinate frame note:
734    /// - This matrix is stored in the TRANSFORMED coefficient frame (post-`Qs`),
735    ///   i.e. it is compatible with `canonical_transformed`, `beta_transformed`,
736    ///   and transformed Hessians without additional coordinate mapping.
737    ///
738    /// These vectors span the structural null space used by positive-part
739    /// log-determinant conventions.
740    pub u_truncated: Array2<f64>,
741    /// The rho-independent shrinkage ridge magnitude that was added to each
742    /// eigenvalue of the penalized block. Zero means no shrinkage was applied.
743    pub penalty_shrinkage_ridge: f64,
744}
745
746// ---------------------------------------------------------------------------
747// Kronecker factor decomposition primitives
748// ---------------------------------------------------------------------------
749
750/// Per-factor decomposition result for Kronecker penalties.
751struct KroneckerFactorDecomp {
752    root: Array2<f64>,              // rank_j × q_j
753    positive_eigenvalues: Vec<f64>, // length = rank_j
754    rank: usize,
755    dim: usize,
756}
757
758/// Eigendecompose each Kronecker factor separately at O(Σ q_j³).
759/// Returns per-factor decompositions, or `None` if any factor is zero.
760fn decompose_kronecker_factors(
761    factors: &[Array2<f64>],
762    context: &str,
763) -> Result<Option<Vec<KroneckerFactorDecomp>>, EstimationError> {
764    let mut decomps = Vec::with_capacity(factors.len());
765    for (j, factor) in factors.iter().enumerate() {
766        let q_j = factor.nrows();
767        if q_j != factor.ncols() {
768            crate::bail_invalid_estim!(
769                "{context}: Kronecker factor {j} must be square, got {}x{}",
770                factor.nrows(),
771                factor.ncols()
772            );
773        }
774        let is_identity = {
775            let mut is_id = true;
776            'outer: for r in 0..q_j {
777                for c in 0..q_j {
778                    let expected = if r == c { 1.0 } else { 0.0 };
779                    if (factor[[r, c]] - expected).abs() > 1e-12 {
780                        is_id = false;
781                        break 'outer;
782                    }
783                }
784            }
785            is_id
786        };
787        if is_identity {
788            decomps.push(KroneckerFactorDecomp {
789                root: Array2::eye(q_j),
790                positive_eigenvalues: vec![1.0; q_j],
791                rank: q_j,
792                dim: q_j,
793            });
794            continue;
795        }
796        let analysis = analyze_penalty_block(factor).map_err(|err| {
797            EstimationError::InvalidInput(format!(
798                "{context}: Kronecker factor {j} eigendecomp failed: {err}"
799            ))
800        })?;
801        if analysis.rank == 0 {
802            return Ok(None);
803        }
804        // Build the factor root from ONLY the range (positive-curvature)
805        // directions via the canonical classifier — never the null or
806        // negative-curvature directions (#1425).
807        let factor_classes =
808            crate::basis::SpectralClassification::new(&analysis.eigenvalues, analysis.tol);
809        let mut root_j = Array2::zeros((analysis.rank, q_j));
810        let mut pos_eigs = Vec::with_capacity(analysis.rank);
811        for (row_idx, &i) in factor_classes.range_idx.iter().enumerate() {
812            let eigenval = analysis.eigenvalues[i];
813            let sqrt_ev = eigenval.sqrt();
814            let evec = analysis.eigenvectors.column(i);
815            for (col, &v) in evec.iter().enumerate() {
816                root_j[[row_idx, col]] = sqrt_ev * v;
817            }
818            pos_eigs.push(eigenval);
819        }
820        decomps.push(KroneckerFactorDecomp {
821            root: root_j,
822            positive_eigenvalues: pos_eigs,
823            rank: analysis.rank,
824            dim: q_j,
825        });
826    }
827    Ok(Some(decomps))
828}
829
830/// Build the block-local Kronecker root from pre-computed factor decompositions.
831fn assemble_kronecker_root_local(decomps: &[KroneckerFactorDecomp]) -> Array2<f64> {
832    let mut kron_root = decomps[0].root.clone();
833    for fr in &decomps[1..] {
834        let (r1, c1) = kron_root.dim();
835        let (r2, c2) = (fr.rank, fr.dim);
836        let mut new_root = Array2::zeros((r1 * r2, c1 * c2));
837        for i1 in 0..r1 {
838            for i2 in 0..r2 {
839                for j1 in 0..c1 {
840                    for j2 in 0..c2 {
841                        new_root[[i1 * r2 + i2, j1 * c2 + j2]] =
842                            kron_root[[i1, j1]] * fr.root[[i2, j2]];
843                    }
844                }
845            }
846        }
847        kron_root = new_root;
848    }
849    kron_root
850}
851
852/// Compute eigenvalues of the Kronecker product from per-factor eigenvalues.
853fn kronecker_eigenvalues(decomps: &[KroneckerFactorDecomp], block_dim: usize) -> (Vec<f64>, usize) {
854    let mut kron_eigs = decomps[0].positive_eigenvalues.clone();
855    for fd in &decomps[1..] {
856        let mut new_eigs = Vec::with_capacity(kron_eigs.len() * fd.positive_eigenvalues.len());
857        for &a in &kron_eigs {
858            for &b in &fd.positive_eigenvalues {
859                new_eigs.push(a * b);
860            }
861        }
862        kron_eigs = new_eigs;
863    }
864    let max_ev = kron_eigs.iter().copied().fold(0.0_f64, f64::max);
865    let tol = max_ev * 1e-10 * (block_dim as f64);
866    let positive: Vec<f64> = kron_eigs.into_iter().filter(|&ev| ev > tol).collect();
867    let nullity = block_dim - positive.len();
868    (positive, nullity)
869}
870
871// ---------------------------------------------------------------------------
872// CanonicalPenalty — block-local processed penalty for the solver
873// ---------------------------------------------------------------------------
874
875/// A canonicalized penalty with block-local root, ready for the solver.
876///
877/// Instead of storing a full `p x p` penalty matrix, this stores only the
878/// `rank x block_dim` root and the column range, enabling O(p_k^2) operations
879/// instead of O(p^2).
880#[derive(Clone)]
881pub struct CanonicalPenalty {
882    /// Square root matrix: S_k = root^T * root.
883    /// Shape: `rank x block_dim` for block-local, `rank x p` for dense.
884    pub root: Array2<f64>,
885    /// Column range in the global coefficient vector [start..end).
886    /// For dense penalties this is `0..p`.
887    pub col_range: std::ops::Range<usize>,
888    /// Full parameter dimension p.
889    pub total_dim: usize,
890    /// Structural nullity of the local penalty.
891    pub nullity: usize,
892    /// The symmetrized block-local penalty matrix (block_dim × block_dim).
893    /// Cached at construction time to avoid recomputing root^T * root
894    /// in hot paths (penalty assembly, trace products).
895    pub local: Array2<f64>,
896    /// Block-local prior mean used to center this penalty.
897    pub prior_mean: Array1<f64>,
898    /// Positive eigenvalues of the local penalty matrix (length = rank).
899    /// Cached at construction time for REML logdet block-factored paths.
900    pub positive_eigenvalues: Vec<f64>,
901    /// Optional operator-form handle bit-equivalent to `local`. Propagated
902    /// from `PenaltySpec::Block.op`. Downstream PIRLS and REML exact operator
903    /// algebra route through this for dense-Gram-free matvec when present.
904    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
905}
906
907impl std::fmt::Debug for CanonicalPenalty {
908    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
909        f.debug_struct("CanonicalPenalty")
910            .field(
911                "root",
912                &format_args!("{}×{}", self.root.nrows(), self.root.ncols()),
913            )
914            .field("col_range", &self.col_range)
915            .field("total_dim", &self.total_dim)
916            .field("nullity", &self.nullity)
917            .field(
918                "local",
919                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
920            )
921            .field("prior_mean_len", &self.prior_mean.len())
922            .field("positive_eigenvalues", &self.positive_eigenvalues)
923            .field("op", &self.op.as_ref().map(|o| o.dim()))
924            .finish()
925    }
926}
927
928impl CanonicalPenalty {
929    /// Construct a dense (full-width) canonical penalty from a `rank x p` root.
930    /// Used to wrap reparam-transformed roots for consumers that expect
931    /// `&[CanonicalPenalty]`.
932    pub fn from_dense_root(root: Array2<f64>, p: usize) -> Self {
933        Self::from_dense_root_with_mean(root, p, Array1::zeros(p))
934    }
935
936    pub fn from_dense_root_with_mean(root: Array2<f64>, p: usize, prior_mean: Array1<f64>) -> Self {
937        assert_eq!(prior_mean.len(), p);
938        let local = root.t().dot(&root);
939        let positive_eigenvalues = Vec::new(); // not needed for TK paths
940        Self {
941            root,
942            col_range: 0..p,
943            total_dim: p,
944            nullity: 0,
945            local,
946            prior_mean,
947            positive_eigenvalues,
948            op: None,
949        }
950    }
951
952    /// Embed the block-local root into a full-width `rank × total_dim` matrix.
953    /// For dense penalties (col_range = 0..p), returns the root unchanged.
954    pub fn full_width_root(&self) -> Array2<f64> {
955        if self.col_range.start == 0 && self.col_range.end == self.total_dim {
956            return self.root.clone();
957        }
958        let rank = self.root.nrows();
959        let mut full = Array2::<f64>::zeros((rank, self.total_dim));
960        full.slice_mut(ndarray::s![.., self.col_range.clone()])
961            .assign(&self.root);
962        full
963    }
964
965    /// Numerical rank of this penalty.
966    pub fn rank(&self) -> usize {
967        self.root.nrows()
968    }
969
970    /// Block dimension (number of columns this penalty covers).
971    pub fn block_dim(&self) -> usize {
972        self.col_range.len()
973    }
974
975    /// Whether this penalty is block-local (col_range != 0..total_dim).
976    pub const fn is_block_local(&self) -> bool {
977        self.col_range.start != 0 || self.col_range.end != self.total_dim
978    }
979
980    /// Return a reference to the cached local penalty matrix.
981    /// Shape: `block_dim x block_dim`.
982    pub fn local_ref(&self) -> &Array2<f64> {
983        &self.local
984    }
985
986    /// Return an owned copy of the local penalty matrix.
987    /// Prefer `local_ref()` when a reference suffices.
988    pub fn local_penalty(&self) -> Array2<f64> {
989        self.local.clone()
990    }
991
992    /// Accumulate lambda * S_k into a pre-allocated `p x p` target matrix.
993    /// Only touches the block [col_range × col_range].
994    pub fn accumulate_weighted(&self, target: &mut Array2<f64>, lambda: f64) {
995        if lambda == 0.0 || self.rank() == 0 {
996            return;
997        }
998        let r = &self.col_range;
999        target
1000            .slice_mut(s![r.start..r.end, r.start..r.end])
1001            .scaled_add(lambda, &self.local);
1002    }
1003
1004    /// Compute `scale * tr(M · S_k)` where M is a `p × p` dense matrix.
1005    /// Only reads `M[start..end, start..end]` — O(block_dim²) not O(p²).
1006    pub fn trace_product(&self, m: &Array2<f64>, scale: f64) -> f64 {
1007        if self.rank() == 0 || scale == 0.0 {
1008            return 0.0;
1009        }
1010        let r = &self.col_range;
1011        let m_block = m.slice(s![r.start..r.end, r.start..r.end]);
1012        let rm = self.root.dot(&m_block);
1013        scale
1014            * rm.iter()
1015                .zip(self.root.iter())
1016                .map(|(&a, &b)| a * b)
1017                .sum::<f64>()
1018    }
1019
1020    /// Compute `scale * v^T S_k v` (quadratic form).
1021    /// Only reads `v[start..end]` — O(rank × block_dim) not O(rank × p).
1022    pub fn quadratic(&self, v: &Array1<f64>, scale: f64) -> f64 {
1023        if self.rank() == 0 || scale == 0.0 {
1024            return 0.0;
1025        }
1026        let v_block = v.slice(s![self.col_range.start..self.col_range.end]);
1027        let rv = self.root.dot(&v_block);
1028        scale * rv.dot(&rv)
1029    }
1030
1031    /// Compute `scale * S_k * prior_mean` embedded into the global basis.
1032    pub fn prior_linear_shift(&self, scale: f64) -> Array1<f64> {
1033        let mut out = Array1::<f64>::zeros(self.total_dim);
1034        if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
1035            return out;
1036        }
1037        let block = self.local.dot(&self.prior_mean) * scale;
1038        out.slice_mut(s![self.col_range.start..self.col_range.end])
1039            .assign(&block);
1040        out
1041    }
1042
1043    /// Compute `scale * prior_mean' S_k prior_mean`.
1044    pub fn prior_constant_shift(&self, scale: f64) -> f64 {
1045        if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
1046            return 0.0;
1047        }
1048        scale * self.prior_mean.dot(&self.local.dot(&self.prior_mean))
1049    }
1050
1051    /// Embed this block's prior mean into the global coefficient basis.
1052    pub fn full_width_prior_mean(&self) -> Array1<f64> {
1053        if self.col_range.start == 0 && self.col_range.end == self.total_dim {
1054            return self.prior_mean.clone();
1055        }
1056        let mut out = Array1::<f64>::zeros(self.total_dim);
1057        out.slice_mut(s![self.col_range.start..self.col_range.end])
1058            .assign(&self.prior_mean);
1059        out
1060    }
1061
1062    /// Convert to a PenaltyCoordinate for the unified REML evaluator.
1063    pub fn to_penalty_coordinate(&self) -> gam_problem::PenaltyCoordinate {
1064        use gam_problem::PenaltyCoordinate;
1065        if self.is_block_local() {
1066            PenaltyCoordinate::from_block_root_with_mean(
1067                self.root.clone(),
1068                self.col_range.start,
1069                self.col_range.end,
1070                self.total_dim,
1071                self.prior_mean.clone(),
1072            )
1073        } else {
1074            PenaltyCoordinate::from_dense_root_with_mean(self.root.clone(), self.prior_mean.clone())
1075        }
1076    }
1077}
1078
1079/// Detect and report structurally identical (or near-identical) penalty pairs
1080/// in the canonical bundle.
1081///
1082/// Two penalties `S_i`, `S_j` with the same `col_range` are compared by their
1083/// matrix cosine:
1084///
1085///     cos(S_i, S_j) = tr(S_i S_j) / sqrt(tr(S_i^2) * tr(S_j^2))
1086///
1087/// Because `local` is symmetric, `tr(A·B) = sum_{r,c} A[r,c] * B[r,c]` — i.e.,
1088/// the Frobenius inner product. Pairs with different `col_range` cannot be
1089/// functionally identical and are skipped.
1090///
1091/// Logging policy:
1092/// - `cos > 1 - 1e-8` → `log::warn!` with `[PENALTY-REDUNDANCY]`. Every such
1093///   pair is emitted because it represents a structural model error (the LAML
1094///   cost has a Z₂-symmetric saddle that ARC's cubic regularization will
1095///   happily converge to under first-order stationarity).
1096/// - `0.99 < cos ≤ 1 - 1e-8` → `log::info!` with `[PENALTY-SIMILARITY]`.
1097///   At large scale (`k > 64`) only the top-3 highest-cosine such pairs are
1098///   logged to bound log volume.
1099///
1100/// Returns `Vec<(i, j, cos)>` for the **redundant** pairs (cos > 1 - 1e-8),
1101/// primarily to make this function unit-testable without a log capture.
1102///
1103/// Performance: this is O(k² · block_dim²); intended to be called exactly
1104/// once per fit (e.g. from `RemlState::newwith_offset_shared`).
1105pub fn report_penalty_pair_redundancy(canonical: &[CanonicalPenalty]) -> Vec<(usize, usize, f64)> {
1106    const REDUNDANCY_THRESHOLD: f64 = 1.0 - 1e-8;
1107    const SIMILARITY_THRESHOLD: f64 = 0.99;
1108    const LARGE_SCALE_K_THRESHOLD: usize = 64;
1109    const TOP_SIMILARITY_PAIRS: usize = 3;
1110
1111    let k = canonical.len();
1112    let mut redundant: Vec<(usize, usize, f64)> = Vec::new();
1113    let mut similar: Vec<(usize, usize, f64)> = Vec::new();
1114
1115    // Pre-compute tr(S_i^2) = sum of squares of S_i entries (Frobenius norm
1116    // squared). `local` is symmetric, so this equals tr(S_i^T S_i) = tr(S_i^2).
1117    let trace_sq: Vec<f64> = canonical
1118        .iter()
1119        .map(|p| p.local.iter().map(|&v| v * v).sum::<f64>())
1120        .collect();
1121
1122    for i in 0..k {
1123        if trace_sq[i] == 0.0 {
1124            continue;
1125        }
1126        for j in (i + 1)..k {
1127            if trace_sq[j] == 0.0 {
1128                continue;
1129            }
1130            // Different col_range → cannot be functionally identical by
1131            // construction (the block-local matrices live in disjoint or
1132            // mismatched parameter subspaces).
1133            if canonical[i].col_range != canonical[j].col_range {
1134                continue;
1135            }
1136            // Shapes must match — they do when col_range matches because
1137            // `local` is `block_dim × block_dim` and `block_dim = col_range.len()`.
1138            assert_eq!(canonical[i].local.dim(), canonical[j].local.dim());
1139
1140            let inner: f64 = canonical[i]
1141                .local
1142                .iter()
1143                .zip(canonical[j].local.iter())
1144                .map(|(&a, &b)| a * b)
1145                .sum();
1146            let denom = (trace_sq[i] * trace_sq[j]).sqrt();
1147            if denom == 0.0 {
1148                continue;
1149            }
1150            let cos = inner / denom;
1151
1152            if cos > REDUNDANCY_THRESHOLD {
1153                redundant.push((i, j, cos));
1154            } else if cos > SIMILARITY_THRESHOLD {
1155                similar.push((i, j, cos));
1156            }
1157        }
1158    }
1159
1160    // Always emit every redundancy — these are structural model errors.
1161    for &(i, j, cos) in &redundant {
1162        log::warn!(
1163            "[PENALTY-REDUNDANCY] penalties i={i} j={j} are structurally identical \
1164             (cos={cos:.6}) — model is over-parameterized along their antisymmetric \
1165             direction; expect a Z₂-symmetric saddle in the LAML cost. Consider \
1166             re-specifying (e.g. anisotropic→isotropic for spatial smoothers with \
1167             weak axis signal)."
1168        );
1169    }
1170
1171    // Cap similarity log volume at large scale.
1172    if k > LARGE_SCALE_K_THRESHOLD && similar.len() > TOP_SIMILARITY_PAIRS {
1173        similar.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1174        similar.truncate(TOP_SIMILARITY_PAIRS);
1175    }
1176    for (i, j, cos) in similar {
1177        log::info!(
1178            "[PENALTY-SIMILARITY] penalties i={i} j={j} are near-identical \
1179             (cos={cos:.6}) — outer Hessian may be ill-conditioned along their \
1180             antisymmetric direction."
1181        );
1182    }
1183
1184    redundant
1185}
1186
1187/// Canonicalize a single `PenaltySpec` into a `CanonicalPenalty` by computing
1188/// the block-local eigendecomposition and extracting the root.
1189///
1190/// This is O(block_dim^3) instead of O(p^3) for block-local penalties.
1191/// Returns `None` if the penalty has rank zero (should be dropped).
1192pub fn canonicalize_penalty_spec(
1193    spec: &crate::PenaltySpec,
1194    p: usize,
1195    idx: usize,
1196    context: &str,
1197) -> Result<Option<CanonicalPenalty>, EstimationError> {
1198    use crate::PenaltySpec;
1199
1200    crate::validate_penalty_spec_shape(idx, spec, p, context)?;
1201
1202    let (local_matrix, col_range, prior_mean_spec, hint, op) = match spec {
1203        PenaltySpec::Block {
1204            local,
1205            col_range,
1206            prior_mean,
1207            structure_hint,
1208            op,
1209        } => (
1210            local.view(),
1211            col_range.clone(),
1212            prior_mean,
1213            structure_hint.as_ref(),
1214            op.clone(),
1215        ),
1216        PenaltySpec::Dense(m) => (
1217            m.view(),
1218            0..p,
1219            &gam_problem::CoefficientPriorMean::Zero,
1220            None,
1221            None,
1222        ),
1223        PenaltySpec::DenseWithMean { matrix, prior_mean } => {
1224            (matrix.view(), 0..p, prior_mean, None, None)
1225        }
1226    };
1227
1228    let block_dim = col_range.len();
1229    let prior_mean = prior_mean_spec
1230        .evaluate(block_dim, &format!("{context}: penalty {idx}"))
1231        .map_err(|e| EstimationError::InvalidInput(e.0))?;
1232
1233    // ── Ridge fast path: closed-form, no eigendecomposition ──
1234    if let Some(PenaltyStructureHint::Ridge(scale)) = hint {
1235        if *scale <= 0.0 {
1236            return Ok(None);
1237        }
1238        let sqrt_scale = scale.sqrt();
1239        let mut root = Array2::zeros((block_dim, block_dim));
1240        for i in 0..block_dim {
1241            root[[i, i]] = sqrt_scale;
1242        }
1243        // Ridge penalties are diagonal by construction, but still route through
1244        // the crate-wide ndarray symmetrizer so every construction variant uses
1245        // the same "average the transpose" cleanup instead of a local copy.
1246        let mut local_sym = local_matrix.to_owned();
1247        symmetrize_in_place(&mut local_sym);
1248        return Ok(Some(CanonicalPenalty {
1249            root,
1250            col_range,
1251            total_dim: p,
1252            nullity: 0,
1253            local: local_sym,
1254            prior_mean,
1255            positive_eigenvalues: vec![*scale; block_dim],
1256            op,
1257        }));
1258    }
1259
1260    // ── Kronecker fast path: single per-factor eigendecomposition ──
1261    if let Some(PenaltyStructureHint::Kronecker(factors)) = hint {
1262        let decomps =
1263            match decompose_kronecker_factors(factors, &format!("{context} penalty {idx}"))? {
1264                None => return Ok(None),
1265                Some(d) => d,
1266            };
1267        let (positive_eigenvalues, nullity) = kronecker_eigenvalues(&decomps, block_dim);
1268        if positive_eigenvalues.is_empty() {
1269            return Ok(None);
1270        }
1271        let root = assemble_kronecker_root_local(&decomps);
1272        let mut local_sym = local_matrix.to_owned();
1273        symmetrize_in_place(&mut local_sym);
1274        return Ok(Some(CanonicalPenalty {
1275            root,
1276            col_range,
1277            total_dim: p,
1278            nullity,
1279            local: local_sym,
1280            prior_mean,
1281            positive_eigenvalues,
1282            op,
1283        }));
1284    }
1285
1286    // ── Generic block-local path: eigendecompose at O(block_dim³) ──
1287    let local_owned = local_matrix.to_owned();
1288    let analysis = analyze_penalty_block(&local_owned).map_err(|err| {
1289        EstimationError::InvalidInput(format!(
1290            "{context}: penalty canonicalization failed at index {idx}: {err}"
1291        ))
1292    })?;
1293
1294    if analysis.rank == 0 {
1295        log::debug!(
1296            "Dropped inactive penalty block idx={idx} reason={}",
1297            if analysis.iszero {
1298                "ZeroMatrix"
1299            } else {
1300                "NumericalRankZero"
1301            }
1302        );
1303        return Ok(None);
1304    }
1305
1306    // Reuse the eigendecomposition from analyze_penalty_block and route the
1307    // range / null / negative-curvature split through the one canonical
1308    // classifier, so this root construction cannot disagree with the block's
1309    // own `rank` / `nullity` / `negative_dim` about which directions are
1310    // penalized, unpenalized, or non-PSD (#1425).
1311    let tolerance = analysis.tol;
1312    let classes = crate::basis::SpectralClassification::new(&analysis.eigenvalues, tolerance);
1313    let rank_k = classes.rank();
1314    assert_eq!(
1315        rank_k, analysis.rank,
1316        "penalty-root rank disagreement: SpectralClassification rank={rank_k} vs analyze_penalty_block rank={} (#1425 canonical-classifier invariant)",
1317        analysis.rank
1318    );
1319
1320    // Build the penalty root R from ONLY the range directions (positive
1321    // curvature): R has one row per range eigenpair, scaled by sqrt(ev), so
1322    // RᵀR reconstructs S on range(S). Null directions contribute nothing
1323    // (their eigenvalue is zero); negative-curvature directions are NEVER
1324    // square-rooted into R (their sqrt is imaginary) and are NOT null — they
1325    // are simply dropped from R, exactly as the closed-form Duchon kernels at
1326    // high d require to preserve the q_pen / q_null invariant downstream.
1327    let mut root = Array2::zeros((rank_k, block_dim));
1328    let mut positive_eigenvalues = Vec::with_capacity(rank_k);
1329    for (row_idx, &i) in classes.range_idx.iter().enumerate() {
1330        let eigenval = analysis.eigenvalues[i];
1331        let eigenvec = analysis.eigenvectors.column(i);
1332        root.row_mut(row_idx).assign(&(&eigenvec * eigenval.sqrt()));
1333        positive_eigenvalues.push(eigenval);
1334    }
1335
1336    // Surface any genuine negative curvature honestly: it is neither range
1337    // (dropped from R) nor null (excluded from `nullity`), so it would
1338    // otherwise vanish without a trace. A non-PSD penalty reaching this path
1339    // is a real geometric fact (e.g. high-d Duchon kernels) the operator
1340    // should be able to see.
1341    if classes.is_indefinite() {
1342        log::debug!(
1343            "{context}: penalty block idx={idx} carries {} negative-curvature \
1344             eigendirection(s) below -tol={tolerance:e}; dropped from the canonical \
1345             root and NOT counted as null space (rank={rank_k}, nullity={})",
1346            classes.negative_dim(),
1347            classes.nullity()
1348        );
1349    }
1350
1351    // Store the PSD reconstruction RᵀR rather than the raw symmetrised input so
1352    // the cached `local` matches the rank truncation embedded in `root`
1353    // (negative-curvature directions are excluded from both, as above).
1354    let local = root.t().dot(&root);
1355    Ok(Some(CanonicalPenalty {
1356        root,
1357        col_range,
1358        total_dim: p,
1359        nullity: classes.nullity(),
1360        local,
1361        prior_mean,
1362        positive_eigenvalues,
1363        op,
1364    }))
1365}
1366
1367/// Canonicalize a batch of penalty specs, dropping zero-rank penalties.
1368/// Returns (active_penalties, active_nullspace_dims).
1369pub fn canonicalize_penalty_specs(
1370    specs: &[crate::PenaltySpec],
1371    nullspace_dims: &[usize],
1372    p: usize,
1373    context: &str,
1374) -> Result<(Vec<CanonicalPenalty>, Vec<usize>), EstimationError> {
1375    if specs.len() != nullspace_dims.len() {
1376        crate::bail_invalid_estim!(
1377            "{context}: nullspace_dims length mismatch: penalties={}, nullspace_dims={}",
1378            specs.len(),
1379            nullspace_dims.len()
1380        );
1381    }
1382
1383    let mut active = Vec::with_capacity(specs.len());
1384    let mut active_nullspace = Vec::with_capacity(specs.len());
1385    for (idx, spec) in specs.iter().enumerate() {
1386        if let Some(canonical) = canonicalize_penalty_spec(spec, p, idx, context)? {
1387            active_nullspace.push(nullspace_dims[idx]);
1388            active.push(canonical);
1389        }
1390    }
1391    Ok((active, active_nullspace))
1392}
1393
1394/// Hard cap on the dimension `p` allowed to fall back to a dense p × p
1395/// eigendecomposition of the overlapping balanced penalty.
1396///
1397/// Beyond this cap the overlapping-penalty path errors out instead of
1398/// allocating an O(p²) workspace whose eigendecomposition would dominate the
1399/// solve. ResourcePolicy threading is the long-term home for this cap (the
1400/// resource_serialize agent is widening ResourcePolicy coverage); until that
1401/// lands, both overlapping branches share this single constant so they can't
1402/// drift apart.
1403pub(crate) const OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P: usize = 4096;
1404
1405/// Creates a balanced penalty root from canonical penalties.
1406///
1407/// When all penalties have non-overlapping col_ranges, the balanced sum is
1408/// block-diagonal and eigendecomposition is done per-block at O(Σ p_k³)
1409/// instead of the global O(p³). Falls back to the global path when penalties
1410/// overlap.
1411pub fn create_balanced_penalty_root_from_canonical(
1412    penalties: &[CanonicalPenalty],
1413    p: usize,
1414) -> Result<Array2<f64>, EstimationError> {
1415    if penalties.is_empty() {
1416        return Ok(Array2::zeros((0, p)));
1417    }
1418
1419    // Group penalties by col_range.
1420    let mut block_groups: BTreeMap<(usize, usize), Vec<&CanonicalPenalty>> = BTreeMap::new();
1421    for cp in penalties {
1422        if cp.rank() == 0 {
1423            continue;
1424        }
1425        let key = (cp.col_range.start, cp.col_range.end);
1426        block_groups.entry(key).or_default().push(cp);
1427    }
1428
1429    if block_groups.is_empty() {
1430        return Ok(Array2::zeros((0, p)));
1431    }
1432
1433    // Check for overlapping ranges.
1434    let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1435    let mut overlapping = false;
1436    for i in 1..ranges.len() {
1437        if ranges[i].0 < ranges[i - 1].1 {
1438            overlapping = true;
1439            break;
1440        }
1441    }
1442
1443    if overlapping {
1444        if p > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1445            return Err(EstimationError::LayoutError(format!(
1446                "overlapping penalty root would require dense {}x{} eigendecomposition; \
1447                 large-model dense fallback is disabled. Keep penalties structured or \
1448                 extend the overlapping-penalty solver path",
1449                p, p
1450            )));
1451        }
1452        // Fallback: accumulate into p × p and eigendecompose globally.
1453        let mut s_balanced = Array2::zeros((p, p));
1454        for cp in penalties {
1455            if cp.rank() == 0 {
1456                continue;
1457            }
1458            let local = cp.local_ref();
1459            let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1460            if frob_norm > 1e-12 {
1461                let r = &cp.col_range;
1462                s_balanced
1463                    .slice_mut(s![r.start..r.end, r.start..r.end])
1464                    .scaled_add(1.0 / frob_norm, local);
1465            }
1466        }
1467        let (eigenvalues, eigenvectors) =
1468            robust_eigh(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1469        let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1470        let tolerance = if max_eig > 0.0 {
1471            max_eig * 1e-12
1472        } else {
1473            1e-12
1474        };
1475        let penalty_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1476        if penalty_rank == 0 {
1477            return Ok(Array2::zeros((0, p)));
1478        }
1479        let mut eb = Array2::zeros((p, penalty_rank));
1480        let mut col_idx = 0;
1481        for (i, &eigenval) in eigenvalues.iter().enumerate() {
1482            if eigenval > tolerance {
1483                let sqrt_ev = eigenval.sqrt();
1484                let evec = eigenvectors.column(i);
1485                eb.column_mut(col_idx).assign(&(&evec * sqrt_ev));
1486                col_idx += 1;
1487            }
1488        }
1489        return Ok(eb.t().to_owned());
1490    }
1491
1492    // Non-overlapping: eigendecompose per block at O(Σ p_k³).
1493    struct BlockRoot {
1494        col_range: Range<usize>,
1495        root: Array2<f64>, // rank_b × block_dim
1496    }
1497    // Materialize the BTreeMap order first. Rayon preserves Vec collection
1498    // order for indexed parallel iterators, so assembly below remains stable by
1499    // ascending column range while independent block eigendecompositions run in
1500    // parallel.
1501    let ordered_blocks: Vec<((usize, usize), Vec<&CanonicalPenalty>)> =
1502        block_groups.into_iter().collect();
1503    let block_roots: Vec<BlockRoot> = ordered_blocks
1504        .into_par_iter()
1505        .map(
1506            |((start, end), cps)| -> Result<Option<BlockRoot>, EstimationError> {
1507                let block_dim = end - start;
1508                let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1509
1510                for cp in cps {
1511                    let local = cp.local_ref();
1512                    let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1513                    if frob_norm > 1e-12 {
1514                        s_balanced_local.scaled_add(1.0 / frob_norm, local);
1515                    }
1516                }
1517
1518                let (eigenvalues, eigenvectors) =
1519                    robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1520                let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1521                let tolerance = if max_eig > 0.0 {
1522                    max_eig * 1e-12
1523                } else {
1524                    1e-12
1525                };
1526                let block_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1527
1528                if block_rank == 0 {
1529                    return Ok(None);
1530                }
1531
1532                let mut root = Array2::zeros((block_rank, block_dim));
1533                let mut row_idx = 0;
1534                for (i, &eigenval) in eigenvalues.iter().enumerate() {
1535                    if eigenval > tolerance {
1536                        let sqrt_ev = eigenval.sqrt();
1537                        let evec = eigenvectors.column(i);
1538                        root.row_mut(row_idx).assign(&(&evec * sqrt_ev));
1539                        row_idx += 1;
1540                    }
1541                }
1542
1543                Ok(Some(BlockRoot {
1544                    col_range: start..end,
1545                    root,
1546                }))
1547            },
1548        )
1549        .collect::<Result<Vec<_>, _>>()?
1550        .into_iter()
1551        .flatten()
1552        .collect();
1553    let total_rank: usize = block_roots.iter().map(|br| br.root.nrows()).sum();
1554
1555    if total_rank == 0 {
1556        return Ok(Array2::zeros((0, p)));
1557    }
1558
1559    // Assemble global balanced root: total_rank × p
1560    let mut eb = Array2::zeros((total_rank, p));
1561    let mut row_offset = 0;
1562    for br in &block_roots {
1563        let rank_b = br.root.nrows();
1564        eb.slice_mut(s![
1565            row_offset..(row_offset + rank_b),
1566            br.col_range.start..br.col_range.end
1567        ])
1568        .assign(&br.root);
1569        row_offset += rank_b;
1570    }
1571
1572    Ok(eb)
1573}
1574
1575/// Lambda-independent reparameterization invariants derived from penalty structure.
1576#[derive(Clone)]
1577struct SubspaceSplit {
1578    q_pen: Array2<f64>,
1579    q_null: Array2<f64>,
1580}
1581
1582impl SubspaceSplit {
1583    fn identity(p: usize) -> Self {
1584        Self {
1585            q_pen: Array2::zeros((p, 0)),
1586            q_null: Array2::eye(p),
1587        }
1588    }
1589
1590    fn from_ordered_qs(
1591        qs: &Mat<f64>,
1592        penalized_rank: usize,
1593        p: usize,
1594    ) -> Result<Self, EstimationError> {
1595        if qs.nrows() != p || qs.ncols() != p {
1596            return Err(EstimationError::LayoutError(format!(
1597                "Invalid Q basis dimensions: expected {p}x{p}, got {}x{}",
1598                qs.nrows(),
1599                qs.ncols()
1600            )));
1601        }
1602        if penalized_rank > p {
1603            return Err(EstimationError::LayoutError(format!(
1604                "Invalid penalized rank {penalized_rank} for p={p}"
1605            )));
1606        }
1607
1608        let null_count = p - penalized_rank;
1609        let mut q_pen = Array2::<f64>::zeros((p, penalized_rank));
1610        let mut q_null = Array2::<f64>::zeros((p, null_count));
1611        for i in 0..p {
1612            for j in 0..penalized_rank {
1613                q_pen[(i, j)] = qs[(i, j)];
1614            }
1615            for j in 0..null_count {
1616                q_null[(i, j)] = qs[(i, penalized_rank + j)];
1617            }
1618        }
1619
1620        Ok(Self { q_pen, q_null })
1621    }
1622
1623    fn rank(&self) -> usize {
1624        self.q_pen.ncols()
1625    }
1626
1627    fn p(&self) -> usize {
1628        self.q_pen.nrows()
1629    }
1630
1631    fn compose_qs(&self) -> Array2<f64> {
1632        let p = self.p();
1633        let rank = self.rank();
1634        let null_count = self.q_null.ncols();
1635        let mut qs = Array2::<f64>::zeros((p, p));
1636        for i in 0..p {
1637            for j in 0..rank {
1638                qs[(i, j)] = self.q_pen[(i, j)];
1639            }
1640            for j in 0..null_count {
1641                qs[(i, rank + j)] = self.q_null[(i, j)];
1642            }
1643        }
1644        qs
1645    }
1646}
1647
1648/// Lambda-independent reparameterization invariants derived from penalty structure.
1649#[derive(Clone)]
1650pub struct ReparamInvariant {
1651    split: SubspaceSplit,
1652    /// The balanced eigenvector matrix Q (p x p). Block-local roots are
1653    /// transformed on-the-fly as `R_block @ Q[start..end, :]` instead of
1654    /// storing pre-multiplied full-width roots.
1655    qs_base: Array2<f64>,
1656    has_nonzero: bool,
1657    /// Largest eigenvalue of the balanced (unit-Frobenius) penalty matrix.
1658    /// Used as the scale reference for the shrinkage floor.
1659    max_balanced_eigenvalue: f64,
1660}
1661
1662impl ReparamInvariant {
1663    /// Returns the largest eigenvalue of the balanced penalty matrix.
1664    /// This is lambda-independent and provides a natural scale for shrinkage.
1665    pub const fn max_balanced_eigenvalue(&self) -> f64 {
1666        self.max_balanced_eigenvalue
1667    }
1668}
1669
1670/// Precompute the lambda-invariant reparameterization structure from canonical penalties.
1671///
1672/// Uses block-local roots directly instead of requiring rank x p global roots.
1673/// Each `CanonicalPenalty` carries its own block-local root and column range,
1674/// so the balanced sum can be assembled without ever materializing full-size
1675/// penalty matrices.
1676pub fn precompute_reparam_invariant_from_canonical(
1677    penalties: &[CanonicalPenalty],
1678    p_total: usize,
1679) -> Result<ReparamInvariant, EstimationError> {
1680    use std::cmp::Ordering;
1681
1682    let m = penalties.len();
1683
1684    if m == 0 {
1685        return Ok(ReparamInvariant {
1686            split: SubspaceSplit::identity(p_total),
1687            qs_base: Array2::eye(p_total),
1688            has_nonzero: false,
1689            max_balanced_eigenvalue: 0.0,
1690        });
1691    }
1692
1693    // Group penalties by col_range to detect block-diagonal structure.
1694    struct PenRef {
1695        penalty_index: usize,
1696    }
1697    let mut block_groups: BTreeMap<(usize, usize), Vec<PenRef>> = BTreeMap::new();
1698    let mut has_nonzero = false;
1699    for (i, cp) in penalties.iter().enumerate() {
1700        if cp.rank() == 0 {
1701            continue;
1702        }
1703        let local = cp.local_ref();
1704        let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1705        if frob_norm > 1e-12 {
1706            has_nonzero = true;
1707        }
1708        let key = (cp.col_range.start, cp.col_range.end);
1709        block_groups
1710            .entry(key)
1711            .or_default()
1712            .push(PenRef { penalty_index: i });
1713    }
1714
1715    if !has_nonzero {
1716        return Ok(ReparamInvariant {
1717            split: SubspaceSplit::identity(p_total),
1718            qs_base: Array2::eye(p_total),
1719            has_nonzero: false,
1720            max_balanced_eigenvalue: 0.0,
1721        });
1722    }
1723
1724    // Check for overlapping ranges.
1725    let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1726    let mut overlapping = false;
1727    for i in 1..ranges.len() {
1728        if ranges[i].0 < ranges[i - 1].1 {
1729            overlapping = true;
1730            break;
1731        }
1732    }
1733
1734    if overlapping {
1735        // Mirror the dense-fallback guard from
1736        // `create_balanced_penalty_root_from_canonical`. Without this, large-scale-
1737        // scale models with overlapping penalties allocated a full
1738        // p_total × p_total workspace and ran an O(p³) eigendecomposition
1739        // before any solver code saw the problem size.
1740        if p_total > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1741            return Err(EstimationError::LayoutError(format!(
1742                "overlapping penalty reparameterization would require dense {}x{} eigendecomposition; \
1743                 large-model dense fallback is disabled. Keep penalties structured or \
1744                 extend the overlapping-penalty solver path",
1745                p_total, p_total
1746            )));
1747        }
1748        // Fallback: global p×p eigendecomposition.
1749        let mut s_balanced = Mat::<f64>::zeros(p_total, p_total);
1750        for cp in penalties {
1751            if cp.rank() == 0 {
1752                continue;
1753            }
1754            let local = cp.local_ref();
1755            let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1756            if frob_norm > 1e-12 {
1757                let scale = 1.0 / frob_norm;
1758                let r = &cp.col_range;
1759                for i in 0..local.nrows() {
1760                    for j in 0..local.ncols() {
1761                        s_balanced[(r.start + i, r.start + j)] += scale * local[[i, j]];
1762                    }
1763                }
1764            }
1765        }
1766
1767        let (bal_eigenvalues, bal_eigenvectors) =
1768            robust_eigh_faer(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1769
1770        let mut order: Vec<usize> = (0..p_total).collect();
1771        order.sort_by(|&i, &j| {
1772            bal_eigenvalues[j]
1773                .partial_cmp(&bal_eigenvalues[i])
1774                .unwrap_or(Ordering::Equal)
1775                .then(i.cmp(&j))
1776        });
1777
1778        let mut qs = Mat::<f64>::zeros(p_total, p_total);
1779        for (col_idx, &idx) in order.iter().enumerate() {
1780            for row in 0..p_total {
1781                qs[(row, col_idx)] = bal_eigenvectors[(row, idx)];
1782            }
1783        }
1784
1785        let max_bal = order
1786            .iter()
1787            .map(|&idx| bal_eigenvalues[idx].abs())
1788            .fold(0.0_f64, f64::max);
1789        let rank_tol = if max_bal > 0.0 {
1790            max_bal * 1e-12
1791        } else {
1792            1e-12
1793        };
1794        let penalized_rank = order
1795            .iter()
1796            .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1797            .count();
1798        let split = SubspaceSplit::from_ordered_qs(&qs, penalized_rank, p_total)?;
1799
1800        return Ok(ReparamInvariant {
1801            split,
1802            qs_base: mat_to_array(&qs),
1803            has_nonzero,
1804            max_balanced_eigenvalue: max_bal,
1805        });
1806    }
1807
1808    // -----------------------------------------------------------------------
1809    // Non-overlapping: block-diagonal eigendecomposition at O(Σ p_k³).
1810    // -----------------------------------------------------------------------
1811    // The balanced sum is block-diagonal ⟹ its eigenvectors are block-local.
1812    // Q_pen and Q_null are assembled by embedding block-local eigenvectors.
1813
1814    // Track which columns are covered by any penalty.
1815    let mut covered = vec![false; p_total];
1816    for cp in penalties {
1817        for j in cp.col_range.clone() {
1818            covered[j] = true;
1819        }
1820    }
1821    let uncovered_cols: Vec<usize> = (0..p_total).filter(|j| !covered[*j]).collect();
1822
1823    struct BlockResult {
1824        col_range: Range<usize>,
1825        q_pen_local: Array2<f64>,  // block_dim × pen_rank
1826        q_null_local: Array2<f64>, // block_dim × null_rank
1827        /// Largest balanced eigenvalue contributed by this block.
1828        max_balanced_eigenvalue: f64,
1829        /// Column offset of this block's penalized directions within global Q_pen.
1830        pen_col_offset: usize,
1831        /// Column offset of this block's null directions within global Q_null.
1832        null_col_offset: usize,
1833    }
1834
1835    // BTreeMap iteration defines the deterministic block order; collecting the
1836    // indexed parallel iterator preserves that order while eigendecomposing each
1837    // independent canonical penalty block concurrently.
1838    let block_specs: Vec<_> = block_groups.iter().collect();
1839    let mut block_results: Vec<BlockResult> = block_specs
1840        .into_par_iter()
1841        .map(
1842            |(&(start, end), refs)| -> Result<BlockResult, EstimationError> {
1843                let block_dim = end - start;
1844
1845                // Build local balanced sum.
1846                let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1847                let mut block_has_nonzero = false;
1848                for pref in refs {
1849                    let cp = &penalties[pref.penalty_index];
1850                    let local = cp.local_ref();
1851                    let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1852                    if frob_norm > 1e-12 {
1853                        s_balanced_local.scaled_add(1.0 / frob_norm, local);
1854                        block_has_nonzero = true;
1855                    }
1856                }
1857
1858                if !block_has_nonzero {
1859                    return Ok(BlockResult {
1860                        col_range: start..end,
1861                        q_pen_local: Array2::zeros((block_dim, 0)),
1862                        q_null_local: Array2::eye(block_dim),
1863                        max_balanced_eigenvalue: 0.0,
1864                        pen_col_offset: 0,  // set later
1865                        null_col_offset: 0, // set later
1866                    });
1867                }
1868
1869                // Eigendecompose the local balanced penalty.
1870                let (bal_eigenvalues, bal_eigenvectors) =
1871                    robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1872
1873                let mut order: Vec<usize> = (0..block_dim).collect();
1874                order.sort_by(|&i, &j| {
1875                    bal_eigenvalues[j]
1876                        .partial_cmp(&bal_eigenvalues[i])
1877                        .unwrap_or(Ordering::Equal)
1878                        .then(i.cmp(&j))
1879                });
1880
1881                let max_bal = order
1882                    .iter()
1883                    .map(|&idx| bal_eigenvalues[idx].abs())
1884                    .fold(0.0_f64, f64::max);
1885                let rank_tol = if max_bal > 0.0 {
1886                    max_bal * 1e-12
1887                } else {
1888                    1e-12
1889                };
1890                let penalized_rank = order
1891                    .iter()
1892                    .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1893                    .count();
1894                let null_count = block_dim - penalized_rank;
1895
1896                let mut q_pen_local = Array2::zeros((block_dim, penalized_rank));
1897                let mut q_null_local = Array2::zeros((block_dim, null_count));
1898                for (col_idx, &idx) in order.iter().enumerate() {
1899                    if col_idx < penalized_rank {
1900                        for row in 0..block_dim {
1901                            q_pen_local[[row, col_idx]] = bal_eigenvectors[[row, idx]];
1902                        }
1903                    } else {
1904                        let null_col = col_idx - penalized_rank;
1905                        for row in 0..block_dim {
1906                            q_null_local[[row, null_col]] = bal_eigenvectors[[row, idx]];
1907                        }
1908                    }
1909                }
1910
1911                Ok(BlockResult {
1912                    col_range: start..end,
1913                    q_pen_local,
1914                    q_null_local,
1915                    max_balanced_eigenvalue: max_bal,
1916                    pen_col_offset: 0,  // set later
1917                    null_col_offset: 0, // set later
1918                })
1919            },
1920        )
1921        .collect::<Result<_, _>>()?;
1922    let global_max_bal = block_results
1923        .iter()
1924        .map(|br| br.max_balanced_eigenvalue)
1925        .fold(0.0_f64, f64::max);
1926
1927    // Compute column offsets for each block in the global Q_pen / Q_null layout.
1928    let total_pen_rank: usize = block_results.iter().map(|br| br.q_pen_local.ncols()).sum();
1929    let total_null: usize = block_results
1930        .iter()
1931        .map(|br| br.q_null_local.ncols())
1932        .sum::<usize>()
1933        + uncovered_cols.len();
1934    {
1935        let mut pen_off = 0usize;
1936        let mut null_off = 0usize;
1937        for br in &mut block_results {
1938            br.pen_col_offset = pen_off;
1939            br.null_col_offset = null_off;
1940            pen_off += br.q_pen_local.ncols();
1941            null_off += br.q_null_local.ncols();
1942        }
1943    }
1944
1945    let mut q_pen = Array2::zeros((p_total, total_pen_rank));
1946    let mut q_null = Array2::zeros((p_total, total_null));
1947
1948    for br in &block_results {
1949        let start = br.col_range.start;
1950        let bd = br.q_pen_local.nrows();
1951        let pen_r = br.q_pen_local.ncols();
1952        let null_r = br.q_null_local.ncols();
1953        if pen_r > 0 {
1954            q_pen
1955                .slice_mut(s![
1956                    start..(start + bd),
1957                    br.pen_col_offset..(br.pen_col_offset + pen_r)
1958                ])
1959                .assign(&br.q_pen_local);
1960        }
1961        if null_r > 0 {
1962            q_null
1963                .slice_mut(s![
1964                    start..(start + bd),
1965                    br.null_col_offset..(br.null_col_offset + null_r)
1966                ])
1967                .assign(&br.q_null_local);
1968        }
1969    }
1970    let mut null_col = block_results
1971        .iter()
1972        .map(|br| br.q_null_local.ncols())
1973        .sum::<usize>();
1974    for &j in &uncovered_cols {
1975        q_null[[j, null_col]] = 1.0;
1976        null_col += 1;
1977    }
1978
1979    let split = SubspaceSplit { q_pen, q_null };
1980
1981    // Store the global Q_s = [Q_pen | Q_null] from the split.
1982    // Block-local roots are transformed on-the-fly as R_block @ Q[start..end, :]
1983    // inside the reparam engine, avoiding O(k * rank * p) storage.
1984    let qs_global = split.compose_qs();
1985
1986    Ok(ReparamInvariant {
1987        split,
1988        qs_base: qs_global,
1989        has_nonzero,
1990        max_balanced_eigenvalue: global_max_bal,
1991    })
1992}
1993
1994fn structurally_penalized_columns(penalties: &[CanonicalPenalty], p: usize) -> Vec<bool> {
1995    let mut active = vec![false; p];
1996    for cp in penalties {
1997        let local = cp.local_ref();
1998        let scale = local.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
1999        if scale <= 0.0 {
2000            continue;
2001        }
2002        let tol = scale * 1e-12;
2003        for local_col in 0..cp.block_dim() {
2004            let mut column_active = false;
2005            for row in 0..cp.block_dim() {
2006                if local[[row, local_col]].abs() > tol || local[[local_col, row]].abs() > tol {
2007                    column_active = true;
2008                    break;
2009                }
2010            }
2011            if column_active {
2012                active[cp.col_range.start + local_col] = true;
2013            }
2014        }
2015    }
2016    active
2017}
2018
2019/// Apply stable reparameterization using precomputed lambda-invariant structures.
2020///
2021/// `penalty_shrinkage_floor`: optional relative shrinkage floor for eigenvalues
2022/// of the penalized block. If `Some(epsilon)`, a rho-independent ridge of
2023/// magnitude `epsilon * max_balanced_eigenvalue` is added to each eigenvalue
2024/// of the combined penalty on the penalized block. This prevents barely-penalized
2025/// directions from causing pathological non-Gaussianity in the posterior (e.g.,
2026/// extreme skewness under logit link with high-dimensional spatial smooths).
2027/// A typical value is `1e-6`. Set to `None` or `Some(0.0)` to disable.
2028pub fn stable_reparameterizationwith_invariant(
2029    penalties: &[CanonicalPenalty],
2030    lambdas: &[f64],
2031    p: usize,
2032    invariant: &ReparamInvariant,
2033    penalty_shrinkage_floor: Option<f64>,
2034) -> Result<ReparamResult, EstimationError> {
2035    let m = penalties.len();
2036
2037    if lambdas.len() != m {
2038        return Err(EstimationError::ParameterConstraintViolation(format!(
2039            "Lambda count mismatch: expected {} lambdas for {} penalties, got {}",
2040            m,
2041            m,
2042            lambdas.len()
2043        )));
2044    }
2045
2046    // No separate length check needed — penalties are matched against lambdas above,
2047    // and the invariant's qs_base is p x p (dimension-checked by the split).
2048
2049    // #1074: the gam#1379 finite-ceiling on λ_k = exp(ρ_k) (clamp to 1e300 to
2050    // avoid `∞·0 = NaN` when the outer optimizer drives a redundant penalty
2051    // direction's log-λ past ~709) was DELETED. It masked the real defect: the
2052    // optimizer drives a redundant/unidentified penalty direction off to ∞
2053    // instead of that direction being detected and dropped from the model.
2054    // The root fix (detect+drop the redundant penalty direction at construction)
2055    // is tracked separately; λ now passes through raw.
2056
2057    if m == 0 {
2058        return Ok(ReparamResult {
2059            s_transformed: Array2::zeros((p, p)),
2060            log_det: 0.0,
2061            det1: Array1::zeros(0),
2062            qs: Array2::eye(p),
2063            canonical_transformed: vec![],
2064            e_transformed: Array2::zeros((0, p)),
2065            // All modes truncated when no penalties; already in transformed frame.
2066            u_truncated: Array2::eye(p),
2067            penalty_shrinkage_ridge: 0.0,
2068        });
2069    }
2070
2071    if !invariant.has_nonzero {
2072        let qs = invariant.split.compose_qs();
2073        let u_truncated = qs.t().dot(&invariant.split.q_null);
2074        // All penalties are zero — canonical_transformed = originals (no rotation needed).
2075        let canonical_transformed: Vec<CanonicalPenalty> = penalties.to_vec();
2076        return Ok(ReparamResult {
2077            s_transformed: Array2::zeros((p, p)),
2078            log_det: 0.0,
2079            det1: Array1::zeros(m),
2080            qs,
2081            canonical_transformed,
2082            e_transformed: Array2::zeros((0, p)),
2083            u_truncated,
2084            penalty_shrinkage_ridge: 0.0,
2085        });
2086    }
2087
2088    let q_pen = array_to_faer(&invariant.split.q_pen);
2089    let q_null = array_to_faer(&invariant.split.q_null);
2090    let qs_base = array_to_faer(&invariant.qs_base);
2091    // Each penalty root transform is independent: R_k_block @ Q[start..end, :].
2092    // Run those per-penalty products (and their S_k = R_k'R_k caches) in
2093    // parallel, then collect in slice order so all downstream accumulation stays
2094    // deterministic and bit-for-bit stable with respect to penalty ordering.
2095    let penalty_transforms: Vec<(Mat<f64>, Mat<f64>)> = penalties
2096        .par_iter()
2097        .map(|cp| {
2098            let r = &cp.col_range;
2099            let root_faer = array_to_faer(&cp.root);
2100            let q_block = qs_base.submatrix(r.start, 0, cp.block_dim(), p);
2101            let mut product = Mat::<f64>::zeros(cp.rank(), p);
2102            matmul(
2103                product.as_mut(),
2104                Accum::Replace,
2105                root_faer.as_ref(),
2106                q_block,
2107                1.0,
2108                Par::Seq,
2109            );
2110            let s_k = penalty_from_root_faer(&product);
2111            (product, s_k)
2112        })
2113        .collect();
2114    let (rs_transformed, s_k_penalized_cache): (Vec<Mat<f64>>, Vec<Mat<f64>>) =
2115        penalty_transforms.into_iter().unzip();
2116
2117    let penalized_rank = invariant.split.rank();
2118
2119    let mut range_eigenvalues_sorted: Vec<f64> = Vec::new();
2120    let mut range_rotation = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2121    if penalized_rank > 0 {
2122        let mut range_block = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2123        // Deterministic assembly: the independent S_k transforms were computed in
2124        // parallel above, but the lambda-weighted sum is accumulated serially in
2125        // canonical penalty order to avoid order-dependent floating-point drift.
2126        for (lambda, s_k) in lambdas.iter().zip(s_k_penalized_cache.iter()) {
2127            for i in 0..penalized_rank {
2128                for j in 0..penalized_rank {
2129                    range_block[(i, j)] += *lambda * s_k[(i, j)];
2130                }
2131            }
2132        }
2133        let (range_eigenvalues, range_eigenvectors) =
2134            robust_eigh_faer(&range_block, Side::Lower, "range penalty block")?;
2135
2136        let mut range_order: Vec<usize> = (0..penalized_rank).collect();
2137        range_order.sort_by(|&i, &j| {
2138            range_eigenvalues[j]
2139                .partial_cmp(&range_eigenvalues[i])
2140                .unwrap_or(std::cmp::Ordering::Equal)
2141                .then(i.cmp(&j))
2142        });
2143        range_eigenvalues_sorted = range_order
2144            .iter()
2145            .map(|&idx| range_eigenvalues[idx])
2146            .collect();
2147
2148        // Build range_rotation = U (sorted eigenvectors) for E and S⁺
2149        // construction only.  DO NOT apply to q_pen or rs_transformed —
2150        // keeping Q_s lambda-independent prevents BFGS coordinate-system
2151        // drift when multiple penalties interact (the eigenvectors of
2152        // Σ λ_k S_k rotate with λ, breaking the quasi-Newton Hessian
2153        // approximation at eigenvalue crossings).
2154        for (col_idx, &idx) in range_order.iter().enumerate() {
2155            for row in 0..penalized_rank {
2156                range_rotation[(row, col_idx)] = range_eigenvectors[(row, idx)];
2157            }
2158        }
2159        // q_pen and rs_transformed stay in the lambda-independent
2160        // invariant basis.  E and S⁺ below are expressed in this same
2161        // basis using U from the eigendecomposition.
2162    }
2163
2164    // Subspace-invariant penalty spectral calculus:
2165    // - Penalized and null spaces are fixed by the lambda-invariant basis `qs_base`.
2166    // - Runtime lambda dependence only appears in the penalized block eigenvalues.
2167    // This avoids basis mixing inside the degenerate zero-eigenspace.
2168    let structural_rank = penalized_rank;
2169    let mut range_eigs_sorted: Vec<f64> = range_eigenvalues_sorted;
2170    let structurally_penalized_cols = structurally_penalized_columns(penalties, p);
2171
2172    // Shrinkage floor: add a rho-independent ridge to the penalized block eigenvalues.
2173    // This prevents barely-penalized directions from causing pathological non-Gaussianity
2174    // in the posterior (extreme skewness under non-canonical links like logit with
2175    // high-dimensional spatial smooths). The ridge magnitude is proportional to the
2176    // balanced penalty's max eigenvalue (lambda-independent scale), so LAML gradients
2177    // w.r.t. rho remain correct: d(epsilon * I)/d(rho_k) = 0.
2178    //
2179    // The shrinkage ridge is a real prior contribution: it changes the quadratic
2180    // form, the penalty pseudo-logdet, and downstream Hessians. It is currently
2181    // surfaced through `ReparamResult::penalty_shrinkage_ridge` and consumed by
2182    // PIRLS/REML via that per-call channel. The longer-term home is a
2183    // `RidgePassport` with `RidgePolicy::explicit_stabilization_full` scoped to
2184    // the penalized block so the same delta is reflected in serialization, the
2185    // Laplace Hessian, and the prior logdet without callers having to thread an
2186    // additional scalar — once the ridge-ledger plumbing covers per-block
2187    // ScaledIdentity passports.
2188    let shrinkage_ridge = penalty_shrinkage_floor
2189        .filter(|&eps| eps > 0.0)
2190        .map(|eps| eps * invariant.max_balanced_eigenvalue)
2191        .unwrap_or(0.0);
2192    if shrinkage_ridge > 0.0 {
2193        let min_eig_before = range_eigs_sorted
2194            .iter()
2195            .copied()
2196            .fold(f64::INFINITY, f64::min);
2197        let mut shrinkage_floor_applied = 0usize;
2198        for eig_idx in 0..range_eigs_sorted.len() {
2199            let mut penalized_energy = 0.0;
2200            for original_col in 0..p {
2201                if structurally_penalized_cols[original_col] {
2202                    let mut coordinate = 0.0;
2203                    for pen_col in 0..penalized_rank {
2204                        coordinate +=
2205                            q_pen[(original_col, pen_col)] * range_rotation[(pen_col, eig_idx)];
2206                    }
2207                    penalized_energy += coordinate * coordinate;
2208                }
2209            }
2210            if penalized_energy > 1e-8 {
2211                range_eigs_sorted[eig_idx] += shrinkage_ridge;
2212                shrinkage_floor_applied += 1;
2213            }
2214        }
2215        // Log when the floor materially changes the smallest eigenvalue (>1% relative shift).
2216        if min_eig_before > 0.0 && shrinkage_ridge / min_eig_before > 0.01 {
2217            log::debug!(
2218                "Penalty shrinkage floor active: ridge={:.3e} (min_eig_before={:.3e}, ratio={:.1e}, max_bal_eig={:.3e}, applied_dirs={})",
2219                shrinkage_ridge,
2220                min_eig_before,
2221                shrinkage_ridge / min_eig_before,
2222                invariant.max_balanced_eigenvalue,
2223                shrinkage_floor_applied,
2224            );
2225        }
2226    }
2227
2228    let eigenvalue_floor = invariant.max_balanced_eigenvalue.max(1.0) * 1e-12;
2229    let qs = compose_qs_from_split(&q_pen, &q_null, p);
2230
2231    // Guard against any accidental penalized/null mixing. The transformed penalty
2232    // roots must have negligible support on null columns by construction.
2233    let leakage = assess_subspace_leakage(&qs, &rs_transformed, structural_rank, p);
2234    if !subspace_split_is_consistent(&leakage, p) {
2235        return Err(EstimationError::LayoutError(format!(
2236            "Reparameterization subspace split is inconsistent: max null leakage {:.3e} (rel {:.3e}, worst penalty {}), max |Qp'Qn| {:.3e}",
2237            leakage.max_abs_sq.sqrt(),
2238            leakage.max_rel_sq.sqrt(),
2239            leakage.worst_penalty,
2240            leakage.max_cross_gram_abs,
2241        )));
2242    }
2243
2244    // Truncated basis in transformed coordinates:
2245    //   U_⊥^(t) = Qs^T U_⊥^(orig) = Qs^T Q_n.
2246    let mut u_truncated_mat = Mat::<f64>::zeros(p, q_null.ncols());
2247    matmul(
2248        u_truncated_mat.as_mut(),
2249        Accum::Replace,
2250        qs.transpose(),
2251        q_null.as_ref(),
2252        1.0,
2253        Par::Seq,
2254    );
2255
2256    // E is represented in TRANSFORMED coordinates (beta_t).  Because the
2257    // penalized subspace is NOT rotated by the lambda-dependent eigenvectors
2258    // (to keep Q_s stable across BFGS iterations), E is no longer diagonal.
2259    // Instead E = diag(√d) · U' embedded in structural_rank × p, so that
2260    // E'E = U diag(d) U' = Σ λ_k S_k in the invariant penalized basis.
2261    let mut e_transformed_mat = Mat::<f64>::zeros(structural_rank, p);
2262    for row_idx in 0..structural_rank {
2263        let safe_eigenval = range_eigs_sorted[row_idx].max(eigenvalue_floor);
2264        let sqrt_eigenval = safe_eigenval.sqrt();
2265        // E[row, j] = sqrt(d_row) * U'[row, j] = sqrt(d_row) * U[j, row]
2266        for j in 0..penalized_rank {
2267            e_transformed_mat[(row_idx, j)] = sqrt_eigenval * range_rotation[(j, row_idx)];
2268        }
2269    }
2270
2271    // Pseudo-logdet on the structural penalized block.  The null block is split
2272    // out above, so there is no nullspace normalization here.  Spectrally-noisy
2273    // directions (eigenvalues snapped to 0 by `classify_eigenvalues_strict`
2274    // because they fall below the `c * eps_machine * p * scale` tolerance in
2275    // the lambda-weighted sum) are floored to `eigenvalue_floor` to keep the
2276    // log-det finite and consistent with the floored values used to construct
2277    // `e_transformed_mat` above.  This avoids spurious P-IRLS failures when the
2278    // lambda dynamic range is wide (e.g. during BFGS line search probing extreme
2279    // rho candidates).  Materially negative or non-finite spectra are already
2280    // rejected by the strict classifier upstream; this loop re-checks the
2281    // *post-shrinkage* range eigenvalues against the same floor.
2282    //
2283    // The same floored spectrum is used in the trace formula tr(S⁺ S_k) below,
2284    // matching the rank structure embedded in `e_transformed_mat` and avoiding
2285    // a 1/0 in the trace contraction when an eigenvalue was floored to 0.
2286    let mut floored_eigs: Vec<f64> = Vec::with_capacity(range_eigs_sorted.len());
2287    let mut log_det_sum = KahanSum::default();
2288    for (idx, &ev) in range_eigs_sorted.iter().enumerate() {
2289        if !ev.is_finite() || ev < -eigenvalue_floor {
2290            return Err(EstimationError::LayoutError(format!(
2291                "Penalty pseudo-logdet has a non-finite or large-negative structural eigenvalue at index {idx}: {ev:.3e}"
2292            )));
2293        }
2294        let safe_ev = ev.max(eigenvalue_floor);
2295        floored_eigs.push(safe_ev);
2296        if idx < penalized_rank {
2297            log_det_sum.add(safe_ev.ln());
2298        }
2299    }
2300    let log_det = log_det_sum.sum();
2301    let delta = 0.0;
2302
2303    // The det1 contractions are independent once the eigensystem is fixed.  Use
2304    // indexed parallel collection so the output vector preserves lambda order.
2305    let det1vec: Vec<f64> = (0..lambdas.len())
2306        .into_par_iter()
2307        .map(|k| {
2308            let s_k = &s_k_penalized_cache[k];
2309            // Compute tr((S+δI)⁻¹ S_k) in the range eigenbasis without ever
2310            // materializing (S+δI)⁻¹. Using faer's matmul keeps this contraction
2311            // aligned with the orthogonal-similarity debug reference path.
2312            let trace = trace_penalty_in_orthogonal_basis(
2313                s_k,
2314                penalized_rank,
2315                &range_rotation,
2316                &floored_eigs,
2317                delta,
2318            );
2319            lambdas[k] * trace
2320        })
2321        .collect();
2322
2323    {
2324        // Guardrail: cross-check the primary Rayleigh-quotient contraction
2325        // against a full orthogonal similarity transform, while staying in
2326        // the same numerically stable eigenbasis coordinates.
2327        let mut maxdet1_mismatch = 0.0_f64;
2328        let mut det1_scale = 0.0_f64;
2329        for (k, lambda) in lambdas.iter().enumerate() {
2330            let s_k_penalized = &s_k_penalized_cache[k];
2331            let s_k_eigenbasis = orthogonal_similarity_transform_faer(
2332                s_k_penalized,
2333                penalized_rank,
2334                &range_rotation,
2335            );
2336            let mut trace = KahanSum::default();
2337            for l in 0..penalized_rank {
2338                trace.add(s_k_eigenbasis[(l, l)] / (floored_eigs[l] + delta));
2339            }
2340            let reference = *lambda * trace.sum();
2341            maxdet1_mismatch = maxdet1_mismatch.max((reference - det1vec[k]).abs());
2342            det1_scale = det1_scale.max(reference.abs()).max(det1vec[k].abs());
2343        }
2344        let det1_tolerance = 1e-7 * det1_scale.max(1.0);
2345        assert!(
2346            maxdet1_mismatch <= det1_tolerance,
2347            "det1 mismatch between optimized and reference formulas: max_abs={maxdet1_mismatch:.3e}, tol={det1_tolerance:.3e}"
2348        );
2349    }
2350
2351    // Rebuild s_transformed from e_transformed to ensure rank consistency.
2352    //
2353    // The sum of λ*S_k may contain numerical noise modes (eigenvalues ~1e-15) that
2354    // become significant when λ is large (e.g., 10^12). These modes would appear in H
2355    // but are truncated from log|S|_+, creating a "phantom penalty" in the objective.
2356    //
2357    // By reconstructing s_transformed = E^T * E, we force the penalty matrix used
2358    // in H to have the EXACT same rank structure as the one used for log|S|_+.
2359    // Any mode truncated from the prior is now strictly zero in the Hessian
2360    // calculation, ensuring mathematical consistency of the gradients.
2361    let mut s_truncated = Mat::<f64>::zeros(p, p);
2362    matmul(
2363        s_truncated.as_mut(),
2364        Accum::Replace,
2365        e_transformed_mat.transpose(),
2366        e_transformed_mat.as_ref(),
2367        1.0,
2368        Par::Seq,
2369    );
2370
2371    {
2372        // Structural check: transformed S must not leak into declared null coordinates.
2373        let mut max_null_diag = 0.0_f64;
2374        let mut max_null_offdiag = 0.0_f64;
2375        for i in structural_rank..p {
2376            max_null_diag = max_null_diag.max(s_truncated[(i, i)].abs());
2377            for j in 0..p {
2378                if i != j {
2379                    max_null_offdiag = max_null_offdiag.max(s_truncated[(i, j)].abs());
2380                }
2381            }
2382        }
2383        assert!(
2384            max_null_diag <= 1e-10 && max_null_offdiag <= 1e-10,
2385            "null-space leakage in transformed penalty: max_null_diag={max_null_diag:.3e}, max_null_offdiag={max_null_offdiag:.3e}"
2386        );
2387    }
2388
2389    let qs_array = mat_to_array(&qs);
2390    let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2391        .par_iter()
2392        .zip(penalties.par_iter())
2393        .map(|(r, cp)| {
2394            let mean_transformed = qs_array.t().dot(&cp.full_width_prior_mean());
2395            CanonicalPenalty::from_dense_root_with_mean(mat_to_array(r), p, mean_transformed)
2396        })
2397        .collect();
2398    Ok(ReparamResult {
2399        s_transformed: mat_to_array(&s_truncated),
2400        log_det,
2401        det1: Array1::from(det1vec),
2402        qs: qs_array,
2403        canonical_transformed,
2404        e_transformed: mat_to_array(&e_transformed_mat),
2405        u_truncated: mat_to_array(&u_truncated_mat),
2406        penalty_shrinkage_ridge: shrinkage_ridge,
2407    })
2408}
2409
2410/// Minimal engine layout descriptor that avoids domain-specific layout coupling.
2411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2412pub struct EngineDims {
2413    pub p: usize,
2414    pub k: usize,
2415}
2416
2417impl EngineDims {
2418    pub fn new(p: usize, k: usize) -> Self {
2419        Self { p, k }
2420    }
2421}
2422
2423/// Engine-facing stable reparameterization API using only `(p, k)`.
2424///
2425/// When `cached_invariant` is `Some`, reuses the precomputed eigendecomposition
2426/// (the hot path inside the REML loop). When `None`, computes the invariant on
2427/// the fly (the post-REML refit path). Merging both cases into a single entry
2428/// point ensures `penalty_shrinkage_floor` is always applied regardless of
2429/// whether a cached invariant is available.
2430/// Stable reparameterization from block-local canonical penalties.
2431pub fn stable_reparameterization_engine_canonical(
2432    penalties: &[CanonicalPenalty],
2433    lambdas: &[f64],
2434    dims: EngineDims,
2435    cached_invariant: Option<&ReparamInvariant>,
2436    penalty_shrinkage_floor: Option<f64>,
2437) -> Result<ReparamResult, EstimationError> {
2438    let owned;
2439    let invariant = match cached_invariant {
2440        Some(inv) => inv,
2441        None => {
2442            owned = precompute_reparam_invariant_from_canonical(penalties, dims.p)?;
2443            &owned
2444        }
2445    };
2446    stable_reparameterizationwith_invariant(
2447        penalties,
2448        lambdas,
2449        dims.p,
2450        invariant,
2451        penalty_shrinkage_floor,
2452    )
2453}
2454
2455// ---------------------------------------------------------------------------
2456// Kronecker-factored reparameterization for tensor-product smooths
2457// ---------------------------------------------------------------------------
2458
2459/// Result of Kronecker-factored reparameterization.
2460///
2461/// Exploits the fact that for Kronecker-structured penalties, the joint
2462/// eigenvector matrix is `U_1 ⊗ ... ⊗ U_d` and the reparameterized design
2463/// is a rowwise Kronecker of `(B_k U_k)` — all remaining factored.
2464#[derive(Clone)]
2465pub struct KroneckerReparamResult {
2466    /// Reparameterized marginal designs: `B_k · U_k` for each marginal k.
2467    ///
2468    /// `Arc`-shared with the λ-invariant cache so the per-outer-iterate
2469    /// memoized engine bumps a refcount instead of deep-copying the
2470    /// (n × q) reparameterized marginals every call.
2471    pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
2472    /// Marginal eigenvalues from each marginal penalty eigendecomposition.
2473    pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
2474    /// Marginal eigenvector matrices U_k.
2475    pub marginal_qs: Arc<Vec<Array2<f64>>>,
2476    /// log|S|₊ computed from marginal eigenvalue grid.
2477    pub log_det: f64,
2478    /// First derivatives of log|S|₊ w.r.t. ρ_k = log(λ_k).
2479    pub det1: Array1<f64>,
2480    /// Second derivatives of log|S|₊ w.r.t. ρ.
2481    pub det2: Array2<f64>,
2482    /// Shrinkage ridge added to eigenvalues (if any).
2483    pub penalty_shrinkage_ridge: f64,
2484    /// Whether a double penalty (global ridge) is present.
2485    pub has_double_penalty: bool,
2486    /// Marginal basis dimensions.
2487    pub marginal_dims: Vec<usize>,
2488}
2489
2490impl KroneckerReparamResult {
2491    /// Materialize the joint Qs matrix (U_1 ⊗ ... ⊗ U_d) as dense p×p.
2492    /// Only for fallback paths — avoid in hot loops.
2493    pub fn materialize_qs(&self) -> Array2<f64> {
2494        let mut qs = Array2::<f64>::eye(1);
2495        for u_k in self.marginal_qs.iter() {
2496            qs = kronecker_product(&qs, u_k);
2497        }
2498        qs
2499    }
2500
2501    /// Materialize s_transformed (the penalty in the reparameterized basis).
2502    /// In the eigenbasis, this is diagonal with entries Σ_k λ_k μ_{k,j_k}.
2503    pub fn materialize_s_transformed(&self, lambdas: &[f64]) -> Array2<f64> {
2504        let d = self.marginal_dims.len();
2505        let p: usize = self.marginal_dims.iter().copied().product();
2506        let mut s = Array2::<f64>::zeros((p, p));
2507
2508        // Delegate the per-cell tensor-penalty accumulation to the shared
2509        // `kronecker_cell_sigma` (#1172/#1185 single source of truth). Fold the
2510        // `lambdas.len() > d` guard into `has_double` to preserve exact gating.
2511        let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2512            self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2513        let has_double = self.has_double_penalty && lambdas.len() > d;
2514        let mut multi_idx = vec![0usize; d];
2515        let mut flat = 0usize;
2516        loop {
2517            let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2518                &eigenvalue_views,
2519                &multi_idx,
2520                lambdas,
2521                d,
2522                has_double,
2523                self.penalty_shrinkage_ridge,
2524            );
2525            s[[flat, flat]] = sigma;
2526            flat += 1;
2527
2528            if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2529                break;
2530            }
2531        }
2532        s
2533    }
2534
2535    /// Explicitly materialize the dense artifact bundle expected by legacy
2536    /// downstream consumers. This is not part of the native Kronecker solve path.
2537    pub fn materialize_dense_artifact_result(
2538        &self,
2539        rs_list: &[Array2<f64>],
2540        lambdas: &[f64],
2541        p: usize,
2542    ) -> Result<ReparamResult, EstimationError> {
2543        const KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P: usize = 4096;
2544        if p > KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P {
2545            return Err(EstimationError::LayoutError(format!(
2546                "Kronecker reparameterization would materialize dense {}x{} compatibility tensors; \
2547                 large-model dense fallback is disabled. Wire the downstream solver to consume \
2548                 the factored Kronecker result directly",
2549                p, p
2550            )));
2551        }
2552        let qs = self.materialize_qs();
2553        let s_transformed = self.materialize_s_transformed(lambdas);
2554
2555        // Transform penalty roots: R_k_transformed = R_k · Qs
2556        let rs_transformed: Vec<Array2<f64>> = if rs_list.len() >= 2 {
2557            use rayon::prelude::*;
2558            rs_list
2559                .par_iter()
2560                .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2561                .collect()
2562        } else {
2563            rs_list
2564                .iter()
2565                .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2566                .collect()
2567        };
2568        // rs_transposed removed — canonical_transformed is the single source of truth.
2569
2570        // Build e_transformed: combined penalty square root in transformed coords.
2571        // For Kronecker structure, the penalty is diagonal in the eigenbasis.
2572        // e_transformed rows are the nonzero rows of sqrt(Σ_k λ_k S_k)^{1/2}.
2573        let d = self.marginal_dims.len();
2574        // Delegate the per-cell tensor-penalty accumulation to the shared
2575        // `kronecker_cell_sigma` (the #1172/#1185 single source of truth). The
2576        // double-penalty term is only valid when `lambdas` actually carries the
2577        // λ_d entry, so fold the original `lambdas.len() > d` guard into the
2578        // `has_double_penalty` flag passed to the helper — preserving the exact
2579        // gating behavior.
2580        let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2581            self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2582        let has_double = self.has_double_penalty && lambdas.len() > d;
2583        let diag_vals: Vec<f64> = {
2584            let mut vals = Vec::with_capacity(p);
2585            let mut multi_idx = vec![0usize; d];
2586            loop {
2587                let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2588                    &eigenvalue_views,
2589                    &multi_idx,
2590                    lambdas,
2591                    d,
2592                    has_double,
2593                    self.penalty_shrinkage_ridge,
2594                );
2595                vals.push(if sigma > 0.0 { sigma.sqrt() } else { 0.0 });
2596
2597                if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2598                    break;
2599                }
2600            }
2601            vals
2602        };
2603        let rank = diag_vals.iter().filter(|&&v| v > 1e-12).count();
2604        let mut e_transformed = Array2::<f64>::zeros((rank, p));
2605        let mut row = 0;
2606        for (j, &v) in diag_vals.iter().enumerate() {
2607            if v > 1e-12 {
2608                e_transformed[[row, j]] = v;
2609                row += 1;
2610            }
2611        }
2612
2613        // u_truncated: null-space eigenvectors (columns with zero eigenvalue).
2614        let null_count = p - rank;
2615        let mut u_truncated = Array2::<f64>::zeros((p, null_count));
2616        let mut col = 0;
2617        for (j, &v) in diag_vals.iter().enumerate() {
2618            if v <= 1e-12 {
2619                u_truncated[[j, col]] = 1.0; // standard basis vector in eigenbasis
2620                col += 1;
2621            }
2622        }
2623
2624        let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2625            .iter()
2626            .map(|r| CanonicalPenalty::from_dense_root(r.clone(), p))
2627            .collect();
2628        Ok(ReparamResult {
2629            s_transformed,
2630            log_det: self.log_det,
2631            det1: self.det1.clone(),
2632            qs,
2633            canonical_transformed,
2634            e_transformed,
2635            u_truncated,
2636            penalty_shrinkage_ridge: self.penalty_shrinkage_ridge,
2637        })
2638    }
2639}
2640
2641/// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`
2642/// from factored marginal eigenvalues.
2643///
2644/// Shared implementation for `KroneckerPenaltySystem::logdet_and_derivatives`
2645/// and `kronecker_reparameterization_engine`.  Iterates over the ∏q_j
2646/// multi-index grid in O(d · ∏q_j) time with no O(p²) storage.
2647const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
2648
2649/// Per-cell Kronecker eigenvalue accumulation — the single source of truth for
2650/// the #1172/#1185 tensor-penalty math.
2651///
2652/// For the multi-index cell `multi_idx`, accumulates:
2653///   - `sigma`            = Σ_k λ_k · μ_k  (+ joint-null double-penalty term + ridge)
2654///   - `structural_sigma` = Σ_k μ_k        (unweighted; classifies joint-null cells)
2655///   - `joint_null`       = whether the cell lies in the joint null space
2656///
2657/// `marginal_eigenvalues[k][multi_idx[k]]` is the k-th marginal eigenvalue μ_k.
2658/// The double-penalty (global ridge) term `λ_d` is added only on joint-null
2659/// cells; the structural shrinkage `ridge` is added only on structurally
2660/// penalized cells. This mirrors the gated logic fixed in #1172/#1185 and MUST
2661/// be kept identical across every caller.
2662#[inline]
2663fn kronecker_cell_sigma(
2664    marginal_eigenvalues: &[ArrayView1<'_, f64>],
2665    multi_idx: &[usize],
2666    lambdas: &[f64],
2667    d: usize,
2668    has_double_penalty: bool,
2669    ridge: f64,
2670) -> (f64, f64, bool) {
2671    let mut sigma = 0.0;
2672    let mut structural_sigma = 0.0;
2673    for k in 0..d {
2674        let marginal_eigenvalue = marginal_eigenvalues[k][multi_idx[k]];
2675        structural_sigma += marginal_eigenvalue;
2676        sigma += lambdas[k] * marginal_eigenvalue;
2677    }
2678    let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
2679    if has_double_penalty && joint_null {
2680        sigma += lambdas[d];
2681    }
2682    if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
2683        sigma += ridge;
2684    }
2685    (sigma, structural_sigma, joint_null)
2686}
2687
2688/// Advance a row-major multi-index over the `dims` grid in place.
2689/// Returns `true` when the grid is exhausted (the index wrapped back to all-zero).
2690#[inline]
2691fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
2692    let mut carry = true;
2693    for dim in (0..dims.len()).rev() {
2694        if carry {
2695            multi_idx[dim] += 1;
2696            if multi_idx[dim] < dims[dim] {
2697                carry = false;
2698            } else {
2699                multi_idx[dim] = 0;
2700            }
2701        }
2702    }
2703    carry
2704}
2705
2706pub fn kronecker_logdet_and_derivatives(
2707    marginal_eigenvalues: &[ArrayView1<'_, f64>],
2708    marginal_dims: &[usize],
2709    lambdas: &[f64],
2710    has_double_penalty: bool,
2711    ridge: f64,
2712) -> (f64, Array1<f64>, Array2<f64>) {
2713    let d = marginal_dims.len();
2714    let n_pen = d + if has_double_penalty { 1 } else { 0 };
2715
2716    let mut logdet = 0.0;
2717    let mut grad = Array1::<f64>::zeros(n_pen);
2718    let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2719    let tol = 1e-12;
2720
2721    let mut multi_idx = vec![0usize; d];
2722    loop {
2723        let (sigma, _structural_sigma, joint_null) = kronecker_cell_sigma(
2724            marginal_eigenvalues,
2725            &multi_idx,
2726            lambdas,
2727            d,
2728            has_double_penalty,
2729            ridge,
2730        );
2731
2732        if sigma > tol {
2733            logdet += sigma.ln();
2734            let inv_sigma = 1.0 / sigma;
2735            let inv_sigma2 = inv_sigma * inv_sigma;
2736
2737            for k in 0..d {
2738                let ck = lambdas[k] * marginal_eigenvalues[k][multi_idx[k]];
2739                grad[k] += ck * inv_sigma;
2740            }
2741            if has_double_penalty && joint_null {
2742                grad[d] += lambdas[d] * inv_sigma;
2743            }
2744
2745            for k in 0..n_pen {
2746                let ck = if k < d {
2747                    lambdas[k] * marginal_eigenvalues[k][multi_idx[k]]
2748                } else if joint_null {
2749                    lambdas[d]
2750                } else {
2751                    0.0
2752                };
2753                // When ck == 0 (a zero λ, a zero marginal eigenvalue, or a cell
2754                // outside the joint null for the ridge penalty) every term this
2755                // index k contributes — `ck·inv_sigma − ck²·inv_sigma2` on the
2756                // diagonal and `−ck·cl·inv_sigma2` on every off-diagonal — is
2757                // exactly 0.0, so adding them to the finite running accumulators
2758                // is a bit-identical no-op. Skip the inner sweep entirely.
2759                if ck == 0.0 {
2760                    continue;
2761                }
2762                hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2763                for l in (k + 1)..n_pen {
2764                    let cl = if l < d {
2765                        lambdas[l] * marginal_eigenvalues[l][multi_idx[l]]
2766                    } else if joint_null {
2767                        lambdas[d]
2768                    } else {
2769                        0.0
2770                    };
2771                    let off = -ck * cl * inv_sigma2;
2772                    hess[[k, l]] += off;
2773                    hess[[l, k]] += off;
2774                }
2775            }
2776        }
2777
2778        if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
2779            break;
2780        }
2781    }
2782
2783    (logdet, grad, hess)
2784}
2785
2786// #1521: `KroneckerInvariantStructure` is defined once in `crate::kronecker`
2787// (the leaf data+compute module). The byte-identical copy that the carve left
2788// here is replaced by an import so the cache and this engine share one type.
2789use crate::kronecker::KroneckerInvariantStructure;
2790
2791/// Kronecker-factored reparameterization for tensor-product penalties.
2792///
2793/// Instead of eigendecomposing the full p×p balanced penalty (O(p³)), this
2794/// eigendecomposes each marginal penalty separately (O(Σ q_k³)) and computes
2795/// the joint eigensystem as the Kronecker product of marginal eigensystems.
2796pub fn kronecker_reparameterization_engine(
2797    marginal_designs: &[Array2<f64>],
2798    marginal_penalties: &[Array2<f64>],
2799    marginal_dims: &[usize],
2800    lambdas: &[f64],
2801    has_double_penalty: bool,
2802    penalty_shrinkage_floor: Option<f64>,
2803) -> Result<KroneckerReparamResult, EstimationError> {
2804    let d = marginal_dims.len();
2805    if marginal_designs.len() != d || marginal_penalties.len() != d {
2806        return Err(EstimationError::LayoutError(format!(
2807            "kronecker_reparameterization_engine: dimension mismatch: designs={}, penalties={}, dims={}",
2808            marginal_designs.len(),
2809            marginal_penalties.len(),
2810            d
2811        )));
2812    }
2813
2814    let invariant =
2815        KroneckerInvariantStructure::compute(marginal_designs, marginal_penalties, marginal_dims)?;
2816    kronecker_reparameterization_engine_with_invariant(
2817        &invariant,
2818        marginal_dims,
2819        lambdas,
2820        has_double_penalty,
2821        penalty_shrinkage_floor,
2822    )
2823}
2824
2825/// Kronecker-factored reparameterization reusing a precomputed λ-invariant
2826/// structure (eigensystems, reparameterized marginals, shrinkage scale).
2827///
2828/// Bit-identical to `kronecker_reparameterization_engine` for the same marginal
2829/// data — the only difference is that the `eigh()` / `B_k U_k` work was hoisted
2830/// out of the per-iterate path into the cached `invariant`. Only the λ-dependent
2831/// `kronecker_logdet_and_derivatives` sweep and `floor * max_bal` scaling run here.
2832pub fn kronecker_reparameterization_engine_with_invariant(
2833    invariant: &KroneckerInvariantStructure,
2834    marginal_dims: &[usize],
2835    lambdas: &[f64],
2836    has_double_penalty: bool,
2837    penalty_shrinkage_floor: Option<f64>,
2838) -> Result<KroneckerReparamResult, EstimationError> {
2839    // Arc refcount bumps — the underlying eigensystems / reparameterized
2840    // marginals are λ-invariant and shared with the cache, not deep-copied.
2841    let marginal_eigenvalues = Arc::clone(&invariant.marginal_eigenvalues);
2842    let marginal_qs = Arc::clone(&invariant.marginal_qs);
2843    let reparameterized_marginals = Arc::clone(&invariant.reparameterized_marginals);
2844
2845    // Compute shrinkage ridge from balanced penalty eigenvalue scale.
2846    let penalty_shrinkage_ridge = if let Some(floor) = penalty_shrinkage_floor {
2847        floor * invariant.max_balanced_eigenvalue
2848    } else {
2849        0.0
2850    };
2851
2852    let marginal_eigenvalue_views: Vec<_> = marginal_eigenvalues
2853        .iter()
2854        .map(|evals| evals.view())
2855        .collect();
2856    let (log_det, det1, det2) = kronecker_logdet_and_derivatives(
2857        &marginal_eigenvalue_views,
2858        marginal_dims,
2859        lambdas,
2860        has_double_penalty,
2861        penalty_shrinkage_ridge,
2862    );
2863
2864    Ok(KroneckerReparamResult {
2865        reparameterized_marginals,
2866        marginal_eigenvalues,
2867        marginal_qs,
2868        log_det,
2869        det1,
2870        det2,
2871        penalty_shrinkage_ridge,
2872        has_double_penalty,
2873        marginal_dims: marginal_dims.to_vec(),
2874    })
2875}
2876
2877/// Calculate the 2-norm condition number of a matrix.
2878///
2879/// For symmetric matrices (the dominant case for GAM Hessians/penalties),
2880/// this uses an eigenvalue path and computes:
2881///   cond_2(A) = max_i |lambda_i| / min_i |lambda_i|
2882/// which is exactly equal to the singular-value definition for symmetric A.
2883///
2884/// For non-symmetric matrices, this falls back to SVD:
2885///   cond_2(A) = sigma_max / sigma_min
2886///
2887/// This preserves semantics while avoiding full SVD in hot paths.
2888///
2889/// # Arguments
2890/// * `matrix` - The matrix to analyze
2891///
2892/// # Returns
2893/// * `Ok(condition_number)` - The condition number (max_sv / min_sv)
2894/// * `Ok(f64::INFINITY)` - If the matrix is effectively singular (min_sv < 1e-12)
2895/// * `Err` - If SVD computation fails
2896pub fn calculate_condition_number(matrix: &Array2<f64>) -> Result<f64, FaerLinalgError> {
2897    let (rows, cols) = matrix.dim();
2898    if rows == 0 || cols == 0 {
2899        return Ok(1.0);
2900    }
2901
2902    // Fast path for (near-)symmetric square matrices.
2903    if rows == cols {
2904        let mut max_abs = 0.0_f64;
2905        let mut max_asym = 0.0_f64;
2906        for i in 0..rows {
2907            for j in 0..cols {
2908                max_abs = max_abs.max(matrix[[i, j]].abs());
2909            }
2910            for j in 0..i {
2911                let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2912                if diff > max_asym {
2913                    max_asym = diff;
2914                }
2915            }
2916        }
2917        let sym_tol = max_abs.max(1.0) * 1e-12;
2918        if max_asym <= sym_tol {
2919            let (evals, _) = matrix.eigh(Side::Lower)?;
2920            let mut max_abs_eval = 0.0_f64;
2921            let mut min_abs_eval = f64::INFINITY;
2922            for &lam in evals.iter() {
2923                let s = lam.abs();
2924                max_abs_eval = max_abs_eval.max(s);
2925                min_abs_eval = min_abs_eval.min(s);
2926            }
2927            if min_abs_eval < 1e-12 {
2928                return Ok(f64::INFINITY);
2929            }
2930            return Ok(max_abs_eval / min_abs_eval);
2931        }
2932    }
2933
2934    // General matrix fallback.
2935    let (_, s, _) = matrix.svd(false, false)?;
2936    let max_sv = s.iter().fold(0.0_f64, |max, &val| max.max(val));
2937    let min_sv = s.iter().fold(f64::INFINITY, |min, &val| min.min(val));
2938    if min_sv < 1e-12 {
2939        return Ok(f64::INFINITY);
2940    }
2941    Ok(max_sv / min_sv)
2942}
2943
2944#[cfg(test)]
2945mod tests {
2946    use super::{
2947        CanonicalPenalty, REL_PSD_FLOOR, SubspaceLeakageMetrics, assess_subspace_leakage,
2948        classify_eigenvalues_strict, precompute_reparam_invariant_from_canonical,
2949        report_penalty_pair_redundancy, stable_reparameterizationwith_invariant,
2950        subspace_split_is_consistent,
2951    };
2952    use crate::EstimationError;
2953    use crate::construction::kronecker_product;
2954    use faer::Mat;
2955    use gam_linalg::faer_ndarray::FaerEigh;
2956    use gam_linalg::utils::inf_norm;
2957    use ndarray::{Array1, Array2, array};
2958
2959    /// Build CanonicalPenalty values from full-width roots for tests.
2960    fn canonical_from_roots(rs_list: &[Array2<f64>], p: usize) -> Vec<CanonicalPenalty> {
2961        rs_list
2962            .iter()
2963            .map(|r| {
2964                let local = r.t().dot(r);
2965                CanonicalPenalty {
2966                    root: r.clone(),
2967                    col_range: 0..p,
2968                    total_dim: p,
2969                    nullity: 0,
2970                    local,
2971                    prior_mean: Array1::zeros(p),
2972                    positive_eigenvalues: Vec::new(),
2973                    op: None,
2974                }
2975            })
2976            .collect()
2977    }
2978
2979    fn metrics_for(
2980        qs: &Mat<f64>,
2981        rs: &[Mat<f64>],
2982        structural_rank: usize,
2983        p: usize,
2984    ) -> SubspaceLeakageMetrics {
2985        assess_subspace_leakage(qs, rs, structural_rank, p)
2986    }
2987
2988    #[test]
2989    fn subspace_leakage_iszero_for_clean_split() {
2990        let p = 4usize;
2991        let structural_rank = 2usize;
2992        let qs = Mat::<f64>::identity(p, p);
2993        let mut r0 = Mat::<f64>::zeros(2, p);
2994        r0[(0, 0)] = 1.0;
2995        r0[(1, 1)] = 2.0;
2996
2997        let m = metrics_for(&qs, &[r0], structural_rank, p);
2998        assert!(m.max_abs_sq <= 1e-16);
2999        assert!(m.max_rel_sq <= 1e-16);
3000        assert!(m.max_cross_gram_abs <= 1e-16);
3001    }
3002
3003    #[test]
3004    fn subspace_leakage_detects_null_column_energy() {
3005        let p = 4usize;
3006        let structural_rank = 2usize;
3007        let qs = Mat::<f64>::identity(p, p);
3008        let mut r0 = Mat::<f64>::zeros(1, p);
3009        r0[(0, 2)] = 3.0;
3010
3011        let m = metrics_for(&qs, &[r0], structural_rank, p);
3012        assert!(m.max_abs_sq > 0.0);
3013        assert!(m.max_rel_sq > 0.99);
3014    }
3015
3016    #[test]
3017    fn subspace_leakage_detects_qp_qn_nonorthogonality() {
3018        let p = 3usize;
3019        let structural_rank = 1usize;
3020        let mut qs = Mat::<f64>::identity(p, p);
3021        qs[(0, 1)] = 0.2;
3022        let r0 = Mat::<f64>::zeros(1, p);
3023
3024        let m = metrics_for(&qs, &[r0], structural_rank, p);
3025        assert!(m.max_cross_gram_abs > 1e-3);
3026    }
3027
3028    #[test]
3029    fn subspace_split_admits_near_threshold_manifold_leakage_1802() {
3030        // #1802: on sphere / Duchon / spline-on-sphere bases the REML outer
3031        // startup rejected EVERY candidate seed with
3032        //   "Reparameterization subspace split is inconsistent:
3033        //    max null leakage 1.174e-4 (rel 1.031e-4, worst penalty 0),
3034        //    max |Qp'Qn| 4.316e-16"
3035        // The split is perfectly ORTHOGONAL (|Qp'Qn| ≈ 4e-16); the only tripped
3036        // quantity is the transformed-root null-block leakage, whose RELATIVE
3037        // energy (1.031e-4)² ≈ 1.06e-8 sits right at `REL_PSD_FLOOR`. That is the
3038        // eigensolver's own numerically-PSD floor — the same floor
3039        // `classify_eigenvalues_strict` uses to declare a mode null — so a
3040        // manifold penalty whose Laplace-Beltrami spectrum decays through the
3041        // rank threshold with no clean gap MUST be admitted. Reproduce that exact
3042        // signature and assert the guard accepts it.
3043        let p = 40usize;
3044        let structural_rank = p - 1;
3045        // Unit-amplitude range column + a null column at √(REL_PSD_FLOOR)
3046        // amplitude, so the null-block relative energy lands at ~REL_PSD_FLOOR.
3047        let null_amp = (1.06e-8_f64).sqrt();
3048        let mut rs = Mat::<f64>::zeros(1, p);
3049        rs[(0, 0)] = 1.0;
3050        rs[(0, p - 1)] = null_amp;
3051        let qs = Mat::<f64>::identity(p, p);
3052        let leakage = metrics_for(&qs, &[rs], structural_rank, p);
3053        // The leakage reproduces the observed band: above the OLD fixed 1e-10
3054        // tolerance (which rejected the sphere fit) yet a benign ~REL_PSD_FLOOR.
3055        assert!(
3056            leakage.max_rel_sq > 1e-10 && leakage.max_rel_sq < 1e-6,
3057            "reproduced leakage should sit in the near-REL_PSD_FLOOR band, got {:.3e}",
3058            leakage.max_rel_sq
3059        );
3060        assert!(leakage.max_cross_gram_abs <= 1e-12);
3061        assert!(
3062            subspace_split_is_consistent(&leakage, p),
3063            "near-REL_PSD_FLOOR null leakage on a manifold basis must be admitted \
3064             (rel_sq={:.3e}, tol={:.3e})",
3065            leakage.max_rel_sq,
3066            (p as f64) * REL_PSD_FLOOR,
3067        );
3068    }
3069
3070    #[test]
3071    fn subspace_split_still_rejects_genuine_inconsistency_1802() {
3072        // The #1802 relaxation only widens the leakage tolerance to the
3073        // classifier's own `p · REL_PSD_FLOOR` floor; it must NOT admit a
3074        // genuinely broken split. A whole penalized mode dumped into the null
3075        // block (O(1) relative leakage) is still rejected.
3076        let p = 40usize;
3077        let structural_rank = p - 1;
3078        let mut rs = Mat::<f64>::zeros(1, p);
3079        rs[(0, p - 1)] = 1.0; // all penalty-root energy in the null column
3080        let qs = Mat::<f64>::identity(p, p);
3081        let leakage = metrics_for(&qs, &[rs], structural_rank, p);
3082        assert!(leakage.max_rel_sq > 0.99);
3083        assert!(
3084            !subspace_split_is_consistent(&leakage, p),
3085            "an O(1) null-block leakage is a real inconsistency and must be rejected"
3086        );
3087
3088        // A non-orthogonal split is rejected regardless of root leakage.
3089        let mut qs_bad = Mat::<f64>::identity(3, 3);
3090        qs_bad[(0, 1)] = 0.2;
3091        let clean = Mat::<f64>::zeros(1, 3);
3092        let leakage2 = metrics_for(&qs_bad, &[clean], 1, 3);
3093        assert!(leakage2.max_cross_gram_abs > 1e-3);
3094        assert!(
3095            !subspace_split_is_consistent(&leakage2, 3),
3096            "a non-orthogonal Qp/Qn split must be rejected"
3097        );
3098    }
3099
3100    #[test]
3101    fn u_truncated_is_transformed_frame_in_nonzero_case() {
3102        let p = 3usize;
3103        let rs_list = vec![array![[1.0, 0.0, 0.0]]];
3104        let canonical = canonical_from_roots(&rs_list, p);
3105        let lambdas = vec![2.0];
3106        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3107            .expect("precompute invariant");
3108        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3109            .expect("stable reparam");
3110
3111        let expected = rep.qs.t().dot(&inv.split.q_null);
3112        let diff = &rep.u_truncated - &expected;
3113        let max_abs = inf_norm(diff.iter().copied());
3114        assert!(
3115            max_abs <= 1e-10,
3116            "u_truncated frame mismatch: max_abs={max_abs}"
3117        );
3118    }
3119
3120    #[test]
3121    fn infinite_lambda_keeps_range_penalty_block_finite_1379() {
3122        // gam#1379 / gam#1074: a genuinely infinite λ = exp(ρ) is NOT silently
3123        // clamped to a finite ceiling. The original #1379 fix added a 1e300
3124        // ceiling so `∞ · 0` could not poison the range block Σ_k λ_k S_k, but
3125        // #1074 DELETED that clamp on purpose (see the comment at the top of
3126        // `stable_reparameterizationwith_invariant`): masking ∞ hid the real
3127        // defect — the outer optimizer driving a redundant/unidentified penalty
3128        // direction off to ∞ instead of that direction being detected and
3129        // dropped. With the clamp gone, a literal `f64::INFINITY` λ surfaces as
3130        // a clean, detectable error (the eigensolver rejects the NaN-poisoned
3131        // block) rather than a silent finite success. Pin that contract: ∞ must
3132        // ERROR, not be quietly clamped.
3133        //
3134        // Fixture: two penalties on a 3-wide block. The first penalizes only
3135        // coordinate 0 (so its block S_k has structural zeros everywhere except
3136        // [0,0]); give it λ = +∞. The second penalizes coordinate 1 at a normal
3137        // λ.
3138        let p = 3usize;
3139        let rs_list = vec![array![[1.0, 0.0, 0.0]], array![[0.0, 1.0, 0.0]]];
3140        let canonical = canonical_from_roots(&rs_list, p);
3141        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3142            .expect("precompute invariant");
3143
3144        let lambdas_inf = vec![f64::INFINITY, 3.0];
3145        let inf_result =
3146            stable_reparameterizationwith_invariant(&canonical, &lambdas_inf, p, &inv, None);
3147        assert!(
3148            inf_result.is_err(),
3149            "an infinite lambda must surface as an error, not be silently clamped (#1074)"
3150        );
3151
3152        // A finite (even very large) λ must still produce an all-finite reparam:
3153        // the function is robust to large-but-finite penalties; only the
3154        // non-finite input is rejected.
3155        let lambdas_big = vec![1e300_f64, 3.0];
3156        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas_big, p, &inv, None)
3157            .expect("stable reparam at large-but-finite lambda");
3158        assert!(
3159            rep.s_transformed.iter().all(|v| v.is_finite()),
3160            "transformed penalty must be finite at large-but-finite lambda"
3161        );
3162        assert!(
3163            rep.qs.iter().all(|v| v.is_finite()),
3164            "reparam rotation must be finite at large-but-finite lambda"
3165        );
3166        assert!(
3167            rep.log_det.is_finite(),
3168            "penalty log-det must be finite at large-but-finite lambda"
3169        );
3170        assert!(
3171            rep.det1.iter().all(|v| v.is_finite()),
3172            "penalty log-det derivatives must be finite at large-but-finite lambda"
3173        );
3174    }
3175
3176    #[test]
3177    fn u_truncated_is_identitywhen_no_penalties() {
3178        let p = 4usize;
3179        let canonical: Vec<CanonicalPenalty> = Vec::new();
3180        let lambdas: Vec<f64> = Vec::new();
3181        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3182            .expect("precompute invariant");
3183        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3184            .expect("stable reparam");
3185        assert_eq!(rep.u_truncated, Array2::<f64>::eye(p));
3186    }
3187
3188    #[test]
3189    fn dense_shrinkage_floor_skips_structurally_unpenalized_range_columns() {
3190        let p = 3usize;
3191        let canonical = canonical_from_roots(&[array![[1.0, 0.0, 0.0]]], p);
3192        let invariant = super::ReparamInvariant {
3193            split: super::SubspaceSplit {
3194                q_pen: array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
3195                q_null: array![[0.0], [0.0], [1.0]],
3196            },
3197            qs_base: Array2::eye(p),
3198            has_nonzero: true,
3199            max_balanced_eigenvalue: 1.0,
3200        };
3201
3202        let rep =
3203            stable_reparameterizationwith_invariant(&canonical, &[2.0], p, &invariant, Some(1e-6))
3204                .expect("stable reparameterization");
3205        assert!(rep.s_transformed[[0, 0]] > 2.0);
3206        assert!(
3207            rep.s_transformed[[1, 1]] <= 1e-11,
3208            "structurally unpenalized range coordinate received shrinkage ridge: {}",
3209            rep.s_transformed[[1, 1]]
3210        );
3211    }
3212
3213    #[test]
3214    fn kronecker_shrinkage_floor_preserves_joint_null_space() {
3215        let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3216        let marginal_penalties = vec![
3217            array![[0.0, 0.0], [0.0, 2.0]],
3218            array![[0.0, 0.0], [0.0, 3.0]],
3219        ];
3220        let marginal_dims = vec![2usize, 2usize];
3221        let lambdas = vec![5.0, 7.0];
3222
3223        let rep = super::kronecker_reparameterization_engine(
3224            &marginal_designs,
3225            &marginal_penalties,
3226            &marginal_dims,
3227            &lambdas,
3228            false,
3229            Some(1e-6),
3230        )
3231        .expect("kronecker reparameterization");
3232        assert!(rep.penalty_shrinkage_ridge > 0.0);
3233
3234        let s = rep.materialize_s_transformed(&lambdas);
3235        assert!(
3236            s[[0, 0]].abs() <= 1e-14,
3237            "joint tensor null direction must remain unpenalized, got {}",
3238            s[[0, 0]]
3239        );
3240        assert!(s[[1, 1]] > lambdas[1] * 3.0);
3241        assert!(s[[2, 2]] > lambdas[0] * 2.0);
3242        assert!(s[[3, 3]] > lambdas[0] * 2.0 + lambdas[1] * 3.0);
3243
3244        let tensor_roots = vec![
3245            array![
3246                [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3247                [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3248            ],
3249            array![
3250                [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3251                [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3252            ],
3253        ];
3254        let dense = rep
3255            .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3256            .expect("dense artifact materialization");
3257        assert_eq!(dense.e_transformed.nrows(), 3);
3258        assert_eq!(dense.u_truncated.ncols(), 1);
3259    }
3260
3261    #[test]
3262    fn kronecker_memoized_invariant_is_bit_identical_to_unmemoized_engine() {
3263        // The hot-path memoization (compute the marginal eigensystems /
3264        // reparameterized marginals once, reuse across outer iterates) must
3265        // produce a KroneckerReparamResult that is *bit-identical* to the
3266        // unmemoized engine for the same marginal data and λ — the cached work
3267        // is literally the same eigendecomposition. Cover several λ on one fixed
3268        // invariant structure (the realistic outer-loop pattern).
3269        let marginal_designs = vec![
3270            array![[1.0, 0.3, -0.2], [0.4, 1.0, 0.1], [-0.1, 0.2, 1.0]],
3271            array![[1.0, -0.5], [0.2, 1.0], [0.7, 0.3]],
3272        ];
3273        let marginal_penalties = vec![
3274            array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]],
3275            array![[3.0, -1.5], [-1.5, 3.0]],
3276        ];
3277        let marginal_dims = vec![3usize, 2usize];
3278
3279        let invariant = super::KroneckerInvariantStructure::compute(
3280            &marginal_designs,
3281            &marginal_penalties,
3282            &marginal_dims,
3283        )
3284        .expect("invariant structure");
3285
3286        for lambdas in [
3287            vec![5.0, 7.0],
3288            vec![0.0, 7.0],
3289            vec![5.0, 0.0],
3290            vec![1e-3, 1e3],
3291        ] {
3292            for floor in [None, Some(1e-6)] {
3293                let unmemoized = super::kronecker_reparameterization_engine(
3294                    &marginal_designs,
3295                    &marginal_penalties,
3296                    &marginal_dims,
3297                    &lambdas,
3298                    true,
3299                    floor,
3300                )
3301                .expect("unmemoized engine");
3302                let memoized = super::kronecker_reparameterization_engine_with_invariant(
3303                    &invariant,
3304                    &marginal_dims,
3305                    &lambdas,
3306                    true,
3307                    floor,
3308                )
3309                .expect("memoized engine");
3310
3311                assert_eq!(memoized.log_det.to_bits(), unmemoized.log_det.to_bits());
3312                assert_eq!(
3313                    memoized.penalty_shrinkage_ridge.to_bits(),
3314                    unmemoized.penalty_shrinkage_ridge.to_bits()
3315                );
3316                for (a, b) in memoized.det1.iter().zip(unmemoized.det1.iter()) {
3317                    assert_eq!(a.to_bits(), b.to_bits());
3318                }
3319                for (a, b) in memoized.det2.iter().zip(unmemoized.det2.iter()) {
3320                    assert_eq!(a.to_bits(), b.to_bits());
3321                }
3322                for (ma, ua) in memoized
3323                    .reparameterized_marginals
3324                    .iter()
3325                    .zip(unmemoized.reparameterized_marginals.iter())
3326                {
3327                    for (a, b) in ma.iter().zip(ua.iter()) {
3328                        assert_eq!(a.to_bits(), b.to_bits());
3329                    }
3330                }
3331                for (mq, uq) in memoized
3332                    .marginal_qs
3333                    .iter()
3334                    .zip(unmemoized.marginal_qs.iter())
3335                {
3336                    for (a, b) in mq.iter().zip(uq.iter()) {
3337                        assert_eq!(a.to_bits(), b.to_bits());
3338                    }
3339                }
3340            }
3341        }
3342    }
3343
3344    #[test]
3345    fn kronecker_double_penalty_shrinks_only_joint_null_space() {
3346        let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3347        let marginal_penalties = vec![
3348            array![[0.0, 0.0], [0.0, 2.0]],
3349            array![[0.0, 0.0], [0.0, 3.0]],
3350        ];
3351        let marginal_dims = vec![2usize, 2usize];
3352        let lambdas = vec![5.0, 7.0, 11.0];
3353
3354        let rep = super::kronecker_reparameterization_engine(
3355            &marginal_designs,
3356            &marginal_penalties,
3357            &marginal_dims,
3358            &lambdas,
3359            true,
3360            None,
3361        )
3362        .expect("kronecker reparameterization");
3363
3364        let s = rep.materialize_s_transformed(&lambdas);
3365        let expected = [11.0, 21.0, 10.0, 31.0];
3366        for (idx, expected_diag) in expected.iter().copied().enumerate() {
3367            assert!(
3368                (s[[idx, idx]] - expected_diag).abs() <= 1e-12,
3369                "diagonal {idx} got {}, expected {expected_diag}",
3370                s[[idx, idx]]
3371            );
3372        }
3373
3374        let expected_logdet: f64 = expected.iter().map(|v| f64::ln(*v)).sum();
3375        assert!((rep.log_det - expected_logdet).abs() <= 1e-12);
3376        assert!(
3377            (rep.det1[2] - 1.0).abs() <= 1e-12,
3378            "double-penalty derivative must come only from the joint null mode, got {}",
3379            rep.det1[2]
3380        );
3381        assert!(rep.det2[[2, 2]].abs() <= 1e-12);
3382
3383        let tensor_roots = vec![
3384            array![
3385                [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3386                [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3387            ],
3388            array![
3389                [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3390                [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3391            ],
3392        ];
3393        let dense = rep
3394            .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3395            .expect("dense artifact materialization");
3396        for (idx, expected_diag) in expected.iter().copied().enumerate() {
3397            assert!(
3398                (dense.s_transformed[[idx, idx]] - expected_diag).abs() <= 1e-12,
3399                "dense artifact diagonal {idx} got {}, expected {expected_diag}",
3400                dense.s_transformed[[idx, idx]]
3401            );
3402        }
3403    }
3404
3405    #[test]
3406    fn transformed_penalty_is_diagonal_in_transformed_frame() {
3407        let p = 3usize;
3408        let inv_sqrt2 = 2.0_f64.sqrt().recip();
3409        // Penalize a rotated direction in original space so Qs is non-trivial.
3410        let rs_list = vec![array![[inv_sqrt2, inv_sqrt2, 0.0]]];
3411        let canonical = canonical_from_roots(&rs_list, p);
3412        let lambdas = vec![4.0];
3413        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3414            .expect("precompute invariant");
3415        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3416            .expect("stable reparam");
3417
3418        assert_eq!(rep.e_transformed.nrows(), 1);
3419        assert!(rep.e_transformed[[0, 0]].abs() > 0.0);
3420        assert!(rep.e_transformed[[0, 1]].abs() <= 1e-12);
3421        assert!(rep.e_transformed[[0, 2]].abs() <= 1e-12);
3422        // Exact pseudo-logdet on the structural penalized block has no
3423        // delta-dependent nullspace normalization.
3424        let expected_det1 = 1.0_f64;
3425        assert!((rep.det1[0] - expected_det1).abs() <= 1e-12);
3426
3427        let s = rep.s_transformed;
3428        let mut max_offdiag = 0.0_f64;
3429        for i in 0..p {
3430            for j in 0..p {
3431                if i != j {
3432                    max_offdiag = max_offdiag.max(s[[i, j]].abs());
3433                }
3434            }
3435        }
3436        assert!(
3437            max_offdiag <= 1e-10,
3438            "transformed penalty should be diagonal, max offdiag={max_offdiag}"
3439        );
3440        assert!(s[[1, 1]].abs() <= 1e-10);
3441        assert!(s[[2, 2]].abs() <= 1e-10);
3442    }
3443
3444    #[test]
3445    fn det1_matches_rank_for_single_full_rank_penalty() {
3446        let p = 2usize;
3447        let inv_sqrt2 = 2.0_f64.sqrt().recip();
3448        // Q^T for a 45-degree rotation.
3449        let q_t = [[inv_sqrt2, inv_sqrt2], [-inv_sqrt2, inv_sqrt2]];
3450        // R = diag(3, 1) * Q^T gives S = Q * diag(9, 1) * Q^T.
3451        let rs = array![
3452            [3.0 * q_t[0][0], 3.0 * q_t[0][1]],
3453            [1.0 * q_t[1][0], 1.0 * q_t[1][1]]
3454        ];
3455        let rs_list = vec![rs];
3456        let canonical = canonical_from_roots(&rs_list, p);
3457        let lambdas = vec![5.0];
3458
3459        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3460            .expect("precompute invariant");
3461        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3462            .expect("stable reparam");
3463
3464        assert_eq!(rep.e_transformed.nrows(), p);
3465        let det1 = rep.det1[0];
3466        // Exact pseudo-logdet on the structural penalized block:
3467        //   det1 = lambda * sum_l d_l / (lambda*d_l)
3468        // where d_l are eigenvalues of S_k.
3469        let s_k_eigs = [9.0_f64, 1.0_f64];
3470        let lambda = 5.0_f64;
3471        let expected_det1: f64 = s_k_eigs.iter().map(|&d| lambda * d / (lambda * d)).sum();
3472        assert!(
3473            (det1 - expected_det1).abs() <= 1e-12,
3474            "expected det1={expected_det1}, got {det1}",
3475        );
3476
3477        let s = rep.s_transformed;
3478        assert!(s[[0, 1]].abs() <= 1e-10);
3479        assert!(s[[1, 0]].abs() <= 1e-10);
3480        assert!(s[[0, 0]] > 0.0);
3481        assert!(s[[1, 1]] > 0.0);
3482    }
3483
3484    #[test]
3485    fn kronecker_reparam_logdet_matches_dense() {
3486        // 2D tensor product: q1=3, q2=4.
3487        // Marginal penalties: second-order difference matrices.
3488        let q1 = 3;
3489        let q2 = 4;
3490        let s1 = {
3491            let mut s = Array2::<f64>::zeros((q1, q1));
3492            // D2' D2 for order 2 on 3 points: [[1,-2,1],[-2,4,-2],[1,-2,1]]... simplified
3493            s[[0, 0]] = 1.0;
3494            s[[0, 1]] = -1.0;
3495            s[[1, 0]] = -1.0;
3496            s[[1, 1]] = 2.0;
3497            s[[1, 2]] = -1.0;
3498            s[[2, 1]] = -1.0;
3499            s[[2, 2]] = 1.0;
3500            s
3501        };
3502        let s2 = {
3503            let mut s = Array2::<f64>::zeros((q2, q2));
3504            s[[0, 0]] = 1.0;
3505            s[[0, 1]] = -1.0;
3506            s[[1, 0]] = -1.0;
3507            s[[1, 1]] = 2.0;
3508            s[[1, 2]] = -1.0;
3509            s[[2, 1]] = -1.0;
3510            s[[2, 2]] = 2.0;
3511            s[[2, 3]] = -1.0;
3512            s[[3, 2]] = -1.0;
3513            s[[3, 3]] = 1.0;
3514            s
3515        };
3516
3517        let lambdas = [2.5, 1.3];
3518        // Build dense Kronecker penalty: λ1 (S1⊗I) + λ2 (I⊗S2).
3519        let p = q1 * q2;
3520        let i1 = Array2::<f64>::eye(q1);
3521        let i2 = Array2::<f64>::eye(q2);
3522        let pen0 = kronecker_product(&s1, &i2);
3523        let pen1 = kronecker_product(&i1, &s2);
3524        let mut s_dense = Array2::<f64>::zeros((p, p));
3525        s_dense.scaled_add(lambdas[0], &pen0);
3526        s_dense.scaled_add(lambdas[1], &pen1);
3527
3528        // Dense eigendecomposition for reference pseudo-logdet.
3529        let (evals_dense, _): (ndarray::Array1<f64>, ndarray::Array2<f64>) =
3530            s_dense.eigh(faer::Side::Lower).unwrap();
3531        let tol = 1e-12;
3532        let ref_logdet: f64 = evals_dense
3533            .iter()
3534            .filter(|&&v: &&f64| v > tol)
3535            .map(|&v: &f64| v.ln())
3536            .sum();
3537
3538        // Kronecker reparameterization engine.
3539        let marginal_designs = vec![
3540            Array2::<f64>::eye(q1), // dummy designs
3541            Array2::<f64>::eye(q2),
3542        ];
3543        let marginal_penalties = vec![s1, s2];
3544        let kron_result = super::kronecker_reparameterization_engine(
3545            &marginal_designs,
3546            &marginal_penalties,
3547            &[q1, q2],
3548            &lambdas,
3549            false,
3550            None,
3551        )
3552        .unwrap();
3553
3554        let diff = (kron_result.log_det - ref_logdet).abs();
3555        assert!(
3556            diff < 1e-8,
3557            "Kronecker logdet {:.10} vs dense {:.10}, diff={:.3e}",
3558            kron_result.log_det,
3559            ref_logdet,
3560            diff,
3561        );
3562
3563        // Check derivatives via central FD in rho-space (rho = log lambda).
3564        let rhos: Vec<f64> = lambdas.iter().map(|&l| l.ln()).collect();
3565        let eps = 1e-5;
3566        for k in 0..2 {
3567            let mut rho_plus = rhos.clone();
3568            rho_plus[k] += eps;
3569            let mut rho_minus = rhos.clone();
3570            rho_minus[k] -= eps;
3571            let lam_plus: Vec<f64> = rho_plus.iter().map(|&r| r.exp()).collect();
3572            let lam_minus: Vec<f64> = rho_minus.iter().map(|&r| r.exp()).collect();
3573            let result_plus = super::kronecker_reparameterization_engine(
3574                &marginal_designs,
3575                &marginal_penalties,
3576                &[q1, q2],
3577                &lam_plus,
3578                false,
3579                None,
3580            )
3581            .unwrap();
3582            let result_minus = super::kronecker_reparameterization_engine(
3583                &marginal_designs,
3584                &marginal_penalties,
3585                &[q1, q2],
3586                &lam_minus,
3587                false,
3588                None,
3589            )
3590            .unwrap();
3591            let fd_deriv = (result_plus.log_det - result_minus.log_det) / (2.0 * eps);
3592            let analytic_deriv = kron_result.det1[k];
3593            let rel_err = if analytic_deriv.abs() > 1e-10 {
3594                (fd_deriv - analytic_deriv).abs() / analytic_deriv.abs()
3595            } else {
3596                (fd_deriv - analytic_deriv).abs()
3597            };
3598            assert!(
3599                rel_err < 1e-4,
3600                "det1[{k}] mismatch: analytic={:.8}, fd={:.8}, rel_err={:.3e}",
3601                analytic_deriv,
3602                fd_deriv,
3603                rel_err,
3604            );
3605        }
3606    }
3607
3608    #[test]
3609    fn classify_strict_rejects_nan_eigenvalue() {
3610        let mut eigs = [1.0, f64::NAN, 0.5];
3611        match classify_eigenvalues_strict(&mut eigs, "test_nan") {
3612            Err(EstimationError::PenaltySpectrumNonFinite {
3613                context,
3614                index,
3615                value,
3616            }) => {
3617                assert_eq!(context, "test_nan");
3618                assert_eq!(index, 1);
3619                assert!(value.is_nan());
3620            }
3621            other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3622        }
3623    }
3624
3625    #[test]
3626    fn classify_strict_rejects_inf_eigenvalue() {
3627        let mut eigs = [1.0, 0.5, f64::INFINITY];
3628        match classify_eigenvalues_strict(&mut eigs, "test_inf") {
3629            Err(EstimationError::PenaltySpectrumNonFinite { index, value, .. }) => {
3630                assert_eq!(index, 2);
3631                assert!(value.is_infinite());
3632            }
3633            other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3634        }
3635    }
3636
3637    #[test]
3638    fn classify_strict_rejects_materially_indefinite() {
3639        // -1e-2 with scale ~1.0 is well above any reasonable roundoff tolerance.
3640        let mut eigs = [1.0, -1e-2, 0.5];
3641        match classify_eigenvalues_strict(&mut eigs, "test_indef") {
3642            Err(EstimationError::PenaltySpectrumIndefinite {
3643                context,
3644                index,
3645                value,
3646                ..
3647            }) => {
3648                assert_eq!(context, "test_indef");
3649                assert_eq!(index, 1);
3650                assert!((value + 1e-2).abs() <= 1e-15);
3651            }
3652            other => panic!("expected PenaltySpectrumIndefinite, got {:?}", other),
3653        }
3654    }
3655
3656    #[test]
3657    fn classify_strict_accepts_roundoff_negative() {
3658        // -1e-16 * scale is well within tol = 64 * eps * p * scale.
3659        let scale = 1.0_f64;
3660        let roundoff = -1e-16 * scale;
3661        let mut eigs = [scale, 0.5 * scale, roundoff, 0.25 * scale];
3662        classify_eigenvalues_strict(&mut eigs, "test_roundoff").expect("roundoff must classify");
3663        // The roundoff eigenvalue is snapped to exact zero.
3664        assert_eq!(eigs[2], 0.0);
3665        // Strictly positive entries must be preserved.
3666        assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3667    }
3668
3669    #[test]
3670    fn classify_strict_accepts_extreme_lambda_assembly_noise_1619() {
3671        // #1619: high-rank thin-plate / Duchon penalties (p≈200) assembled and
3672        // reparameterized at extreme λ produce float-noise negative eigenvalues at
3673        // ~1e-11 relative to the spectrum scale (~1e13 there). The bare machine-ε
3674        // floor (~12×ε relative ≈ 1e-12) rejected these as "indefinite" and
3675        // spuriously failed the inner P-IRLS solve. They are PSD to numerical
3676        // precision and must be snapped to zero, not rejected.
3677        let scale = 8.509e12_f64;
3678        // Worst (eig, scale) pair observed in the issue: eig/scale ≈ -7.7e-11.
3679        let noise = -6.546e2_f64;
3680        assert!(
3681            (noise.abs() / scale) < 1.0e-10,
3682            "fixture must reproduce the ~1e-11-relative noise from #1619"
3683        );
3684        let mut eigs = vec![scale, 0.5 * scale, noise, 0.1 * scale];
3685        classify_eigenvalues_strict(&mut eigs, "range penalty block")
3686            .expect("a ~1e-11-relative roundoff-negative eigenvalue must be accepted (#1619)");
3687        // The roundoff-negative eigenvalue is snapped to exact zero.
3688        assert_eq!(eigs[2], 0.0);
3689        // Strictly positive entries are preserved.
3690        assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3691    }
3692
3693    #[test]
3694    fn classify_strict_snaps_subtol_positive_to_zero() {
3695        // Positive eigenvalues below the tolerance are also snapped to exact 0
3696        // so downstream rank counts and pseudo-logdets are deterministic.
3697        let scale = 10.0_f64;
3698        let subtol = 1e-15 * scale;
3699        let mut eigs = [scale, subtol];
3700        classify_eigenvalues_strict(&mut eigs, "test_sub_pos").expect("sub-tol positive ok");
3701        assert_eq!(eigs[1], 0.0);
3702    }
3703
3704    /// Build a `CanonicalPenalty` directly from a symmetric `local` matrix.
3705    /// Bypasses root extraction — the redundancy diagnostic only reads `local`
3706    /// and `col_range`, so the rest is filler.
3707    fn canonical_from_local(
3708        local: Array2<f64>,
3709        col_range: std::ops::Range<usize>,
3710        total_dim: usize,
3711    ) -> CanonicalPenalty {
3712        let block_dim = local.nrows();
3713        // A trivially valid root: zero rank. The diagnostic doesn't read root.
3714        let root = Array2::<f64>::zeros((0, block_dim));
3715        CanonicalPenalty {
3716            root,
3717            col_range,
3718            total_dim,
3719            nullity: 0,
3720            local,
3721            prior_mean: Array1::zeros(block_dim),
3722            positive_eigenvalues: Vec::new(),
3723            op: None,
3724        }
3725    }
3726
3727    #[test]
3728    fn report_penalty_pair_redundancy_detects_identical_pair() {
3729        // Penalty 0: a "generic" SPD matrix on cols 0..3.
3730        let s0 = ndarray::array![[2.0, 0.5, 0.0], [0.5, 1.0, 0.25], [0.0, 0.25, 1.5],];
3731        // Penalties 1 and 2: identical block-local penalty on the SAME col_range.
3732        // This is the Z₂-symmetric saddle scenario.
3733        let s_shared = ndarray::array![[1.0, -0.5, 0.0], [-0.5, 2.0, -0.5], [0.0, -0.5, 1.0],];
3734
3735        let bundle = vec![
3736            canonical_from_local(s0, 0..3, 3),
3737            canonical_from_local(s_shared.clone(), 0..3, 3),
3738            canonical_from_local(s_shared, 0..3, 3),
3739        ];
3740
3741        let redundant = report_penalty_pair_redundancy(&bundle);
3742
3743        // Exactly one redundant pair: (1, 2). Pairs (0, 1) and (0, 2) involve
3744        // distinct matrices and must NOT be flagged.
3745        assert_eq!(
3746            redundant.len(),
3747            1,
3748            "expected exactly one redundant pair, got {:?}",
3749            redundant
3750        );
3751        let (i, j, cos) = redundant[0];
3752        assert_eq!((i, j), (1, 2));
3753        assert!(
3754            cos > 1.0 - 1e-12,
3755            "cosine for identical penalties should be ~1.0, got {cos}"
3756        );
3757    }
3758
3759    #[test]
3760    fn report_penalty_pair_redundancy_skips_different_col_ranges() {
3761        // Two identical local matrices but on disjoint col_ranges. The
3762        // function must NOT flag them — they live in different parameter
3763        // subspaces by construction.
3764        let s = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
3765        let bundle = vec![
3766            canonical_from_local(s.clone(), 0..2, 4),
3767            canonical_from_local(s, 2..4, 4),
3768        ];
3769        let redundant = report_penalty_pair_redundancy(&bundle);
3770        assert!(
3771            redundant.is_empty(),
3772            "different col_ranges must not be flagged"
3773        );
3774    }
3775}