Skip to main content

gam_terms/
construction.rs

1use crate::basis::analyze_penalty_block;
2use crate::EstimationError;
3use crate::smooth::PenaltyStructureHint;
4use faer::linalg::matmul::matmul;
5use faer::{Accum, Mat, MatRef, Par, Side};
6use gam_linalg::faer_ndarray::{FaerEigh, FaerLinalgError, FaerSvd};
7use gam_linalg::matrix::symmetrize_in_place;
8use gam_linalg::utils::KahanSum;
9use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut2, Axis, s};
10use rayon::iter::{
11    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
12};
13use std::collections::{BTreeMap, HashSet};
14use std::ops::Range;
15use std::sync::Arc;
16
17#[derive(Clone)]
18pub enum PenaltyRepresentation {
19    Dense(Array2<f64>),
20    Banded {
21        bands: Vec<Array1<f64>>,
22        offsets: Vec<i32>,
23    },
24    Kronecker {
25        /// Full penalty-block Kronecker product `left ⊗ right`: each entry of
26        /// `left` scales an entire copy of `right` in the dense expansion.
27        ///
28        /// This is distinct from chunked kernel design assembly, where center
29        /// rows are kernel-evaluation arguments rather than matrix factors.
30        left: Array2<f64>,
31        right: Array2<f64>,
32    },
33}
34
35impl PenaltyRepresentation {
36    /// Side length of the square penalty block this representation expands to.
37    pub fn block_dimension(&self) -> usize {
38        match self {
39            PenaltyRepresentation::Dense(matrix) => matrix.nrows(),
40            PenaltyRepresentation::Banded { bands, offsets } => {
41                let mut dim = 0usize;
42                for (band, &offset) in bands.iter().zip(offsets.iter()) {
43                    let len = band.len();
44                    let extent = if offset >= 0 {
45                        len + offset as usize
46                    } else {
47                        len + (-offset) as usize
48                    };
49                    dim = dim.max(extent);
50                }
51                dim
52            }
53            PenaltyRepresentation::Kronecker { left, right } => left.nrows() * right.nrows(),
54        }
55    }
56
57    /// Materialize this representation (Dense / Banded / Kronecker) into a
58    /// single dense symmetric penalty block.
59    pub fn to_block_dense(&self) -> Array2<f64> {
60        match self {
61            PenaltyRepresentation::Dense(matrix) => matrix.clone(),
62            PenaltyRepresentation::Banded { bands, offsets } => {
63                let dim = self.block_dimension();
64                let mut dense = Array2::zeros((dim, dim));
65                let positive_offsets: HashSet<usize> = offsets
66                    .iter()
67                    .filter_map(|&off| (off >= 0).then_some(off as usize))
68                    .collect();
69                for (band, &offset) in bands.iter().zip(offsets.iter()) {
70                    let off = offset.unsigned_abs() as usize;
71                    if offset < 0 && positive_offsets.contains(&off) {
72                        continue;
73                    }
74                    for (idx, &value) in band.iter().enumerate() {
75                        let (i, j) = if offset >= 0 {
76                            (idx, idx + off)
77                        } else {
78                            (idx + off, idx)
79                        };
80                        if i >= dim || j >= dim {
81                            continue;
82                        }
83                        dense[[i, j]] = value;
84                        dense[[j, i]] = value;
85                    }
86                }
87                dense
88            }
89            PenaltyRepresentation::Kronecker { left, right } => {
90                let (lrows, l_cols) = left.dim();
91                let (rrows, r_cols) = right.dim();
92                let mut result = Array2::zeros((lrows * rrows, l_cols * r_cols));
93                for i in 0..lrows {
94                    for j in 0..l_cols {
95                        let scale = left[(i, j)];
96                        if scale == 0.0 {
97                            continue;
98                        }
99                        let mut block = result.slice_mut(s![
100                            i * rrows..(i + 1) * rrows,
101                            j * r_cols..(j + 1) * r_cols
102                        ]);
103                        block.assign(&(right * scale));
104                    }
105                }
106                result
107            }
108        }
109    }
110}
111
112#[derive(Clone)]
113pub struct PenaltyMatrix {
114    pub col_range: Range<usize>,
115    pub representation: PenaltyRepresentation,
116}
117
118impl PenaltyMatrix {
119    fn accumulate_into(&self, mut dest: ArrayViewMut2<'_, f64>, weight: f64) {
120        if weight == 0.0 {
121            return;
122        }
123        match &self.representation {
124            PenaltyRepresentation::Dense(block) => {
125                dest.scaled_add(weight, block);
126            }
127            PenaltyRepresentation::Banded { bands, offsets } => {
128                let positive_offsets: HashSet<usize> = offsets
129                    .iter()
130                    .filter_map(|&off| (off >= 0).then_some(off as usize))
131                    .collect();
132                for (band, &offset) in bands.iter().zip(offsets.iter()) {
133                    let off = offset.unsigned_abs() as usize;
134                    if offset < 0 && positive_offsets.contains(&off) {
135                        continue;
136                    }
137                    for (idx, &value) in band.iter().enumerate() {
138                        let (i, j) = if offset >= 0 {
139                            (idx, idx + off)
140                        } else {
141                            (idx + off, idx)
142                        };
143                        let Some(entry_ij) = dest.get_mut((i, j)) else {
144                            continue;
145                        };
146                        *entry_ij += weight * value;
147                        if i != j
148                            && let Some(entry_ji) = dest.get_mut((j, i))
149                        {
150                            *entry_ji += weight * value;
151                        }
152                    }
153                }
154            }
155            PenaltyRepresentation::Kronecker { left, right } => {
156                let (lrows, l_cols) = left.dim();
157                let (rrows, r_cols) = right.dim();
158                for i in 0..lrows {
159                    for j in 0..l_cols {
160                        let scale = left[(i, j)] * weight;
161                        if scale == 0.0 {
162                            continue;
163                        }
164                        let mut block = dest.slice_mut(s![
165                            i * rrows..(i + 1) * rrows,
166                            j * r_cols..(j + 1) * r_cols
167                        ]);
168                        block.scaled_add(scale, right);
169                    }
170                }
171            }
172        }
173    }
174
175    pub fn to_dense(&self, total_dim: usize) -> Array2<f64> {
176        let mut dense = Array2::<f64>::zeros((total_dim, total_dim));
177        self.accumulate_into(
178            dense.slice_mut(s![self.col_range.clone(), self.col_range.clone()]),
179            1.0,
180        );
181        dense
182    }
183}
184
185pub(crate) fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
186    let (rows, cols) = array.dim();
187    Mat::from_fn(rows, cols, |i, j| array[[i, j]])
188}
189
190pub(crate) fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
191    let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
192    for i in 0..mat.nrows() {
193        for j in 0..mat.ncols() {
194            out[[i, j]] = mat[(i, j)];
195        }
196    }
197    out
198}
199
200fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
201    let (rows, cols) = matrix.shape();
202    let mut maxval = 0.0_f64;
203    for i in 0..rows {
204        for j in 0..cols {
205            let val = matrix[(i, j)];
206            if val.is_finite() {
207                maxval = maxval.max(val.abs());
208            }
209        }
210    }
211    maxval
212}
213
214fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
215    let (rows, cols) = matrix.as_ref().shape();
216    assert_eq!(rows, cols, "Matrix must be square for sanitization");
217
218    let mut sanitized = matrix.clone();
219
220    for i in 0..rows {
221        let diag = sanitized[(i, i)];
222        if !diag.is_finite() {
223            sanitized[(i, i)] = 0.0;
224        }
225        for j in (i + 1)..cols {
226            let mut upper = sanitized[(i, j)];
227            let mut lower = sanitized[(j, i)];
228            if !upper.is_finite() {
229                upper = 0.0;
230            }
231            if !lower.is_finite() {
232                lower = 0.0;
233            }
234            let avg = 0.5 * (upper + lower);
235            sanitized[(i, j)] = avg;
236            sanitized[(j, i)] = avg;
237        }
238    }
239
240    let scale = mat_max_abs_element(sanitized.as_ref());
241    let tiny = (scale * 1e-14).max(1e-30);
242    for i in 0..rows {
243        for j in 0..cols {
244            let val = sanitized[(i, j)];
245            if !val.is_finite() {
246                sanitized[(i, j)] = 0.0;
247            } else if val.abs() < tiny {
248                sanitized[(i, j)] = 0.0;
249            }
250        }
251    }
252
253    sanitized
254}
255
256fn penalty_from_root_faer(root: &Mat<f64>) -> Mat<f64> {
257    let cols = root.ncols();
258    let mut full = Mat::<f64>::zeros(cols, cols);
259    let root_ref = root.as_ref();
260    let root_t = root_ref.transpose();
261    matmul(
262        full.as_mut(),
263        Accum::Replace,
264        root_t,
265        root_ref,
266        1.0,
267        Par::Seq,
268    );
269    sanitize_symmetric_faer(&full)
270}
271
272fn symmetrize_faer_matrix_in_place(matrix: &mut Mat<f64>) {
273    let n = matrix.nrows().min(matrix.ncols());
274    for i in 0..n {
275        for j in 0..i {
276            let avg = 0.5 * (matrix[(i, j)] + matrix[(j, i)]);
277            matrix[(i, j)] = avg;
278            matrix[(j, i)] = avg;
279        }
280    }
281}
282
283fn orthogonal_similarity_transform_faer(
284    matrix: &Mat<f64>,
285    block_dim: usize,
286    orthogonal: &Mat<f64>,
287) -> Mat<f64> {
288    let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
289    let cols = orthogonal.ncols();
290    let mut temp = Mat::<f64>::zeros(block_dim, cols);
291    matmul(
292        temp.as_mut(),
293        Accum::Replace,
294        matrix_block,
295        orthogonal.as_ref(),
296        1.0,
297        Par::Seq,
298    );
299    let mut rotated = Mat::<f64>::zeros(cols, cols);
300    matmul(
301        rotated.as_mut(),
302        Accum::Replace,
303        orthogonal.transpose(),
304        temp.as_ref(),
305        1.0,
306        Par::Seq,
307    );
308    symmetrize_faer_matrix_in_place(&mut rotated);
309    rotated
310}
311
312fn trace_penalty_in_orthogonal_basis(
313    matrix: &Mat<f64>,
314    block_dim: usize,
315    orthogonal: &Mat<f64>,
316    rotated_eigenvalues: &[f64],
317    delta: f64,
318) -> f64 {
319    let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
320    let cols = orthogonal.ncols();
321    assert!(rotated_eigenvalues.len() >= cols);
322    let mut projected = Mat::<f64>::zeros(block_dim, cols);
323    matmul(
324        projected.as_mut(),
325        Accum::Replace,
326        matrix_block,
327        orthogonal.as_ref(),
328        1.0,
329        Par::Seq,
330    );
331    let mut trace = KahanSum::default();
332    for l in 0..cols {
333        let mut diag_ll = KahanSum::default();
334        for i in 0..block_dim {
335            diag_ll.add(orthogonal[(i, l)] * projected[(i, l)]);
336        }
337        trace.add(diag_ll.sum() / (rotated_eigenvalues[l] + delta));
338    }
339    trace.sum()
340}
341
342pub fn trace_reduced_penalty_covariance(
343    reduced_penalty: &Array2<f64>,
344    covariance_basis: &Array2<f64>,
345) -> f64 {
346    assert_eq!(
347        reduced_penalty.dim(),
348        covariance_basis.dim(),
349        "trace_reduced_penalty_covariance dimension mismatch"
350    );
351    let r = covariance_basis.nrows();
352    let mut trace = KahanSum::default();
353    for i in 0..r {
354        for j in 0..r {
355            trace.add(covariance_basis[[i, j]] * reduced_penalty[[j, i]]);
356        }
357    }
358    trace.sum()
359}
360
361pub fn trace_penalty_covariance_in_orthogonal_basis(
362    matrix: &Array2<f64>,
363    orthogonal: &Array2<f64>,
364    covariance_basis: &Array2<f64>,
365) -> f64 {
366    let reduced = gam_linalg::faer_ndarray::fast_ab(
367        &gam_linalg::faer_ndarray::fast_atb(orthogonal, matrix),
368        orthogonal,
369    );
370    trace_reduced_penalty_covariance(&reduced, covariance_basis)
371}
372
373/// Strict spectral classifier used as a final guard on penalty eigendecompositions.
374///
375/// Penalty matrices fed to the GAM solver are required to be PSD by construction.
376/// This routine snaps roundoff-zero eigenvalues to exact zero, accepts strictly
377/// positive eigenvalues, and rejects materially-indefinite or non-finite spectra
378/// with a hard error rather than silently rewriting them. The previous behaviour
379/// (mass-zeroing negative or non-finite eigenvalues) hid construction bugs and
380/// changed the optimisation objective downstream.
381///
382/// The acceptance tolerance is the larger of a machine-ε floor
383/// (`C_EPS_P_FACTOR * eps_machine * p * scale`, with `C_EPS_P_FACTOR = 64`
384/// absorbing the rounding accumulated in a symmetric eigendecomposition of a
385/// moderate-dimension matrix) and a relative "numerically PSD" floor
386/// `REL_PSD_FLOOR * scale`. The latter dominates for large high-rank penalties
387/// assembled / reparameterized at extreme λ, where roundoff produces
388/// ~1e-11-relative negative eigenvalues that are PSD to any reasonable precision
389/// yet exceeded the bare ~12×ε machine floor and spuriously failed the inner
390/// P-IRLS solve (#1619). Genuine indefiniteness is O(1) relative and is still
391/// rejected far above either floor.
392fn classify_eigenvalues_strict(
393    eigenvalues: &mut [f64],
394    context: &str,
395) -> Result<(), EstimationError> {
396    const C_EPS_P_FACTOR: f64 = 64.0;
397    // Standard "numerically PSD" relative threshold (~sqrt(machine ε)). A negative
398    // eigenvalue smaller than this fraction of the spectrum scale is roundoff, not
399    // genuine negative curvature, and is snapped to zero rather than rejected.
400    const REL_PSD_FLOOR: f64 = 1.0e-8;
401    let p = eigenvalues.len();
402
403    let mut scale = 0.0_f64;
404    for (idx, &val) in eigenvalues.iter().enumerate() {
405        if !val.is_finite() {
406            return Err(EstimationError::PenaltySpectrumNonFinite {
407                context: context.to_string(),
408                index: idx,
409                value: val,
410            });
411        }
412        scale = scale.max(val.abs());
413    }
414
415    // p * eps captures the rounding floor of a symmetric eigendecomposition of a
416    // p-dimensional matrix; multiplying by `scale` lifts the floor to the actual
417    // magnitude of the spectrum. For large high-rank penalties assembled at
418    // extreme λ this machine floor (~12×ε relative) is tighter than the roundoff
419    // actually produced, so we take the larger of it and a relative numerically-PSD
420    // floor `REL_PSD_FLOOR * scale` (#1619).
421    let machine_floor = C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale;
422    let tolerance = machine_floor
423        .max(REL_PSD_FLOOR * scale)
424        .max(f64::MIN_POSITIVE);
425
426    for (idx, val) in eigenvalues.iter_mut().enumerate() {
427        if val.abs() <= tolerance {
428            *val = 0.0;
429        } else if *val < 0.0 {
430            return Err(EstimationError::PenaltySpectrumIndefinite {
431                context: context.to_string(),
432                index: idx,
433                value: *val,
434                tolerance,
435                scale,
436            });
437        }
438    }
439    Ok(())
440}
441
442fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
443    matrix: &M,
444    context: &str,
445    validate_input: Validate,
446    sanitize: Sanitize,
447    mut eig_call: EigCall,
448    map_error: MapErr,
449) -> Result<(Vec<f64>, V), EstimationError>
450where
451    Validate: Fn(&M, &str) -> Result<(), EstimationError>,
452    Sanitize: Fn(&M) -> M,
453    EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
454    MapErr: Fn(E, &str) -> EstimationError,
455{
456    validate_input(matrix, context)?;
457
458    // The sanitize step only enforces exact symmetry by averaging M and M^T and
459    // zeros sub-eps noise; it never adds a diagonal ridge. Adding ridge changes
460    // the matrix being decomposed, which silently changes the optimisation
461    // objective downstream. If eigh genuinely fails on a finite symmetric input,
462    // surface the error instead of mutating the spectrum.
463    let candidate = sanitize(matrix);
464    match eig_call(&candidate) {
465        Ok((mut eigenvalues, eigenvectors)) => {
466            classify_eigenvalues_strict(&mut eigenvalues, context)?;
467            Ok((eigenvalues, eigenvectors))
468        }
469        Err(err) => Err(map_error(err, context)),
470    }
471}
472
473pub(crate) fn robust_eigh_faer(
474    matrix: &Mat<f64>,
475    side: Side,
476    context: &str,
477) -> Result<(Vec<f64>, Mat<f64>), EstimationError> {
478    robust_eighwith_policy(
479        matrix,
480        context,
481        |mat, ctx| {
482            let (rows, cols) = mat.as_ref().shape();
483            for i in 0..rows {
484                for j in 0..cols {
485                    let val = mat[(i, j)];
486                    if !val.is_finite() {
487                        let max_abs = mat_max_abs_element(mat.as_ref());
488                        crate::bail_invalid_estim!(
489                            "{} contains non-finite entries (max finite magnitude {:.3e})",
490                            ctx,
491                            max_abs
492                        );
493                    }
494                }
495            }
496            Ok(())
497        },
498        sanitize_symmetric_faer,
499        |candidate| {
500            let eig = candidate.as_ref().self_adjoint_eigen(side)?;
501            let diag = eig.S();
502            let mut eigenvalues = Vec::with_capacity(diag.dim());
503            for idx in 0..diag.dim() {
504                eigenvalues.push(diag[idx]);
505            }
506
507            let vectors_ref = eig.U();
508            let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
509            for i in 0..vectors_ref.nrows() {
510                for j in 0..vectors_ref.ncols() {
511                    eigenvectors[(i, j)] = vectors_ref[(i, j)];
512                }
513            }
514            Ok((eigenvalues, eigenvectors))
515        },
516        |err, _ctx| {
517            EstimationError::EigendecompositionFailed(FaerLinalgError::SelfAdjointEigen(err))
518        },
519    )
520}
521
522fn robust_eigh(
523    matrix: &Array2<f64>,
524    side: Side,
525    context: &str,
526) -> Result<(Array1<f64>, Array2<f64>), EstimationError> {
527    let matrix_faer = array_to_faer(matrix);
528    let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
529    Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
530}
531
532pub(crate) fn kronecker_marginal_eigensystems(
533    marginal_penalties: &[Array2<f64>],
534    context: &str,
535) -> Result<Vec<(Array1<f64>, Array2<f64>)>, EstimationError> {
536    let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
537    for (k, penalty) in marginal_penalties.iter().enumerate() {
538        eigensystems.push(robust_eigh(
539            penalty,
540            Side::Lower,
541            &format!("{context} marginal {k}"),
542        )?);
543    }
544    Ok(eigensystems)
545}
546
547#[derive(Debug, Clone, Copy)]
548struct SubspaceLeakageMetrics {
549    max_abs_sq: f64,
550    max_rel_sq: f64,
551    worst_penalty: usize,
552    max_cross_gram_abs: f64,
553}
554
555fn assess_subspace_leakage(
556    qs: &Mat<f64>,
557    rs_transformed: &[Mat<f64>],
558    structural_rank: usize,
559    p: usize,
560) -> SubspaceLeakageMetrics {
561    let mut max_abs_sq = 0.0_f64;
562    let mut max_rel_sq = 0.0_f64;
563    let mut worst_penalty = 0usize;
564
565    for (k, rs) in rs_transformed.iter().enumerate() {
566        let rows = rs.nrows();
567        let cols = rs.ncols().min(p);
568        let null_start = structural_rank.min(cols);
569        let mut abs_sq = 0.0_f64;
570        let mut total_sq = 0.0_f64;
571        for i in 0..rows {
572            for j in 0..cols {
573                let v = rs[(i, j)];
574                let vv = v * v;
575                total_sq += vv;
576                if j >= null_start {
577                    abs_sq += vv;
578                }
579            }
580        }
581        let rel_sq = if total_sq > 0.0 {
582            abs_sq / total_sq
583        } else {
584            0.0
585        };
586        if rel_sq > max_rel_sq {
587            max_rel_sq = rel_sq;
588            worst_penalty = k;
589        }
590        max_abs_sq = max_abs_sq.max(abs_sq);
591    }
592
593    let mut max_cross_gram_abs = 0.0_f64;
594    let null_count = p.saturating_sub(structural_rank);
595    if structural_rank > 0 && null_count > 0 {
596        for i in 0..structural_rank {
597            for j in 0..null_count {
598                let qn_col = structural_rank + j;
599                let mut dot = 0.0_f64;
600                for r in 0..p {
601                    dot += qs[(r, i)] * qs[(r, qn_col)];
602                }
603                max_cross_gram_abs = max_cross_gram_abs.max(dot.abs());
604            }
605        }
606    }
607
608    SubspaceLeakageMetrics {
609        max_abs_sq,
610        max_rel_sq,
611        worst_penalty,
612        max_cross_gram_abs,
613    }
614}
615
616fn compose_qs_from_split(q_pen: &Mat<f64>, q_null: &Mat<f64>, p: usize) -> Mat<f64> {
617    let rank = q_pen.ncols();
618    let null_count = q_null.ncols();
619    let mut qs = Mat::<f64>::zeros(p, p);
620    for i in 0..p {
621        for j in 0..rank {
622            qs[(i, j)] = q_pen[(i, j)];
623        }
624        for j in 0..null_count {
625            qs[(i, rank + j)] = q_null[(i, j)];
626        }
627    }
628    qs
629}
630
631/// Computes the Kronecker product A ⊗ B for penalty matrix construction.
632/// This is used to create tensor product penalties that enforce smoothness
633/// in multiple dimensions for interaction terms.
634pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
635    let (arows, a_cols) = a.dim();
636    let (brows, b_cols) = b.dim();
637    if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
638        return Array2::zeros((arows * brows, a_cols * b_cols));
639    }
640    let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
641
642    result
643        .axis_chunks_iter_mut(Axis(0), brows)
644        .into_par_iter()
645        .enumerate()
646        .for_each(|(i, mut row_block)| {
647            let arow = a.row(i);
648            let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
649            for (j, mut block) in col_chunks.into_iter().enumerate() {
650                let aval = arow[j];
651                if aval == 0.0 {
652                    continue;
653                }
654                for (dest, &src) in block.iter_mut().zip(b.iter()) {
655                    *dest = aval * src;
656                }
657            }
658        });
659
660    result
661}
662
663/// Result of the stable reparameterization algorithm from Wood (2011) Appendix B
664#[derive(Clone)]
665pub struct ReparamResult {
666    /// Penalty matrix in TRANSFORMED coefficient coordinates.
667    ///
668    /// This must be compatible with `beta_transformed` and `X_transformed = X * Qs`.
669    pub s_transformed: Array2<f64>,
670    /// Log-determinant of the penalty matrix (stable computation)
671    pub log_det: f64,
672    /// First derivatives of log-determinant w.r.t. log-smoothing parameters
673    pub det1: Array1<f64>,
674    /// Orthogonal transformation matrix Qs
675    pub qs: Array2<f64>,
676    /// Canonical penalties in the TRANSFORMED coordinate frame.
677    /// The single source of truth for penalty roots in the transformed frame.
678    /// Downstream consumers use these for block-local `PenaltyCoordinate`
679    /// construction, TK correction, and ext-coord paths.
680    pub canonical_transformed: Vec<CanonicalPenalty>,
681    /// Lambda-dependent penalty square root in TRANSFORMED coordinates (rank x p matrix).
682    /// This is used for applying the actual penalty in the least squares solve.
683    pub e_transformed: Array2<f64>,
684    /// Truncated eigenvectors (p × m where m = p - structural_rank).
685    ///
686    /// Coordinate frame note:
687    /// - This matrix is stored in the TRANSFORMED coefficient frame (post-`Qs`),
688    ///   i.e. it is compatible with `canonical_transformed`, `beta_transformed`,
689    ///   and transformed Hessians without additional coordinate mapping.
690    ///
691    /// These vectors span the structural null space used by positive-part
692    /// log-determinant conventions.
693    pub u_truncated: Array2<f64>,
694    /// The rho-independent shrinkage ridge magnitude that was added to each
695    /// eigenvalue of the penalized block. Zero means no shrinkage was applied.
696    pub penalty_shrinkage_ridge: f64,
697}
698
699// ---------------------------------------------------------------------------
700// Kronecker factor decomposition primitives
701// ---------------------------------------------------------------------------
702
703/// Per-factor decomposition result for Kronecker penalties.
704struct KroneckerFactorDecomp {
705    root: Array2<f64>,              // rank_j × q_j
706    positive_eigenvalues: Vec<f64>, // length = rank_j
707    rank: usize,
708    dim: usize,
709}
710
711/// Eigendecompose each Kronecker factor separately at O(Σ q_j³).
712/// Returns per-factor decompositions, or `None` if any factor is zero.
713fn decompose_kronecker_factors(
714    factors: &[Array2<f64>],
715    context: &str,
716) -> Result<Option<Vec<KroneckerFactorDecomp>>, EstimationError> {
717    let mut decomps = Vec::with_capacity(factors.len());
718    for (j, factor) in factors.iter().enumerate() {
719        let q_j = factor.nrows();
720        if q_j != factor.ncols() {
721            crate::bail_invalid_estim!(
722                "{context}: Kronecker factor {j} must be square, got {}x{}",
723                factor.nrows(),
724                factor.ncols()
725            );
726        }
727        let is_identity = {
728            let mut is_id = true;
729            'outer: for r in 0..q_j {
730                for c in 0..q_j {
731                    let expected = if r == c { 1.0 } else { 0.0 };
732                    if (factor[[r, c]] - expected).abs() > 1e-12 {
733                        is_id = false;
734                        break 'outer;
735                    }
736                }
737            }
738            is_id
739        };
740        if is_identity {
741            decomps.push(KroneckerFactorDecomp {
742                root: Array2::eye(q_j),
743                positive_eigenvalues: vec![1.0; q_j],
744                rank: q_j,
745                dim: q_j,
746            });
747            continue;
748        }
749        let analysis = analyze_penalty_block(factor).map_err(|err| {
750            EstimationError::InvalidInput(format!(
751                "{context}: Kronecker factor {j} eigendecomp failed: {err}"
752            ))
753        })?;
754        if analysis.rank == 0 {
755            return Ok(None);
756        }
757        // Build the factor root from ONLY the range (positive-curvature)
758        // directions via the canonical classifier — never the null or
759        // negative-curvature directions (#1425).
760        let factor_classes =
761            crate::basis::SpectralClassification::new(&analysis.eigenvalues, analysis.tol);
762        let mut root_j = Array2::zeros((analysis.rank, q_j));
763        let mut pos_eigs = Vec::with_capacity(analysis.rank);
764        for (row_idx, &i) in factor_classes.range_idx.iter().enumerate() {
765            let eigenval = analysis.eigenvalues[i];
766            let sqrt_ev = eigenval.sqrt();
767            let evec = analysis.eigenvectors.column(i);
768            for (col, &v) in evec.iter().enumerate() {
769                root_j[[row_idx, col]] = sqrt_ev * v;
770            }
771            pos_eigs.push(eigenval);
772        }
773        decomps.push(KroneckerFactorDecomp {
774            root: root_j,
775            positive_eigenvalues: pos_eigs,
776            rank: analysis.rank,
777            dim: q_j,
778        });
779    }
780    Ok(Some(decomps))
781}
782
783/// Build the block-local Kronecker root from pre-computed factor decompositions.
784fn assemble_kronecker_root_local(decomps: &[KroneckerFactorDecomp]) -> Array2<f64> {
785    let mut kron_root = decomps[0].root.clone();
786    for fr in &decomps[1..] {
787        let (r1, c1) = kron_root.dim();
788        let (r2, c2) = (fr.rank, fr.dim);
789        let mut new_root = Array2::zeros((r1 * r2, c1 * c2));
790        for i1 in 0..r1 {
791            for i2 in 0..r2 {
792                for j1 in 0..c1 {
793                    for j2 in 0..c2 {
794                        new_root[[i1 * r2 + i2, j1 * c2 + j2]] =
795                            kron_root[[i1, j1]] * fr.root[[i2, j2]];
796                    }
797                }
798            }
799        }
800        kron_root = new_root;
801    }
802    kron_root
803}
804
805/// Compute eigenvalues of the Kronecker product from per-factor eigenvalues.
806fn kronecker_eigenvalues(decomps: &[KroneckerFactorDecomp], block_dim: usize) -> (Vec<f64>, usize) {
807    let mut kron_eigs = decomps[0].positive_eigenvalues.clone();
808    for fd in &decomps[1..] {
809        let mut new_eigs = Vec::with_capacity(kron_eigs.len() * fd.positive_eigenvalues.len());
810        for &a in &kron_eigs {
811            for &b in &fd.positive_eigenvalues {
812                new_eigs.push(a * b);
813            }
814        }
815        kron_eigs = new_eigs;
816    }
817    let max_ev = kron_eigs.iter().copied().fold(0.0_f64, f64::max);
818    let tol = max_ev * 1e-10 * (block_dim as f64);
819    let positive: Vec<f64> = kron_eigs.into_iter().filter(|&ev| ev > tol).collect();
820    let nullity = block_dim - positive.len();
821    (positive, nullity)
822}
823
824// ---------------------------------------------------------------------------
825// CanonicalPenalty — block-local processed penalty for the solver
826// ---------------------------------------------------------------------------
827
828/// A canonicalized penalty with block-local root, ready for the solver.
829///
830/// Instead of storing a full `p x p` penalty matrix, this stores only the
831/// `rank x block_dim` root and the column range, enabling O(p_k^2) operations
832/// instead of O(p^2).
833#[derive(Clone)]
834pub struct CanonicalPenalty {
835    /// Square root matrix: S_k = root^T * root.
836    /// Shape: `rank x block_dim` for block-local, `rank x p` for dense.
837    pub root: Array2<f64>,
838    /// Column range in the global coefficient vector [start..end).
839    /// For dense penalties this is `0..p`.
840    pub col_range: std::ops::Range<usize>,
841    /// Full parameter dimension p.
842    pub total_dim: usize,
843    /// Structural nullity of the local penalty.
844    pub nullity: usize,
845    /// The symmetrized block-local penalty matrix (block_dim × block_dim).
846    /// Cached at construction time to avoid recomputing root^T * root
847    /// in hot paths (penalty assembly, trace products).
848    pub local: Array2<f64>,
849    /// Block-local prior mean used to center this penalty.
850    pub prior_mean: Array1<f64>,
851    /// Positive eigenvalues of the local penalty matrix (length = rank).
852    /// Cached at construction time for REML logdet block-factored paths.
853    pub positive_eigenvalues: Vec<f64>,
854    /// Optional operator-form handle bit-equivalent to `local`. Propagated
855    /// from `PenaltySpec::Block.op`. Downstream PIRLS and REML exact operator
856    /// algebra route through this for dense-Gram-free matvec when present.
857    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
858}
859
860impl std::fmt::Debug for CanonicalPenalty {
861    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
862        f.debug_struct("CanonicalPenalty")
863            .field(
864                "root",
865                &format_args!("{}×{}", self.root.nrows(), self.root.ncols()),
866            )
867            .field("col_range", &self.col_range)
868            .field("total_dim", &self.total_dim)
869            .field("nullity", &self.nullity)
870            .field(
871                "local",
872                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
873            )
874            .field("prior_mean_len", &self.prior_mean.len())
875            .field("positive_eigenvalues", &self.positive_eigenvalues)
876            .field("op", &self.op.as_ref().map(|o| o.dim()))
877            .finish()
878    }
879}
880
881impl CanonicalPenalty {
882    /// Construct a dense (full-width) canonical penalty from a `rank x p` root.
883    /// Used to wrap reparam-transformed roots for consumers that expect
884    /// `&[CanonicalPenalty]`.
885    pub fn from_dense_root(root: Array2<f64>, p: usize) -> Self {
886        Self::from_dense_root_with_mean(root, p, Array1::zeros(p))
887    }
888
889    pub fn from_dense_root_with_mean(root: Array2<f64>, p: usize, prior_mean: Array1<f64>) -> Self {
890        assert_eq!(prior_mean.len(), p);
891        let local = root.t().dot(&root);
892        let positive_eigenvalues = Vec::new(); // not needed for TK paths
893        Self {
894            root,
895            col_range: 0..p,
896            total_dim: p,
897            nullity: 0,
898            local,
899            prior_mean,
900            positive_eigenvalues,
901            op: None,
902        }
903    }
904
905    /// Embed the block-local root into a full-width `rank × total_dim` matrix.
906    /// For dense penalties (col_range = 0..p), returns the root unchanged.
907    pub fn full_width_root(&self) -> Array2<f64> {
908        if self.col_range.start == 0 && self.col_range.end == self.total_dim {
909            return self.root.clone();
910        }
911        let rank = self.root.nrows();
912        let mut full = Array2::<f64>::zeros((rank, self.total_dim));
913        full.slice_mut(ndarray::s![.., self.col_range.clone()])
914            .assign(&self.root);
915        full
916    }
917
918    /// Numerical rank of this penalty.
919    pub fn rank(&self) -> usize {
920        self.root.nrows()
921    }
922
923    /// Block dimension (number of columns this penalty covers).
924    pub fn block_dim(&self) -> usize {
925        self.col_range.len()
926    }
927
928    /// Whether this penalty is block-local (col_range != 0..total_dim).
929    pub const fn is_block_local(&self) -> bool {
930        self.col_range.start != 0 || self.col_range.end != self.total_dim
931    }
932
933    /// Return a reference to the cached local penalty matrix.
934    /// Shape: `block_dim x block_dim`.
935    pub fn local_ref(&self) -> &Array2<f64> {
936        &self.local
937    }
938
939    /// Return an owned copy of the local penalty matrix.
940    /// Prefer `local_ref()` when a reference suffices.
941    pub fn local_penalty(&self) -> Array2<f64> {
942        self.local.clone()
943    }
944
945    /// Accumulate lambda * S_k into a pre-allocated `p x p` target matrix.
946    /// Only touches the block [col_range × col_range].
947    pub fn accumulate_weighted(&self, target: &mut Array2<f64>, lambda: f64) {
948        if lambda == 0.0 || self.rank() == 0 {
949            return;
950        }
951        let r = &self.col_range;
952        target
953            .slice_mut(s![r.start..r.end, r.start..r.end])
954            .scaled_add(lambda, &self.local);
955    }
956
957    /// Compute `scale * tr(M · S_k)` where M is a `p × p` dense matrix.
958    /// Only reads `M[start..end, start..end]` — O(block_dim²) not O(p²).
959    pub fn trace_product(&self, m: &Array2<f64>, scale: f64) -> f64 {
960        if self.rank() == 0 || scale == 0.0 {
961            return 0.0;
962        }
963        let r = &self.col_range;
964        let m_block = m.slice(s![r.start..r.end, r.start..r.end]);
965        let rm = self.root.dot(&m_block);
966        scale
967            * rm.iter()
968                .zip(self.root.iter())
969                .map(|(&a, &b)| a * b)
970                .sum::<f64>()
971    }
972
973    /// Compute `scale * v^T S_k v` (quadratic form).
974    /// Only reads `v[start..end]` — O(rank × block_dim) not O(rank × p).
975    pub fn quadratic(&self, v: &Array1<f64>, scale: f64) -> f64 {
976        if self.rank() == 0 || scale == 0.0 {
977            return 0.0;
978        }
979        let v_block = v.slice(s![self.col_range.start..self.col_range.end]);
980        let rv = self.root.dot(&v_block);
981        scale * rv.dot(&rv)
982    }
983
984    /// Compute `scale * S_k * prior_mean` embedded into the global basis.
985    pub fn prior_linear_shift(&self, scale: f64) -> Array1<f64> {
986        let mut out = Array1::<f64>::zeros(self.total_dim);
987        if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
988            return out;
989        }
990        let block = self.local.dot(&self.prior_mean) * scale;
991        out.slice_mut(s![self.col_range.start..self.col_range.end])
992            .assign(&block);
993        out
994    }
995
996    /// Compute `scale * prior_mean' S_k prior_mean`.
997    pub fn prior_constant_shift(&self, scale: f64) -> f64 {
998        if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
999            return 0.0;
1000        }
1001        scale * self.prior_mean.dot(&self.local.dot(&self.prior_mean))
1002    }
1003
1004    /// Embed this block's prior mean into the global coefficient basis.
1005    pub fn full_width_prior_mean(&self) -> Array1<f64> {
1006        if self.col_range.start == 0 && self.col_range.end == self.total_dim {
1007            return self.prior_mean.clone();
1008        }
1009        let mut out = Array1::<f64>::zeros(self.total_dim);
1010        out.slice_mut(s![self.col_range.start..self.col_range.end])
1011            .assign(&self.prior_mean);
1012        out
1013    }
1014
1015    /// Convert to a PenaltyCoordinate for the unified REML evaluator.
1016    pub fn to_penalty_coordinate(
1017        &self,
1018    ) -> gam_problem::PenaltyCoordinate {
1019        use gam_problem::PenaltyCoordinate;
1020        if self.is_block_local() {
1021            PenaltyCoordinate::from_block_root_with_mean(
1022                self.root.clone(),
1023                self.col_range.start,
1024                self.col_range.end,
1025                self.total_dim,
1026                self.prior_mean.clone(),
1027            )
1028        } else {
1029            PenaltyCoordinate::from_dense_root_with_mean(self.root.clone(), self.prior_mean.clone())
1030        }
1031    }
1032}
1033
1034/// Detect and report structurally identical (or near-identical) penalty pairs
1035/// in the canonical bundle.
1036///
1037/// Two penalties `S_i`, `S_j` with the same `col_range` are compared by their
1038/// matrix cosine:
1039///
1040///     cos(S_i, S_j) = tr(S_i S_j) / sqrt(tr(S_i^2) * tr(S_j^2))
1041///
1042/// Because `local` is symmetric, `tr(A·B) = sum_{r,c} A[r,c] * B[r,c]` — i.e.,
1043/// the Frobenius inner product. Pairs with different `col_range` cannot be
1044/// functionally identical and are skipped.
1045///
1046/// Logging policy:
1047/// - `cos > 1 - 1e-8` → `log::warn!` with `[PENALTY-REDUNDANCY]`. Every such
1048///   pair is emitted because it represents a structural model error (the LAML
1049///   cost has a Z₂-symmetric saddle that ARC's cubic regularization will
1050///   happily converge to under first-order stationarity).
1051/// - `0.99 < cos ≤ 1 - 1e-8` → `log::info!` with `[PENALTY-SIMILARITY]`.
1052///   At large scale (`k > 64`) only the top-3 highest-cosine such pairs are
1053///   logged to bound log volume.
1054///
1055/// Returns `Vec<(i, j, cos)>` for the **redundant** pairs (cos > 1 - 1e-8),
1056/// primarily to make this function unit-testable without a log capture.
1057///
1058/// Performance: this is O(k² · block_dim²); intended to be called exactly
1059/// once per fit (e.g. from `RemlState::newwith_offset_shared`).
1060pub fn report_penalty_pair_redundancy(canonical: &[CanonicalPenalty]) -> Vec<(usize, usize, f64)> {
1061    const REDUNDANCY_THRESHOLD: f64 = 1.0 - 1e-8;
1062    const SIMILARITY_THRESHOLD: f64 = 0.99;
1063    const LARGE_SCALE_K_THRESHOLD: usize = 64;
1064    const TOP_SIMILARITY_PAIRS: usize = 3;
1065
1066    let k = canonical.len();
1067    let mut redundant: Vec<(usize, usize, f64)> = Vec::new();
1068    let mut similar: Vec<(usize, usize, f64)> = Vec::new();
1069
1070    // Pre-compute tr(S_i^2) = sum of squares of S_i entries (Frobenius norm
1071    // squared). `local` is symmetric, so this equals tr(S_i^T S_i) = tr(S_i^2).
1072    let trace_sq: Vec<f64> = canonical
1073        .iter()
1074        .map(|p| p.local.iter().map(|&v| v * v).sum::<f64>())
1075        .collect();
1076
1077    for i in 0..k {
1078        if trace_sq[i] == 0.0 {
1079            continue;
1080        }
1081        for j in (i + 1)..k {
1082            if trace_sq[j] == 0.0 {
1083                continue;
1084            }
1085            // Different col_range → cannot be functionally identical by
1086            // construction (the block-local matrices live in disjoint or
1087            // mismatched parameter subspaces).
1088            if canonical[i].col_range != canonical[j].col_range {
1089                continue;
1090            }
1091            // Shapes must match — they do when col_range matches because
1092            // `local` is `block_dim × block_dim` and `block_dim = col_range.len()`.
1093            assert_eq!(canonical[i].local.dim(), canonical[j].local.dim());
1094
1095            let inner: f64 = canonical[i]
1096                .local
1097                .iter()
1098                .zip(canonical[j].local.iter())
1099                .map(|(&a, &b)| a * b)
1100                .sum();
1101            let denom = (trace_sq[i] * trace_sq[j]).sqrt();
1102            if denom == 0.0 {
1103                continue;
1104            }
1105            let cos = inner / denom;
1106
1107            if cos > REDUNDANCY_THRESHOLD {
1108                redundant.push((i, j, cos));
1109            } else if cos > SIMILARITY_THRESHOLD {
1110                similar.push((i, j, cos));
1111            }
1112        }
1113    }
1114
1115    // Always emit every redundancy — these are structural model errors.
1116    for &(i, j, cos) in &redundant {
1117        log::warn!(
1118            "[PENALTY-REDUNDANCY] penalties i={i} j={j} are structurally identical \
1119             (cos={cos:.6}) — model is over-parameterized along their antisymmetric \
1120             direction; expect a Z₂-symmetric saddle in the LAML cost. Consider \
1121             re-specifying (e.g. anisotropic→isotropic for spatial smoothers with \
1122             weak axis signal)."
1123        );
1124    }
1125
1126    // Cap similarity log volume at large scale.
1127    if k > LARGE_SCALE_K_THRESHOLD && similar.len() > TOP_SIMILARITY_PAIRS {
1128        similar.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1129        similar.truncate(TOP_SIMILARITY_PAIRS);
1130    }
1131    for (i, j, cos) in similar {
1132        log::info!(
1133            "[PENALTY-SIMILARITY] penalties i={i} j={j} are near-identical \
1134             (cos={cos:.6}) — outer Hessian may be ill-conditioned along their \
1135             antisymmetric direction."
1136        );
1137    }
1138
1139    redundant
1140}
1141
1142/// Canonicalize a single `PenaltySpec` into a `CanonicalPenalty` by computing
1143/// the block-local eigendecomposition and extracting the root.
1144///
1145/// This is O(block_dim^3) instead of O(p^3) for block-local penalties.
1146/// Returns `None` if the penalty has rank zero (should be dropped).
1147pub fn canonicalize_penalty_spec(
1148    spec: &crate::PenaltySpec,
1149    p: usize,
1150    idx: usize,
1151    context: &str,
1152) -> Result<Option<CanonicalPenalty>, EstimationError> {
1153    use crate::PenaltySpec;
1154
1155    crate::validate_penalty_spec_shape(idx, spec, p, context)?;
1156
1157    let (local_matrix, col_range, prior_mean_spec, hint, op) = match spec {
1158        PenaltySpec::Block {
1159            local,
1160            col_range,
1161            prior_mean,
1162            structure_hint,
1163            op,
1164        } => (
1165            local.view(),
1166            col_range.clone(),
1167            prior_mean,
1168            structure_hint.as_ref(),
1169            op.clone(),
1170        ),
1171        PenaltySpec::Dense(m) => (
1172            m.view(),
1173            0..p,
1174            &gam_problem::CoefficientPriorMean::Zero,
1175            None,
1176            None,
1177        ),
1178        PenaltySpec::DenseWithMean { matrix, prior_mean } => {
1179            (matrix.view(), 0..p, prior_mean, None, None)
1180        }
1181    };
1182
1183    let block_dim = col_range.len();
1184    let prior_mean = prior_mean_spec
1185        .evaluate(block_dim, &format!("{context}: penalty {idx}"))
1186        .map_err(|e| EstimationError::InvalidInput(e.0))?;
1187
1188    // ── Ridge fast path: closed-form, no eigendecomposition ──
1189    if let Some(PenaltyStructureHint::Ridge(scale)) = hint {
1190        if *scale <= 0.0 {
1191            return Ok(None);
1192        }
1193        let sqrt_scale = scale.sqrt();
1194        let mut root = Array2::zeros((block_dim, block_dim));
1195        for i in 0..block_dim {
1196            root[[i, i]] = sqrt_scale;
1197        }
1198        // Ridge penalties are diagonal by construction, but still route through
1199        // the crate-wide ndarray symmetrizer so every construction variant uses
1200        // the same "average the transpose" cleanup instead of a local copy.
1201        let mut local_sym = local_matrix.to_owned();
1202        symmetrize_in_place(&mut local_sym);
1203        return Ok(Some(CanonicalPenalty {
1204            root,
1205            col_range,
1206            total_dim: p,
1207            nullity: 0,
1208            local: local_sym,
1209            prior_mean,
1210            positive_eigenvalues: vec![*scale; block_dim],
1211            op,
1212        }));
1213    }
1214
1215    // ── Kronecker fast path: single per-factor eigendecomposition ──
1216    if let Some(PenaltyStructureHint::Kronecker(factors)) = hint {
1217        let decomps =
1218            match decompose_kronecker_factors(factors, &format!("{context} penalty {idx}"))? {
1219                None => return Ok(None),
1220                Some(d) => d,
1221            };
1222        let (positive_eigenvalues, nullity) = kronecker_eigenvalues(&decomps, block_dim);
1223        if positive_eigenvalues.is_empty() {
1224            return Ok(None);
1225        }
1226        let root = assemble_kronecker_root_local(&decomps);
1227        let mut local_sym = local_matrix.to_owned();
1228        symmetrize_in_place(&mut local_sym);
1229        return Ok(Some(CanonicalPenalty {
1230            root,
1231            col_range,
1232            total_dim: p,
1233            nullity,
1234            local: local_sym,
1235            prior_mean,
1236            positive_eigenvalues,
1237            op,
1238        }));
1239    }
1240
1241    // ── Generic block-local path: eigendecompose at O(block_dim³) ──
1242    let local_owned = local_matrix.to_owned();
1243    let analysis = analyze_penalty_block(&local_owned).map_err(|err| {
1244        EstimationError::InvalidInput(format!(
1245            "{context}: penalty canonicalization failed at index {idx}: {err}"
1246        ))
1247    })?;
1248
1249    if analysis.rank == 0 {
1250        log::debug!(
1251            "Dropped inactive penalty block idx={idx} reason={}",
1252            if analysis.iszero {
1253                "ZeroMatrix"
1254            } else {
1255                "NumericalRankZero"
1256            }
1257        );
1258        return Ok(None);
1259    }
1260
1261    // Reuse the eigendecomposition from analyze_penalty_block and route the
1262    // range / null / negative-curvature split through the one canonical
1263    // classifier, so this root construction cannot disagree with the block's
1264    // own `rank` / `nullity` / `negative_dim` about which directions are
1265    // penalized, unpenalized, or non-PSD (#1425).
1266    let tolerance = analysis.tol;
1267    let classes = crate::basis::SpectralClassification::new(&analysis.eigenvalues, tolerance);
1268    let rank_k = classes.rank();
1269    assert_eq!(
1270        rank_k, analysis.rank,
1271        "penalty-root rank disagreement: SpectralClassification rank={rank_k} vs analyze_penalty_block rank={} (#1425 canonical-classifier invariant)",
1272        analysis.rank
1273    );
1274
1275    // Build the penalty root R from ONLY the range directions (positive
1276    // curvature): R has one row per range eigenpair, scaled by sqrt(ev), so
1277    // RᵀR reconstructs S on range(S). Null directions contribute nothing
1278    // (their eigenvalue is zero); negative-curvature directions are NEVER
1279    // square-rooted into R (their sqrt is imaginary) and are NOT null — they
1280    // are simply dropped from R, exactly as the closed-form Duchon kernels at
1281    // high d require to preserve the q_pen / q_null invariant downstream.
1282    let mut root = Array2::zeros((rank_k, block_dim));
1283    let mut positive_eigenvalues = Vec::with_capacity(rank_k);
1284    for (row_idx, &i) in classes.range_idx.iter().enumerate() {
1285        let eigenval = analysis.eigenvalues[i];
1286        let eigenvec = analysis.eigenvectors.column(i);
1287        root.row_mut(row_idx).assign(&(&eigenvec * eigenval.sqrt()));
1288        positive_eigenvalues.push(eigenval);
1289    }
1290
1291    // Surface any genuine negative curvature honestly: it is neither range
1292    // (dropped from R) nor null (excluded from `nullity`), so it would
1293    // otherwise vanish without a trace. A non-PSD penalty reaching this path
1294    // is a real geometric fact (e.g. high-d Duchon kernels) the operator
1295    // should be able to see.
1296    if classes.is_indefinite() {
1297        log::debug!(
1298            "{context}: penalty block idx={idx} carries {} negative-curvature \
1299             eigendirection(s) below -tol={tolerance:e}; dropped from the canonical \
1300             root and NOT counted as null space (rank={rank_k}, nullity={})",
1301            classes.negative_dim(),
1302            classes.nullity()
1303        );
1304    }
1305
1306    // Store the PSD reconstruction RᵀR rather than the raw symmetrised input so
1307    // the cached `local` matches the rank truncation embedded in `root`
1308    // (negative-curvature directions are excluded from both, as above).
1309    let local = root.t().dot(&root);
1310    Ok(Some(CanonicalPenalty {
1311        root,
1312        col_range,
1313        total_dim: p,
1314        nullity: classes.nullity(),
1315        local,
1316        prior_mean,
1317        positive_eigenvalues,
1318        op,
1319    }))
1320}
1321
1322/// Canonicalize a batch of penalty specs, dropping zero-rank penalties.
1323/// Returns (active_penalties, active_nullspace_dims).
1324pub fn canonicalize_penalty_specs(
1325    specs: &[crate::PenaltySpec],
1326    nullspace_dims: &[usize],
1327    p: usize,
1328    context: &str,
1329) -> Result<(Vec<CanonicalPenalty>, Vec<usize>), EstimationError> {
1330    if specs.len() != nullspace_dims.len() {
1331        crate::bail_invalid_estim!(
1332            "{context}: nullspace_dims length mismatch: penalties={}, nullspace_dims={}",
1333            specs.len(),
1334            nullspace_dims.len()
1335        );
1336    }
1337
1338    let mut active = Vec::with_capacity(specs.len());
1339    let mut active_nullspace = Vec::with_capacity(specs.len());
1340    for (idx, spec) in specs.iter().enumerate() {
1341        if let Some(canonical) = canonicalize_penalty_spec(spec, p, idx, context)? {
1342            active_nullspace.push(nullspace_dims[idx]);
1343            active.push(canonical);
1344        }
1345    }
1346    Ok((active, active_nullspace))
1347}
1348
1349/// Hard cap on the dimension `p` allowed to fall back to a dense p × p
1350/// eigendecomposition of the overlapping balanced penalty.
1351///
1352/// Beyond this cap the overlapping-penalty path errors out instead of
1353/// allocating an O(p²) workspace whose eigendecomposition would dominate the
1354/// solve. ResourcePolicy threading is the long-term home for this cap (the
1355/// resource_serialize agent is widening ResourcePolicy coverage); until that
1356/// lands, both overlapping branches share this single constant so they can't
1357/// drift apart.
1358pub(crate) const OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P: usize = 4096;
1359
1360/// Creates a balanced penalty root from canonical penalties.
1361///
1362/// When all penalties have non-overlapping col_ranges, the balanced sum is
1363/// block-diagonal and eigendecomposition is done per-block at O(Σ p_k³)
1364/// instead of the global O(p³). Falls back to the global path when penalties
1365/// overlap.
1366pub fn create_balanced_penalty_root_from_canonical(
1367    penalties: &[CanonicalPenalty],
1368    p: usize,
1369) -> Result<Array2<f64>, EstimationError> {
1370    if penalties.is_empty() {
1371        return Ok(Array2::zeros((0, p)));
1372    }
1373
1374    // Group penalties by col_range.
1375    let mut block_groups: BTreeMap<(usize, usize), Vec<&CanonicalPenalty>> = BTreeMap::new();
1376    for cp in penalties {
1377        if cp.rank() == 0 {
1378            continue;
1379        }
1380        let key = (cp.col_range.start, cp.col_range.end);
1381        block_groups.entry(key).or_default().push(cp);
1382    }
1383
1384    if block_groups.is_empty() {
1385        return Ok(Array2::zeros((0, p)));
1386    }
1387
1388    // Check for overlapping ranges.
1389    let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1390    let mut overlapping = false;
1391    for i in 1..ranges.len() {
1392        if ranges[i].0 < ranges[i - 1].1 {
1393            overlapping = true;
1394            break;
1395        }
1396    }
1397
1398    if overlapping {
1399        if p > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1400            return Err(EstimationError::LayoutError(format!(
1401                "overlapping penalty root would require dense {}x{} eigendecomposition; \
1402                 large-model dense fallback is disabled. Keep penalties structured or \
1403                 extend the overlapping-penalty solver path",
1404                p, p
1405            )));
1406        }
1407        // Fallback: accumulate into p × p and eigendecompose globally.
1408        let mut s_balanced = Array2::zeros((p, p));
1409        for cp in penalties {
1410            if cp.rank() == 0 {
1411                continue;
1412            }
1413            let local = cp.local_ref();
1414            let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1415            if frob_norm > 1e-12 {
1416                let r = &cp.col_range;
1417                s_balanced
1418                    .slice_mut(s![r.start..r.end, r.start..r.end])
1419                    .scaled_add(1.0 / frob_norm, local);
1420            }
1421        }
1422        let (eigenvalues, eigenvectors) =
1423            robust_eigh(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1424        let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1425        let tolerance = if max_eig > 0.0 {
1426            max_eig * 1e-12
1427        } else {
1428            1e-12
1429        };
1430        let penalty_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1431        if penalty_rank == 0 {
1432            return Ok(Array2::zeros((0, p)));
1433        }
1434        let mut eb = Array2::zeros((p, penalty_rank));
1435        let mut col_idx = 0;
1436        for (i, &eigenval) in eigenvalues.iter().enumerate() {
1437            if eigenval > tolerance {
1438                let sqrt_ev = eigenval.sqrt();
1439                let evec = eigenvectors.column(i);
1440                eb.column_mut(col_idx).assign(&(&evec * sqrt_ev));
1441                col_idx += 1;
1442            }
1443        }
1444        return Ok(eb.t().to_owned());
1445    }
1446
1447    // Non-overlapping: eigendecompose per block at O(Σ p_k³).
1448    struct BlockRoot {
1449        col_range: Range<usize>,
1450        root: Array2<f64>, // rank_b × block_dim
1451    }
1452    // Materialize the BTreeMap order first. Rayon preserves Vec collection
1453    // order for indexed parallel iterators, so assembly below remains stable by
1454    // ascending column range while independent block eigendecompositions run in
1455    // parallel.
1456    let ordered_blocks: Vec<((usize, usize), Vec<&CanonicalPenalty>)> =
1457        block_groups.into_iter().collect();
1458    let block_roots: Vec<BlockRoot> = ordered_blocks
1459        .into_par_iter()
1460        .map(
1461            |((start, end), cps)| -> Result<Option<BlockRoot>, EstimationError> {
1462                let block_dim = end - start;
1463                let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1464
1465                for cp in cps {
1466                    let local = cp.local_ref();
1467                    let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1468                    if frob_norm > 1e-12 {
1469                        s_balanced_local.scaled_add(1.0 / frob_norm, local);
1470                    }
1471                }
1472
1473                let (eigenvalues, eigenvectors) =
1474                    robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1475                let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1476                let tolerance = if max_eig > 0.0 {
1477                    max_eig * 1e-12
1478                } else {
1479                    1e-12
1480                };
1481                let block_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1482
1483                if block_rank == 0 {
1484                    return Ok(None);
1485                }
1486
1487                let mut root = Array2::zeros((block_rank, block_dim));
1488                let mut row_idx = 0;
1489                for (i, &eigenval) in eigenvalues.iter().enumerate() {
1490                    if eigenval > tolerance {
1491                        let sqrt_ev = eigenval.sqrt();
1492                        let evec = eigenvectors.column(i);
1493                        root.row_mut(row_idx).assign(&(&evec * sqrt_ev));
1494                        row_idx += 1;
1495                    }
1496                }
1497
1498                Ok(Some(BlockRoot {
1499                    col_range: start..end,
1500                    root,
1501                }))
1502            },
1503        )
1504        .collect::<Result<Vec<_>, _>>()?
1505        .into_iter()
1506        .flatten()
1507        .collect();
1508    let total_rank: usize = block_roots.iter().map(|br| br.root.nrows()).sum();
1509
1510    if total_rank == 0 {
1511        return Ok(Array2::zeros((0, p)));
1512    }
1513
1514    // Assemble global balanced root: total_rank × p
1515    let mut eb = Array2::zeros((total_rank, p));
1516    let mut row_offset = 0;
1517    for br in &block_roots {
1518        let rank_b = br.root.nrows();
1519        eb.slice_mut(s![
1520            row_offset..(row_offset + rank_b),
1521            br.col_range.start..br.col_range.end
1522        ])
1523        .assign(&br.root);
1524        row_offset += rank_b;
1525    }
1526
1527    Ok(eb)
1528}
1529
1530/// Lambda-independent reparameterization invariants derived from penalty structure.
1531#[derive(Clone)]
1532struct SubspaceSplit {
1533    q_pen: Array2<f64>,
1534    q_null: Array2<f64>,
1535}
1536
1537impl SubspaceSplit {
1538    fn identity(p: usize) -> Self {
1539        Self {
1540            q_pen: Array2::zeros((p, 0)),
1541            q_null: Array2::eye(p),
1542        }
1543    }
1544
1545    fn from_ordered_qs(
1546        qs: &Mat<f64>,
1547        penalized_rank: usize,
1548        p: usize,
1549    ) -> Result<Self, EstimationError> {
1550        if qs.nrows() != p || qs.ncols() != p {
1551            return Err(EstimationError::LayoutError(format!(
1552                "Invalid Q basis dimensions: expected {p}x{p}, got {}x{}",
1553                qs.nrows(),
1554                qs.ncols()
1555            )));
1556        }
1557        if penalized_rank > p {
1558            return Err(EstimationError::LayoutError(format!(
1559                "Invalid penalized rank {penalized_rank} for p={p}"
1560            )));
1561        }
1562
1563        let null_count = p - penalized_rank;
1564        let mut q_pen = Array2::<f64>::zeros((p, penalized_rank));
1565        let mut q_null = Array2::<f64>::zeros((p, null_count));
1566        for i in 0..p {
1567            for j in 0..penalized_rank {
1568                q_pen[(i, j)] = qs[(i, j)];
1569            }
1570            for j in 0..null_count {
1571                q_null[(i, j)] = qs[(i, penalized_rank + j)];
1572            }
1573        }
1574
1575        Ok(Self { q_pen, q_null })
1576    }
1577
1578    fn rank(&self) -> usize {
1579        self.q_pen.ncols()
1580    }
1581
1582    fn p(&self) -> usize {
1583        self.q_pen.nrows()
1584    }
1585
1586    fn compose_qs(&self) -> Array2<f64> {
1587        let p = self.p();
1588        let rank = self.rank();
1589        let null_count = self.q_null.ncols();
1590        let mut qs = Array2::<f64>::zeros((p, p));
1591        for i in 0..p {
1592            for j in 0..rank {
1593                qs[(i, j)] = self.q_pen[(i, j)];
1594            }
1595            for j in 0..null_count {
1596                qs[(i, rank + j)] = self.q_null[(i, j)];
1597            }
1598        }
1599        qs
1600    }
1601}
1602
1603/// Lambda-independent reparameterization invariants derived from penalty structure.
1604#[derive(Clone)]
1605pub struct ReparamInvariant {
1606    split: SubspaceSplit,
1607    /// The balanced eigenvector matrix Q (p x p). Block-local roots are
1608    /// transformed on-the-fly as `R_block @ Q[start..end, :]` instead of
1609    /// storing pre-multiplied full-width roots.
1610    qs_base: Array2<f64>,
1611    has_nonzero: bool,
1612    /// Largest eigenvalue of the balanced (unit-Frobenius) penalty matrix.
1613    /// Used as the scale reference for the shrinkage floor.
1614    max_balanced_eigenvalue: f64,
1615}
1616
1617impl ReparamInvariant {
1618    /// Returns the largest eigenvalue of the balanced penalty matrix.
1619    /// This is lambda-independent and provides a natural scale for shrinkage.
1620    pub const fn max_balanced_eigenvalue(&self) -> f64 {
1621        self.max_balanced_eigenvalue
1622    }
1623}
1624
1625/// Precompute the lambda-invariant reparameterization structure from canonical penalties.
1626///
1627/// Uses block-local roots directly instead of requiring rank x p global roots.
1628/// Each `CanonicalPenalty` carries its own block-local root and column range,
1629/// so the balanced sum can be assembled without ever materializing full-size
1630/// penalty matrices.
1631pub fn precompute_reparam_invariant_from_canonical(
1632    penalties: &[CanonicalPenalty],
1633    p_total: usize,
1634) -> Result<ReparamInvariant, EstimationError> {
1635    use std::cmp::Ordering;
1636
1637    let m = penalties.len();
1638
1639    if m == 0 {
1640        return Ok(ReparamInvariant {
1641            split: SubspaceSplit::identity(p_total),
1642            qs_base: Array2::eye(p_total),
1643            has_nonzero: false,
1644            max_balanced_eigenvalue: 0.0,
1645        });
1646    }
1647
1648    // Group penalties by col_range to detect block-diagonal structure.
1649    struct PenRef {
1650        penalty_index: usize,
1651    }
1652    let mut block_groups: BTreeMap<(usize, usize), Vec<PenRef>> = BTreeMap::new();
1653    let mut has_nonzero = false;
1654    for (i, cp) in penalties.iter().enumerate() {
1655        if cp.rank() == 0 {
1656            continue;
1657        }
1658        let local = cp.local_ref();
1659        let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1660        if frob_norm > 1e-12 {
1661            has_nonzero = true;
1662        }
1663        let key = (cp.col_range.start, cp.col_range.end);
1664        block_groups
1665            .entry(key)
1666            .or_default()
1667            .push(PenRef { penalty_index: i });
1668    }
1669
1670    if !has_nonzero {
1671        return Ok(ReparamInvariant {
1672            split: SubspaceSplit::identity(p_total),
1673            qs_base: Array2::eye(p_total),
1674            has_nonzero: false,
1675            max_balanced_eigenvalue: 0.0,
1676        });
1677    }
1678
1679    // Check for overlapping ranges.
1680    let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1681    let mut overlapping = false;
1682    for i in 1..ranges.len() {
1683        if ranges[i].0 < ranges[i - 1].1 {
1684            overlapping = true;
1685            break;
1686        }
1687    }
1688
1689    if overlapping {
1690        // Mirror the dense-fallback guard from
1691        // `create_balanced_penalty_root_from_canonical`. Without this, large-scale-
1692        // scale models with overlapping penalties allocated a full
1693        // p_total × p_total workspace and ran an O(p³) eigendecomposition
1694        // before any solver code saw the problem size.
1695        if p_total > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1696            return Err(EstimationError::LayoutError(format!(
1697                "overlapping penalty reparameterization would require dense {}x{} eigendecomposition; \
1698                 large-model dense fallback is disabled. Keep penalties structured or \
1699                 extend the overlapping-penalty solver path",
1700                p_total, p_total
1701            )));
1702        }
1703        // Fallback: global p×p eigendecomposition.
1704        let mut s_balanced = Mat::<f64>::zeros(p_total, p_total);
1705        for cp in penalties {
1706            if cp.rank() == 0 {
1707                continue;
1708            }
1709            let local = cp.local_ref();
1710            let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1711            if frob_norm > 1e-12 {
1712                let scale = 1.0 / frob_norm;
1713                let r = &cp.col_range;
1714                for i in 0..local.nrows() {
1715                    for j in 0..local.ncols() {
1716                        s_balanced[(r.start + i, r.start + j)] += scale * local[[i, j]];
1717                    }
1718                }
1719            }
1720        }
1721
1722        let (bal_eigenvalues, bal_eigenvectors) =
1723            robust_eigh_faer(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1724
1725        let mut order: Vec<usize> = (0..p_total).collect();
1726        order.sort_by(|&i, &j| {
1727            bal_eigenvalues[j]
1728                .partial_cmp(&bal_eigenvalues[i])
1729                .unwrap_or(Ordering::Equal)
1730                .then(i.cmp(&j))
1731        });
1732
1733        let mut qs = Mat::<f64>::zeros(p_total, p_total);
1734        for (col_idx, &idx) in order.iter().enumerate() {
1735            for row in 0..p_total {
1736                qs[(row, col_idx)] = bal_eigenvectors[(row, idx)];
1737            }
1738        }
1739
1740        let max_bal = order
1741            .iter()
1742            .map(|&idx| bal_eigenvalues[idx].abs())
1743            .fold(0.0_f64, f64::max);
1744        let rank_tol = if max_bal > 0.0 {
1745            max_bal * 1e-12
1746        } else {
1747            1e-12
1748        };
1749        let penalized_rank = order
1750            .iter()
1751            .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1752            .count();
1753        let split = SubspaceSplit::from_ordered_qs(&qs, penalized_rank, p_total)?;
1754
1755        return Ok(ReparamInvariant {
1756            split,
1757            qs_base: mat_to_array(&qs),
1758            has_nonzero,
1759            max_balanced_eigenvalue: max_bal,
1760        });
1761    }
1762
1763    // -----------------------------------------------------------------------
1764    // Non-overlapping: block-diagonal eigendecomposition at O(Σ p_k³).
1765    // -----------------------------------------------------------------------
1766    // The balanced sum is block-diagonal ⟹ its eigenvectors are block-local.
1767    // Q_pen and Q_null are assembled by embedding block-local eigenvectors.
1768
1769    // Track which columns are covered by any penalty.
1770    let mut covered = vec![false; p_total];
1771    for cp in penalties {
1772        for j in cp.col_range.clone() {
1773            covered[j] = true;
1774        }
1775    }
1776    let uncovered_cols: Vec<usize> = (0..p_total).filter(|j| !covered[*j]).collect();
1777
1778    struct BlockResult {
1779        col_range: Range<usize>,
1780        q_pen_local: Array2<f64>,  // block_dim × pen_rank
1781        q_null_local: Array2<f64>, // block_dim × null_rank
1782        /// Largest balanced eigenvalue contributed by this block.
1783        max_balanced_eigenvalue: f64,
1784        /// Column offset of this block's penalized directions within global Q_pen.
1785        pen_col_offset: usize,
1786        /// Column offset of this block's null directions within global Q_null.
1787        null_col_offset: usize,
1788    }
1789
1790    // BTreeMap iteration defines the deterministic block order; collecting the
1791    // indexed parallel iterator preserves that order while eigendecomposing each
1792    // independent canonical penalty block concurrently.
1793    let block_specs: Vec<_> = block_groups.iter().collect();
1794    let mut block_results: Vec<BlockResult> = block_specs
1795        .into_par_iter()
1796        .map(
1797            |(&(start, end), refs)| -> Result<BlockResult, EstimationError> {
1798                let block_dim = end - start;
1799
1800                // Build local balanced sum.
1801                let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1802                let mut block_has_nonzero = false;
1803                for pref in refs {
1804                    let cp = &penalties[pref.penalty_index];
1805                    let local = cp.local_ref();
1806                    let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1807                    if frob_norm > 1e-12 {
1808                        s_balanced_local.scaled_add(1.0 / frob_norm, local);
1809                        block_has_nonzero = true;
1810                    }
1811                }
1812
1813                if !block_has_nonzero {
1814                    return Ok(BlockResult {
1815                        col_range: start..end,
1816                        q_pen_local: Array2::zeros((block_dim, 0)),
1817                        q_null_local: Array2::eye(block_dim),
1818                        max_balanced_eigenvalue: 0.0,
1819                        pen_col_offset: 0,  // set later
1820                        null_col_offset: 0, // set later
1821                    });
1822                }
1823
1824                // Eigendecompose the local balanced penalty.
1825                let (bal_eigenvalues, bal_eigenvectors) =
1826                    robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1827
1828                let mut order: Vec<usize> = (0..block_dim).collect();
1829                order.sort_by(|&i, &j| {
1830                    bal_eigenvalues[j]
1831                        .partial_cmp(&bal_eigenvalues[i])
1832                        .unwrap_or(Ordering::Equal)
1833                        .then(i.cmp(&j))
1834                });
1835
1836                let max_bal = order
1837                    .iter()
1838                    .map(|&idx| bal_eigenvalues[idx].abs())
1839                    .fold(0.0_f64, f64::max);
1840                let rank_tol = if max_bal > 0.0 {
1841                    max_bal * 1e-12
1842                } else {
1843                    1e-12
1844                };
1845                let penalized_rank = order
1846                    .iter()
1847                    .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1848                    .count();
1849                let null_count = block_dim - penalized_rank;
1850
1851                let mut q_pen_local = Array2::zeros((block_dim, penalized_rank));
1852                let mut q_null_local = Array2::zeros((block_dim, null_count));
1853                for (col_idx, &idx) in order.iter().enumerate() {
1854                    if col_idx < penalized_rank {
1855                        for row in 0..block_dim {
1856                            q_pen_local[[row, col_idx]] = bal_eigenvectors[[row, idx]];
1857                        }
1858                    } else {
1859                        let null_col = col_idx - penalized_rank;
1860                        for row in 0..block_dim {
1861                            q_null_local[[row, null_col]] = bal_eigenvectors[[row, idx]];
1862                        }
1863                    }
1864                }
1865
1866                Ok(BlockResult {
1867                    col_range: start..end,
1868                    q_pen_local,
1869                    q_null_local,
1870                    max_balanced_eigenvalue: max_bal,
1871                    pen_col_offset: 0,  // set later
1872                    null_col_offset: 0, // set later
1873                })
1874            },
1875        )
1876        .collect::<Result<_, _>>()?;
1877    let global_max_bal = block_results
1878        .iter()
1879        .map(|br| br.max_balanced_eigenvalue)
1880        .fold(0.0_f64, f64::max);
1881
1882    // Compute column offsets for each block in the global Q_pen / Q_null layout.
1883    let total_pen_rank: usize = block_results.iter().map(|br| br.q_pen_local.ncols()).sum();
1884    let total_null: usize = block_results
1885        .iter()
1886        .map(|br| br.q_null_local.ncols())
1887        .sum::<usize>()
1888        + uncovered_cols.len();
1889    {
1890        let mut pen_off = 0usize;
1891        let mut null_off = 0usize;
1892        for br in &mut block_results {
1893            br.pen_col_offset = pen_off;
1894            br.null_col_offset = null_off;
1895            pen_off += br.q_pen_local.ncols();
1896            null_off += br.q_null_local.ncols();
1897        }
1898    }
1899
1900    let mut q_pen = Array2::zeros((p_total, total_pen_rank));
1901    let mut q_null = Array2::zeros((p_total, total_null));
1902
1903    for br in &block_results {
1904        let start = br.col_range.start;
1905        let bd = br.q_pen_local.nrows();
1906        let pen_r = br.q_pen_local.ncols();
1907        let null_r = br.q_null_local.ncols();
1908        if pen_r > 0 {
1909            q_pen
1910                .slice_mut(s![
1911                    start..(start + bd),
1912                    br.pen_col_offset..(br.pen_col_offset + pen_r)
1913                ])
1914                .assign(&br.q_pen_local);
1915        }
1916        if null_r > 0 {
1917            q_null
1918                .slice_mut(s![
1919                    start..(start + bd),
1920                    br.null_col_offset..(br.null_col_offset + null_r)
1921                ])
1922                .assign(&br.q_null_local);
1923        }
1924    }
1925    let mut null_col = block_results
1926        .iter()
1927        .map(|br| br.q_null_local.ncols())
1928        .sum::<usize>();
1929    for &j in &uncovered_cols {
1930        q_null[[j, null_col]] = 1.0;
1931        null_col += 1;
1932    }
1933
1934    let split = SubspaceSplit { q_pen, q_null };
1935
1936    // Store the global Q_s = [Q_pen | Q_null] from the split.
1937    // Block-local roots are transformed on-the-fly as R_block @ Q[start..end, :]
1938    // inside the reparam engine, avoiding O(k * rank * p) storage.
1939    let qs_global = split.compose_qs();
1940
1941    Ok(ReparamInvariant {
1942        split,
1943        qs_base: qs_global,
1944        has_nonzero,
1945        max_balanced_eigenvalue: global_max_bal,
1946    })
1947}
1948
1949fn structurally_penalized_columns(penalties: &[CanonicalPenalty], p: usize) -> Vec<bool> {
1950    let mut active = vec![false; p];
1951    for cp in penalties {
1952        let local = cp.local_ref();
1953        let scale = local.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
1954        if scale <= 0.0 {
1955            continue;
1956        }
1957        let tol = scale * 1e-12;
1958        for local_col in 0..cp.block_dim() {
1959            let mut column_active = false;
1960            for row in 0..cp.block_dim() {
1961                if local[[row, local_col]].abs() > tol || local[[local_col, row]].abs() > tol {
1962                    column_active = true;
1963                    break;
1964                }
1965            }
1966            if column_active {
1967                active[cp.col_range.start + local_col] = true;
1968            }
1969        }
1970    }
1971    active
1972}
1973
1974/// Apply stable reparameterization using precomputed lambda-invariant structures.
1975///
1976/// `penalty_shrinkage_floor`: optional relative shrinkage floor for eigenvalues
1977/// of the penalized block. If `Some(epsilon)`, a rho-independent ridge of
1978/// magnitude `epsilon * max_balanced_eigenvalue` is added to each eigenvalue
1979/// of the combined penalty on the penalized block. This prevents barely-penalized
1980/// directions from causing pathological non-Gaussianity in the posterior (e.g.,
1981/// extreme skewness under logit link with high-dimensional spatial smooths).
1982/// A typical value is `1e-6`. Set to `None` or `Some(0.0)` to disable.
1983pub fn stable_reparameterizationwith_invariant(
1984    penalties: &[CanonicalPenalty],
1985    lambdas: &[f64],
1986    p: usize,
1987    invariant: &ReparamInvariant,
1988    penalty_shrinkage_floor: Option<f64>,
1989) -> Result<ReparamResult, EstimationError> {
1990    let m = penalties.len();
1991
1992    if lambdas.len() != m {
1993        return Err(EstimationError::ParameterConstraintViolation(format!(
1994            "Lambda count mismatch: expected {} lambdas for {} penalties, got {}",
1995            m,
1996            m,
1997            lambdas.len()
1998        )));
1999    }
2000
2001    // No separate length check needed — penalties are matched against lambdas above,
2002    // and the invariant's qs_base is p x p (dimension-checked by the split).
2003
2004    // #1074: the gam#1379 finite-ceiling on λ_k = exp(ρ_k) (clamp to 1e300 to
2005    // avoid `∞·0 = NaN` when the outer optimizer drives a redundant penalty
2006    // direction's log-λ past ~709) was DELETED. It masked the real defect: the
2007    // optimizer drives a redundant/unidentified penalty direction off to ∞
2008    // instead of that direction being detected and dropped from the model.
2009    // The root fix (detect+drop the redundant penalty direction at construction)
2010    // is tracked separately; λ now passes through raw.
2011
2012    if m == 0 {
2013        return Ok(ReparamResult {
2014            s_transformed: Array2::zeros((p, p)),
2015            log_det: 0.0,
2016            det1: Array1::zeros(0),
2017            qs: Array2::eye(p),
2018            canonical_transformed: vec![],
2019            e_transformed: Array2::zeros((0, p)),
2020            // All modes truncated when no penalties; already in transformed frame.
2021            u_truncated: Array2::eye(p),
2022            penalty_shrinkage_ridge: 0.0,
2023        });
2024    }
2025
2026    if !invariant.has_nonzero {
2027        let qs = invariant.split.compose_qs();
2028        let u_truncated = qs.t().dot(&invariant.split.q_null);
2029        // All penalties are zero — canonical_transformed = originals (no rotation needed).
2030        let canonical_transformed: Vec<CanonicalPenalty> = penalties.to_vec();
2031        return Ok(ReparamResult {
2032            s_transformed: Array2::zeros((p, p)),
2033            log_det: 0.0,
2034            det1: Array1::zeros(m),
2035            qs,
2036            canonical_transformed,
2037            e_transformed: Array2::zeros((0, p)),
2038            u_truncated,
2039            penalty_shrinkage_ridge: 0.0,
2040        });
2041    }
2042
2043    let q_pen = array_to_faer(&invariant.split.q_pen);
2044    let q_null = array_to_faer(&invariant.split.q_null);
2045    let qs_base = array_to_faer(&invariant.qs_base);
2046    // Each penalty root transform is independent: R_k_block @ Q[start..end, :].
2047    // Run those per-penalty products (and their S_k = R_k'R_k caches) in
2048    // parallel, then collect in slice order so all downstream accumulation stays
2049    // deterministic and bit-for-bit stable with respect to penalty ordering.
2050    let penalty_transforms: Vec<(Mat<f64>, Mat<f64>)> = penalties
2051        .par_iter()
2052        .map(|cp| {
2053            let r = &cp.col_range;
2054            let root_faer = array_to_faer(&cp.root);
2055            let q_block = qs_base.submatrix(r.start, 0, cp.block_dim(), p);
2056            let mut product = Mat::<f64>::zeros(cp.rank(), p);
2057            matmul(
2058                product.as_mut(),
2059                Accum::Replace,
2060                root_faer.as_ref(),
2061                q_block,
2062                1.0,
2063                Par::Seq,
2064            );
2065            let s_k = penalty_from_root_faer(&product);
2066            (product, s_k)
2067        })
2068        .collect();
2069    let (rs_transformed, s_k_penalized_cache): (Vec<Mat<f64>>, Vec<Mat<f64>>) =
2070        penalty_transforms.into_iter().unzip();
2071
2072    let penalized_rank = invariant.split.rank();
2073
2074    let mut range_eigenvalues_sorted: Vec<f64> = Vec::new();
2075    let mut range_rotation = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2076    if penalized_rank > 0 {
2077        let mut range_block = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2078        // Deterministic assembly: the independent S_k transforms were computed in
2079        // parallel above, but the lambda-weighted sum is accumulated serially in
2080        // canonical penalty order to avoid order-dependent floating-point drift.
2081        for (lambda, s_k) in lambdas.iter().zip(s_k_penalized_cache.iter()) {
2082            for i in 0..penalized_rank {
2083                for j in 0..penalized_rank {
2084                    range_block[(i, j)] += *lambda * s_k[(i, j)];
2085                }
2086            }
2087        }
2088        let (range_eigenvalues, range_eigenvectors) =
2089            robust_eigh_faer(&range_block, Side::Lower, "range penalty block")?;
2090
2091        let mut range_order: Vec<usize> = (0..penalized_rank).collect();
2092        range_order.sort_by(|&i, &j| {
2093            range_eigenvalues[j]
2094                .partial_cmp(&range_eigenvalues[i])
2095                .unwrap_or(std::cmp::Ordering::Equal)
2096                .then(i.cmp(&j))
2097        });
2098        range_eigenvalues_sorted = range_order
2099            .iter()
2100            .map(|&idx| range_eigenvalues[idx])
2101            .collect();
2102
2103        // Build range_rotation = U (sorted eigenvectors) for E and S⁺
2104        // construction only.  DO NOT apply to q_pen or rs_transformed —
2105        // keeping Q_s lambda-independent prevents BFGS coordinate-system
2106        // drift when multiple penalties interact (the eigenvectors of
2107        // Σ λ_k S_k rotate with λ, breaking the quasi-Newton Hessian
2108        // approximation at eigenvalue crossings).
2109        for (col_idx, &idx) in range_order.iter().enumerate() {
2110            for row in 0..penalized_rank {
2111                range_rotation[(row, col_idx)] = range_eigenvectors[(row, idx)];
2112            }
2113        }
2114        // q_pen and rs_transformed stay in the lambda-independent
2115        // invariant basis.  E and S⁺ below are expressed in this same
2116        // basis using U from the eigendecomposition.
2117    }
2118
2119    // Subspace-invariant penalty spectral calculus:
2120    // - Penalized and null spaces are fixed by the lambda-invariant basis `qs_base`.
2121    // - Runtime lambda dependence only appears in the penalized block eigenvalues.
2122    // This avoids basis mixing inside the degenerate zero-eigenspace.
2123    let structural_rank = penalized_rank;
2124    let mut range_eigs_sorted: Vec<f64> = range_eigenvalues_sorted;
2125    let structurally_penalized_cols = structurally_penalized_columns(penalties, p);
2126
2127    // Shrinkage floor: add a rho-independent ridge to the penalized block eigenvalues.
2128    // This prevents barely-penalized directions from causing pathological non-Gaussianity
2129    // in the posterior (extreme skewness under non-canonical links like logit with
2130    // high-dimensional spatial smooths). The ridge magnitude is proportional to the
2131    // balanced penalty's max eigenvalue (lambda-independent scale), so LAML gradients
2132    // w.r.t. rho remain correct: d(epsilon * I)/d(rho_k) = 0.
2133    //
2134    // The shrinkage ridge is a real prior contribution: it changes the quadratic
2135    // form, the penalty pseudo-logdet, and downstream Hessians. It is currently
2136    // surfaced through `ReparamResult::penalty_shrinkage_ridge` and consumed by
2137    // PIRLS/REML via that per-call channel. The longer-term home is a
2138    // `RidgePassport` with `RidgePolicy::explicit_stabilization_full` scoped to
2139    // the penalized block so the same delta is reflected in serialization, the
2140    // Laplace Hessian, and the prior logdet without callers having to thread an
2141    // additional scalar — once the ridge-ledger plumbing covers per-block
2142    // ScaledIdentity passports.
2143    let shrinkage_ridge = penalty_shrinkage_floor
2144        .filter(|&eps| eps > 0.0)
2145        .map(|eps| eps * invariant.max_balanced_eigenvalue)
2146        .unwrap_or(0.0);
2147    if shrinkage_ridge > 0.0 {
2148        let min_eig_before = range_eigs_sorted
2149            .iter()
2150            .copied()
2151            .fold(f64::INFINITY, f64::min);
2152        let mut shrinkage_floor_applied = 0usize;
2153        for eig_idx in 0..range_eigs_sorted.len() {
2154            let mut penalized_energy = 0.0;
2155            for original_col in 0..p {
2156                if structurally_penalized_cols[original_col] {
2157                    let mut coordinate = 0.0;
2158                    for pen_col in 0..penalized_rank {
2159                        coordinate +=
2160                            q_pen[(original_col, pen_col)] * range_rotation[(pen_col, eig_idx)];
2161                    }
2162                    penalized_energy += coordinate * coordinate;
2163                }
2164            }
2165            if penalized_energy > 1e-8 {
2166                range_eigs_sorted[eig_idx] += shrinkage_ridge;
2167                shrinkage_floor_applied += 1;
2168            }
2169        }
2170        // Log when the floor materially changes the smallest eigenvalue (>1% relative shift).
2171        if min_eig_before > 0.0 && shrinkage_ridge / min_eig_before > 0.01 {
2172            log::debug!(
2173                "Penalty shrinkage floor active: ridge={:.3e} (min_eig_before={:.3e}, ratio={:.1e}, max_bal_eig={:.3e}, applied_dirs={})",
2174                shrinkage_ridge,
2175                min_eig_before,
2176                shrinkage_ridge / min_eig_before,
2177                invariant.max_balanced_eigenvalue,
2178                shrinkage_floor_applied,
2179            );
2180        }
2181    }
2182
2183    let eigenvalue_floor = invariant.max_balanced_eigenvalue.max(1.0) * 1e-12;
2184    let qs = compose_qs_from_split(&q_pen, &q_null, p);
2185
2186    // Guard against any accidental penalized/null mixing. The transformed penalty
2187    // roots must have negligible support on null columns by construction.
2188    let leakage = assess_subspace_leakage(&qs, &rs_transformed, structural_rank, p);
2189    let leakage_rel_tol = 1e-10;
2190    let leakage_abs_tol = 1e-12;
2191    let orth_tol = 1e-10;
2192    if leakage.max_rel_sq > leakage_rel_tol && leakage.max_abs_sq > leakage_abs_tol
2193        || leakage.max_cross_gram_abs > orth_tol
2194    {
2195        return Err(EstimationError::LayoutError(format!(
2196            "Reparameterization subspace split is inconsistent: max null leakage {:.3e} (rel {:.3e}, worst penalty {}), max |Qp'Qn| {:.3e}",
2197            leakage.max_abs_sq.sqrt(),
2198            leakage.max_rel_sq.sqrt(),
2199            leakage.worst_penalty,
2200            leakage.max_cross_gram_abs,
2201        )));
2202    }
2203
2204    // Truncated basis in transformed coordinates:
2205    //   U_⊥^(t) = Qs^T U_⊥^(orig) = Qs^T Q_n.
2206    let mut u_truncated_mat = Mat::<f64>::zeros(p, q_null.ncols());
2207    matmul(
2208        u_truncated_mat.as_mut(),
2209        Accum::Replace,
2210        qs.transpose(),
2211        q_null.as_ref(),
2212        1.0,
2213        Par::Seq,
2214    );
2215
2216    // E is represented in TRANSFORMED coordinates (beta_t).  Because the
2217    // penalized subspace is NOT rotated by the lambda-dependent eigenvectors
2218    // (to keep Q_s stable across BFGS iterations), E is no longer diagonal.
2219    // Instead E = diag(√d) · U' embedded in structural_rank × p, so that
2220    // E'E = U diag(d) U' = Σ λ_k S_k in the invariant penalized basis.
2221    let mut e_transformed_mat = Mat::<f64>::zeros(structural_rank, p);
2222    for row_idx in 0..structural_rank {
2223        let safe_eigenval = range_eigs_sorted[row_idx].max(eigenvalue_floor);
2224        let sqrt_eigenval = safe_eigenval.sqrt();
2225        // E[row, j] = sqrt(d_row) * U'[row, j] = sqrt(d_row) * U[j, row]
2226        for j in 0..penalized_rank {
2227            e_transformed_mat[(row_idx, j)] = sqrt_eigenval * range_rotation[(j, row_idx)];
2228        }
2229    }
2230
2231    // Pseudo-logdet on the structural penalized block.  The null block is split
2232    // out above, so there is no nullspace normalization here.  Spectrally-noisy
2233    // directions (eigenvalues snapped to 0 by `classify_eigenvalues_strict`
2234    // because they fall below the `c * eps_machine * p * scale` tolerance in
2235    // the lambda-weighted sum) are floored to `eigenvalue_floor` to keep the
2236    // log-det finite and consistent with the floored values used to construct
2237    // `e_transformed_mat` above.  This avoids spurious P-IRLS failures when the
2238    // lambda dynamic range is wide (e.g. during BFGS line search probing extreme
2239    // rho candidates).  Materially negative or non-finite spectra are already
2240    // rejected by the strict classifier upstream; this loop re-checks the
2241    // *post-shrinkage* range eigenvalues against the same floor.
2242    //
2243    // The same floored spectrum is used in the trace formula tr(S⁺ S_k) below,
2244    // matching the rank structure embedded in `e_transformed_mat` and avoiding
2245    // a 1/0 in the trace contraction when an eigenvalue was floored to 0.
2246    let mut floored_eigs: Vec<f64> = Vec::with_capacity(range_eigs_sorted.len());
2247    let mut log_det_sum = KahanSum::default();
2248    for (idx, &ev) in range_eigs_sorted.iter().enumerate() {
2249        if !ev.is_finite() || ev < -eigenvalue_floor {
2250            return Err(EstimationError::LayoutError(format!(
2251                "Penalty pseudo-logdet has a non-finite or large-negative structural eigenvalue at index {idx}: {ev:.3e}"
2252            )));
2253        }
2254        let safe_ev = ev.max(eigenvalue_floor);
2255        floored_eigs.push(safe_ev);
2256        if idx < penalized_rank {
2257            log_det_sum.add(safe_ev.ln());
2258        }
2259    }
2260    let log_det = log_det_sum.sum();
2261    let delta = 0.0;
2262
2263    // The det1 contractions are independent once the eigensystem is fixed.  Use
2264    // indexed parallel collection so the output vector preserves lambda order.
2265    let det1vec: Vec<f64> = (0..lambdas.len())
2266        .into_par_iter()
2267        .map(|k| {
2268            let s_k = &s_k_penalized_cache[k];
2269            // Compute tr((S+δI)⁻¹ S_k) in the range eigenbasis without ever
2270            // materializing (S+δI)⁻¹. Using faer's matmul keeps this contraction
2271            // aligned with the orthogonal-similarity debug reference path.
2272            let trace = trace_penalty_in_orthogonal_basis(
2273                s_k,
2274                penalized_rank,
2275                &range_rotation,
2276                &floored_eigs,
2277                delta,
2278            );
2279            lambdas[k] * trace
2280        })
2281        .collect();
2282
2283    {
2284        // Guardrail: cross-check the primary Rayleigh-quotient contraction
2285        // against a full orthogonal similarity transform, while staying in
2286        // the same numerically stable eigenbasis coordinates.
2287        let mut maxdet1_mismatch = 0.0_f64;
2288        let mut det1_scale = 0.0_f64;
2289        for (k, lambda) in lambdas.iter().enumerate() {
2290            let s_k_penalized = &s_k_penalized_cache[k];
2291            let s_k_eigenbasis = orthogonal_similarity_transform_faer(
2292                s_k_penalized,
2293                penalized_rank,
2294                &range_rotation,
2295            );
2296            let mut trace = KahanSum::default();
2297            for l in 0..penalized_rank {
2298                trace.add(s_k_eigenbasis[(l, l)] / (floored_eigs[l] + delta));
2299            }
2300            let reference = *lambda * trace.sum();
2301            maxdet1_mismatch = maxdet1_mismatch.max((reference - det1vec[k]).abs());
2302            det1_scale = det1_scale.max(reference.abs()).max(det1vec[k].abs());
2303        }
2304        let det1_tolerance = 1e-7 * det1_scale.max(1.0);
2305        assert!(
2306            maxdet1_mismatch <= det1_tolerance,
2307            "det1 mismatch between optimized and reference formulas: max_abs={maxdet1_mismatch:.3e}, tol={det1_tolerance:.3e}"
2308        );
2309    }
2310
2311    // Rebuild s_transformed from e_transformed to ensure rank consistency.
2312    //
2313    // The sum of λ*S_k may contain numerical noise modes (eigenvalues ~1e-15) that
2314    // become significant when λ is large (e.g., 10^12). These modes would appear in H
2315    // but are truncated from log|S|_+, creating a "phantom penalty" in the objective.
2316    //
2317    // By reconstructing s_transformed = E^T * E, we force the penalty matrix used
2318    // in H to have the EXACT same rank structure as the one used for log|S|_+.
2319    // Any mode truncated from the prior is now strictly zero in the Hessian
2320    // calculation, ensuring mathematical consistency of the gradients.
2321    let mut s_truncated = Mat::<f64>::zeros(p, p);
2322    matmul(
2323        s_truncated.as_mut(),
2324        Accum::Replace,
2325        e_transformed_mat.transpose(),
2326        e_transformed_mat.as_ref(),
2327        1.0,
2328        Par::Seq,
2329    );
2330
2331    {
2332        // Structural check: transformed S must not leak into declared null coordinates.
2333        let mut max_null_diag = 0.0_f64;
2334        let mut max_null_offdiag = 0.0_f64;
2335        for i in structural_rank..p {
2336            max_null_diag = max_null_diag.max(s_truncated[(i, i)].abs());
2337            for j in 0..p {
2338                if i != j {
2339                    max_null_offdiag = max_null_offdiag.max(s_truncated[(i, j)].abs());
2340                }
2341            }
2342        }
2343        assert!(
2344            max_null_diag <= 1e-10 && max_null_offdiag <= 1e-10,
2345            "null-space leakage in transformed penalty: max_null_diag={max_null_diag:.3e}, max_null_offdiag={max_null_offdiag:.3e}"
2346        );
2347    }
2348
2349    let qs_array = mat_to_array(&qs);
2350    let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2351        .par_iter()
2352        .zip(penalties.par_iter())
2353        .map(|(r, cp)| {
2354            let mean_transformed = qs_array.t().dot(&cp.full_width_prior_mean());
2355            CanonicalPenalty::from_dense_root_with_mean(mat_to_array(r), p, mean_transformed)
2356        })
2357        .collect();
2358    Ok(ReparamResult {
2359        s_transformed: mat_to_array(&s_truncated),
2360        log_det,
2361        det1: Array1::from(det1vec),
2362        qs: qs_array,
2363        canonical_transformed,
2364        e_transformed: mat_to_array(&e_transformed_mat),
2365        u_truncated: mat_to_array(&u_truncated_mat),
2366        penalty_shrinkage_ridge: shrinkage_ridge,
2367    })
2368}
2369
2370/// Minimal engine layout descriptor that avoids domain-specific layout coupling.
2371#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2372pub struct EngineDims {
2373    pub p: usize,
2374    pub k: usize,
2375}
2376
2377impl EngineDims {
2378    pub fn new(p: usize, k: usize) -> Self {
2379        Self { p, k }
2380    }
2381}
2382
2383/// Engine-facing stable reparameterization API using only `(p, k)`.
2384///
2385/// When `cached_invariant` is `Some`, reuses the precomputed eigendecomposition
2386/// (the hot path inside the REML loop). When `None`, computes the invariant on
2387/// the fly (the post-REML refit path). Merging both cases into a single entry
2388/// point ensures `penalty_shrinkage_floor` is always applied regardless of
2389/// whether a cached invariant is available.
2390/// Stable reparameterization from block-local canonical penalties.
2391pub fn stable_reparameterization_engine_canonical(
2392    penalties: &[CanonicalPenalty],
2393    lambdas: &[f64],
2394    dims: EngineDims,
2395    cached_invariant: Option<&ReparamInvariant>,
2396    penalty_shrinkage_floor: Option<f64>,
2397) -> Result<ReparamResult, EstimationError> {
2398    let owned;
2399    let invariant = match cached_invariant {
2400        Some(inv) => inv,
2401        None => {
2402            owned = precompute_reparam_invariant_from_canonical(penalties, dims.p)?;
2403            &owned
2404        }
2405    };
2406    stable_reparameterizationwith_invariant(
2407        penalties,
2408        lambdas,
2409        dims.p,
2410        invariant,
2411        penalty_shrinkage_floor,
2412    )
2413}
2414
2415// ---------------------------------------------------------------------------
2416// Kronecker-factored reparameterization for tensor-product smooths
2417// ---------------------------------------------------------------------------
2418
2419/// Result of Kronecker-factored reparameterization.
2420///
2421/// Exploits the fact that for Kronecker-structured penalties, the joint
2422/// eigenvector matrix is `U_1 ⊗ ... ⊗ U_d` and the reparameterized design
2423/// is a rowwise Kronecker of `(B_k U_k)` — all remaining factored.
2424#[derive(Clone)]
2425pub struct KroneckerReparamResult {
2426    /// Reparameterized marginal designs: `B_k · U_k` for each marginal k.
2427    ///
2428    /// `Arc`-shared with the λ-invariant cache so the per-outer-iterate
2429    /// memoized engine bumps a refcount instead of deep-copying the
2430    /// (n × q) reparameterized marginals every call.
2431    pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
2432    /// Marginal eigenvalues from each marginal penalty eigendecomposition.
2433    pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
2434    /// Marginal eigenvector matrices U_k.
2435    pub marginal_qs: Arc<Vec<Array2<f64>>>,
2436    /// log|S|₊ computed from marginal eigenvalue grid.
2437    pub log_det: f64,
2438    /// First derivatives of log|S|₊ w.r.t. ρ_k = log(λ_k).
2439    pub det1: Array1<f64>,
2440    /// Second derivatives of log|S|₊ w.r.t. ρ.
2441    pub det2: Array2<f64>,
2442    /// Shrinkage ridge added to eigenvalues (if any).
2443    pub penalty_shrinkage_ridge: f64,
2444    /// Whether a double penalty (global ridge) is present.
2445    pub has_double_penalty: bool,
2446    /// Marginal basis dimensions.
2447    pub marginal_dims: Vec<usize>,
2448}
2449
2450impl KroneckerReparamResult {
2451    /// Materialize the joint Qs matrix (U_1 ⊗ ... ⊗ U_d) as dense p×p.
2452    /// Only for fallback paths — avoid in hot loops.
2453    pub fn materialize_qs(&self) -> Array2<f64> {
2454        let mut qs = Array2::<f64>::eye(1);
2455        for u_k in self.marginal_qs.iter() {
2456            qs = kronecker_product(&qs, u_k);
2457        }
2458        qs
2459    }
2460
2461    /// Materialize s_transformed (the penalty in the reparameterized basis).
2462    /// In the eigenbasis, this is diagonal with entries Σ_k λ_k μ_{k,j_k}.
2463    pub fn materialize_s_transformed(&self, lambdas: &[f64]) -> Array2<f64> {
2464        let d = self.marginal_dims.len();
2465        let p: usize = self.marginal_dims.iter().copied().product();
2466        let mut s = Array2::<f64>::zeros((p, p));
2467
2468        // Delegate the per-cell tensor-penalty accumulation to the shared
2469        // `kronecker_cell_sigma` (#1172/#1185 single source of truth). Fold the
2470        // `lambdas.len() > d` guard into `has_double` to preserve exact gating.
2471        let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2472            self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2473        let has_double = self.has_double_penalty && lambdas.len() > d;
2474        let mut multi_idx = vec![0usize; d];
2475        let mut flat = 0usize;
2476        loop {
2477            let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2478                &eigenvalue_views,
2479                &multi_idx,
2480                lambdas,
2481                d,
2482                has_double,
2483                self.penalty_shrinkage_ridge,
2484            );
2485            s[[flat, flat]] = sigma;
2486            flat += 1;
2487
2488            if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2489                break;
2490            }
2491        }
2492        s
2493    }
2494
2495    /// Explicitly materialize the dense artifact bundle expected by legacy
2496    /// downstream consumers. This is not part of the native Kronecker solve path.
2497    pub fn materialize_dense_artifact_result(
2498        &self,
2499        rs_list: &[Array2<f64>],
2500        lambdas: &[f64],
2501        p: usize,
2502    ) -> Result<ReparamResult, EstimationError> {
2503        const KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P: usize = 4096;
2504        if p > KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P {
2505            return Err(EstimationError::LayoutError(format!(
2506                "Kronecker reparameterization would materialize dense {}x{} compatibility tensors; \
2507                 large-model dense fallback is disabled. Wire the downstream solver to consume \
2508                 the factored Kronecker result directly",
2509                p, p
2510            )));
2511        }
2512        let qs = self.materialize_qs();
2513        let s_transformed = self.materialize_s_transformed(lambdas);
2514
2515        // Transform penalty roots: R_k_transformed = R_k · Qs
2516        let rs_transformed: Vec<Array2<f64>> = if rs_list.len() >= 2 {
2517            use rayon::prelude::*;
2518            rs_list
2519                .par_iter()
2520                .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2521                .collect()
2522        } else {
2523            rs_list
2524                .iter()
2525                .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2526                .collect()
2527        };
2528        // rs_transposed removed — canonical_transformed is the single source of truth.
2529
2530        // Build e_transformed: combined penalty square root in transformed coords.
2531        // For Kronecker structure, the penalty is diagonal in the eigenbasis.
2532        // e_transformed rows are the nonzero rows of sqrt(Σ_k λ_k S_k)^{1/2}.
2533        let d = self.marginal_dims.len();
2534        // Delegate the per-cell tensor-penalty accumulation to the shared
2535        // `kronecker_cell_sigma` (the #1172/#1185 single source of truth). The
2536        // double-penalty term is only valid when `lambdas` actually carries the
2537        // λ_d entry, so fold the original `lambdas.len() > d` guard into the
2538        // `has_double_penalty` flag passed to the helper — preserving the exact
2539        // gating behavior.
2540        let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2541            self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2542        let has_double = self.has_double_penalty && lambdas.len() > d;
2543        let diag_vals: Vec<f64> = {
2544            let mut vals = Vec::with_capacity(p);
2545            let mut multi_idx = vec![0usize; d];
2546            loop {
2547                let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2548                    &eigenvalue_views,
2549                    &multi_idx,
2550                    lambdas,
2551                    d,
2552                    has_double,
2553                    self.penalty_shrinkage_ridge,
2554                );
2555                vals.push(if sigma > 0.0 { sigma.sqrt() } else { 0.0 });
2556
2557                if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2558                    break;
2559                }
2560            }
2561            vals
2562        };
2563        let rank = diag_vals.iter().filter(|&&v| v > 1e-12).count();
2564        let mut e_transformed = Array2::<f64>::zeros((rank, p));
2565        let mut row = 0;
2566        for (j, &v) in diag_vals.iter().enumerate() {
2567            if v > 1e-12 {
2568                e_transformed[[row, j]] = v;
2569                row += 1;
2570            }
2571        }
2572
2573        // u_truncated: null-space eigenvectors (columns with zero eigenvalue).
2574        let null_count = p - rank;
2575        let mut u_truncated = Array2::<f64>::zeros((p, null_count));
2576        let mut col = 0;
2577        for (j, &v) in diag_vals.iter().enumerate() {
2578            if v <= 1e-12 {
2579                u_truncated[[j, col]] = 1.0; // standard basis vector in eigenbasis
2580                col += 1;
2581            }
2582        }
2583
2584        let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2585            .iter()
2586            .map(|r| CanonicalPenalty::from_dense_root(r.clone(), p))
2587            .collect();
2588        Ok(ReparamResult {
2589            s_transformed,
2590            log_det: self.log_det,
2591            det1: self.det1.clone(),
2592            qs,
2593            canonical_transformed,
2594            e_transformed,
2595            u_truncated,
2596            penalty_shrinkage_ridge: self.penalty_shrinkage_ridge,
2597        })
2598    }
2599}
2600
2601/// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`
2602/// from factored marginal eigenvalues.
2603///
2604/// Shared implementation for `KroneckerPenaltySystem::logdet_and_derivatives`
2605/// and `kronecker_reparameterization_engine`.  Iterates over the ∏q_j
2606/// multi-index grid in O(d · ∏q_j) time with no O(p²) storage.
2607const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
2608
2609/// Per-cell Kronecker eigenvalue accumulation — the single source of truth for
2610/// the #1172/#1185 tensor-penalty math.
2611///
2612/// For the multi-index cell `multi_idx`, accumulates:
2613///   - `sigma`            = Σ_k λ_k · μ_k  (+ joint-null double-penalty term + ridge)
2614///   - `structural_sigma` = Σ_k μ_k        (unweighted; classifies joint-null cells)
2615///   - `joint_null`       = whether the cell lies in the joint null space
2616///
2617/// `marginal_eigenvalues[k][multi_idx[k]]` is the k-th marginal eigenvalue μ_k.
2618/// The double-penalty (global ridge) term `λ_d` is added only on joint-null
2619/// cells; the structural shrinkage `ridge` is added only on structurally
2620/// penalized cells. This mirrors the gated logic fixed in #1172/#1185 and MUST
2621/// be kept identical across every caller.
2622#[inline]
2623fn kronecker_cell_sigma(
2624    marginal_eigenvalues: &[ArrayView1<'_, f64>],
2625    multi_idx: &[usize],
2626    lambdas: &[f64],
2627    d: usize,
2628    has_double_penalty: bool,
2629    ridge: f64,
2630) -> (f64, f64, bool) {
2631    let mut sigma = 0.0;
2632    let mut structural_sigma = 0.0;
2633    for k in 0..d {
2634        let marginal_eigenvalue = marginal_eigenvalues[k][multi_idx[k]];
2635        structural_sigma += marginal_eigenvalue;
2636        sigma += lambdas[k] * marginal_eigenvalue;
2637    }
2638    let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
2639    if has_double_penalty && joint_null {
2640        sigma += lambdas[d];
2641    }
2642    if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
2643        sigma += ridge;
2644    }
2645    (sigma, structural_sigma, joint_null)
2646}
2647
2648/// Advance a row-major multi-index over the `dims` grid in place.
2649/// Returns `true` when the grid is exhausted (the index wrapped back to all-zero).
2650#[inline]
2651fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
2652    let mut carry = true;
2653    for dim in (0..dims.len()).rev() {
2654        if carry {
2655            multi_idx[dim] += 1;
2656            if multi_idx[dim] < dims[dim] {
2657                carry = false;
2658            } else {
2659                multi_idx[dim] = 0;
2660            }
2661        }
2662    }
2663    carry
2664}
2665
2666pub fn kronecker_logdet_and_derivatives(
2667    marginal_eigenvalues: &[ArrayView1<'_, f64>],
2668    marginal_dims: &[usize],
2669    lambdas: &[f64],
2670    has_double_penalty: bool,
2671    ridge: f64,
2672) -> (f64, Array1<f64>, Array2<f64>) {
2673    let d = marginal_dims.len();
2674    let n_pen = d + if has_double_penalty { 1 } else { 0 };
2675
2676    let mut logdet = 0.0;
2677    let mut grad = Array1::<f64>::zeros(n_pen);
2678    let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2679    let tol = 1e-12;
2680
2681    let mut multi_idx = vec![0usize; d];
2682    loop {
2683        let (sigma, _structural_sigma, joint_null) = kronecker_cell_sigma(
2684            marginal_eigenvalues,
2685            &multi_idx,
2686            lambdas,
2687            d,
2688            has_double_penalty,
2689            ridge,
2690        );
2691
2692        if sigma > tol {
2693            logdet += sigma.ln();
2694            let inv_sigma = 1.0 / sigma;
2695            let inv_sigma2 = inv_sigma * inv_sigma;
2696
2697            for k in 0..d {
2698                let ck = lambdas[k] * marginal_eigenvalues[k][multi_idx[k]];
2699                grad[k] += ck * inv_sigma;
2700            }
2701            if has_double_penalty && joint_null {
2702                grad[d] += lambdas[d] * inv_sigma;
2703            }
2704
2705            for k in 0..n_pen {
2706                let ck = if k < d {
2707                    lambdas[k] * marginal_eigenvalues[k][multi_idx[k]]
2708                } else if joint_null {
2709                    lambdas[d]
2710                } else {
2711                    0.0
2712                };
2713                // When ck == 0 (a zero λ, a zero marginal eigenvalue, or a cell
2714                // outside the joint null for the ridge penalty) every term this
2715                // index k contributes — `ck·inv_sigma − ck²·inv_sigma2` on the
2716                // diagonal and `−ck·cl·inv_sigma2` on every off-diagonal — is
2717                // exactly 0.0, so adding them to the finite running accumulators
2718                // is a bit-identical no-op. Skip the inner sweep entirely.
2719                if ck == 0.0 {
2720                    continue;
2721                }
2722                hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2723                for l in (k + 1)..n_pen {
2724                    let cl = if l < d {
2725                        lambdas[l] * marginal_eigenvalues[l][multi_idx[l]]
2726                    } else if joint_null {
2727                        lambdas[d]
2728                    } else {
2729                        0.0
2730                    };
2731                    let off = -ck * cl * inv_sigma2;
2732                    hess[[k, l]] += off;
2733                    hess[[l, k]] += off;
2734                }
2735            }
2736        }
2737
2738        if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
2739            break;
2740        }
2741    }
2742
2743    (logdet, grad, hess)
2744}
2745
2746// #1521: `KroneckerInvariantStructure` is defined once in `crate::kronecker`
2747// (the leaf data+compute module). The byte-identical copy that the carve left
2748// here is replaced by an import so the cache and this engine share one type.
2749use crate::kronecker::KroneckerInvariantStructure;
2750
2751/// Kronecker-factored reparameterization for tensor-product penalties.
2752///
2753/// Instead of eigendecomposing the full p×p balanced penalty (O(p³)), this
2754/// eigendecomposes each marginal penalty separately (O(Σ q_k³)) and computes
2755/// the joint eigensystem as the Kronecker product of marginal eigensystems.
2756pub fn kronecker_reparameterization_engine(
2757    marginal_designs: &[Array2<f64>],
2758    marginal_penalties: &[Array2<f64>],
2759    marginal_dims: &[usize],
2760    lambdas: &[f64],
2761    has_double_penalty: bool,
2762    penalty_shrinkage_floor: Option<f64>,
2763) -> Result<KroneckerReparamResult, EstimationError> {
2764    let d = marginal_dims.len();
2765    if marginal_designs.len() != d || marginal_penalties.len() != d {
2766        return Err(EstimationError::LayoutError(format!(
2767            "kronecker_reparameterization_engine: dimension mismatch: designs={}, penalties={}, dims={}",
2768            marginal_designs.len(),
2769            marginal_penalties.len(),
2770            d
2771        )));
2772    }
2773
2774    let invariant =
2775        KroneckerInvariantStructure::compute(marginal_designs, marginal_penalties, marginal_dims)?;
2776    kronecker_reparameterization_engine_with_invariant(
2777        &invariant,
2778        marginal_dims,
2779        lambdas,
2780        has_double_penalty,
2781        penalty_shrinkage_floor,
2782    )
2783}
2784
2785/// Kronecker-factored reparameterization reusing a precomputed λ-invariant
2786/// structure (eigensystems, reparameterized marginals, shrinkage scale).
2787///
2788/// Bit-identical to `kronecker_reparameterization_engine` for the same marginal
2789/// data — the only difference is that the `eigh()` / `B_k U_k` work was hoisted
2790/// out of the per-iterate path into the cached `invariant`. Only the λ-dependent
2791/// `kronecker_logdet_and_derivatives` sweep and `floor * max_bal` scaling run here.
2792pub fn kronecker_reparameterization_engine_with_invariant(
2793    invariant: &KroneckerInvariantStructure,
2794    marginal_dims: &[usize],
2795    lambdas: &[f64],
2796    has_double_penalty: bool,
2797    penalty_shrinkage_floor: Option<f64>,
2798) -> Result<KroneckerReparamResult, EstimationError> {
2799    // Arc refcount bumps — the underlying eigensystems / reparameterized
2800    // marginals are λ-invariant and shared with the cache, not deep-copied.
2801    let marginal_eigenvalues = Arc::clone(&invariant.marginal_eigenvalues);
2802    let marginal_qs = Arc::clone(&invariant.marginal_qs);
2803    let reparameterized_marginals = Arc::clone(&invariant.reparameterized_marginals);
2804
2805    // Compute shrinkage ridge from balanced penalty eigenvalue scale.
2806    let penalty_shrinkage_ridge = if let Some(floor) = penalty_shrinkage_floor {
2807        floor * invariant.max_balanced_eigenvalue
2808    } else {
2809        0.0
2810    };
2811
2812    let marginal_eigenvalue_views: Vec<_> = marginal_eigenvalues
2813        .iter()
2814        .map(|evals| evals.view())
2815        .collect();
2816    let (log_det, det1, det2) = kronecker_logdet_and_derivatives(
2817        &marginal_eigenvalue_views,
2818        marginal_dims,
2819        lambdas,
2820        has_double_penalty,
2821        penalty_shrinkage_ridge,
2822    );
2823
2824    Ok(KroneckerReparamResult {
2825        reparameterized_marginals,
2826        marginal_eigenvalues,
2827        marginal_qs,
2828        log_det,
2829        det1,
2830        det2,
2831        penalty_shrinkage_ridge,
2832        has_double_penalty,
2833        marginal_dims: marginal_dims.to_vec(),
2834    })
2835}
2836
2837/// Calculate the 2-norm condition number of a matrix.
2838///
2839/// For symmetric matrices (the dominant case for GAM Hessians/penalties),
2840/// this uses an eigenvalue path and computes:
2841///   cond_2(A) = max_i |lambda_i| / min_i |lambda_i|
2842/// which is exactly equal to the singular-value definition for symmetric A.
2843///
2844/// For non-symmetric matrices, this falls back to SVD:
2845///   cond_2(A) = sigma_max / sigma_min
2846///
2847/// This preserves semantics while avoiding full SVD in hot paths.
2848///
2849/// # Arguments
2850/// * `matrix` - The matrix to analyze
2851///
2852/// # Returns
2853/// * `Ok(condition_number)` - The condition number (max_sv / min_sv)
2854/// * `Ok(f64::INFINITY)` - If the matrix is effectively singular (min_sv < 1e-12)
2855/// * `Err` - If SVD computation fails
2856pub fn calculate_condition_number(matrix: &Array2<f64>) -> Result<f64, FaerLinalgError> {
2857    let (rows, cols) = matrix.dim();
2858    if rows == 0 || cols == 0 {
2859        return Ok(1.0);
2860    }
2861
2862    // Fast path for (near-)symmetric square matrices.
2863    if rows == cols {
2864        let mut max_abs = 0.0_f64;
2865        let mut max_asym = 0.0_f64;
2866        for i in 0..rows {
2867            for j in 0..cols {
2868                max_abs = max_abs.max(matrix[[i, j]].abs());
2869            }
2870            for j in 0..i {
2871                let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2872                if diff > max_asym {
2873                    max_asym = diff;
2874                }
2875            }
2876        }
2877        let sym_tol = max_abs.max(1.0) * 1e-12;
2878        if max_asym <= sym_tol {
2879            let (evals, _) = matrix.eigh(Side::Lower)?;
2880            let mut max_abs_eval = 0.0_f64;
2881            let mut min_abs_eval = f64::INFINITY;
2882            for &lam in evals.iter() {
2883                let s = lam.abs();
2884                max_abs_eval = max_abs_eval.max(s);
2885                min_abs_eval = min_abs_eval.min(s);
2886            }
2887            if min_abs_eval < 1e-12 {
2888                return Ok(f64::INFINITY);
2889            }
2890            return Ok(max_abs_eval / min_abs_eval);
2891        }
2892    }
2893
2894    // General matrix fallback.
2895    let (_, s, _) = matrix.svd(false, false)?;
2896    let max_sv = s.iter().fold(0.0_f64, |max, &val| max.max(val));
2897    let min_sv = s.iter().fold(f64::INFINITY, |min, &val| min.min(val));
2898    if min_sv < 1e-12 {
2899        return Ok(f64::INFINITY);
2900    }
2901    Ok(max_sv / min_sv)
2902}
2903
2904#[cfg(test)]
2905mod tests {
2906    use super::{
2907        CanonicalPenalty, SubspaceLeakageMetrics, assess_subspace_leakage,
2908        classify_eigenvalues_strict, precompute_reparam_invariant_from_canonical,
2909        report_penalty_pair_redundancy, stable_reparameterizationwith_invariant,
2910    };
2911    use crate::construction::kronecker_product;
2912    use crate::EstimationError;
2913    use faer::Mat;
2914    use gam_linalg::faer_ndarray::FaerEigh;
2915    use gam_linalg::utils::inf_norm;
2916    use ndarray::{Array1, Array2, array};
2917
2918    /// Build CanonicalPenalty values from full-width roots for tests.
2919    fn canonical_from_roots(rs_list: &[Array2<f64>], p: usize) -> Vec<CanonicalPenalty> {
2920        rs_list
2921            .iter()
2922            .map(|r| {
2923                let local = r.t().dot(r);
2924                CanonicalPenalty {
2925                    root: r.clone(),
2926                    col_range: 0..p,
2927                    total_dim: p,
2928                    nullity: 0,
2929                    local,
2930                    prior_mean: Array1::zeros(p),
2931                    positive_eigenvalues: Vec::new(),
2932                    op: None,
2933                }
2934            })
2935            .collect()
2936    }
2937
2938    fn metrics_for(
2939        qs: &Mat<f64>,
2940        rs: &[Mat<f64>],
2941        structural_rank: usize,
2942        p: usize,
2943    ) -> SubspaceLeakageMetrics {
2944        assess_subspace_leakage(qs, rs, structural_rank, p)
2945    }
2946
2947    #[test]
2948    fn subspace_leakage_iszero_for_clean_split() {
2949        let p = 4usize;
2950        let structural_rank = 2usize;
2951        let qs = Mat::<f64>::identity(p, p);
2952        let mut r0 = Mat::<f64>::zeros(2, p);
2953        r0[(0, 0)] = 1.0;
2954        r0[(1, 1)] = 2.0;
2955
2956        let m = metrics_for(&qs, &[r0], structural_rank, p);
2957        assert!(m.max_abs_sq <= 1e-16);
2958        assert!(m.max_rel_sq <= 1e-16);
2959        assert!(m.max_cross_gram_abs <= 1e-16);
2960    }
2961
2962    #[test]
2963    fn subspace_leakage_detects_null_column_energy() {
2964        let p = 4usize;
2965        let structural_rank = 2usize;
2966        let qs = Mat::<f64>::identity(p, p);
2967        let mut r0 = Mat::<f64>::zeros(1, p);
2968        r0[(0, 2)] = 3.0;
2969
2970        let m = metrics_for(&qs, &[r0], structural_rank, p);
2971        assert!(m.max_abs_sq > 0.0);
2972        assert!(m.max_rel_sq > 0.99);
2973    }
2974
2975    #[test]
2976    fn subspace_leakage_detects_qp_qn_nonorthogonality() {
2977        let p = 3usize;
2978        let structural_rank = 1usize;
2979        let mut qs = Mat::<f64>::identity(p, p);
2980        qs[(0, 1)] = 0.2;
2981        let r0 = Mat::<f64>::zeros(1, p);
2982
2983        let m = metrics_for(&qs, &[r0], structural_rank, p);
2984        assert!(m.max_cross_gram_abs > 1e-3);
2985    }
2986
2987    #[test]
2988    fn u_truncated_is_transformed_frame_in_nonzero_case() {
2989        let p = 3usize;
2990        let rs_list = vec![array![[1.0, 0.0, 0.0]]];
2991        let canonical = canonical_from_roots(&rs_list, p);
2992        let lambdas = vec![2.0];
2993        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
2994            .expect("precompute invariant");
2995        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
2996            .expect("stable reparam");
2997
2998        let expected = rep.qs.t().dot(&inv.split.q_null);
2999        let diff = &rep.u_truncated - &expected;
3000        let max_abs = inf_norm(diff.iter().copied());
3001        assert!(
3002            max_abs <= 1e-10,
3003            "u_truncated frame mismatch: max_abs={max_abs}"
3004        );
3005    }
3006
3007    #[test]
3008    fn infinite_lambda_keeps_range_penalty_block_finite_1379() {
3009        // gam#1379 / gam#1074: a genuinely infinite λ = exp(ρ) is NOT silently
3010        // clamped to a finite ceiling. The original #1379 fix added a 1e300
3011        // ceiling so `∞ · 0` could not poison the range block Σ_k λ_k S_k, but
3012        // #1074 DELETED that clamp on purpose (see the comment at the top of
3013        // `stable_reparameterizationwith_invariant`): masking ∞ hid the real
3014        // defect — the outer optimizer driving a redundant/unidentified penalty
3015        // direction off to ∞ instead of that direction being detected and
3016        // dropped. With the clamp gone, a literal `f64::INFINITY` λ surfaces as
3017        // a clean, detectable error (the eigensolver rejects the NaN-poisoned
3018        // block) rather than a silent finite success. Pin that contract: ∞ must
3019        // ERROR, not be quietly clamped.
3020        //
3021        // Fixture: two penalties on a 3-wide block. The first penalizes only
3022        // coordinate 0 (so its block S_k has structural zeros everywhere except
3023        // [0,0]); give it λ = +∞. The second penalizes coordinate 1 at a normal
3024        // λ.
3025        let p = 3usize;
3026        let rs_list = vec![array![[1.0, 0.0, 0.0]], array![[0.0, 1.0, 0.0]]];
3027        let canonical = canonical_from_roots(&rs_list, p);
3028        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3029            .expect("precompute invariant");
3030
3031        let lambdas_inf = vec![f64::INFINITY, 3.0];
3032        let inf_result =
3033            stable_reparameterizationwith_invariant(&canonical, &lambdas_inf, p, &inv, None);
3034        assert!(
3035            inf_result.is_err(),
3036            "an infinite lambda must surface as an error, not be silently clamped (#1074)"
3037        );
3038
3039        // A finite (even very large) λ must still produce an all-finite reparam:
3040        // the function is robust to large-but-finite penalties; only the
3041        // non-finite input is rejected.
3042        let lambdas_big = vec![1e300_f64, 3.0];
3043        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas_big, p, &inv, None)
3044            .expect("stable reparam at large-but-finite lambda");
3045        assert!(
3046            rep.s_transformed.iter().all(|v| v.is_finite()),
3047            "transformed penalty must be finite at large-but-finite lambda"
3048        );
3049        assert!(
3050            rep.qs.iter().all(|v| v.is_finite()),
3051            "reparam rotation must be finite at large-but-finite lambda"
3052        );
3053        assert!(
3054            rep.log_det.is_finite(),
3055            "penalty log-det must be finite at large-but-finite lambda"
3056        );
3057        assert!(
3058            rep.det1.iter().all(|v| v.is_finite()),
3059            "penalty log-det derivatives must be finite at large-but-finite lambda"
3060        );
3061    }
3062
3063    #[test]
3064    fn u_truncated_is_identitywhen_no_penalties() {
3065        let p = 4usize;
3066        let canonical: Vec<CanonicalPenalty> = Vec::new();
3067        let lambdas: Vec<f64> = Vec::new();
3068        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3069            .expect("precompute invariant");
3070        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3071            .expect("stable reparam");
3072        assert_eq!(rep.u_truncated, Array2::<f64>::eye(p));
3073    }
3074
3075    #[test]
3076    fn dense_shrinkage_floor_skips_structurally_unpenalized_range_columns() {
3077        let p = 3usize;
3078        let canonical = canonical_from_roots(&[array![[1.0, 0.0, 0.0]]], p);
3079        let invariant = super::ReparamInvariant {
3080            split: super::SubspaceSplit {
3081                q_pen: array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
3082                q_null: array![[0.0], [0.0], [1.0]],
3083            },
3084            qs_base: Array2::eye(p),
3085            has_nonzero: true,
3086            max_balanced_eigenvalue: 1.0,
3087        };
3088
3089        let rep =
3090            stable_reparameterizationwith_invariant(&canonical, &[2.0], p, &invariant, Some(1e-6))
3091                .expect("stable reparameterization");
3092        assert!(rep.s_transformed[[0, 0]] > 2.0);
3093        assert!(
3094            rep.s_transformed[[1, 1]] <= 1e-11,
3095            "structurally unpenalized range coordinate received shrinkage ridge: {}",
3096            rep.s_transformed[[1, 1]]
3097        );
3098    }
3099
3100    #[test]
3101    fn kronecker_shrinkage_floor_preserves_joint_null_space() {
3102        let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3103        let marginal_penalties = vec![
3104            array![[0.0, 0.0], [0.0, 2.0]],
3105            array![[0.0, 0.0], [0.0, 3.0]],
3106        ];
3107        let marginal_dims = vec![2usize, 2usize];
3108        let lambdas = vec![5.0, 7.0];
3109
3110        let rep = super::kronecker_reparameterization_engine(
3111            &marginal_designs,
3112            &marginal_penalties,
3113            &marginal_dims,
3114            &lambdas,
3115            false,
3116            Some(1e-6),
3117        )
3118        .expect("kronecker reparameterization");
3119        assert!(rep.penalty_shrinkage_ridge > 0.0);
3120
3121        let s = rep.materialize_s_transformed(&lambdas);
3122        assert!(
3123            s[[0, 0]].abs() <= 1e-14,
3124            "joint tensor null direction must remain unpenalized, got {}",
3125            s[[0, 0]]
3126        );
3127        assert!(s[[1, 1]] > lambdas[1] * 3.0);
3128        assert!(s[[2, 2]] > lambdas[0] * 2.0);
3129        assert!(s[[3, 3]] > lambdas[0] * 2.0 + lambdas[1] * 3.0);
3130
3131        let tensor_roots = vec![
3132            array![
3133                [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3134                [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3135            ],
3136            array![
3137                [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3138                [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3139            ],
3140        ];
3141        let dense = rep
3142            .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3143            .expect("dense artifact materialization");
3144        assert_eq!(dense.e_transformed.nrows(), 3);
3145        assert_eq!(dense.u_truncated.ncols(), 1);
3146    }
3147
3148    #[test]
3149    fn kronecker_memoized_invariant_is_bit_identical_to_unmemoized_engine() {
3150        // The hot-path memoization (compute the marginal eigensystems /
3151        // reparameterized marginals once, reuse across outer iterates) must
3152        // produce a KroneckerReparamResult that is *bit-identical* to the
3153        // unmemoized engine for the same marginal data and λ — the cached work
3154        // is literally the same eigendecomposition. Cover several λ on one fixed
3155        // invariant structure (the realistic outer-loop pattern).
3156        let marginal_designs = vec![
3157            array![[1.0, 0.3, -0.2], [0.4, 1.0, 0.1], [-0.1, 0.2, 1.0]],
3158            array![[1.0, -0.5], [0.2, 1.0], [0.7, 0.3]],
3159        ];
3160        let marginal_penalties = vec![
3161            array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]],
3162            array![[3.0, -1.5], [-1.5, 3.0]],
3163        ];
3164        let marginal_dims = vec![3usize, 2usize];
3165
3166        let invariant = super::KroneckerInvariantStructure::compute(
3167            &marginal_designs,
3168            &marginal_penalties,
3169            &marginal_dims,
3170        )
3171        .expect("invariant structure");
3172
3173        for lambdas in [
3174            vec![5.0, 7.0],
3175            vec![0.0, 7.0],
3176            vec![5.0, 0.0],
3177            vec![1e-3, 1e3],
3178        ] {
3179            for floor in [None, Some(1e-6)] {
3180                let unmemoized = super::kronecker_reparameterization_engine(
3181                    &marginal_designs,
3182                    &marginal_penalties,
3183                    &marginal_dims,
3184                    &lambdas,
3185                    true,
3186                    floor,
3187                )
3188                .expect("unmemoized engine");
3189                let memoized = super::kronecker_reparameterization_engine_with_invariant(
3190                    &invariant,
3191                    &marginal_dims,
3192                    &lambdas,
3193                    true,
3194                    floor,
3195                )
3196                .expect("memoized engine");
3197
3198                assert_eq!(memoized.log_det.to_bits(), unmemoized.log_det.to_bits());
3199                assert_eq!(
3200                    memoized.penalty_shrinkage_ridge.to_bits(),
3201                    unmemoized.penalty_shrinkage_ridge.to_bits()
3202                );
3203                for (a, b) in memoized.det1.iter().zip(unmemoized.det1.iter()) {
3204                    assert_eq!(a.to_bits(), b.to_bits());
3205                }
3206                for (a, b) in memoized.det2.iter().zip(unmemoized.det2.iter()) {
3207                    assert_eq!(a.to_bits(), b.to_bits());
3208                }
3209                for (ma, ua) in memoized
3210                    .reparameterized_marginals
3211                    .iter()
3212                    .zip(unmemoized.reparameterized_marginals.iter())
3213                {
3214                    for (a, b) in ma.iter().zip(ua.iter()) {
3215                        assert_eq!(a.to_bits(), b.to_bits());
3216                    }
3217                }
3218                for (mq, uq) in memoized
3219                    .marginal_qs
3220                    .iter()
3221                    .zip(unmemoized.marginal_qs.iter())
3222                {
3223                    for (a, b) in mq.iter().zip(uq.iter()) {
3224                        assert_eq!(a.to_bits(), b.to_bits());
3225                    }
3226                }
3227            }
3228        }
3229    }
3230
3231    #[test]
3232    fn kronecker_double_penalty_shrinks_only_joint_null_space() {
3233        let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3234        let marginal_penalties = vec![
3235            array![[0.0, 0.0], [0.0, 2.0]],
3236            array![[0.0, 0.0], [0.0, 3.0]],
3237        ];
3238        let marginal_dims = vec![2usize, 2usize];
3239        let lambdas = vec![5.0, 7.0, 11.0];
3240
3241        let rep = super::kronecker_reparameterization_engine(
3242            &marginal_designs,
3243            &marginal_penalties,
3244            &marginal_dims,
3245            &lambdas,
3246            true,
3247            None,
3248        )
3249        .expect("kronecker reparameterization");
3250
3251        let s = rep.materialize_s_transformed(&lambdas);
3252        let expected = [11.0, 21.0, 10.0, 31.0];
3253        for (idx, expected_diag) in expected.iter().copied().enumerate() {
3254            assert!(
3255                (s[[idx, idx]] - expected_diag).abs() <= 1e-12,
3256                "diagonal {idx} got {}, expected {expected_diag}",
3257                s[[idx, idx]]
3258            );
3259        }
3260
3261        let expected_logdet: f64 = expected.iter().map(|v| f64::ln(*v)).sum();
3262        assert!((rep.log_det - expected_logdet).abs() <= 1e-12);
3263        assert!(
3264            (rep.det1[2] - 1.0).abs() <= 1e-12,
3265            "double-penalty derivative must come only from the joint null mode, got {}",
3266            rep.det1[2]
3267        );
3268        assert!(rep.det2[[2, 2]].abs() <= 1e-12);
3269
3270        let tensor_roots = vec![
3271            array![
3272                [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3273                [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3274            ],
3275            array![
3276                [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3277                [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3278            ],
3279        ];
3280        let dense = rep
3281            .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3282            .expect("dense artifact materialization");
3283        for (idx, expected_diag) in expected.iter().copied().enumerate() {
3284            assert!(
3285                (dense.s_transformed[[idx, idx]] - expected_diag).abs() <= 1e-12,
3286                "dense artifact diagonal {idx} got {}, expected {expected_diag}",
3287                dense.s_transformed[[idx, idx]]
3288            );
3289        }
3290    }
3291
3292    #[test]
3293    fn transformed_penalty_is_diagonal_in_transformed_frame() {
3294        let p = 3usize;
3295        let inv_sqrt2 = 2.0_f64.sqrt().recip();
3296        // Penalize a rotated direction in original space so Qs is non-trivial.
3297        let rs_list = vec![array![[inv_sqrt2, inv_sqrt2, 0.0]]];
3298        let canonical = canonical_from_roots(&rs_list, p);
3299        let lambdas = vec![4.0];
3300        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3301            .expect("precompute invariant");
3302        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3303            .expect("stable reparam");
3304
3305        assert_eq!(rep.e_transformed.nrows(), 1);
3306        assert!(rep.e_transformed[[0, 0]].abs() > 0.0);
3307        assert!(rep.e_transformed[[0, 1]].abs() <= 1e-12);
3308        assert!(rep.e_transformed[[0, 2]].abs() <= 1e-12);
3309        // Exact pseudo-logdet on the structural penalized block has no
3310        // delta-dependent nullspace normalization.
3311        let expected_det1 = 1.0_f64;
3312        assert!((rep.det1[0] - expected_det1).abs() <= 1e-12);
3313
3314        let s = rep.s_transformed;
3315        let mut max_offdiag = 0.0_f64;
3316        for i in 0..p {
3317            for j in 0..p {
3318                if i != j {
3319                    max_offdiag = max_offdiag.max(s[[i, j]].abs());
3320                }
3321            }
3322        }
3323        assert!(
3324            max_offdiag <= 1e-10,
3325            "transformed penalty should be diagonal, max offdiag={max_offdiag}"
3326        );
3327        assert!(s[[1, 1]].abs() <= 1e-10);
3328        assert!(s[[2, 2]].abs() <= 1e-10);
3329    }
3330
3331    #[test]
3332    fn det1_matches_rank_for_single_full_rank_penalty() {
3333        let p = 2usize;
3334        let inv_sqrt2 = 2.0_f64.sqrt().recip();
3335        // Q^T for a 45-degree rotation.
3336        let q_t = [[inv_sqrt2, inv_sqrt2], [-inv_sqrt2, inv_sqrt2]];
3337        // R = diag(3, 1) * Q^T gives S = Q * diag(9, 1) * Q^T.
3338        let rs = array![
3339            [3.0 * q_t[0][0], 3.0 * q_t[0][1]],
3340            [1.0 * q_t[1][0], 1.0 * q_t[1][1]]
3341        ];
3342        let rs_list = vec![rs];
3343        let canonical = canonical_from_roots(&rs_list, p);
3344        let lambdas = vec![5.0];
3345
3346        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3347            .expect("precompute invariant");
3348        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3349            .expect("stable reparam");
3350
3351        assert_eq!(rep.e_transformed.nrows(), p);
3352        let det1 = rep.det1[0];
3353        // Exact pseudo-logdet on the structural penalized block:
3354        //   det1 = lambda * sum_l d_l / (lambda*d_l)
3355        // where d_l are eigenvalues of S_k.
3356        let s_k_eigs = [9.0_f64, 1.0_f64];
3357        let lambda = 5.0_f64;
3358        let expected_det1: f64 = s_k_eigs.iter().map(|&d| lambda * d / (lambda * d)).sum();
3359        assert!(
3360            (det1 - expected_det1).abs() <= 1e-12,
3361            "expected det1={expected_det1}, got {det1}",
3362        );
3363
3364        let s = rep.s_transformed;
3365        assert!(s[[0, 1]].abs() <= 1e-10);
3366        assert!(s[[1, 0]].abs() <= 1e-10);
3367        assert!(s[[0, 0]] > 0.0);
3368        assert!(s[[1, 1]] > 0.0);
3369    }
3370
3371    #[test]
3372    fn kronecker_reparam_logdet_matches_dense() {
3373        // 2D tensor product: q1=3, q2=4.
3374        // Marginal penalties: second-order difference matrices.
3375        let q1 = 3;
3376        let q2 = 4;
3377        let s1 = {
3378            let mut s = Array2::<f64>::zeros((q1, q1));
3379            // D2' D2 for order 2 on 3 points: [[1,-2,1],[-2,4,-2],[1,-2,1]]... simplified
3380            s[[0, 0]] = 1.0;
3381            s[[0, 1]] = -1.0;
3382            s[[1, 0]] = -1.0;
3383            s[[1, 1]] = 2.0;
3384            s[[1, 2]] = -1.0;
3385            s[[2, 1]] = -1.0;
3386            s[[2, 2]] = 1.0;
3387            s
3388        };
3389        let s2 = {
3390            let mut s = Array2::<f64>::zeros((q2, q2));
3391            s[[0, 0]] = 1.0;
3392            s[[0, 1]] = -1.0;
3393            s[[1, 0]] = -1.0;
3394            s[[1, 1]] = 2.0;
3395            s[[1, 2]] = -1.0;
3396            s[[2, 1]] = -1.0;
3397            s[[2, 2]] = 2.0;
3398            s[[2, 3]] = -1.0;
3399            s[[3, 2]] = -1.0;
3400            s[[3, 3]] = 1.0;
3401            s
3402        };
3403
3404        let lambdas = [2.5, 1.3];
3405        // Build dense Kronecker penalty: λ1 (S1⊗I) + λ2 (I⊗S2).
3406        let p = q1 * q2;
3407        let i1 = Array2::<f64>::eye(q1);
3408        let i2 = Array2::<f64>::eye(q2);
3409        let pen0 = kronecker_product(&s1, &i2);
3410        let pen1 = kronecker_product(&i1, &s2);
3411        let mut s_dense = Array2::<f64>::zeros((p, p));
3412        s_dense.scaled_add(lambdas[0], &pen0);
3413        s_dense.scaled_add(lambdas[1], &pen1);
3414
3415        // Dense eigendecomposition for reference pseudo-logdet.
3416        let (evals_dense, _): (ndarray::Array1<f64>, ndarray::Array2<f64>) =
3417            s_dense.eigh(faer::Side::Lower).unwrap();
3418        let tol = 1e-12;
3419        let ref_logdet: f64 = evals_dense
3420            .iter()
3421            .filter(|&&v: &&f64| v > tol)
3422            .map(|&v: &f64| v.ln())
3423            .sum();
3424
3425        // Kronecker reparameterization engine.
3426        let marginal_designs = vec![
3427            Array2::<f64>::eye(q1), // dummy designs
3428            Array2::<f64>::eye(q2),
3429        ];
3430        let marginal_penalties = vec![s1, s2];
3431        let kron_result = super::kronecker_reparameterization_engine(
3432            &marginal_designs,
3433            &marginal_penalties,
3434            &[q1, q2],
3435            &lambdas,
3436            false,
3437            None,
3438        )
3439        .unwrap();
3440
3441        let diff = (kron_result.log_det - ref_logdet).abs();
3442        assert!(
3443            diff < 1e-8,
3444            "Kronecker logdet {:.10} vs dense {:.10}, diff={:.3e}",
3445            kron_result.log_det,
3446            ref_logdet,
3447            diff,
3448        );
3449
3450        // Check derivatives via central FD in rho-space (rho = log lambda).
3451        let rhos: Vec<f64> = lambdas.iter().map(|&l| l.ln()).collect();
3452        let eps = 1e-5;
3453        for k in 0..2 {
3454            let mut rho_plus = rhos.clone();
3455            rho_plus[k] += eps;
3456            let mut rho_minus = rhos.clone();
3457            rho_minus[k] -= eps;
3458            let lam_plus: Vec<f64> = rho_plus.iter().map(|&r| r.exp()).collect();
3459            let lam_minus: Vec<f64> = rho_minus.iter().map(|&r| r.exp()).collect();
3460            let result_plus = super::kronecker_reparameterization_engine(
3461                &marginal_designs,
3462                &marginal_penalties,
3463                &[q1, q2],
3464                &lam_plus,
3465                false,
3466                None,
3467            )
3468            .unwrap();
3469            let result_minus = super::kronecker_reparameterization_engine(
3470                &marginal_designs,
3471                &marginal_penalties,
3472                &[q1, q2],
3473                &lam_minus,
3474                false,
3475                None,
3476            )
3477            .unwrap();
3478            let fd_deriv = (result_plus.log_det - result_minus.log_det) / (2.0 * eps);
3479            let analytic_deriv = kron_result.det1[k];
3480            let rel_err = if analytic_deriv.abs() > 1e-10 {
3481                (fd_deriv - analytic_deriv).abs() / analytic_deriv.abs()
3482            } else {
3483                (fd_deriv - analytic_deriv).abs()
3484            };
3485            assert!(
3486                rel_err < 1e-4,
3487                "det1[{k}] mismatch: analytic={:.8}, fd={:.8}, rel_err={:.3e}",
3488                analytic_deriv,
3489                fd_deriv,
3490                rel_err,
3491            );
3492        }
3493    }
3494
3495    #[test]
3496    fn classify_strict_rejects_nan_eigenvalue() {
3497        let mut eigs = [1.0, f64::NAN, 0.5];
3498        match classify_eigenvalues_strict(&mut eigs, "test_nan") {
3499            Err(EstimationError::PenaltySpectrumNonFinite {
3500                context,
3501                index,
3502                value,
3503            }) => {
3504                assert_eq!(context, "test_nan");
3505                assert_eq!(index, 1);
3506                assert!(value.is_nan());
3507            }
3508            other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3509        }
3510    }
3511
3512    #[test]
3513    fn classify_strict_rejects_inf_eigenvalue() {
3514        let mut eigs = [1.0, 0.5, f64::INFINITY];
3515        match classify_eigenvalues_strict(&mut eigs, "test_inf") {
3516            Err(EstimationError::PenaltySpectrumNonFinite { index, value, .. }) => {
3517                assert_eq!(index, 2);
3518                assert!(value.is_infinite());
3519            }
3520            other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3521        }
3522    }
3523
3524    #[test]
3525    fn classify_strict_rejects_materially_indefinite() {
3526        // -1e-2 with scale ~1.0 is well above any reasonable roundoff tolerance.
3527        let mut eigs = [1.0, -1e-2, 0.5];
3528        match classify_eigenvalues_strict(&mut eigs, "test_indef") {
3529            Err(EstimationError::PenaltySpectrumIndefinite {
3530                context,
3531                index,
3532                value,
3533                ..
3534            }) => {
3535                assert_eq!(context, "test_indef");
3536                assert_eq!(index, 1);
3537                assert!((value + 1e-2).abs() <= 1e-15);
3538            }
3539            other => panic!("expected PenaltySpectrumIndefinite, got {:?}", other),
3540        }
3541    }
3542
3543    #[test]
3544    fn classify_strict_accepts_roundoff_negative() {
3545        // -1e-16 * scale is well within tol = 64 * eps * p * scale.
3546        let scale = 1.0_f64;
3547        let roundoff = -1e-16 * scale;
3548        let mut eigs = [scale, 0.5 * scale, roundoff, 0.25 * scale];
3549        classify_eigenvalues_strict(&mut eigs, "test_roundoff").expect("roundoff must classify");
3550        // The roundoff eigenvalue is snapped to exact zero.
3551        assert_eq!(eigs[2], 0.0);
3552        // Strictly positive entries must be preserved.
3553        assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3554    }
3555
3556    #[test]
3557    fn classify_strict_accepts_extreme_lambda_assembly_noise_1619() {
3558        // #1619: high-rank thin-plate / Duchon penalties (p≈200) assembled and
3559        // reparameterized at extreme λ produce float-noise negative eigenvalues at
3560        // ~1e-11 relative to the spectrum scale (~1e13 there). The bare machine-ε
3561        // floor (~12×ε relative ≈ 1e-12) rejected these as "indefinite" and
3562        // spuriously failed the inner P-IRLS solve. They are PSD to numerical
3563        // precision and must be snapped to zero, not rejected.
3564        let scale = 8.509e12_f64;
3565        // Worst (eig, scale) pair observed in the issue: eig/scale ≈ -7.7e-11.
3566        let noise = -6.546e2_f64;
3567        assert!(
3568            (noise.abs() / scale) < 1.0e-10,
3569            "fixture must reproduce the ~1e-11-relative noise from #1619"
3570        );
3571        let mut eigs = vec![scale, 0.5 * scale, noise, 0.1 * scale];
3572        classify_eigenvalues_strict(&mut eigs, "range penalty block")
3573            .expect("a ~1e-11-relative roundoff-negative eigenvalue must be accepted (#1619)");
3574        // The roundoff-negative eigenvalue is snapped to exact zero.
3575        assert_eq!(eigs[2], 0.0);
3576        // Strictly positive entries are preserved.
3577        assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3578    }
3579
3580    #[test]
3581    fn classify_strict_snaps_subtol_positive_to_zero() {
3582        // Positive eigenvalues below the tolerance are also snapped to exact 0
3583        // so downstream rank counts and pseudo-logdets are deterministic.
3584        let scale = 10.0_f64;
3585        let subtol = 1e-15 * scale;
3586        let mut eigs = [scale, subtol];
3587        classify_eigenvalues_strict(&mut eigs, "test_sub_pos").expect("sub-tol positive ok");
3588        assert_eq!(eigs[1], 0.0);
3589    }
3590
3591    /// Build a `CanonicalPenalty` directly from a symmetric `local` matrix.
3592    /// Bypasses root extraction — the redundancy diagnostic only reads `local`
3593    /// and `col_range`, so the rest is filler.
3594    fn canonical_from_local(
3595        local: Array2<f64>,
3596        col_range: std::ops::Range<usize>,
3597        total_dim: usize,
3598    ) -> CanonicalPenalty {
3599        let block_dim = local.nrows();
3600        // A trivially valid root: zero rank. The diagnostic doesn't read root.
3601        let root = Array2::<f64>::zeros((0, block_dim));
3602        CanonicalPenalty {
3603            root,
3604            col_range,
3605            total_dim,
3606            nullity: 0,
3607            local,
3608            prior_mean: Array1::zeros(block_dim),
3609            positive_eigenvalues: Vec::new(),
3610            op: None,
3611        }
3612    }
3613
3614    #[test]
3615    fn report_penalty_pair_redundancy_detects_identical_pair() {
3616        // Penalty 0: a "generic" SPD matrix on cols 0..3.
3617        let s0 = ndarray::array![[2.0, 0.5, 0.0], [0.5, 1.0, 0.25], [0.0, 0.25, 1.5],];
3618        // Penalties 1 and 2: identical block-local penalty on the SAME col_range.
3619        // This is the Z₂-symmetric saddle scenario.
3620        let s_shared = ndarray::array![[1.0, -0.5, 0.0], [-0.5, 2.0, -0.5], [0.0, -0.5, 1.0],];
3621
3622        let bundle = vec![
3623            canonical_from_local(s0, 0..3, 3),
3624            canonical_from_local(s_shared.clone(), 0..3, 3),
3625            canonical_from_local(s_shared, 0..3, 3),
3626        ];
3627
3628        let redundant = report_penalty_pair_redundancy(&bundle);
3629
3630        // Exactly one redundant pair: (1, 2). Pairs (0, 1) and (0, 2) involve
3631        // distinct matrices and must NOT be flagged.
3632        assert_eq!(
3633            redundant.len(),
3634            1,
3635            "expected exactly one redundant pair, got {:?}",
3636            redundant
3637        );
3638        let (i, j, cos) = redundant[0];
3639        assert_eq!((i, j), (1, 2));
3640        assert!(
3641            cos > 1.0 - 1e-12,
3642            "cosine for identical penalties should be ~1.0, got {cos}"
3643        );
3644    }
3645
3646    #[test]
3647    fn report_penalty_pair_redundancy_skips_different_col_ranges() {
3648        // Two identical local matrices but on disjoint col_ranges. The
3649        // function must NOT flag them — they live in different parameter
3650        // subspaces by construction.
3651        let s = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
3652        let bundle = vec![
3653            canonical_from_local(s.clone(), 0..2, 4),
3654            canonical_from_local(s, 2..4, 4),
3655        ];
3656        let redundant = report_penalty_pair_redundancy(&bundle);
3657        assert!(
3658            redundant.is_empty(),
3659            "different col_ranges must not be flagged"
3660        );
3661    }
3662}