Skip to main content

gam_terms/
construction.rs

1use crate::basis::analyze_penalty_block;
2use crate::EstimationError;
3use crate::smooth::PenaltyStructureHint;
4use faer::linalg::matmul::matmul;
5use faer::{Accum, Mat, MatRef, Par, Side};
6use gam_linalg::faer_ndarray::{FaerEigh, FaerLinalgError, FaerSvd};
7use gam_linalg::matrix::symmetrize_in_place;
8use gam_linalg::utils::KahanSum;
9use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut2, Axis, s};
10use rayon::iter::{
11    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
12};
13use std::collections::{BTreeMap, HashSet};
14use std::ops::Range;
15use std::sync::Arc;
16
17#[derive(Clone)]
18pub enum PenaltyRepresentation {
19    Dense(Array2<f64>),
20    Banded {
21        bands: Vec<Array1<f64>>,
22        offsets: Vec<i32>,
23    },
24    Kronecker {
25        /// Full penalty-block Kronecker product `left ⊗ right`: each entry of
26        /// `left` scales an entire copy of `right` in the dense expansion.
27        ///
28        /// This is distinct from chunked kernel design assembly, where center
29        /// rows are kernel-evaluation arguments rather than matrix factors.
30        left: Array2<f64>,
31        right: Array2<f64>,
32    },
33}
34
35impl PenaltyRepresentation {
36    /// Side length of the square penalty block this representation expands to.
37    pub fn block_dimension(&self) -> usize {
38        match self {
39            PenaltyRepresentation::Dense(matrix) => matrix.nrows(),
40            PenaltyRepresentation::Banded { bands, offsets } => {
41                let mut dim = 0usize;
42                for (band, &offset) in bands.iter().zip(offsets.iter()) {
43                    let len = band.len();
44                    let extent = if offset >= 0 {
45                        len + offset as usize
46                    } else {
47                        len + (-offset) as usize
48                    };
49                    dim = dim.max(extent);
50                }
51                dim
52            }
53            PenaltyRepresentation::Kronecker { left, right } => left.nrows() * right.nrows(),
54        }
55    }
56
57    /// Materialize this representation (Dense / Banded / Kronecker) into a
58    /// single dense symmetric penalty block.
59    pub fn to_block_dense(&self) -> Array2<f64> {
60        match self {
61            PenaltyRepresentation::Dense(matrix) => matrix.clone(),
62            PenaltyRepresentation::Banded { bands, offsets } => {
63                let dim = self.block_dimension();
64                let mut dense = Array2::zeros((dim, dim));
65                let positive_offsets: HashSet<usize> = offsets
66                    .iter()
67                    .filter_map(|&off| (off >= 0).then_some(off as usize))
68                    .collect();
69                for (band, &offset) in bands.iter().zip(offsets.iter()) {
70                    let off = offset.unsigned_abs() as usize;
71                    if offset < 0 && positive_offsets.contains(&off) {
72                        continue;
73                    }
74                    for (idx, &value) in band.iter().enumerate() {
75                        let (i, j) = if offset >= 0 {
76                            (idx, idx + off)
77                        } else {
78                            (idx + off, idx)
79                        };
80                        if i >= dim || j >= dim {
81                            continue;
82                        }
83                        dense[[i, j]] = value;
84                        dense[[j, i]] = value;
85                    }
86                }
87                dense
88            }
89            PenaltyRepresentation::Kronecker { left, right } => {
90                let (lrows, l_cols) = left.dim();
91                let (rrows, r_cols) = right.dim();
92                let mut result = Array2::zeros((lrows * rrows, l_cols * r_cols));
93                for i in 0..lrows {
94                    for j in 0..l_cols {
95                        let scale = left[(i, j)];
96                        if scale == 0.0 {
97                            continue;
98                        }
99                        let mut block = result.slice_mut(s![
100                            i * rrows..(i + 1) * rrows,
101                            j * r_cols..(j + 1) * r_cols
102                        ]);
103                        block.assign(&(right * scale));
104                    }
105                }
106                result
107            }
108        }
109    }
110}
111
112#[derive(Clone)]
113pub struct PenaltyMatrix {
114    pub col_range: Range<usize>,
115    pub representation: PenaltyRepresentation,
116}
117
118impl PenaltyMatrix {
119    fn accumulate_into(&self, mut dest: ArrayViewMut2<'_, f64>, weight: f64) {
120        if weight == 0.0 {
121            return;
122        }
123        match &self.representation {
124            PenaltyRepresentation::Dense(block) => {
125                dest.scaled_add(weight, block);
126            }
127            PenaltyRepresentation::Banded { bands, offsets } => {
128                let positive_offsets: HashSet<usize> = offsets
129                    .iter()
130                    .filter_map(|&off| (off >= 0).then_some(off as usize))
131                    .collect();
132                for (band, &offset) in bands.iter().zip(offsets.iter()) {
133                    let off = offset.unsigned_abs() as usize;
134                    if offset < 0 && positive_offsets.contains(&off) {
135                        continue;
136                    }
137                    for (idx, &value) in band.iter().enumerate() {
138                        let (i, j) = if offset >= 0 {
139                            (idx, idx + off)
140                        } else {
141                            (idx + off, idx)
142                        };
143                        let Some(entry_ij) = dest.get_mut((i, j)) else {
144                            continue;
145                        };
146                        *entry_ij += weight * value;
147                        if i != j
148                            && let Some(entry_ji) = dest.get_mut((j, i))
149                        {
150                            *entry_ji += weight * value;
151                        }
152                    }
153                }
154            }
155            PenaltyRepresentation::Kronecker { left, right } => {
156                let (lrows, l_cols) = left.dim();
157                let (rrows, r_cols) = right.dim();
158                for i in 0..lrows {
159                    for j in 0..l_cols {
160                        let scale = left[(i, j)] * weight;
161                        if scale == 0.0 {
162                            continue;
163                        }
164                        let mut block = dest.slice_mut(s![
165                            i * rrows..(i + 1) * rrows,
166                            j * r_cols..(j + 1) * r_cols
167                        ]);
168                        block.scaled_add(scale, right);
169                    }
170                }
171            }
172        }
173    }
174
175    pub fn to_dense(&self, total_dim: usize) -> Array2<f64> {
176        let mut dense = Array2::<f64>::zeros((total_dim, total_dim));
177        self.accumulate_into(
178            dense.slice_mut(s![self.col_range.clone(), self.col_range.clone()]),
179            1.0,
180        );
181        dense
182    }
183}
184
185pub(crate) fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
186    let (rows, cols) = array.dim();
187    Mat::from_fn(rows, cols, |i, j| array[[i, j]])
188}
189
190pub(crate) fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
191    let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
192    for i in 0..mat.nrows() {
193        for j in 0..mat.ncols() {
194            out[[i, j]] = mat[(i, j)];
195        }
196    }
197    out
198}
199
200fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
201    let (rows, cols) = matrix.shape();
202    let mut maxval = 0.0_f64;
203    for i in 0..rows {
204        for j in 0..cols {
205            let val = matrix[(i, j)];
206            if val.is_finite() {
207                maxval = maxval.max(val.abs());
208            }
209        }
210    }
211    maxval
212}
213
214fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
215    let (rows, cols) = matrix.as_ref().shape();
216    assert_eq!(rows, cols, "Matrix must be square for sanitization");
217
218    let mut sanitized = matrix.clone();
219
220    for i in 0..rows {
221        let diag = sanitized[(i, i)];
222        if !diag.is_finite() {
223            sanitized[(i, i)] = 0.0;
224        }
225        for j in (i + 1)..cols {
226            let mut upper = sanitized[(i, j)];
227            let mut lower = sanitized[(j, i)];
228            if !upper.is_finite() {
229                upper = 0.0;
230            }
231            if !lower.is_finite() {
232                lower = 0.0;
233            }
234            let avg = 0.5 * (upper + lower);
235            sanitized[(i, j)] = avg;
236            sanitized[(j, i)] = avg;
237        }
238    }
239
240    let scale = mat_max_abs_element(sanitized.as_ref());
241    let tiny = (scale * 1e-14).max(1e-30);
242    for i in 0..rows {
243        for j in 0..cols {
244            let val = sanitized[(i, j)];
245            if !val.is_finite() {
246                sanitized[(i, j)] = 0.0;
247            } else if val.abs() < tiny {
248                sanitized[(i, j)] = 0.0;
249            }
250        }
251    }
252
253    sanitized
254}
255
256fn penalty_from_root_faer(root: &Mat<f64>) -> Mat<f64> {
257    let cols = root.ncols();
258    let mut full = Mat::<f64>::zeros(cols, cols);
259    let root_ref = root.as_ref();
260    let root_t = root_ref.transpose();
261    matmul(
262        full.as_mut(),
263        Accum::Replace,
264        root_t,
265        root_ref,
266        1.0,
267        Par::Seq,
268    );
269    sanitize_symmetric_faer(&full)
270}
271
272fn symmetrize_faer_matrix_in_place(matrix: &mut Mat<f64>) {
273    let n = matrix.nrows().min(matrix.ncols());
274    for i in 0..n {
275        for j in 0..i {
276            let avg = 0.5 * (matrix[(i, j)] + matrix[(j, i)]);
277            matrix[(i, j)] = avg;
278            matrix[(j, i)] = avg;
279        }
280    }
281}
282
283fn orthogonal_similarity_transform_faer(
284    matrix: &Mat<f64>,
285    block_dim: usize,
286    orthogonal: &Mat<f64>,
287) -> Mat<f64> {
288    let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
289    let cols = orthogonal.ncols();
290    let mut temp = Mat::<f64>::zeros(block_dim, cols);
291    matmul(
292        temp.as_mut(),
293        Accum::Replace,
294        matrix_block,
295        orthogonal.as_ref(),
296        1.0,
297        Par::Seq,
298    );
299    let mut rotated = Mat::<f64>::zeros(cols, cols);
300    matmul(
301        rotated.as_mut(),
302        Accum::Replace,
303        orthogonal.transpose(),
304        temp.as_ref(),
305        1.0,
306        Par::Seq,
307    );
308    symmetrize_faer_matrix_in_place(&mut rotated);
309    rotated
310}
311
312fn trace_penalty_in_orthogonal_basis(
313    matrix: &Mat<f64>,
314    block_dim: usize,
315    orthogonal: &Mat<f64>,
316    rotated_eigenvalues: &[f64],
317    delta: f64,
318) -> f64 {
319    let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
320    let cols = orthogonal.ncols();
321    assert!(rotated_eigenvalues.len() >= cols);
322    let mut projected = Mat::<f64>::zeros(block_dim, cols);
323    matmul(
324        projected.as_mut(),
325        Accum::Replace,
326        matrix_block,
327        orthogonal.as_ref(),
328        1.0,
329        Par::Seq,
330    );
331    let mut trace = KahanSum::default();
332    for l in 0..cols {
333        let mut diag_ll = KahanSum::default();
334        for i in 0..block_dim {
335            diag_ll.add(orthogonal[(i, l)] * projected[(i, l)]);
336        }
337        trace.add(diag_ll.sum() / (rotated_eigenvalues[l] + delta));
338    }
339    trace.sum()
340}
341
342pub fn trace_reduced_penalty_covariance(
343    reduced_penalty: &Array2<f64>,
344    covariance_basis: &Array2<f64>,
345) -> f64 {
346    assert_eq!(
347        reduced_penalty.dim(),
348        covariance_basis.dim(),
349        "trace_reduced_penalty_covariance dimension mismatch"
350    );
351    let r = covariance_basis.nrows();
352    let mut trace = KahanSum::default();
353    for i in 0..r {
354        for j in 0..r {
355            trace.add(covariance_basis[[i, j]] * reduced_penalty[[j, i]]);
356        }
357    }
358    trace.sum()
359}
360
361pub fn trace_penalty_covariance_in_orthogonal_basis(
362    matrix: &Array2<f64>,
363    orthogonal: &Array2<f64>,
364    covariance_basis: &Array2<f64>,
365) -> f64 {
366    let reduced = gam_linalg::faer_ndarray::fast_ab(
367        &gam_linalg::faer_ndarray::fast_atb(orthogonal, matrix),
368        orthogonal,
369    );
370    trace_reduced_penalty_covariance(&reduced, covariance_basis)
371}
372
373/// Strict spectral classifier used as a final guard on penalty eigendecompositions.
374///
375/// Penalty matrices fed to the GAM solver are required to be PSD by construction.
376/// This routine snaps roundoff-zero eigenvalues to exact zero, accepts strictly
377/// positive eigenvalues, and rejects materially-indefinite or non-finite spectra
378/// with a hard error rather than silently rewriting them. The previous behaviour
379/// (mass-zeroing negative or non-finite eigenvalues) hid construction bugs and
380/// changed the optimisation objective downstream.
381///
382/// `C_EPS_P_FACTOR = 64` chooses the multiplier `c` in
383/// `tol = c * eps_machine * p * scale`: 64 absorbs the rounding accumulated in a
384/// symmetric eigendecomposition of a moderate-dimension matrix while still
385/// rejecting the 1e-12 * scale magnitudes that previously slipped through.
386fn classify_eigenvalues_strict(
387    eigenvalues: &mut [f64],
388    context: &str,
389) -> Result<(), EstimationError> {
390    const C_EPS_P_FACTOR: f64 = 64.0;
391    let p = eigenvalues.len();
392
393    let mut scale = 0.0_f64;
394    for (idx, &val) in eigenvalues.iter().enumerate() {
395        if !val.is_finite() {
396            return Err(EstimationError::PenaltySpectrumNonFinite {
397                context: context.to_string(),
398                index: idx,
399                value: val,
400            });
401        }
402        scale = scale.max(val.abs());
403    }
404
405    // p * eps captures the rounding floor of a symmetric eigendecomposition of a
406    // p-dimensional matrix; multiplying by `scale` lifts the floor to the actual
407    // magnitude of the spectrum. The constant `C_EPS_P_FACTOR` provides headroom
408    // for the residual rounding in subsequent matmuls.
409    let tolerance =
410        (C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale).max(f64::MIN_POSITIVE);
411
412    for (idx, val) in eigenvalues.iter_mut().enumerate() {
413        if val.abs() <= tolerance {
414            *val = 0.0;
415        } else if *val < 0.0 {
416            return Err(EstimationError::PenaltySpectrumIndefinite {
417                context: context.to_string(),
418                index: idx,
419                value: *val,
420                tolerance,
421                scale,
422            });
423        }
424    }
425    Ok(())
426}
427
428fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
429    matrix: &M,
430    context: &str,
431    validate_input: Validate,
432    sanitize: Sanitize,
433    mut eig_call: EigCall,
434    map_error: MapErr,
435) -> Result<(Vec<f64>, V), EstimationError>
436where
437    Validate: Fn(&M, &str) -> Result<(), EstimationError>,
438    Sanitize: Fn(&M) -> M,
439    EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
440    MapErr: Fn(E, &str) -> EstimationError,
441{
442    validate_input(matrix, context)?;
443
444    // The sanitize step only enforces exact symmetry by averaging M and M^T and
445    // zeros sub-eps noise; it never adds a diagonal ridge. Adding ridge changes
446    // the matrix being decomposed, which silently changes the optimisation
447    // objective downstream. If eigh genuinely fails on a finite symmetric input,
448    // surface the error instead of mutating the spectrum.
449    let candidate = sanitize(matrix);
450    match eig_call(&candidate) {
451        Ok((mut eigenvalues, eigenvectors)) => {
452            classify_eigenvalues_strict(&mut eigenvalues, context)?;
453            Ok((eigenvalues, eigenvectors))
454        }
455        Err(err) => Err(map_error(err, context)),
456    }
457}
458
459pub(crate) fn robust_eigh_faer(
460    matrix: &Mat<f64>,
461    side: Side,
462    context: &str,
463) -> Result<(Vec<f64>, Mat<f64>), EstimationError> {
464    robust_eighwith_policy(
465        matrix,
466        context,
467        |mat, ctx| {
468            let (rows, cols) = mat.as_ref().shape();
469            for i in 0..rows {
470                for j in 0..cols {
471                    let val = mat[(i, j)];
472                    if !val.is_finite() {
473                        let max_abs = mat_max_abs_element(mat.as_ref());
474                        crate::bail_invalid_estim!(
475                            "{} contains non-finite entries (max finite magnitude {:.3e})",
476                            ctx,
477                            max_abs
478                        );
479                    }
480                }
481            }
482            Ok(())
483        },
484        sanitize_symmetric_faer,
485        |candidate| {
486            let eig = candidate.as_ref().self_adjoint_eigen(side)?;
487            let diag = eig.S();
488            let mut eigenvalues = Vec::with_capacity(diag.dim());
489            for idx in 0..diag.dim() {
490                eigenvalues.push(diag[idx]);
491            }
492
493            let vectors_ref = eig.U();
494            let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
495            for i in 0..vectors_ref.nrows() {
496                for j in 0..vectors_ref.ncols() {
497                    eigenvectors[(i, j)] = vectors_ref[(i, j)];
498                }
499            }
500            Ok((eigenvalues, eigenvectors))
501        },
502        |err, _ctx| {
503            EstimationError::EigendecompositionFailed(FaerLinalgError::SelfAdjointEigen(err))
504        },
505    )
506}
507
508fn robust_eigh(
509    matrix: &Array2<f64>,
510    side: Side,
511    context: &str,
512) -> Result<(Array1<f64>, Array2<f64>), EstimationError> {
513    let matrix_faer = array_to_faer(matrix);
514    let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
515    Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
516}
517
518pub(crate) fn kronecker_marginal_eigensystems(
519    marginal_penalties: &[Array2<f64>],
520    context: &str,
521) -> Result<Vec<(Array1<f64>, Array2<f64>)>, EstimationError> {
522    let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
523    for (k, penalty) in marginal_penalties.iter().enumerate() {
524        eigensystems.push(robust_eigh(
525            penalty,
526            Side::Lower,
527            &format!("{context} marginal {k}"),
528        )?);
529    }
530    Ok(eigensystems)
531}
532
533#[derive(Debug, Clone, Copy)]
534struct SubspaceLeakageMetrics {
535    max_abs_sq: f64,
536    max_rel_sq: f64,
537    worst_penalty: usize,
538    max_cross_gram_abs: f64,
539}
540
541fn assess_subspace_leakage(
542    qs: &Mat<f64>,
543    rs_transformed: &[Mat<f64>],
544    structural_rank: usize,
545    p: usize,
546) -> SubspaceLeakageMetrics {
547    let mut max_abs_sq = 0.0_f64;
548    let mut max_rel_sq = 0.0_f64;
549    let mut worst_penalty = 0usize;
550
551    for (k, rs) in rs_transformed.iter().enumerate() {
552        let rows = rs.nrows();
553        let cols = rs.ncols().min(p);
554        let null_start = structural_rank.min(cols);
555        let mut abs_sq = 0.0_f64;
556        let mut total_sq = 0.0_f64;
557        for i in 0..rows {
558            for j in 0..cols {
559                let v = rs[(i, j)];
560                let vv = v * v;
561                total_sq += vv;
562                if j >= null_start {
563                    abs_sq += vv;
564                }
565            }
566        }
567        let rel_sq = if total_sq > 0.0 {
568            abs_sq / total_sq
569        } else {
570            0.0
571        };
572        if rel_sq > max_rel_sq {
573            max_rel_sq = rel_sq;
574            worst_penalty = k;
575        }
576        max_abs_sq = max_abs_sq.max(abs_sq);
577    }
578
579    let mut max_cross_gram_abs = 0.0_f64;
580    let null_count = p.saturating_sub(structural_rank);
581    if structural_rank > 0 && null_count > 0 {
582        for i in 0..structural_rank {
583            for j in 0..null_count {
584                let qn_col = structural_rank + j;
585                let mut dot = 0.0_f64;
586                for r in 0..p {
587                    dot += qs[(r, i)] * qs[(r, qn_col)];
588                }
589                max_cross_gram_abs = max_cross_gram_abs.max(dot.abs());
590            }
591        }
592    }
593
594    SubspaceLeakageMetrics {
595        max_abs_sq,
596        max_rel_sq,
597        worst_penalty,
598        max_cross_gram_abs,
599    }
600}
601
602fn compose_qs_from_split(q_pen: &Mat<f64>, q_null: &Mat<f64>, p: usize) -> Mat<f64> {
603    let rank = q_pen.ncols();
604    let null_count = q_null.ncols();
605    let mut qs = Mat::<f64>::zeros(p, p);
606    for i in 0..p {
607        for j in 0..rank {
608            qs[(i, j)] = q_pen[(i, j)];
609        }
610        for j in 0..null_count {
611            qs[(i, rank + j)] = q_null[(i, j)];
612        }
613    }
614    qs
615}
616
617/// Computes the Kronecker product A ⊗ B for penalty matrix construction.
618/// This is used to create tensor product penalties that enforce smoothness
619/// in multiple dimensions for interaction terms.
620pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
621    let (arows, a_cols) = a.dim();
622    let (brows, b_cols) = b.dim();
623    if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
624        return Array2::zeros((arows * brows, a_cols * b_cols));
625    }
626    let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
627
628    result
629        .axis_chunks_iter_mut(Axis(0), brows)
630        .into_par_iter()
631        .enumerate()
632        .for_each(|(i, mut row_block)| {
633            let arow = a.row(i);
634            let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
635            for (j, mut block) in col_chunks.into_iter().enumerate() {
636                let aval = arow[j];
637                if aval == 0.0 {
638                    continue;
639                }
640                for (dest, &src) in block.iter_mut().zip(b.iter()) {
641                    *dest = aval * src;
642                }
643            }
644        });
645
646    result
647}
648
649/// Result of the stable reparameterization algorithm from Wood (2011) Appendix B
650#[derive(Clone)]
651pub struct ReparamResult {
652    /// Penalty matrix in TRANSFORMED coefficient coordinates.
653    ///
654    /// This must be compatible with `beta_transformed` and `X_transformed = X * Qs`.
655    pub s_transformed: Array2<f64>,
656    /// Log-determinant of the penalty matrix (stable computation)
657    pub log_det: f64,
658    /// First derivatives of log-determinant w.r.t. log-smoothing parameters
659    pub det1: Array1<f64>,
660    /// Orthogonal transformation matrix Qs
661    pub qs: Array2<f64>,
662    /// Canonical penalties in the TRANSFORMED coordinate frame.
663    /// The single source of truth for penalty roots in the transformed frame.
664    /// Downstream consumers use these for block-local `PenaltyCoordinate`
665    /// construction, TK correction, and ext-coord paths.
666    pub canonical_transformed: Vec<CanonicalPenalty>,
667    /// Lambda-dependent penalty square root in TRANSFORMED coordinates (rank x p matrix).
668    /// This is used for applying the actual penalty in the least squares solve.
669    pub e_transformed: Array2<f64>,
670    /// Truncated eigenvectors (p × m where m = p - structural_rank).
671    ///
672    /// Coordinate frame note:
673    /// - This matrix is stored in the TRANSFORMED coefficient frame (post-`Qs`),
674    ///   i.e. it is compatible with `canonical_transformed`, `beta_transformed`,
675    ///   and transformed Hessians without additional coordinate mapping.
676    ///
677    /// These vectors span the structural null space used by positive-part
678    /// log-determinant conventions.
679    pub u_truncated: Array2<f64>,
680    /// The rho-independent shrinkage ridge magnitude that was added to each
681    /// eigenvalue of the penalized block. Zero means no shrinkage was applied.
682    pub penalty_shrinkage_ridge: f64,
683}
684
685// ---------------------------------------------------------------------------
686// Kronecker factor decomposition primitives
687// ---------------------------------------------------------------------------
688
689/// Per-factor decomposition result for Kronecker penalties.
690struct KroneckerFactorDecomp {
691    root: Array2<f64>,              // rank_j × q_j
692    positive_eigenvalues: Vec<f64>, // length = rank_j
693    rank: usize,
694    dim: usize,
695}
696
697/// Eigendecompose each Kronecker factor separately at O(Σ q_j³).
698/// Returns per-factor decompositions, or `None` if any factor is zero.
699fn decompose_kronecker_factors(
700    factors: &[Array2<f64>],
701    context: &str,
702) -> Result<Option<Vec<KroneckerFactorDecomp>>, EstimationError> {
703    let mut decomps = Vec::with_capacity(factors.len());
704    for (j, factor) in factors.iter().enumerate() {
705        let q_j = factor.nrows();
706        if q_j != factor.ncols() {
707            crate::bail_invalid_estim!(
708                "{context}: Kronecker factor {j} must be square, got {}x{}",
709                factor.nrows(),
710                factor.ncols()
711            );
712        }
713        let is_identity = {
714            let mut is_id = true;
715            'outer: for r in 0..q_j {
716                for c in 0..q_j {
717                    let expected = if r == c { 1.0 } else { 0.0 };
718                    if (factor[[r, c]] - expected).abs() > 1e-12 {
719                        is_id = false;
720                        break 'outer;
721                    }
722                }
723            }
724            is_id
725        };
726        if is_identity {
727            decomps.push(KroneckerFactorDecomp {
728                root: Array2::eye(q_j),
729                positive_eigenvalues: vec![1.0; q_j],
730                rank: q_j,
731                dim: q_j,
732            });
733            continue;
734        }
735        let analysis = analyze_penalty_block(factor).map_err(|err| {
736            EstimationError::InvalidInput(format!(
737                "{context}: Kronecker factor {j} eigendecomp failed: {err}"
738            ))
739        })?;
740        if analysis.rank == 0 {
741            return Ok(None);
742        }
743        // Build the factor root from ONLY the range (positive-curvature)
744        // directions via the canonical classifier — never the null or
745        // negative-curvature directions (#1425).
746        let factor_classes =
747            crate::basis::SpectralClassification::new(&analysis.eigenvalues, analysis.tol);
748        let mut root_j = Array2::zeros((analysis.rank, q_j));
749        let mut pos_eigs = Vec::with_capacity(analysis.rank);
750        for (row_idx, &i) in factor_classes.range_idx.iter().enumerate() {
751            let eigenval = analysis.eigenvalues[i];
752            let sqrt_ev = eigenval.sqrt();
753            let evec = analysis.eigenvectors.column(i);
754            for (col, &v) in evec.iter().enumerate() {
755                root_j[[row_idx, col]] = sqrt_ev * v;
756            }
757            pos_eigs.push(eigenval);
758        }
759        decomps.push(KroneckerFactorDecomp {
760            root: root_j,
761            positive_eigenvalues: pos_eigs,
762            rank: analysis.rank,
763            dim: q_j,
764        });
765    }
766    Ok(Some(decomps))
767}
768
769/// Build the block-local Kronecker root from pre-computed factor decompositions.
770fn assemble_kronecker_root_local(decomps: &[KroneckerFactorDecomp]) -> Array2<f64> {
771    let mut kron_root = decomps[0].root.clone();
772    for fr in &decomps[1..] {
773        let (r1, c1) = kron_root.dim();
774        let (r2, c2) = (fr.rank, fr.dim);
775        let mut new_root = Array2::zeros((r1 * r2, c1 * c2));
776        for i1 in 0..r1 {
777            for i2 in 0..r2 {
778                for j1 in 0..c1 {
779                    for j2 in 0..c2 {
780                        new_root[[i1 * r2 + i2, j1 * c2 + j2]] =
781                            kron_root[[i1, j1]] * fr.root[[i2, j2]];
782                    }
783                }
784            }
785        }
786        kron_root = new_root;
787    }
788    kron_root
789}
790
791/// Compute eigenvalues of the Kronecker product from per-factor eigenvalues.
792fn kronecker_eigenvalues(decomps: &[KroneckerFactorDecomp], block_dim: usize) -> (Vec<f64>, usize) {
793    let mut kron_eigs = decomps[0].positive_eigenvalues.clone();
794    for fd in &decomps[1..] {
795        let mut new_eigs = Vec::with_capacity(kron_eigs.len() * fd.positive_eigenvalues.len());
796        for &a in &kron_eigs {
797            for &b in &fd.positive_eigenvalues {
798                new_eigs.push(a * b);
799            }
800        }
801        kron_eigs = new_eigs;
802    }
803    let max_ev = kron_eigs.iter().copied().fold(0.0_f64, f64::max);
804    let tol = max_ev * 1e-10 * (block_dim as f64);
805    let positive: Vec<f64> = kron_eigs.into_iter().filter(|&ev| ev > tol).collect();
806    let nullity = block_dim - positive.len();
807    (positive, nullity)
808}
809
810// ---------------------------------------------------------------------------
811// CanonicalPenalty — block-local processed penalty for the solver
812// ---------------------------------------------------------------------------
813
814/// A canonicalized penalty with block-local root, ready for the solver.
815///
816/// Instead of storing a full `p x p` penalty matrix, this stores only the
817/// `rank x block_dim` root and the column range, enabling O(p_k^2) operations
818/// instead of O(p^2).
819#[derive(Clone)]
820pub struct CanonicalPenalty {
821    /// Square root matrix: S_k = root^T * root.
822    /// Shape: `rank x block_dim` for block-local, `rank x p` for dense.
823    pub root: Array2<f64>,
824    /// Column range in the global coefficient vector [start..end).
825    /// For dense penalties this is `0..p`.
826    pub col_range: std::ops::Range<usize>,
827    /// Full parameter dimension p.
828    pub total_dim: usize,
829    /// Structural nullity of the local penalty.
830    pub nullity: usize,
831    /// The symmetrized block-local penalty matrix (block_dim × block_dim).
832    /// Cached at construction time to avoid recomputing root^T * root
833    /// in hot paths (penalty assembly, trace products).
834    pub local: Array2<f64>,
835    /// Block-local prior mean used to center this penalty.
836    pub prior_mean: Array1<f64>,
837    /// Positive eigenvalues of the local penalty matrix (length = rank).
838    /// Cached at construction time for REML logdet block-factored paths.
839    pub positive_eigenvalues: Vec<f64>,
840    /// Optional operator-form handle bit-equivalent to `local`. Propagated
841    /// from `PenaltySpec::Block.op`. Downstream PIRLS and REML exact operator
842    /// algebra route through this for dense-Gram-free matvec when present.
843    pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
844}
845
846impl std::fmt::Debug for CanonicalPenalty {
847    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848        f.debug_struct("CanonicalPenalty")
849            .field(
850                "root",
851                &format_args!("{}×{}", self.root.nrows(), self.root.ncols()),
852            )
853            .field("col_range", &self.col_range)
854            .field("total_dim", &self.total_dim)
855            .field("nullity", &self.nullity)
856            .field(
857                "local",
858                &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
859            )
860            .field("prior_mean_len", &self.prior_mean.len())
861            .field("positive_eigenvalues", &self.positive_eigenvalues)
862            .field("op", &self.op.as_ref().map(|o| o.dim()))
863            .finish()
864    }
865}
866
867impl CanonicalPenalty {
868    /// Construct a dense (full-width) canonical penalty from a `rank x p` root.
869    /// Used to wrap reparam-transformed roots for consumers that expect
870    /// `&[CanonicalPenalty]`.
871    pub fn from_dense_root(root: Array2<f64>, p: usize) -> Self {
872        Self::from_dense_root_with_mean(root, p, Array1::zeros(p))
873    }
874
875    pub fn from_dense_root_with_mean(root: Array2<f64>, p: usize, prior_mean: Array1<f64>) -> Self {
876        assert_eq!(prior_mean.len(), p);
877        let local = root.t().dot(&root);
878        let positive_eigenvalues = Vec::new(); // not needed for TK paths
879        Self {
880            root,
881            col_range: 0..p,
882            total_dim: p,
883            nullity: 0,
884            local,
885            prior_mean,
886            positive_eigenvalues,
887            op: None,
888        }
889    }
890
891    /// Embed the block-local root into a full-width `rank × total_dim` matrix.
892    /// For dense penalties (col_range = 0..p), returns the root unchanged.
893    pub fn full_width_root(&self) -> Array2<f64> {
894        if self.col_range.start == 0 && self.col_range.end == self.total_dim {
895            return self.root.clone();
896        }
897        let rank = self.root.nrows();
898        let mut full = Array2::<f64>::zeros((rank, self.total_dim));
899        full.slice_mut(ndarray::s![.., self.col_range.clone()])
900            .assign(&self.root);
901        full
902    }
903
904    /// Numerical rank of this penalty.
905    pub fn rank(&self) -> usize {
906        self.root.nrows()
907    }
908
909    /// Block dimension (number of columns this penalty covers).
910    pub fn block_dim(&self) -> usize {
911        self.col_range.len()
912    }
913
914    /// Whether this penalty is block-local (col_range != 0..total_dim).
915    pub const fn is_block_local(&self) -> bool {
916        self.col_range.start != 0 || self.col_range.end != self.total_dim
917    }
918
919    /// Return a reference to the cached local penalty matrix.
920    /// Shape: `block_dim x block_dim`.
921    pub fn local_ref(&self) -> &Array2<f64> {
922        &self.local
923    }
924
925    /// Return an owned copy of the local penalty matrix.
926    /// Prefer `local_ref()` when a reference suffices.
927    pub fn local_penalty(&self) -> Array2<f64> {
928        self.local.clone()
929    }
930
931    /// Accumulate lambda * S_k into a pre-allocated `p x p` target matrix.
932    /// Only touches the block [col_range × col_range].
933    pub fn accumulate_weighted(&self, target: &mut Array2<f64>, lambda: f64) {
934        if lambda == 0.0 || self.rank() == 0 {
935            return;
936        }
937        let r = &self.col_range;
938        target
939            .slice_mut(s![r.start..r.end, r.start..r.end])
940            .scaled_add(lambda, &self.local);
941    }
942
943    /// Compute `scale * tr(M · S_k)` where M is a `p × p` dense matrix.
944    /// Only reads `M[start..end, start..end]` — O(block_dim²) not O(p²).
945    pub fn trace_product(&self, m: &Array2<f64>, scale: f64) -> f64 {
946        if self.rank() == 0 || scale == 0.0 {
947            return 0.0;
948        }
949        let r = &self.col_range;
950        let m_block = m.slice(s![r.start..r.end, r.start..r.end]);
951        let rm = self.root.dot(&m_block);
952        scale
953            * rm.iter()
954                .zip(self.root.iter())
955                .map(|(&a, &b)| a * b)
956                .sum::<f64>()
957    }
958
959    /// Compute `scale * v^T S_k v` (quadratic form).
960    /// Only reads `v[start..end]` — O(rank × block_dim) not O(rank × p).
961    pub fn quadratic(&self, v: &Array1<f64>, scale: f64) -> f64 {
962        if self.rank() == 0 || scale == 0.0 {
963            return 0.0;
964        }
965        let v_block = v.slice(s![self.col_range.start..self.col_range.end]);
966        let rv = self.root.dot(&v_block);
967        scale * rv.dot(&rv)
968    }
969
970    /// Compute `scale * S_k * prior_mean` embedded into the global basis.
971    pub fn prior_linear_shift(&self, scale: f64) -> Array1<f64> {
972        let mut out = Array1::<f64>::zeros(self.total_dim);
973        if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
974            return out;
975        }
976        let block = self.local.dot(&self.prior_mean) * scale;
977        out.slice_mut(s![self.col_range.start..self.col_range.end])
978            .assign(&block);
979        out
980    }
981
982    /// Compute `scale * prior_mean' S_k prior_mean`.
983    pub fn prior_constant_shift(&self, scale: f64) -> f64 {
984        if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
985            return 0.0;
986        }
987        scale * self.prior_mean.dot(&self.local.dot(&self.prior_mean))
988    }
989
990    /// Embed this block's prior mean into the global coefficient basis.
991    pub fn full_width_prior_mean(&self) -> Array1<f64> {
992        if self.col_range.start == 0 && self.col_range.end == self.total_dim {
993            return self.prior_mean.clone();
994        }
995        let mut out = Array1::<f64>::zeros(self.total_dim);
996        out.slice_mut(s![self.col_range.start..self.col_range.end])
997            .assign(&self.prior_mean);
998        out
999    }
1000
1001    /// Convert to a PenaltyCoordinate for the unified REML evaluator.
1002    pub fn to_penalty_coordinate(
1003        &self,
1004    ) -> gam_problem::PenaltyCoordinate {
1005        use gam_problem::PenaltyCoordinate;
1006        if self.is_block_local() {
1007            PenaltyCoordinate::from_block_root_with_mean(
1008                self.root.clone(),
1009                self.col_range.start,
1010                self.col_range.end,
1011                self.total_dim,
1012                self.prior_mean.clone(),
1013            )
1014        } else {
1015            PenaltyCoordinate::from_dense_root_with_mean(self.root.clone(), self.prior_mean.clone())
1016        }
1017    }
1018}
1019
1020/// Detect and report structurally identical (or near-identical) penalty pairs
1021/// in the canonical bundle.
1022///
1023/// Two penalties `S_i`, `S_j` with the same `col_range` are compared by their
1024/// matrix cosine:
1025///
1026///     cos(S_i, S_j) = tr(S_i S_j) / sqrt(tr(S_i^2) * tr(S_j^2))
1027///
1028/// Because `local` is symmetric, `tr(A·B) = sum_{r,c} A[r,c] * B[r,c]` — i.e.,
1029/// the Frobenius inner product. Pairs with different `col_range` cannot be
1030/// functionally identical and are skipped.
1031///
1032/// Logging policy:
1033/// - `cos > 1 - 1e-8` → `log::warn!` with `[PENALTY-REDUNDANCY]`. Every such
1034///   pair is emitted because it represents a structural model error (the LAML
1035///   cost has a Z₂-symmetric saddle that ARC's cubic regularization will
1036///   happily converge to under first-order stationarity).
1037/// - `0.99 < cos ≤ 1 - 1e-8` → `log::info!` with `[PENALTY-SIMILARITY]`.
1038///   At large scale (`k > 64`) only the top-3 highest-cosine such pairs are
1039///   logged to bound log volume.
1040///
1041/// Returns `Vec<(i, j, cos)>` for the **redundant** pairs (cos > 1 - 1e-8),
1042/// primarily to make this function unit-testable without a log capture.
1043///
1044/// Performance: this is O(k² · block_dim²); intended to be called exactly
1045/// once per fit (e.g. from `RemlState::newwith_offset_shared`).
1046pub fn report_penalty_pair_redundancy(canonical: &[CanonicalPenalty]) -> Vec<(usize, usize, f64)> {
1047    const REDUNDANCY_THRESHOLD: f64 = 1.0 - 1e-8;
1048    const SIMILARITY_THRESHOLD: f64 = 0.99;
1049    const LARGE_SCALE_K_THRESHOLD: usize = 64;
1050    const TOP_SIMILARITY_PAIRS: usize = 3;
1051
1052    let k = canonical.len();
1053    let mut redundant: Vec<(usize, usize, f64)> = Vec::new();
1054    let mut similar: Vec<(usize, usize, f64)> = Vec::new();
1055
1056    // Pre-compute tr(S_i^2) = sum of squares of S_i entries (Frobenius norm
1057    // squared). `local` is symmetric, so this equals tr(S_i^T S_i) = tr(S_i^2).
1058    let trace_sq: Vec<f64> = canonical
1059        .iter()
1060        .map(|p| p.local.iter().map(|&v| v * v).sum::<f64>())
1061        .collect();
1062
1063    for i in 0..k {
1064        if trace_sq[i] == 0.0 {
1065            continue;
1066        }
1067        for j in (i + 1)..k {
1068            if trace_sq[j] == 0.0 {
1069                continue;
1070            }
1071            // Different col_range → cannot be functionally identical by
1072            // construction (the block-local matrices live in disjoint or
1073            // mismatched parameter subspaces).
1074            if canonical[i].col_range != canonical[j].col_range {
1075                continue;
1076            }
1077            // Shapes must match — they do when col_range matches because
1078            // `local` is `block_dim × block_dim` and `block_dim = col_range.len()`.
1079            assert_eq!(canonical[i].local.dim(), canonical[j].local.dim());
1080
1081            let inner: f64 = canonical[i]
1082                .local
1083                .iter()
1084                .zip(canonical[j].local.iter())
1085                .map(|(&a, &b)| a * b)
1086                .sum();
1087            let denom = (trace_sq[i] * trace_sq[j]).sqrt();
1088            if denom == 0.0 {
1089                continue;
1090            }
1091            let cos = inner / denom;
1092
1093            if cos > REDUNDANCY_THRESHOLD {
1094                redundant.push((i, j, cos));
1095            } else if cos > SIMILARITY_THRESHOLD {
1096                similar.push((i, j, cos));
1097            }
1098        }
1099    }
1100
1101    // Always emit every redundancy — these are structural model errors.
1102    for &(i, j, cos) in &redundant {
1103        log::warn!(
1104            "[PENALTY-REDUNDANCY] penalties i={i} j={j} are structurally identical \
1105             (cos={cos:.6}) — model is over-parameterized along their antisymmetric \
1106             direction; expect a Z₂-symmetric saddle in the LAML cost. Consider \
1107             re-specifying (e.g. anisotropic→isotropic for spatial smoothers with \
1108             weak axis signal)."
1109        );
1110    }
1111
1112    // Cap similarity log volume at large scale.
1113    if k > LARGE_SCALE_K_THRESHOLD && similar.len() > TOP_SIMILARITY_PAIRS {
1114        similar.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1115        similar.truncate(TOP_SIMILARITY_PAIRS);
1116    }
1117    for (i, j, cos) in similar {
1118        log::info!(
1119            "[PENALTY-SIMILARITY] penalties i={i} j={j} are near-identical \
1120             (cos={cos:.6}) — outer Hessian may be ill-conditioned along their \
1121             antisymmetric direction."
1122        );
1123    }
1124
1125    redundant
1126}
1127
1128/// Canonicalize a single `PenaltySpec` into a `CanonicalPenalty` by computing
1129/// the block-local eigendecomposition and extracting the root.
1130///
1131/// This is O(block_dim^3) instead of O(p^3) for block-local penalties.
1132/// Returns `None` if the penalty has rank zero (should be dropped).
1133pub fn canonicalize_penalty_spec(
1134    spec: &crate::PenaltySpec,
1135    p: usize,
1136    idx: usize,
1137    context: &str,
1138) -> Result<Option<CanonicalPenalty>, EstimationError> {
1139    use crate::PenaltySpec;
1140
1141    crate::validate_penalty_spec_shape(idx, spec, p, context)?;
1142
1143    let (local_matrix, col_range, prior_mean_spec, hint, op) = match spec {
1144        PenaltySpec::Block {
1145            local,
1146            col_range,
1147            prior_mean,
1148            structure_hint,
1149            op,
1150        } => (
1151            local.view(),
1152            col_range.clone(),
1153            prior_mean,
1154            structure_hint.as_ref(),
1155            op.clone(),
1156        ),
1157        PenaltySpec::Dense(m) => (
1158            m.view(),
1159            0..p,
1160            &gam_problem::CoefficientPriorMean::Zero,
1161            None,
1162            None,
1163        ),
1164        PenaltySpec::DenseWithMean { matrix, prior_mean } => {
1165            (matrix.view(), 0..p, prior_mean, None, None)
1166        }
1167    };
1168
1169    let block_dim = col_range.len();
1170    let prior_mean = prior_mean_spec
1171        .evaluate(block_dim, &format!("{context}: penalty {idx}"))
1172        .map_err(|e| EstimationError::InvalidInput(e.0))?;
1173
1174    // ── Ridge fast path: closed-form, no eigendecomposition ──
1175    if let Some(PenaltyStructureHint::Ridge(scale)) = hint {
1176        if *scale <= 0.0 {
1177            return Ok(None);
1178        }
1179        let sqrt_scale = scale.sqrt();
1180        let mut root = Array2::zeros((block_dim, block_dim));
1181        for i in 0..block_dim {
1182            root[[i, i]] = sqrt_scale;
1183        }
1184        // Ridge penalties are diagonal by construction, but still route through
1185        // the crate-wide ndarray symmetrizer so every construction variant uses
1186        // the same "average the transpose" cleanup instead of a local copy.
1187        let mut local_sym = local_matrix.to_owned();
1188        symmetrize_in_place(&mut local_sym);
1189        return Ok(Some(CanonicalPenalty {
1190            root,
1191            col_range,
1192            total_dim: p,
1193            nullity: 0,
1194            local: local_sym,
1195            prior_mean,
1196            positive_eigenvalues: vec![*scale; block_dim],
1197            op,
1198        }));
1199    }
1200
1201    // ── Kronecker fast path: single per-factor eigendecomposition ──
1202    if let Some(PenaltyStructureHint::Kronecker(factors)) = hint {
1203        let decomps =
1204            match decompose_kronecker_factors(factors, &format!("{context} penalty {idx}"))? {
1205                None => return Ok(None),
1206                Some(d) => d,
1207            };
1208        let (positive_eigenvalues, nullity) = kronecker_eigenvalues(&decomps, block_dim);
1209        if positive_eigenvalues.is_empty() {
1210            return Ok(None);
1211        }
1212        let root = assemble_kronecker_root_local(&decomps);
1213        let mut local_sym = local_matrix.to_owned();
1214        symmetrize_in_place(&mut local_sym);
1215        return Ok(Some(CanonicalPenalty {
1216            root,
1217            col_range,
1218            total_dim: p,
1219            nullity,
1220            local: local_sym,
1221            prior_mean,
1222            positive_eigenvalues,
1223            op,
1224        }));
1225    }
1226
1227    // ── Generic block-local path: eigendecompose at O(block_dim³) ──
1228    let local_owned = local_matrix.to_owned();
1229    let analysis = analyze_penalty_block(&local_owned).map_err(|err| {
1230        EstimationError::InvalidInput(format!(
1231            "{context}: penalty canonicalization failed at index {idx}: {err}"
1232        ))
1233    })?;
1234
1235    if analysis.rank == 0 {
1236        log::debug!(
1237            "Dropped inactive penalty block idx={idx} reason={}",
1238            if analysis.iszero {
1239                "ZeroMatrix"
1240            } else {
1241                "NumericalRankZero"
1242            }
1243        );
1244        return Ok(None);
1245    }
1246
1247    // Reuse the eigendecomposition from analyze_penalty_block and route the
1248    // range / null / negative-curvature split through the one canonical
1249    // classifier, so this root construction cannot disagree with the block's
1250    // own `rank` / `nullity` / `negative_dim` about which directions are
1251    // penalized, unpenalized, or non-PSD (#1425).
1252    let tolerance = analysis.tol;
1253    let classes = crate::basis::SpectralClassification::new(&analysis.eigenvalues, tolerance);
1254    let rank_k = classes.rank();
1255    assert_eq!(
1256        rank_k, analysis.rank,
1257        "penalty-root rank disagreement: SpectralClassification rank={rank_k} vs analyze_penalty_block rank={} (#1425 canonical-classifier invariant)",
1258        analysis.rank
1259    );
1260
1261    // Build the penalty root R from ONLY the range directions (positive
1262    // curvature): R has one row per range eigenpair, scaled by sqrt(ev), so
1263    // RᵀR reconstructs S on range(S). Null directions contribute nothing
1264    // (their eigenvalue is zero); negative-curvature directions are NEVER
1265    // square-rooted into R (their sqrt is imaginary) and are NOT null — they
1266    // are simply dropped from R, exactly as the closed-form Duchon kernels at
1267    // high d require to preserve the q_pen / q_null invariant downstream.
1268    let mut root = Array2::zeros((rank_k, block_dim));
1269    let mut positive_eigenvalues = Vec::with_capacity(rank_k);
1270    for (row_idx, &i) in classes.range_idx.iter().enumerate() {
1271        let eigenval = analysis.eigenvalues[i];
1272        let eigenvec = analysis.eigenvectors.column(i);
1273        root.row_mut(row_idx).assign(&(&eigenvec * eigenval.sqrt()));
1274        positive_eigenvalues.push(eigenval);
1275    }
1276
1277    // Surface any genuine negative curvature honestly: it is neither range
1278    // (dropped from R) nor null (excluded from `nullity`), so it would
1279    // otherwise vanish without a trace. A non-PSD penalty reaching this path
1280    // is a real geometric fact (e.g. high-d Duchon kernels) the operator
1281    // should be able to see.
1282    if classes.is_indefinite() {
1283        log::debug!(
1284            "{context}: penalty block idx={idx} carries {} negative-curvature \
1285             eigendirection(s) below -tol={tolerance:e}; dropped from the canonical \
1286             root and NOT counted as null space (rank={rank_k}, nullity={})",
1287            classes.negative_dim(),
1288            classes.nullity()
1289        );
1290    }
1291
1292    // Store the PSD reconstruction RᵀR rather than the raw symmetrised input so
1293    // the cached `local` matches the rank truncation embedded in `root`
1294    // (negative-curvature directions are excluded from both, as above).
1295    let local = root.t().dot(&root);
1296    Ok(Some(CanonicalPenalty {
1297        root,
1298        col_range,
1299        total_dim: p,
1300        nullity: classes.nullity(),
1301        local,
1302        prior_mean,
1303        positive_eigenvalues,
1304        op,
1305    }))
1306}
1307
1308/// Canonicalize a batch of penalty specs, dropping zero-rank penalties.
1309/// Returns (active_penalties, active_nullspace_dims).
1310pub fn canonicalize_penalty_specs(
1311    specs: &[crate::PenaltySpec],
1312    nullspace_dims: &[usize],
1313    p: usize,
1314    context: &str,
1315) -> Result<(Vec<CanonicalPenalty>, Vec<usize>), EstimationError> {
1316    if specs.len() != nullspace_dims.len() {
1317        crate::bail_invalid_estim!(
1318            "{context}: nullspace_dims length mismatch: penalties={}, nullspace_dims={}",
1319            specs.len(),
1320            nullspace_dims.len()
1321        );
1322    }
1323
1324    let mut active = Vec::with_capacity(specs.len());
1325    let mut active_nullspace = Vec::with_capacity(specs.len());
1326    for (idx, spec) in specs.iter().enumerate() {
1327        if let Some(canonical) = canonicalize_penalty_spec(spec, p, idx, context)? {
1328            active_nullspace.push(nullspace_dims[idx]);
1329            active.push(canonical);
1330        }
1331    }
1332    Ok((active, active_nullspace))
1333}
1334
1335/// Hard cap on the dimension `p` allowed to fall back to a dense p × p
1336/// eigendecomposition of the overlapping balanced penalty.
1337///
1338/// Beyond this cap the overlapping-penalty path errors out instead of
1339/// allocating an O(p²) workspace whose eigendecomposition would dominate the
1340/// solve. ResourcePolicy threading is the long-term home for this cap (the
1341/// resource_serialize agent is widening ResourcePolicy coverage); until that
1342/// lands, both overlapping branches share this single constant so they can't
1343/// drift apart.
1344pub(crate) const OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P: usize = 4096;
1345
1346/// Creates a balanced penalty root from canonical penalties.
1347///
1348/// When all penalties have non-overlapping col_ranges, the balanced sum is
1349/// block-diagonal and eigendecomposition is done per-block at O(Σ p_k³)
1350/// instead of the global O(p³). Falls back to the global path when penalties
1351/// overlap.
1352pub fn create_balanced_penalty_root_from_canonical(
1353    penalties: &[CanonicalPenalty],
1354    p: usize,
1355) -> Result<Array2<f64>, EstimationError> {
1356    if penalties.is_empty() {
1357        return Ok(Array2::zeros((0, p)));
1358    }
1359
1360    // Group penalties by col_range.
1361    let mut block_groups: BTreeMap<(usize, usize), Vec<&CanonicalPenalty>> = BTreeMap::new();
1362    for cp in penalties {
1363        if cp.rank() == 0 {
1364            continue;
1365        }
1366        let key = (cp.col_range.start, cp.col_range.end);
1367        block_groups.entry(key).or_default().push(cp);
1368    }
1369
1370    if block_groups.is_empty() {
1371        return Ok(Array2::zeros((0, p)));
1372    }
1373
1374    // Check for overlapping ranges.
1375    let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1376    let mut overlapping = false;
1377    for i in 1..ranges.len() {
1378        if ranges[i].0 < ranges[i - 1].1 {
1379            overlapping = true;
1380            break;
1381        }
1382    }
1383
1384    if overlapping {
1385        if p > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1386            return Err(EstimationError::LayoutError(format!(
1387                "overlapping penalty root would require dense {}x{} eigendecomposition; \
1388                 large-model dense fallback is disabled. Keep penalties structured or \
1389                 extend the overlapping-penalty solver path",
1390                p, p
1391            )));
1392        }
1393        // Fallback: accumulate into p × p and eigendecompose globally.
1394        let mut s_balanced = Array2::zeros((p, p));
1395        for cp in penalties {
1396            if cp.rank() == 0 {
1397                continue;
1398            }
1399            let local = cp.local_ref();
1400            let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1401            if frob_norm > 1e-12 {
1402                let r = &cp.col_range;
1403                s_balanced
1404                    .slice_mut(s![r.start..r.end, r.start..r.end])
1405                    .scaled_add(1.0 / frob_norm, local);
1406            }
1407        }
1408        let (eigenvalues, eigenvectors) =
1409            robust_eigh(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1410        let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1411        let tolerance = if max_eig > 0.0 {
1412            max_eig * 1e-12
1413        } else {
1414            1e-12
1415        };
1416        let penalty_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1417        if penalty_rank == 0 {
1418            return Ok(Array2::zeros((0, p)));
1419        }
1420        let mut eb = Array2::zeros((p, penalty_rank));
1421        let mut col_idx = 0;
1422        for (i, &eigenval) in eigenvalues.iter().enumerate() {
1423            if eigenval > tolerance {
1424                let sqrt_ev = eigenval.sqrt();
1425                let evec = eigenvectors.column(i);
1426                eb.column_mut(col_idx).assign(&(&evec * sqrt_ev));
1427                col_idx += 1;
1428            }
1429        }
1430        return Ok(eb.t().to_owned());
1431    }
1432
1433    // Non-overlapping: eigendecompose per block at O(Σ p_k³).
1434    struct BlockRoot {
1435        col_range: Range<usize>,
1436        root: Array2<f64>, // rank_b × block_dim
1437    }
1438    // Materialize the BTreeMap order first. Rayon preserves Vec collection
1439    // order for indexed parallel iterators, so assembly below remains stable by
1440    // ascending column range while independent block eigendecompositions run in
1441    // parallel.
1442    let ordered_blocks: Vec<((usize, usize), Vec<&CanonicalPenalty>)> =
1443        block_groups.into_iter().collect();
1444    let block_roots: Vec<BlockRoot> = ordered_blocks
1445        .into_par_iter()
1446        .map(
1447            |((start, end), cps)| -> Result<Option<BlockRoot>, EstimationError> {
1448                let block_dim = end - start;
1449                let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1450
1451                for cp in cps {
1452                    let local = cp.local_ref();
1453                    let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1454                    if frob_norm > 1e-12 {
1455                        s_balanced_local.scaled_add(1.0 / frob_norm, local);
1456                    }
1457                }
1458
1459                let (eigenvalues, eigenvectors) =
1460                    robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1461                let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1462                let tolerance = if max_eig > 0.0 {
1463                    max_eig * 1e-12
1464                } else {
1465                    1e-12
1466                };
1467                let block_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1468
1469                if block_rank == 0 {
1470                    return Ok(None);
1471                }
1472
1473                let mut root = Array2::zeros((block_rank, block_dim));
1474                let mut row_idx = 0;
1475                for (i, &eigenval) in eigenvalues.iter().enumerate() {
1476                    if eigenval > tolerance {
1477                        let sqrt_ev = eigenval.sqrt();
1478                        let evec = eigenvectors.column(i);
1479                        root.row_mut(row_idx).assign(&(&evec * sqrt_ev));
1480                        row_idx += 1;
1481                    }
1482                }
1483
1484                Ok(Some(BlockRoot {
1485                    col_range: start..end,
1486                    root,
1487                }))
1488            },
1489        )
1490        .collect::<Result<Vec<_>, _>>()?
1491        .into_iter()
1492        .flatten()
1493        .collect();
1494    let total_rank: usize = block_roots.iter().map(|br| br.root.nrows()).sum();
1495
1496    if total_rank == 0 {
1497        return Ok(Array2::zeros((0, p)));
1498    }
1499
1500    // Assemble global balanced root: total_rank × p
1501    let mut eb = Array2::zeros((total_rank, p));
1502    let mut row_offset = 0;
1503    for br in &block_roots {
1504        let rank_b = br.root.nrows();
1505        eb.slice_mut(s![
1506            row_offset..(row_offset + rank_b),
1507            br.col_range.start..br.col_range.end
1508        ])
1509        .assign(&br.root);
1510        row_offset += rank_b;
1511    }
1512
1513    Ok(eb)
1514}
1515
1516/// Lambda-independent reparameterization invariants derived from penalty structure.
1517#[derive(Clone)]
1518struct SubspaceSplit {
1519    q_pen: Array2<f64>,
1520    q_null: Array2<f64>,
1521}
1522
1523impl SubspaceSplit {
1524    fn identity(p: usize) -> Self {
1525        Self {
1526            q_pen: Array2::zeros((p, 0)),
1527            q_null: Array2::eye(p),
1528        }
1529    }
1530
1531    fn from_ordered_qs(
1532        qs: &Mat<f64>,
1533        penalized_rank: usize,
1534        p: usize,
1535    ) -> Result<Self, EstimationError> {
1536        if qs.nrows() != p || qs.ncols() != p {
1537            return Err(EstimationError::LayoutError(format!(
1538                "Invalid Q basis dimensions: expected {p}x{p}, got {}x{}",
1539                qs.nrows(),
1540                qs.ncols()
1541            )));
1542        }
1543        if penalized_rank > p {
1544            return Err(EstimationError::LayoutError(format!(
1545                "Invalid penalized rank {penalized_rank} for p={p}"
1546            )));
1547        }
1548
1549        let null_count = p - penalized_rank;
1550        let mut q_pen = Array2::<f64>::zeros((p, penalized_rank));
1551        let mut q_null = Array2::<f64>::zeros((p, null_count));
1552        for i in 0..p {
1553            for j in 0..penalized_rank {
1554                q_pen[(i, j)] = qs[(i, j)];
1555            }
1556            for j in 0..null_count {
1557                q_null[(i, j)] = qs[(i, penalized_rank + j)];
1558            }
1559        }
1560
1561        Ok(Self { q_pen, q_null })
1562    }
1563
1564    fn rank(&self) -> usize {
1565        self.q_pen.ncols()
1566    }
1567
1568    fn p(&self) -> usize {
1569        self.q_pen.nrows()
1570    }
1571
1572    fn compose_qs(&self) -> Array2<f64> {
1573        let p = self.p();
1574        let rank = self.rank();
1575        let null_count = self.q_null.ncols();
1576        let mut qs = Array2::<f64>::zeros((p, p));
1577        for i in 0..p {
1578            for j in 0..rank {
1579                qs[(i, j)] = self.q_pen[(i, j)];
1580            }
1581            for j in 0..null_count {
1582                qs[(i, rank + j)] = self.q_null[(i, j)];
1583            }
1584        }
1585        qs
1586    }
1587}
1588
1589/// Lambda-independent reparameterization invariants derived from penalty structure.
1590#[derive(Clone)]
1591pub struct ReparamInvariant {
1592    split: SubspaceSplit,
1593    /// The balanced eigenvector matrix Q (p x p). Block-local roots are
1594    /// transformed on-the-fly as `R_block @ Q[start..end, :]` instead of
1595    /// storing pre-multiplied full-width roots.
1596    qs_base: Array2<f64>,
1597    has_nonzero: bool,
1598    /// Largest eigenvalue of the balanced (unit-Frobenius) penalty matrix.
1599    /// Used as the scale reference for the shrinkage floor.
1600    max_balanced_eigenvalue: f64,
1601}
1602
1603impl ReparamInvariant {
1604    /// Returns the largest eigenvalue of the balanced penalty matrix.
1605    /// This is lambda-independent and provides a natural scale for shrinkage.
1606    pub const fn max_balanced_eigenvalue(&self) -> f64 {
1607        self.max_balanced_eigenvalue
1608    }
1609}
1610
1611/// Precompute the lambda-invariant reparameterization structure from canonical penalties.
1612///
1613/// Uses block-local roots directly instead of requiring rank x p global roots.
1614/// Each `CanonicalPenalty` carries its own block-local root and column range,
1615/// so the balanced sum can be assembled without ever materializing full-size
1616/// penalty matrices.
1617pub fn precompute_reparam_invariant_from_canonical(
1618    penalties: &[CanonicalPenalty],
1619    p_total: usize,
1620) -> Result<ReparamInvariant, EstimationError> {
1621    use std::cmp::Ordering;
1622
1623    let m = penalties.len();
1624
1625    if m == 0 {
1626        return Ok(ReparamInvariant {
1627            split: SubspaceSplit::identity(p_total),
1628            qs_base: Array2::eye(p_total),
1629            has_nonzero: false,
1630            max_balanced_eigenvalue: 0.0,
1631        });
1632    }
1633
1634    // Group penalties by col_range to detect block-diagonal structure.
1635    struct PenRef {
1636        penalty_index: usize,
1637    }
1638    let mut block_groups: BTreeMap<(usize, usize), Vec<PenRef>> = BTreeMap::new();
1639    let mut has_nonzero = false;
1640    for (i, cp) in penalties.iter().enumerate() {
1641        if cp.rank() == 0 {
1642            continue;
1643        }
1644        let local = cp.local_ref();
1645        let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1646        if frob_norm > 1e-12 {
1647            has_nonzero = true;
1648        }
1649        let key = (cp.col_range.start, cp.col_range.end);
1650        block_groups
1651            .entry(key)
1652            .or_default()
1653            .push(PenRef { penalty_index: i });
1654    }
1655
1656    if !has_nonzero {
1657        return Ok(ReparamInvariant {
1658            split: SubspaceSplit::identity(p_total),
1659            qs_base: Array2::eye(p_total),
1660            has_nonzero: false,
1661            max_balanced_eigenvalue: 0.0,
1662        });
1663    }
1664
1665    // Check for overlapping ranges.
1666    let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1667    let mut overlapping = false;
1668    for i in 1..ranges.len() {
1669        if ranges[i].0 < ranges[i - 1].1 {
1670            overlapping = true;
1671            break;
1672        }
1673    }
1674
1675    if overlapping {
1676        // Mirror the dense-fallback guard from
1677        // `create_balanced_penalty_root_from_canonical`. Without this, large-scale-
1678        // scale models with overlapping penalties allocated a full
1679        // p_total × p_total workspace and ran an O(p³) eigendecomposition
1680        // before any solver code saw the problem size.
1681        if p_total > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1682            return Err(EstimationError::LayoutError(format!(
1683                "overlapping penalty reparameterization would require dense {}x{} eigendecomposition; \
1684                 large-model dense fallback is disabled. Keep penalties structured or \
1685                 extend the overlapping-penalty solver path",
1686                p_total, p_total
1687            )));
1688        }
1689        // Fallback: global p×p eigendecomposition.
1690        let mut s_balanced = Mat::<f64>::zeros(p_total, p_total);
1691        for cp in penalties {
1692            if cp.rank() == 0 {
1693                continue;
1694            }
1695            let local = cp.local_ref();
1696            let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1697            if frob_norm > 1e-12 {
1698                let scale = 1.0 / frob_norm;
1699                let r = &cp.col_range;
1700                for i in 0..local.nrows() {
1701                    for j in 0..local.ncols() {
1702                        s_balanced[(r.start + i, r.start + j)] += scale * local[[i, j]];
1703                    }
1704                }
1705            }
1706        }
1707
1708        let (bal_eigenvalues, bal_eigenvectors) =
1709            robust_eigh_faer(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1710
1711        let mut order: Vec<usize> = (0..p_total).collect();
1712        order.sort_by(|&i, &j| {
1713            bal_eigenvalues[j]
1714                .partial_cmp(&bal_eigenvalues[i])
1715                .unwrap_or(Ordering::Equal)
1716                .then(i.cmp(&j))
1717        });
1718
1719        let mut qs = Mat::<f64>::zeros(p_total, p_total);
1720        for (col_idx, &idx) in order.iter().enumerate() {
1721            for row in 0..p_total {
1722                qs[(row, col_idx)] = bal_eigenvectors[(row, idx)];
1723            }
1724        }
1725
1726        let max_bal = order
1727            .iter()
1728            .map(|&idx| bal_eigenvalues[idx].abs())
1729            .fold(0.0_f64, f64::max);
1730        let rank_tol = if max_bal > 0.0 {
1731            max_bal * 1e-12
1732        } else {
1733            1e-12
1734        };
1735        let penalized_rank = order
1736            .iter()
1737            .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1738            .count();
1739        let split = SubspaceSplit::from_ordered_qs(&qs, penalized_rank, p_total)?;
1740
1741        return Ok(ReparamInvariant {
1742            split,
1743            qs_base: mat_to_array(&qs),
1744            has_nonzero,
1745            max_balanced_eigenvalue: max_bal,
1746        });
1747    }
1748
1749    // -----------------------------------------------------------------------
1750    // Non-overlapping: block-diagonal eigendecomposition at O(Σ p_k³).
1751    // -----------------------------------------------------------------------
1752    // The balanced sum is block-diagonal ⟹ its eigenvectors are block-local.
1753    // Q_pen and Q_null are assembled by embedding block-local eigenvectors.
1754
1755    // Track which columns are covered by any penalty.
1756    let mut covered = vec![false; p_total];
1757    for cp in penalties {
1758        for j in cp.col_range.clone() {
1759            covered[j] = true;
1760        }
1761    }
1762    let uncovered_cols: Vec<usize> = (0..p_total).filter(|j| !covered[*j]).collect();
1763
1764    struct BlockResult {
1765        col_range: Range<usize>,
1766        q_pen_local: Array2<f64>,  // block_dim × pen_rank
1767        q_null_local: Array2<f64>, // block_dim × null_rank
1768        /// Largest balanced eigenvalue contributed by this block.
1769        max_balanced_eigenvalue: f64,
1770        /// Column offset of this block's penalized directions within global Q_pen.
1771        pen_col_offset: usize,
1772        /// Column offset of this block's null directions within global Q_null.
1773        null_col_offset: usize,
1774    }
1775
1776    // BTreeMap iteration defines the deterministic block order; collecting the
1777    // indexed parallel iterator preserves that order while eigendecomposing each
1778    // independent canonical penalty block concurrently.
1779    let block_specs: Vec<_> = block_groups.iter().collect();
1780    let mut block_results: Vec<BlockResult> = block_specs
1781        .into_par_iter()
1782        .map(
1783            |(&(start, end), refs)| -> Result<BlockResult, EstimationError> {
1784                let block_dim = end - start;
1785
1786                // Build local balanced sum.
1787                let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1788                let mut block_has_nonzero = false;
1789                for pref in refs {
1790                    let cp = &penalties[pref.penalty_index];
1791                    let local = cp.local_ref();
1792                    let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1793                    if frob_norm > 1e-12 {
1794                        s_balanced_local.scaled_add(1.0 / frob_norm, local);
1795                        block_has_nonzero = true;
1796                    }
1797                }
1798
1799                if !block_has_nonzero {
1800                    return Ok(BlockResult {
1801                        col_range: start..end,
1802                        q_pen_local: Array2::zeros((block_dim, 0)),
1803                        q_null_local: Array2::eye(block_dim),
1804                        max_balanced_eigenvalue: 0.0,
1805                        pen_col_offset: 0,  // set later
1806                        null_col_offset: 0, // set later
1807                    });
1808                }
1809
1810                // Eigendecompose the local balanced penalty.
1811                let (bal_eigenvalues, bal_eigenvectors) =
1812                    robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1813
1814                let mut order: Vec<usize> = (0..block_dim).collect();
1815                order.sort_by(|&i, &j| {
1816                    bal_eigenvalues[j]
1817                        .partial_cmp(&bal_eigenvalues[i])
1818                        .unwrap_or(Ordering::Equal)
1819                        .then(i.cmp(&j))
1820                });
1821
1822                let max_bal = order
1823                    .iter()
1824                    .map(|&idx| bal_eigenvalues[idx].abs())
1825                    .fold(0.0_f64, f64::max);
1826                let rank_tol = if max_bal > 0.0 {
1827                    max_bal * 1e-12
1828                } else {
1829                    1e-12
1830                };
1831                let penalized_rank = order
1832                    .iter()
1833                    .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1834                    .count();
1835                let null_count = block_dim - penalized_rank;
1836
1837                let mut q_pen_local = Array2::zeros((block_dim, penalized_rank));
1838                let mut q_null_local = Array2::zeros((block_dim, null_count));
1839                for (col_idx, &idx) in order.iter().enumerate() {
1840                    if col_idx < penalized_rank {
1841                        for row in 0..block_dim {
1842                            q_pen_local[[row, col_idx]] = bal_eigenvectors[[row, idx]];
1843                        }
1844                    } else {
1845                        let null_col = col_idx - penalized_rank;
1846                        for row in 0..block_dim {
1847                            q_null_local[[row, null_col]] = bal_eigenvectors[[row, idx]];
1848                        }
1849                    }
1850                }
1851
1852                Ok(BlockResult {
1853                    col_range: start..end,
1854                    q_pen_local,
1855                    q_null_local,
1856                    max_balanced_eigenvalue: max_bal,
1857                    pen_col_offset: 0,  // set later
1858                    null_col_offset: 0, // set later
1859                })
1860            },
1861        )
1862        .collect::<Result<_, _>>()?;
1863    let global_max_bal = block_results
1864        .iter()
1865        .map(|br| br.max_balanced_eigenvalue)
1866        .fold(0.0_f64, f64::max);
1867
1868    // Compute column offsets for each block in the global Q_pen / Q_null layout.
1869    let total_pen_rank: usize = block_results.iter().map(|br| br.q_pen_local.ncols()).sum();
1870    let total_null: usize = block_results
1871        .iter()
1872        .map(|br| br.q_null_local.ncols())
1873        .sum::<usize>()
1874        + uncovered_cols.len();
1875    {
1876        let mut pen_off = 0usize;
1877        let mut null_off = 0usize;
1878        for br in &mut block_results {
1879            br.pen_col_offset = pen_off;
1880            br.null_col_offset = null_off;
1881            pen_off += br.q_pen_local.ncols();
1882            null_off += br.q_null_local.ncols();
1883        }
1884    }
1885
1886    let mut q_pen = Array2::zeros((p_total, total_pen_rank));
1887    let mut q_null = Array2::zeros((p_total, total_null));
1888
1889    for br in &block_results {
1890        let start = br.col_range.start;
1891        let bd = br.q_pen_local.nrows();
1892        let pen_r = br.q_pen_local.ncols();
1893        let null_r = br.q_null_local.ncols();
1894        if pen_r > 0 {
1895            q_pen
1896                .slice_mut(s![
1897                    start..(start + bd),
1898                    br.pen_col_offset..(br.pen_col_offset + pen_r)
1899                ])
1900                .assign(&br.q_pen_local);
1901        }
1902        if null_r > 0 {
1903            q_null
1904                .slice_mut(s![
1905                    start..(start + bd),
1906                    br.null_col_offset..(br.null_col_offset + null_r)
1907                ])
1908                .assign(&br.q_null_local);
1909        }
1910    }
1911    let mut null_col = block_results
1912        .iter()
1913        .map(|br| br.q_null_local.ncols())
1914        .sum::<usize>();
1915    for &j in &uncovered_cols {
1916        q_null[[j, null_col]] = 1.0;
1917        null_col += 1;
1918    }
1919
1920    let split = SubspaceSplit { q_pen, q_null };
1921
1922    // Store the global Q_s = [Q_pen | Q_null] from the split.
1923    // Block-local roots are transformed on-the-fly as R_block @ Q[start..end, :]
1924    // inside the reparam engine, avoiding O(k * rank * p) storage.
1925    let qs_global = split.compose_qs();
1926
1927    Ok(ReparamInvariant {
1928        split,
1929        qs_base: qs_global,
1930        has_nonzero,
1931        max_balanced_eigenvalue: global_max_bal,
1932    })
1933}
1934
1935fn structurally_penalized_columns(penalties: &[CanonicalPenalty], p: usize) -> Vec<bool> {
1936    let mut active = vec![false; p];
1937    for cp in penalties {
1938        let local = cp.local_ref();
1939        let scale = local.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
1940        if scale <= 0.0 {
1941            continue;
1942        }
1943        let tol = scale * 1e-12;
1944        for local_col in 0..cp.block_dim() {
1945            let mut column_active = false;
1946            for row in 0..cp.block_dim() {
1947                if local[[row, local_col]].abs() > tol || local[[local_col, row]].abs() > tol {
1948                    column_active = true;
1949                    break;
1950                }
1951            }
1952            if column_active {
1953                active[cp.col_range.start + local_col] = true;
1954            }
1955        }
1956    }
1957    active
1958}
1959
1960/// Apply stable reparameterization using precomputed lambda-invariant structures.
1961///
1962/// `penalty_shrinkage_floor`: optional relative shrinkage floor for eigenvalues
1963/// of the penalized block. If `Some(epsilon)`, a rho-independent ridge of
1964/// magnitude `epsilon * max_balanced_eigenvalue` is added to each eigenvalue
1965/// of the combined penalty on the penalized block. This prevents barely-penalized
1966/// directions from causing pathological non-Gaussianity in the posterior (e.g.,
1967/// extreme skewness under logit link with high-dimensional spatial smooths).
1968/// A typical value is `1e-6`. Set to `None` or `Some(0.0)` to disable.
1969pub fn stable_reparameterizationwith_invariant(
1970    penalties: &[CanonicalPenalty],
1971    lambdas: &[f64],
1972    p: usize,
1973    invariant: &ReparamInvariant,
1974    penalty_shrinkage_floor: Option<f64>,
1975) -> Result<ReparamResult, EstimationError> {
1976    let m = penalties.len();
1977
1978    if lambdas.len() != m {
1979        return Err(EstimationError::ParameterConstraintViolation(format!(
1980            "Lambda count mismatch: expected {} lambdas for {} penalties, got {}",
1981            m,
1982            m,
1983            lambdas.len()
1984        )));
1985    }
1986
1987    // No separate length check needed — penalties are matched against lambdas above,
1988    // and the invariant's qs_base is p x p (dimension-checked by the split).
1989
1990    // #1074: the gam#1379 finite-ceiling on λ_k = exp(ρ_k) (clamp to 1e300 to
1991    // avoid `∞·0 = NaN` when the outer optimizer drives a redundant penalty
1992    // direction's log-λ past ~709) was DELETED. It masked the real defect: the
1993    // optimizer drives a redundant/unidentified penalty direction off to ∞
1994    // instead of that direction being detected and dropped from the model.
1995    // The root fix (detect+drop the redundant penalty direction at construction)
1996    // is tracked separately; λ now passes through raw.
1997
1998    if m == 0 {
1999        return Ok(ReparamResult {
2000            s_transformed: Array2::zeros((p, p)),
2001            log_det: 0.0,
2002            det1: Array1::zeros(0),
2003            qs: Array2::eye(p),
2004            canonical_transformed: vec![],
2005            e_transformed: Array2::zeros((0, p)),
2006            // All modes truncated when no penalties; already in transformed frame.
2007            u_truncated: Array2::eye(p),
2008            penalty_shrinkage_ridge: 0.0,
2009        });
2010    }
2011
2012    if !invariant.has_nonzero {
2013        let qs = invariant.split.compose_qs();
2014        let u_truncated = qs.t().dot(&invariant.split.q_null);
2015        // All penalties are zero — canonical_transformed = originals (no rotation needed).
2016        let canonical_transformed: Vec<CanonicalPenalty> = penalties.to_vec();
2017        return Ok(ReparamResult {
2018            s_transformed: Array2::zeros((p, p)),
2019            log_det: 0.0,
2020            det1: Array1::zeros(m),
2021            qs,
2022            canonical_transformed,
2023            e_transformed: Array2::zeros((0, p)),
2024            u_truncated,
2025            penalty_shrinkage_ridge: 0.0,
2026        });
2027    }
2028
2029    let q_pen = array_to_faer(&invariant.split.q_pen);
2030    let q_null = array_to_faer(&invariant.split.q_null);
2031    let qs_base = array_to_faer(&invariant.qs_base);
2032    // Each penalty root transform is independent: R_k_block @ Q[start..end, :].
2033    // Run those per-penalty products (and their S_k = R_k'R_k caches) in
2034    // parallel, then collect in slice order so all downstream accumulation stays
2035    // deterministic and bit-for-bit stable with respect to penalty ordering.
2036    let penalty_transforms: Vec<(Mat<f64>, Mat<f64>)> = penalties
2037        .par_iter()
2038        .map(|cp| {
2039            let r = &cp.col_range;
2040            let root_faer = array_to_faer(&cp.root);
2041            let q_block = qs_base.submatrix(r.start, 0, cp.block_dim(), p);
2042            let mut product = Mat::<f64>::zeros(cp.rank(), p);
2043            matmul(
2044                product.as_mut(),
2045                Accum::Replace,
2046                root_faer.as_ref(),
2047                q_block,
2048                1.0,
2049                Par::Seq,
2050            );
2051            let s_k = penalty_from_root_faer(&product);
2052            (product, s_k)
2053        })
2054        .collect();
2055    let (rs_transformed, s_k_penalized_cache): (Vec<Mat<f64>>, Vec<Mat<f64>>) =
2056        penalty_transforms.into_iter().unzip();
2057
2058    let penalized_rank = invariant.split.rank();
2059
2060    let mut range_eigenvalues_sorted: Vec<f64> = Vec::new();
2061    let mut range_rotation = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2062    if penalized_rank > 0 {
2063        let mut range_block = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2064        // Deterministic assembly: the independent S_k transforms were computed in
2065        // parallel above, but the lambda-weighted sum is accumulated serially in
2066        // canonical penalty order to avoid order-dependent floating-point drift.
2067        for (lambda, s_k) in lambdas.iter().zip(s_k_penalized_cache.iter()) {
2068            for i in 0..penalized_rank {
2069                for j in 0..penalized_rank {
2070                    range_block[(i, j)] += *lambda * s_k[(i, j)];
2071                }
2072            }
2073        }
2074        let (range_eigenvalues, range_eigenvectors) =
2075            robust_eigh_faer(&range_block, Side::Lower, "range penalty block")?;
2076
2077        let mut range_order: Vec<usize> = (0..penalized_rank).collect();
2078        range_order.sort_by(|&i, &j| {
2079            range_eigenvalues[j]
2080                .partial_cmp(&range_eigenvalues[i])
2081                .unwrap_or(std::cmp::Ordering::Equal)
2082                .then(i.cmp(&j))
2083        });
2084        range_eigenvalues_sorted = range_order
2085            .iter()
2086            .map(|&idx| range_eigenvalues[idx])
2087            .collect();
2088
2089        // Build range_rotation = U (sorted eigenvectors) for E and S⁺
2090        // construction only.  DO NOT apply to q_pen or rs_transformed —
2091        // keeping Q_s lambda-independent prevents BFGS coordinate-system
2092        // drift when multiple penalties interact (the eigenvectors of
2093        // Σ λ_k S_k rotate with λ, breaking the quasi-Newton Hessian
2094        // approximation at eigenvalue crossings).
2095        for (col_idx, &idx) in range_order.iter().enumerate() {
2096            for row in 0..penalized_rank {
2097                range_rotation[(row, col_idx)] = range_eigenvectors[(row, idx)];
2098            }
2099        }
2100        // q_pen and rs_transformed stay in the lambda-independent
2101        // invariant basis.  E and S⁺ below are expressed in this same
2102        // basis using U from the eigendecomposition.
2103    }
2104
2105    // Subspace-invariant penalty spectral calculus:
2106    // - Penalized and null spaces are fixed by the lambda-invariant basis `qs_base`.
2107    // - Runtime lambda dependence only appears in the penalized block eigenvalues.
2108    // This avoids basis mixing inside the degenerate zero-eigenspace.
2109    let structural_rank = penalized_rank;
2110    let mut range_eigs_sorted: Vec<f64> = range_eigenvalues_sorted;
2111    let structurally_penalized_cols = structurally_penalized_columns(penalties, p);
2112
2113    // Shrinkage floor: add a rho-independent ridge to the penalized block eigenvalues.
2114    // This prevents barely-penalized directions from causing pathological non-Gaussianity
2115    // in the posterior (extreme skewness under non-canonical links like logit with
2116    // high-dimensional spatial smooths). The ridge magnitude is proportional to the
2117    // balanced penalty's max eigenvalue (lambda-independent scale), so LAML gradients
2118    // w.r.t. rho remain correct: d(epsilon * I)/d(rho_k) = 0.
2119    //
2120    // The shrinkage ridge is a real prior contribution: it changes the quadratic
2121    // form, the penalty pseudo-logdet, and downstream Hessians. It is currently
2122    // surfaced through `ReparamResult::penalty_shrinkage_ridge` and consumed by
2123    // PIRLS/REML via that per-call channel. The longer-term home is a
2124    // `RidgePassport` with `RidgePolicy::explicit_stabilization_full` scoped to
2125    // the penalized block so the same delta is reflected in serialization, the
2126    // Laplace Hessian, and the prior logdet without callers having to thread an
2127    // additional scalar — once the ridge-ledger plumbing covers per-block
2128    // ScaledIdentity passports.
2129    let shrinkage_ridge = penalty_shrinkage_floor
2130        .filter(|&eps| eps > 0.0)
2131        .map(|eps| eps * invariant.max_balanced_eigenvalue)
2132        .unwrap_or(0.0);
2133    if shrinkage_ridge > 0.0 {
2134        let min_eig_before = range_eigs_sorted
2135            .iter()
2136            .copied()
2137            .fold(f64::INFINITY, f64::min);
2138        let mut shrinkage_floor_applied = 0usize;
2139        for eig_idx in 0..range_eigs_sorted.len() {
2140            let mut penalized_energy = 0.0;
2141            for original_col in 0..p {
2142                if structurally_penalized_cols[original_col] {
2143                    let mut coordinate = 0.0;
2144                    for pen_col in 0..penalized_rank {
2145                        coordinate +=
2146                            q_pen[(original_col, pen_col)] * range_rotation[(pen_col, eig_idx)];
2147                    }
2148                    penalized_energy += coordinate * coordinate;
2149                }
2150            }
2151            if penalized_energy > 1e-8 {
2152                range_eigs_sorted[eig_idx] += shrinkage_ridge;
2153                shrinkage_floor_applied += 1;
2154            }
2155        }
2156        // Log when the floor materially changes the smallest eigenvalue (>1% relative shift).
2157        if min_eig_before > 0.0 && shrinkage_ridge / min_eig_before > 0.01 {
2158            log::debug!(
2159                "Penalty shrinkage floor active: ridge={:.3e} (min_eig_before={:.3e}, ratio={:.1e}, max_bal_eig={:.3e}, applied_dirs={})",
2160                shrinkage_ridge,
2161                min_eig_before,
2162                shrinkage_ridge / min_eig_before,
2163                invariant.max_balanced_eigenvalue,
2164                shrinkage_floor_applied,
2165            );
2166        }
2167    }
2168
2169    let eigenvalue_floor = invariant.max_balanced_eigenvalue.max(1.0) * 1e-12;
2170    let qs = compose_qs_from_split(&q_pen, &q_null, p);
2171
2172    // Guard against any accidental penalized/null mixing. The transformed penalty
2173    // roots must have negligible support on null columns by construction.
2174    let leakage = assess_subspace_leakage(&qs, &rs_transformed, structural_rank, p);
2175    let leakage_rel_tol = 1e-10;
2176    let leakage_abs_tol = 1e-12;
2177    let orth_tol = 1e-10;
2178    if leakage.max_rel_sq > leakage_rel_tol && leakage.max_abs_sq > leakage_abs_tol
2179        || leakage.max_cross_gram_abs > orth_tol
2180    {
2181        return Err(EstimationError::LayoutError(format!(
2182            "Reparameterization subspace split is inconsistent: max null leakage {:.3e} (rel {:.3e}, worst penalty {}), max |Qp'Qn| {:.3e}",
2183            leakage.max_abs_sq.sqrt(),
2184            leakage.max_rel_sq.sqrt(),
2185            leakage.worst_penalty,
2186            leakage.max_cross_gram_abs,
2187        )));
2188    }
2189
2190    // Truncated basis in transformed coordinates:
2191    //   U_⊥^(t) = Qs^T U_⊥^(orig) = Qs^T Q_n.
2192    let mut u_truncated_mat = Mat::<f64>::zeros(p, q_null.ncols());
2193    matmul(
2194        u_truncated_mat.as_mut(),
2195        Accum::Replace,
2196        qs.transpose(),
2197        q_null.as_ref(),
2198        1.0,
2199        Par::Seq,
2200    );
2201
2202    // E is represented in TRANSFORMED coordinates (beta_t).  Because the
2203    // penalized subspace is NOT rotated by the lambda-dependent eigenvectors
2204    // (to keep Q_s stable across BFGS iterations), E is no longer diagonal.
2205    // Instead E = diag(√d) · U' embedded in structural_rank × p, so that
2206    // E'E = U diag(d) U' = Σ λ_k S_k in the invariant penalized basis.
2207    let mut e_transformed_mat = Mat::<f64>::zeros(structural_rank, p);
2208    for row_idx in 0..structural_rank {
2209        let safe_eigenval = range_eigs_sorted[row_idx].max(eigenvalue_floor);
2210        let sqrt_eigenval = safe_eigenval.sqrt();
2211        // E[row, j] = sqrt(d_row) * U'[row, j] = sqrt(d_row) * U[j, row]
2212        for j in 0..penalized_rank {
2213            e_transformed_mat[(row_idx, j)] = sqrt_eigenval * range_rotation[(j, row_idx)];
2214        }
2215    }
2216
2217    // Pseudo-logdet on the structural penalized block.  The null block is split
2218    // out above, so there is no nullspace normalization here.  Spectrally-noisy
2219    // directions (eigenvalues snapped to 0 by `classify_eigenvalues_strict`
2220    // because they fall below the `c * eps_machine * p * scale` tolerance in
2221    // the lambda-weighted sum) are floored to `eigenvalue_floor` to keep the
2222    // log-det finite and consistent with the floored values used to construct
2223    // `e_transformed_mat` above.  This avoids spurious P-IRLS failures when the
2224    // lambda dynamic range is wide (e.g. during BFGS line search probing extreme
2225    // rho candidates).  Materially negative or non-finite spectra are already
2226    // rejected by the strict classifier upstream; this loop re-checks the
2227    // *post-shrinkage* range eigenvalues against the same floor.
2228    //
2229    // The same floored spectrum is used in the trace formula tr(S⁺ S_k) below,
2230    // matching the rank structure embedded in `e_transformed_mat` and avoiding
2231    // a 1/0 in the trace contraction when an eigenvalue was floored to 0.
2232    let mut floored_eigs: Vec<f64> = Vec::with_capacity(range_eigs_sorted.len());
2233    let mut log_det_sum = KahanSum::default();
2234    for (idx, &ev) in range_eigs_sorted.iter().enumerate() {
2235        if !ev.is_finite() || ev < -eigenvalue_floor {
2236            return Err(EstimationError::LayoutError(format!(
2237                "Penalty pseudo-logdet has a non-finite or large-negative structural eigenvalue at index {idx}: {ev:.3e}"
2238            )));
2239        }
2240        let safe_ev = ev.max(eigenvalue_floor);
2241        floored_eigs.push(safe_ev);
2242        if idx < penalized_rank {
2243            log_det_sum.add(safe_ev.ln());
2244        }
2245    }
2246    let log_det = log_det_sum.sum();
2247    let delta = 0.0;
2248
2249    // The det1 contractions are independent once the eigensystem is fixed.  Use
2250    // indexed parallel collection so the output vector preserves lambda order.
2251    let det1vec: Vec<f64> = (0..lambdas.len())
2252        .into_par_iter()
2253        .map(|k| {
2254            let s_k = &s_k_penalized_cache[k];
2255            // Compute tr((S+δI)⁻¹ S_k) in the range eigenbasis without ever
2256            // materializing (S+δI)⁻¹. Using faer's matmul keeps this contraction
2257            // aligned with the orthogonal-similarity debug reference path.
2258            let trace = trace_penalty_in_orthogonal_basis(
2259                s_k,
2260                penalized_rank,
2261                &range_rotation,
2262                &floored_eigs,
2263                delta,
2264            );
2265            lambdas[k] * trace
2266        })
2267        .collect();
2268
2269    {
2270        // Guardrail: cross-check the primary Rayleigh-quotient contraction
2271        // against a full orthogonal similarity transform, while staying in
2272        // the same numerically stable eigenbasis coordinates.
2273        let mut maxdet1_mismatch = 0.0_f64;
2274        let mut det1_scale = 0.0_f64;
2275        for (k, lambda) in lambdas.iter().enumerate() {
2276            let s_k_penalized = &s_k_penalized_cache[k];
2277            let s_k_eigenbasis = orthogonal_similarity_transform_faer(
2278                s_k_penalized,
2279                penalized_rank,
2280                &range_rotation,
2281            );
2282            let mut trace = KahanSum::default();
2283            for l in 0..penalized_rank {
2284                trace.add(s_k_eigenbasis[(l, l)] / (floored_eigs[l] + delta));
2285            }
2286            let reference = *lambda * trace.sum();
2287            maxdet1_mismatch = maxdet1_mismatch.max((reference - det1vec[k]).abs());
2288            det1_scale = det1_scale.max(reference.abs()).max(det1vec[k].abs());
2289        }
2290        let det1_tolerance = 1e-7 * det1_scale.max(1.0);
2291        assert!(
2292            maxdet1_mismatch <= det1_tolerance,
2293            "det1 mismatch between optimized and reference formulas: max_abs={maxdet1_mismatch:.3e}, tol={det1_tolerance:.3e}"
2294        );
2295    }
2296
2297    // Rebuild s_transformed from e_transformed to ensure rank consistency.
2298    //
2299    // The sum of λ*S_k may contain numerical noise modes (eigenvalues ~1e-15) that
2300    // become significant when λ is large (e.g., 10^12). These modes would appear in H
2301    // but are truncated from log|S|_+, creating a "phantom penalty" in the objective.
2302    //
2303    // By reconstructing s_transformed = E^T * E, we force the penalty matrix used
2304    // in H to have the EXACT same rank structure as the one used for log|S|_+.
2305    // Any mode truncated from the prior is now strictly zero in the Hessian
2306    // calculation, ensuring mathematical consistency of the gradients.
2307    let mut s_truncated = Mat::<f64>::zeros(p, p);
2308    matmul(
2309        s_truncated.as_mut(),
2310        Accum::Replace,
2311        e_transformed_mat.transpose(),
2312        e_transformed_mat.as_ref(),
2313        1.0,
2314        Par::Seq,
2315    );
2316
2317    {
2318        // Structural check: transformed S must not leak into declared null coordinates.
2319        let mut max_null_diag = 0.0_f64;
2320        let mut max_null_offdiag = 0.0_f64;
2321        for i in structural_rank..p {
2322            max_null_diag = max_null_diag.max(s_truncated[(i, i)].abs());
2323            for j in 0..p {
2324                if i != j {
2325                    max_null_offdiag = max_null_offdiag.max(s_truncated[(i, j)].abs());
2326                }
2327            }
2328        }
2329        assert!(
2330            max_null_diag <= 1e-10 && max_null_offdiag <= 1e-10,
2331            "null-space leakage in transformed penalty: max_null_diag={max_null_diag:.3e}, max_null_offdiag={max_null_offdiag:.3e}"
2332        );
2333    }
2334
2335    let qs_array = mat_to_array(&qs);
2336    let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2337        .par_iter()
2338        .zip(penalties.par_iter())
2339        .map(|(r, cp)| {
2340            let mean_transformed = qs_array.t().dot(&cp.full_width_prior_mean());
2341            CanonicalPenalty::from_dense_root_with_mean(mat_to_array(r), p, mean_transformed)
2342        })
2343        .collect();
2344    Ok(ReparamResult {
2345        s_transformed: mat_to_array(&s_truncated),
2346        log_det,
2347        det1: Array1::from(det1vec),
2348        qs: qs_array,
2349        canonical_transformed,
2350        e_transformed: mat_to_array(&e_transformed_mat),
2351        u_truncated: mat_to_array(&u_truncated_mat),
2352        penalty_shrinkage_ridge: shrinkage_ridge,
2353    })
2354}
2355
2356/// Minimal engine layout descriptor that avoids domain-specific layout coupling.
2357#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2358pub struct EngineDims {
2359    pub p: usize,
2360    pub k: usize,
2361}
2362
2363impl EngineDims {
2364    pub fn new(p: usize, k: usize) -> Self {
2365        Self { p, k }
2366    }
2367}
2368
2369/// Engine-facing stable reparameterization API using only `(p, k)`.
2370///
2371/// When `cached_invariant` is `Some`, reuses the precomputed eigendecomposition
2372/// (the hot path inside the REML loop). When `None`, computes the invariant on
2373/// the fly (the post-REML refit path). Merging both cases into a single entry
2374/// point ensures `penalty_shrinkage_floor` is always applied regardless of
2375/// whether a cached invariant is available.
2376/// Stable reparameterization from block-local canonical penalties.
2377pub fn stable_reparameterization_engine_canonical(
2378    penalties: &[CanonicalPenalty],
2379    lambdas: &[f64],
2380    dims: EngineDims,
2381    cached_invariant: Option<&ReparamInvariant>,
2382    penalty_shrinkage_floor: Option<f64>,
2383) -> Result<ReparamResult, EstimationError> {
2384    let owned;
2385    let invariant = match cached_invariant {
2386        Some(inv) => inv,
2387        None => {
2388            owned = precompute_reparam_invariant_from_canonical(penalties, dims.p)?;
2389            &owned
2390        }
2391    };
2392    stable_reparameterizationwith_invariant(
2393        penalties,
2394        lambdas,
2395        dims.p,
2396        invariant,
2397        penalty_shrinkage_floor,
2398    )
2399}
2400
2401// ---------------------------------------------------------------------------
2402// Kronecker-factored reparameterization for tensor-product smooths
2403// ---------------------------------------------------------------------------
2404
2405/// Result of Kronecker-factored reparameterization.
2406///
2407/// Exploits the fact that for Kronecker-structured penalties, the joint
2408/// eigenvector matrix is `U_1 ⊗ ... ⊗ U_d` and the reparameterized design
2409/// is a rowwise Kronecker of `(B_k U_k)` — all remaining factored.
2410#[derive(Clone)]
2411pub struct KroneckerReparamResult {
2412    /// Reparameterized marginal designs: `B_k · U_k` for each marginal k.
2413    ///
2414    /// `Arc`-shared with the λ-invariant cache so the per-outer-iterate
2415    /// memoized engine bumps a refcount instead of deep-copying the
2416    /// (n × q) reparameterized marginals every call.
2417    pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
2418    /// Marginal eigenvalues from each marginal penalty eigendecomposition.
2419    pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
2420    /// Marginal eigenvector matrices U_k.
2421    pub marginal_qs: Arc<Vec<Array2<f64>>>,
2422    /// log|S|₊ computed from marginal eigenvalue grid.
2423    pub log_det: f64,
2424    /// First derivatives of log|S|₊ w.r.t. ρ_k = log(λ_k).
2425    pub det1: Array1<f64>,
2426    /// Second derivatives of log|S|₊ w.r.t. ρ.
2427    pub det2: Array2<f64>,
2428    /// Shrinkage ridge added to eigenvalues (if any).
2429    pub penalty_shrinkage_ridge: f64,
2430    /// Whether a double penalty (global ridge) is present.
2431    pub has_double_penalty: bool,
2432    /// Marginal basis dimensions.
2433    pub marginal_dims: Vec<usize>,
2434}
2435
2436impl KroneckerReparamResult {
2437    /// Materialize the joint Qs matrix (U_1 ⊗ ... ⊗ U_d) as dense p×p.
2438    /// Only for fallback paths — avoid in hot loops.
2439    pub fn materialize_qs(&self) -> Array2<f64> {
2440        let mut qs = Array2::<f64>::eye(1);
2441        for u_k in self.marginal_qs.iter() {
2442            qs = kronecker_product(&qs, u_k);
2443        }
2444        qs
2445    }
2446
2447    /// Materialize s_transformed (the penalty in the reparameterized basis).
2448    /// In the eigenbasis, this is diagonal with entries Σ_k λ_k μ_{k,j_k}.
2449    pub fn materialize_s_transformed(&self, lambdas: &[f64]) -> Array2<f64> {
2450        let d = self.marginal_dims.len();
2451        let p: usize = self.marginal_dims.iter().copied().product();
2452        let mut s = Array2::<f64>::zeros((p, p));
2453
2454        // Delegate the per-cell tensor-penalty accumulation to the shared
2455        // `kronecker_cell_sigma` (#1172/#1185 single source of truth). Fold the
2456        // `lambdas.len() > d` guard into `has_double` to preserve exact gating.
2457        let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2458            self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2459        let has_double = self.has_double_penalty && lambdas.len() > d;
2460        let mut multi_idx = vec![0usize; d];
2461        let mut flat = 0usize;
2462        loop {
2463            let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2464                &eigenvalue_views,
2465                &multi_idx,
2466                lambdas,
2467                d,
2468                has_double,
2469                self.penalty_shrinkage_ridge,
2470            );
2471            s[[flat, flat]] = sigma;
2472            flat += 1;
2473
2474            if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2475                break;
2476            }
2477        }
2478        s
2479    }
2480
2481    /// Explicitly materialize the dense artifact bundle expected by legacy
2482    /// downstream consumers. This is not part of the native Kronecker solve path.
2483    pub fn materialize_dense_artifact_result(
2484        &self,
2485        rs_list: &[Array2<f64>],
2486        lambdas: &[f64],
2487        p: usize,
2488    ) -> Result<ReparamResult, EstimationError> {
2489        const KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P: usize = 4096;
2490        if p > KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P {
2491            return Err(EstimationError::LayoutError(format!(
2492                "Kronecker reparameterization would materialize dense {}x{} compatibility tensors; \
2493                 large-model dense fallback is disabled. Wire the downstream solver to consume \
2494                 the factored Kronecker result directly",
2495                p, p
2496            )));
2497        }
2498        let qs = self.materialize_qs();
2499        let s_transformed = self.materialize_s_transformed(lambdas);
2500
2501        // Transform penalty roots: R_k_transformed = R_k · Qs
2502        let rs_transformed: Vec<Array2<f64>> = if rs_list.len() >= 2 {
2503            use rayon::prelude::*;
2504            rs_list
2505                .par_iter()
2506                .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2507                .collect()
2508        } else {
2509            rs_list
2510                .iter()
2511                .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2512                .collect()
2513        };
2514        // rs_transposed removed — canonical_transformed is the single source of truth.
2515
2516        // Build e_transformed: combined penalty square root in transformed coords.
2517        // For Kronecker structure, the penalty is diagonal in the eigenbasis.
2518        // e_transformed rows are the nonzero rows of sqrt(Σ_k λ_k S_k)^{1/2}.
2519        let d = self.marginal_dims.len();
2520        // Delegate the per-cell tensor-penalty accumulation to the shared
2521        // `kronecker_cell_sigma` (the #1172/#1185 single source of truth). The
2522        // double-penalty term is only valid when `lambdas` actually carries the
2523        // λ_d entry, so fold the original `lambdas.len() > d` guard into the
2524        // `has_double_penalty` flag passed to the helper — preserving the exact
2525        // gating behavior.
2526        let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2527            self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2528        let has_double = self.has_double_penalty && lambdas.len() > d;
2529        let diag_vals: Vec<f64> = {
2530            let mut vals = Vec::with_capacity(p);
2531            let mut multi_idx = vec![0usize; d];
2532            loop {
2533                let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2534                    &eigenvalue_views,
2535                    &multi_idx,
2536                    lambdas,
2537                    d,
2538                    has_double,
2539                    self.penalty_shrinkage_ridge,
2540                );
2541                vals.push(if sigma > 0.0 { sigma.sqrt() } else { 0.0 });
2542
2543                if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2544                    break;
2545                }
2546            }
2547            vals
2548        };
2549        let rank = diag_vals.iter().filter(|&&v| v > 1e-12).count();
2550        let mut e_transformed = Array2::<f64>::zeros((rank, p));
2551        let mut row = 0;
2552        for (j, &v) in diag_vals.iter().enumerate() {
2553            if v > 1e-12 {
2554                e_transformed[[row, j]] = v;
2555                row += 1;
2556            }
2557        }
2558
2559        // u_truncated: null-space eigenvectors (columns with zero eigenvalue).
2560        let null_count = p - rank;
2561        let mut u_truncated = Array2::<f64>::zeros((p, null_count));
2562        let mut col = 0;
2563        for (j, &v) in diag_vals.iter().enumerate() {
2564            if v <= 1e-12 {
2565                u_truncated[[j, col]] = 1.0; // standard basis vector in eigenbasis
2566                col += 1;
2567            }
2568        }
2569
2570        let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2571            .iter()
2572            .map(|r| CanonicalPenalty::from_dense_root(r.clone(), p))
2573            .collect();
2574        Ok(ReparamResult {
2575            s_transformed,
2576            log_det: self.log_det,
2577            det1: self.det1.clone(),
2578            qs,
2579            canonical_transformed,
2580            e_transformed,
2581            u_truncated,
2582            penalty_shrinkage_ridge: self.penalty_shrinkage_ridge,
2583        })
2584    }
2585}
2586
2587/// Compute `log|S|₊` and its first/second derivatives w.r.t. `ρ_k = log(λ_k)`
2588/// from factored marginal eigenvalues.
2589///
2590/// Shared implementation for `KroneckerPenaltySystem::logdet_and_derivatives`
2591/// and `kronecker_reparameterization_engine`.  Iterates over the ∏q_j
2592/// multi-index grid in O(d · ∏q_j) time with no O(p²) storage.
2593const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
2594
2595/// Per-cell Kronecker eigenvalue accumulation — the single source of truth for
2596/// the #1172/#1185 tensor-penalty math.
2597///
2598/// For the multi-index cell `multi_idx`, accumulates:
2599///   - `sigma`            = Σ_k λ_k · μ_k  (+ joint-null double-penalty term + ridge)
2600///   - `structural_sigma` = Σ_k μ_k        (unweighted; classifies joint-null cells)
2601///   - `joint_null`       = whether the cell lies in the joint null space
2602///
2603/// `marginal_eigenvalues[k][multi_idx[k]]` is the k-th marginal eigenvalue μ_k.
2604/// The double-penalty (global ridge) term `λ_d` is added only on joint-null
2605/// cells; the structural shrinkage `ridge` is added only on structurally
2606/// penalized cells. This mirrors the gated logic fixed in #1172/#1185 and MUST
2607/// be kept identical across every caller.
2608#[inline]
2609fn kronecker_cell_sigma(
2610    marginal_eigenvalues: &[ArrayView1<'_, f64>],
2611    multi_idx: &[usize],
2612    lambdas: &[f64],
2613    d: usize,
2614    has_double_penalty: bool,
2615    ridge: f64,
2616) -> (f64, f64, bool) {
2617    let mut sigma = 0.0;
2618    let mut structural_sigma = 0.0;
2619    for k in 0..d {
2620        let marginal_eigenvalue = marginal_eigenvalues[k][multi_idx[k]];
2621        structural_sigma += marginal_eigenvalue;
2622        sigma += lambdas[k] * marginal_eigenvalue;
2623    }
2624    let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
2625    if has_double_penalty && joint_null {
2626        sigma += lambdas[d];
2627    }
2628    if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
2629        sigma += ridge;
2630    }
2631    (sigma, structural_sigma, joint_null)
2632}
2633
2634/// Advance a row-major multi-index over the `dims` grid in place.
2635/// Returns `true` when the grid is exhausted (the index wrapped back to all-zero).
2636#[inline]
2637fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
2638    let mut carry = true;
2639    for dim in (0..dims.len()).rev() {
2640        if carry {
2641            multi_idx[dim] += 1;
2642            if multi_idx[dim] < dims[dim] {
2643                carry = false;
2644            } else {
2645                multi_idx[dim] = 0;
2646            }
2647        }
2648    }
2649    carry
2650}
2651
2652pub fn kronecker_logdet_and_derivatives(
2653    marginal_eigenvalues: &[ArrayView1<'_, f64>],
2654    marginal_dims: &[usize],
2655    lambdas: &[f64],
2656    has_double_penalty: bool,
2657    ridge: f64,
2658) -> (f64, Array1<f64>, Array2<f64>) {
2659    let d = marginal_dims.len();
2660    let n_pen = d + if has_double_penalty { 1 } else { 0 };
2661
2662    let mut logdet = 0.0;
2663    let mut grad = Array1::<f64>::zeros(n_pen);
2664    let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2665    let tol = 1e-12;
2666
2667    let mut multi_idx = vec![0usize; d];
2668    loop {
2669        let (sigma, _structural_sigma, joint_null) = kronecker_cell_sigma(
2670            marginal_eigenvalues,
2671            &multi_idx,
2672            lambdas,
2673            d,
2674            has_double_penalty,
2675            ridge,
2676        );
2677
2678        if sigma > tol {
2679            logdet += sigma.ln();
2680            let inv_sigma = 1.0 / sigma;
2681            let inv_sigma2 = inv_sigma * inv_sigma;
2682
2683            for k in 0..d {
2684                let ck = lambdas[k] * marginal_eigenvalues[k][multi_idx[k]];
2685                grad[k] += ck * inv_sigma;
2686            }
2687            if has_double_penalty && joint_null {
2688                grad[d] += lambdas[d] * inv_sigma;
2689            }
2690
2691            for k in 0..n_pen {
2692                let ck = if k < d {
2693                    lambdas[k] * marginal_eigenvalues[k][multi_idx[k]]
2694                } else if joint_null {
2695                    lambdas[d]
2696                } else {
2697                    0.0
2698                };
2699                // When ck == 0 (a zero λ, a zero marginal eigenvalue, or a cell
2700                // outside the joint null for the ridge penalty) every term this
2701                // index k contributes — `ck·inv_sigma − ck²·inv_sigma2` on the
2702                // diagonal and `−ck·cl·inv_sigma2` on every off-diagonal — is
2703                // exactly 0.0, so adding them to the finite running accumulators
2704                // is a bit-identical no-op. Skip the inner sweep entirely.
2705                if ck == 0.0 {
2706                    continue;
2707                }
2708                hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2709                for l in (k + 1)..n_pen {
2710                    let cl = if l < d {
2711                        lambdas[l] * marginal_eigenvalues[l][multi_idx[l]]
2712                    } else if joint_null {
2713                        lambdas[d]
2714                    } else {
2715                        0.0
2716                    };
2717                    let off = -ck * cl * inv_sigma2;
2718                    hess[[k, l]] += off;
2719                    hess[[l, k]] += off;
2720                }
2721            }
2722        }
2723
2724        if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
2725            break;
2726        }
2727    }
2728
2729    (logdet, grad, hess)
2730}
2731
2732// #1521: `KroneckerInvariantStructure` is defined once in `crate::kronecker`
2733// (the leaf data+compute module). The byte-identical copy that the carve left
2734// here is replaced by an import so the cache and this engine share one type.
2735use crate::kronecker::KroneckerInvariantStructure;
2736
2737/// Kronecker-factored reparameterization for tensor-product penalties.
2738///
2739/// Instead of eigendecomposing the full p×p balanced penalty (O(p³)), this
2740/// eigendecomposes each marginal penalty separately (O(Σ q_k³)) and computes
2741/// the joint eigensystem as the Kronecker product of marginal eigensystems.
2742pub fn kronecker_reparameterization_engine(
2743    marginal_designs: &[Array2<f64>],
2744    marginal_penalties: &[Array2<f64>],
2745    marginal_dims: &[usize],
2746    lambdas: &[f64],
2747    has_double_penalty: bool,
2748    penalty_shrinkage_floor: Option<f64>,
2749) -> Result<KroneckerReparamResult, EstimationError> {
2750    let d = marginal_dims.len();
2751    if marginal_designs.len() != d || marginal_penalties.len() != d {
2752        return Err(EstimationError::LayoutError(format!(
2753            "kronecker_reparameterization_engine: dimension mismatch: designs={}, penalties={}, dims={}",
2754            marginal_designs.len(),
2755            marginal_penalties.len(),
2756            d
2757        )));
2758    }
2759
2760    let invariant =
2761        KroneckerInvariantStructure::compute(marginal_designs, marginal_penalties, marginal_dims)?;
2762    kronecker_reparameterization_engine_with_invariant(
2763        &invariant,
2764        marginal_dims,
2765        lambdas,
2766        has_double_penalty,
2767        penalty_shrinkage_floor,
2768    )
2769}
2770
2771/// Kronecker-factored reparameterization reusing a precomputed λ-invariant
2772/// structure (eigensystems, reparameterized marginals, shrinkage scale).
2773///
2774/// Bit-identical to `kronecker_reparameterization_engine` for the same marginal
2775/// data — the only difference is that the `eigh()` / `B_k U_k` work was hoisted
2776/// out of the per-iterate path into the cached `invariant`. Only the λ-dependent
2777/// `kronecker_logdet_and_derivatives` sweep and `floor * max_bal` scaling run here.
2778pub fn kronecker_reparameterization_engine_with_invariant(
2779    invariant: &KroneckerInvariantStructure,
2780    marginal_dims: &[usize],
2781    lambdas: &[f64],
2782    has_double_penalty: bool,
2783    penalty_shrinkage_floor: Option<f64>,
2784) -> Result<KroneckerReparamResult, EstimationError> {
2785    // Arc refcount bumps — the underlying eigensystems / reparameterized
2786    // marginals are λ-invariant and shared with the cache, not deep-copied.
2787    let marginal_eigenvalues = Arc::clone(&invariant.marginal_eigenvalues);
2788    let marginal_qs = Arc::clone(&invariant.marginal_qs);
2789    let reparameterized_marginals = Arc::clone(&invariant.reparameterized_marginals);
2790
2791    // Compute shrinkage ridge from balanced penalty eigenvalue scale.
2792    let penalty_shrinkage_ridge = if let Some(floor) = penalty_shrinkage_floor {
2793        floor * invariant.max_balanced_eigenvalue
2794    } else {
2795        0.0
2796    };
2797
2798    let marginal_eigenvalue_views: Vec<_> = marginal_eigenvalues
2799        .iter()
2800        .map(|evals| evals.view())
2801        .collect();
2802    let (log_det, det1, det2) = kronecker_logdet_and_derivatives(
2803        &marginal_eigenvalue_views,
2804        marginal_dims,
2805        lambdas,
2806        has_double_penalty,
2807        penalty_shrinkage_ridge,
2808    );
2809
2810    Ok(KroneckerReparamResult {
2811        reparameterized_marginals,
2812        marginal_eigenvalues,
2813        marginal_qs,
2814        log_det,
2815        det1,
2816        det2,
2817        penalty_shrinkage_ridge,
2818        has_double_penalty,
2819        marginal_dims: marginal_dims.to_vec(),
2820    })
2821}
2822
2823/// Calculate the 2-norm condition number of a matrix.
2824///
2825/// For symmetric matrices (the dominant case for GAM Hessians/penalties),
2826/// this uses an eigenvalue path and computes:
2827///   cond_2(A) = max_i |lambda_i| / min_i |lambda_i|
2828/// which is exactly equal to the singular-value definition for symmetric A.
2829///
2830/// For non-symmetric matrices, this falls back to SVD:
2831///   cond_2(A) = sigma_max / sigma_min
2832///
2833/// This preserves semantics while avoiding full SVD in hot paths.
2834///
2835/// # Arguments
2836/// * `matrix` - The matrix to analyze
2837///
2838/// # Returns
2839/// * `Ok(condition_number)` - The condition number (max_sv / min_sv)
2840/// * `Ok(f64::INFINITY)` - If the matrix is effectively singular (min_sv < 1e-12)
2841/// * `Err` - If SVD computation fails
2842pub fn calculate_condition_number(matrix: &Array2<f64>) -> Result<f64, FaerLinalgError> {
2843    let (rows, cols) = matrix.dim();
2844    if rows == 0 || cols == 0 {
2845        return Ok(1.0);
2846    }
2847
2848    // Fast path for (near-)symmetric square matrices.
2849    if rows == cols {
2850        let mut max_abs = 0.0_f64;
2851        let mut max_asym = 0.0_f64;
2852        for i in 0..rows {
2853            for j in 0..cols {
2854                max_abs = max_abs.max(matrix[[i, j]].abs());
2855            }
2856            for j in 0..i {
2857                let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2858                if diff > max_asym {
2859                    max_asym = diff;
2860                }
2861            }
2862        }
2863        let sym_tol = max_abs.max(1.0) * 1e-12;
2864        if max_asym <= sym_tol {
2865            let (evals, _) = matrix.eigh(Side::Lower)?;
2866            let mut max_abs_eval = 0.0_f64;
2867            let mut min_abs_eval = f64::INFINITY;
2868            for &lam in evals.iter() {
2869                let s = lam.abs();
2870                max_abs_eval = max_abs_eval.max(s);
2871                min_abs_eval = min_abs_eval.min(s);
2872            }
2873            if min_abs_eval < 1e-12 {
2874                return Ok(f64::INFINITY);
2875            }
2876            return Ok(max_abs_eval / min_abs_eval);
2877        }
2878    }
2879
2880    // General matrix fallback.
2881    let (_, s, _) = matrix.svd(false, false)?;
2882    let max_sv = s.iter().fold(0.0_f64, |max, &val| max.max(val));
2883    let min_sv = s.iter().fold(f64::INFINITY, |min, &val| min.min(val));
2884    if min_sv < 1e-12 {
2885        return Ok(f64::INFINITY);
2886    }
2887    Ok(max_sv / min_sv)
2888}
2889
2890#[cfg(test)]
2891mod tests {
2892    use super::{
2893        CanonicalPenalty, SubspaceLeakageMetrics, assess_subspace_leakage,
2894        classify_eigenvalues_strict, precompute_reparam_invariant_from_canonical,
2895        report_penalty_pair_redundancy, stable_reparameterizationwith_invariant,
2896    };
2897    use crate::construction::kronecker_product;
2898    use crate::EstimationError;
2899    use faer::Mat;
2900    use gam_linalg::faer_ndarray::FaerEigh;
2901    use gam_linalg::utils::inf_norm;
2902    use ndarray::{Array1, Array2, array};
2903
2904    /// Build CanonicalPenalty values from full-width roots for tests.
2905    fn canonical_from_roots(rs_list: &[Array2<f64>], p: usize) -> Vec<CanonicalPenalty> {
2906        rs_list
2907            .iter()
2908            .map(|r| {
2909                let local = r.t().dot(r);
2910                CanonicalPenalty {
2911                    root: r.clone(),
2912                    col_range: 0..p,
2913                    total_dim: p,
2914                    nullity: 0,
2915                    local,
2916                    prior_mean: Array1::zeros(p),
2917                    positive_eigenvalues: Vec::new(),
2918                    op: None,
2919                }
2920            })
2921            .collect()
2922    }
2923
2924    fn metrics_for(
2925        qs: &Mat<f64>,
2926        rs: &[Mat<f64>],
2927        structural_rank: usize,
2928        p: usize,
2929    ) -> SubspaceLeakageMetrics {
2930        assess_subspace_leakage(qs, rs, structural_rank, p)
2931    }
2932
2933    #[test]
2934    fn subspace_leakage_iszero_for_clean_split() {
2935        let p = 4usize;
2936        let structural_rank = 2usize;
2937        let qs = Mat::<f64>::identity(p, p);
2938        let mut r0 = Mat::<f64>::zeros(2, p);
2939        r0[(0, 0)] = 1.0;
2940        r0[(1, 1)] = 2.0;
2941
2942        let m = metrics_for(&qs, &[r0], structural_rank, p);
2943        assert!(m.max_abs_sq <= 1e-16);
2944        assert!(m.max_rel_sq <= 1e-16);
2945        assert!(m.max_cross_gram_abs <= 1e-16);
2946    }
2947
2948    #[test]
2949    fn subspace_leakage_detects_null_column_energy() {
2950        let p = 4usize;
2951        let structural_rank = 2usize;
2952        let qs = Mat::<f64>::identity(p, p);
2953        let mut r0 = Mat::<f64>::zeros(1, p);
2954        r0[(0, 2)] = 3.0;
2955
2956        let m = metrics_for(&qs, &[r0], structural_rank, p);
2957        assert!(m.max_abs_sq > 0.0);
2958        assert!(m.max_rel_sq > 0.99);
2959    }
2960
2961    #[test]
2962    fn subspace_leakage_detects_qp_qn_nonorthogonality() {
2963        let p = 3usize;
2964        let structural_rank = 1usize;
2965        let mut qs = Mat::<f64>::identity(p, p);
2966        qs[(0, 1)] = 0.2;
2967        let r0 = Mat::<f64>::zeros(1, p);
2968
2969        let m = metrics_for(&qs, &[r0], structural_rank, p);
2970        assert!(m.max_cross_gram_abs > 1e-3);
2971    }
2972
2973    #[test]
2974    fn u_truncated_is_transformed_frame_in_nonzero_case() {
2975        let p = 3usize;
2976        let rs_list = vec![array![[1.0, 0.0, 0.0]]];
2977        let canonical = canonical_from_roots(&rs_list, p);
2978        let lambdas = vec![2.0];
2979        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
2980            .expect("precompute invariant");
2981        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
2982            .expect("stable reparam");
2983
2984        let expected = rep.qs.t().dot(&inv.split.q_null);
2985        let diff = &rep.u_truncated - &expected;
2986        let max_abs = inf_norm(diff.iter().copied());
2987        assert!(
2988            max_abs <= 1e-10,
2989            "u_truncated frame mismatch: max_abs={max_abs}"
2990        );
2991    }
2992
2993    #[test]
2994    fn infinite_lambda_keeps_range_penalty_block_finite_1379() {
2995        // gam#1379 / gam#1074: a genuinely infinite λ = exp(ρ) is NOT silently
2996        // clamped to a finite ceiling. The original #1379 fix added a 1e300
2997        // ceiling so `∞ · 0` could not poison the range block Σ_k λ_k S_k, but
2998        // #1074 DELETED that clamp on purpose (see the comment at the top of
2999        // `stable_reparameterizationwith_invariant`): masking ∞ hid the real
3000        // defect — the outer optimizer driving a redundant/unidentified penalty
3001        // direction off to ∞ instead of that direction being detected and
3002        // dropped. With the clamp gone, a literal `f64::INFINITY` λ surfaces as
3003        // a clean, detectable error (the eigensolver rejects the NaN-poisoned
3004        // block) rather than a silent finite success. Pin that contract: ∞ must
3005        // ERROR, not be quietly clamped.
3006        //
3007        // Fixture: two penalties on a 3-wide block. The first penalizes only
3008        // coordinate 0 (so its block S_k has structural zeros everywhere except
3009        // [0,0]); give it λ = +∞. The second penalizes coordinate 1 at a normal
3010        // λ.
3011        let p = 3usize;
3012        let rs_list = vec![array![[1.0, 0.0, 0.0]], array![[0.0, 1.0, 0.0]]];
3013        let canonical = canonical_from_roots(&rs_list, p);
3014        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3015            .expect("precompute invariant");
3016
3017        let lambdas_inf = vec![f64::INFINITY, 3.0];
3018        let inf_result =
3019            stable_reparameterizationwith_invariant(&canonical, &lambdas_inf, p, &inv, None);
3020        assert!(
3021            inf_result.is_err(),
3022            "an infinite lambda must surface as an error, not be silently clamped (#1074)"
3023        );
3024
3025        // A finite (even very large) λ must still produce an all-finite reparam:
3026        // the function is robust to large-but-finite penalties; only the
3027        // non-finite input is rejected.
3028        let lambdas_big = vec![1e300_f64, 3.0];
3029        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas_big, p, &inv, None)
3030            .expect("stable reparam at large-but-finite lambda");
3031        assert!(
3032            rep.s_transformed.iter().all(|v| v.is_finite()),
3033            "transformed penalty must be finite at large-but-finite lambda"
3034        );
3035        assert!(
3036            rep.qs.iter().all(|v| v.is_finite()),
3037            "reparam rotation must be finite at large-but-finite lambda"
3038        );
3039        assert!(
3040            rep.log_det.is_finite(),
3041            "penalty log-det must be finite at large-but-finite lambda"
3042        );
3043        assert!(
3044            rep.det1.iter().all(|v| v.is_finite()),
3045            "penalty log-det derivatives must be finite at large-but-finite lambda"
3046        );
3047    }
3048
3049    #[test]
3050    fn u_truncated_is_identitywhen_no_penalties() {
3051        let p = 4usize;
3052        let canonical: Vec<CanonicalPenalty> = Vec::new();
3053        let lambdas: Vec<f64> = Vec::new();
3054        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3055            .expect("precompute invariant");
3056        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3057            .expect("stable reparam");
3058        assert_eq!(rep.u_truncated, Array2::<f64>::eye(p));
3059    }
3060
3061    #[test]
3062    fn dense_shrinkage_floor_skips_structurally_unpenalized_range_columns() {
3063        let p = 3usize;
3064        let canonical = canonical_from_roots(&[array![[1.0, 0.0, 0.0]]], p);
3065        let invariant = super::ReparamInvariant {
3066            split: super::SubspaceSplit {
3067                q_pen: array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
3068                q_null: array![[0.0], [0.0], [1.0]],
3069            },
3070            qs_base: Array2::eye(p),
3071            has_nonzero: true,
3072            max_balanced_eigenvalue: 1.0,
3073        };
3074
3075        let rep =
3076            stable_reparameterizationwith_invariant(&canonical, &[2.0], p, &invariant, Some(1e-6))
3077                .expect("stable reparameterization");
3078        assert!(rep.s_transformed[[0, 0]] > 2.0);
3079        assert!(
3080            rep.s_transformed[[1, 1]] <= 1e-11,
3081            "structurally unpenalized range coordinate received shrinkage ridge: {}",
3082            rep.s_transformed[[1, 1]]
3083        );
3084    }
3085
3086    #[test]
3087    fn kronecker_shrinkage_floor_preserves_joint_null_space() {
3088        let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3089        let marginal_penalties = vec![
3090            array![[0.0, 0.0], [0.0, 2.0]],
3091            array![[0.0, 0.0], [0.0, 3.0]],
3092        ];
3093        let marginal_dims = vec![2usize, 2usize];
3094        let lambdas = vec![5.0, 7.0];
3095
3096        let rep = super::kronecker_reparameterization_engine(
3097            &marginal_designs,
3098            &marginal_penalties,
3099            &marginal_dims,
3100            &lambdas,
3101            false,
3102            Some(1e-6),
3103        )
3104        .expect("kronecker reparameterization");
3105        assert!(rep.penalty_shrinkage_ridge > 0.0);
3106
3107        let s = rep.materialize_s_transformed(&lambdas);
3108        assert!(
3109            s[[0, 0]].abs() <= 1e-14,
3110            "joint tensor null direction must remain unpenalized, got {}",
3111            s[[0, 0]]
3112        );
3113        assert!(s[[1, 1]] > lambdas[1] * 3.0);
3114        assert!(s[[2, 2]] > lambdas[0] * 2.0);
3115        assert!(s[[3, 3]] > lambdas[0] * 2.0 + lambdas[1] * 3.0);
3116
3117        let tensor_roots = vec![
3118            array![
3119                [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3120                [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3121            ],
3122            array![
3123                [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3124                [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3125            ],
3126        ];
3127        let dense = rep
3128            .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3129            .expect("dense artifact materialization");
3130        assert_eq!(dense.e_transformed.nrows(), 3);
3131        assert_eq!(dense.u_truncated.ncols(), 1);
3132    }
3133
3134    #[test]
3135    fn kronecker_memoized_invariant_is_bit_identical_to_unmemoized_engine() {
3136        // The hot-path memoization (compute the marginal eigensystems /
3137        // reparameterized marginals once, reuse across outer iterates) must
3138        // produce a KroneckerReparamResult that is *bit-identical* to the
3139        // unmemoized engine for the same marginal data and λ — the cached work
3140        // is literally the same eigendecomposition. Cover several λ on one fixed
3141        // invariant structure (the realistic outer-loop pattern).
3142        let marginal_designs = vec![
3143            array![[1.0, 0.3, -0.2], [0.4, 1.0, 0.1], [-0.1, 0.2, 1.0]],
3144            array![[1.0, -0.5], [0.2, 1.0], [0.7, 0.3]],
3145        ];
3146        let marginal_penalties = vec![
3147            array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]],
3148            array![[3.0, -1.5], [-1.5, 3.0]],
3149        ];
3150        let marginal_dims = vec![3usize, 2usize];
3151
3152        let invariant = super::KroneckerInvariantStructure::compute(
3153            &marginal_designs,
3154            &marginal_penalties,
3155            &marginal_dims,
3156        )
3157        .expect("invariant structure");
3158
3159        for lambdas in [
3160            vec![5.0, 7.0],
3161            vec![0.0, 7.0],
3162            vec![5.0, 0.0],
3163            vec![1e-3, 1e3],
3164        ] {
3165            for floor in [None, Some(1e-6)] {
3166                let unmemoized = super::kronecker_reparameterization_engine(
3167                    &marginal_designs,
3168                    &marginal_penalties,
3169                    &marginal_dims,
3170                    &lambdas,
3171                    true,
3172                    floor,
3173                )
3174                .expect("unmemoized engine");
3175                let memoized = super::kronecker_reparameterization_engine_with_invariant(
3176                    &invariant,
3177                    &marginal_dims,
3178                    &lambdas,
3179                    true,
3180                    floor,
3181                )
3182                .expect("memoized engine");
3183
3184                assert_eq!(memoized.log_det.to_bits(), unmemoized.log_det.to_bits());
3185                assert_eq!(
3186                    memoized.penalty_shrinkage_ridge.to_bits(),
3187                    unmemoized.penalty_shrinkage_ridge.to_bits()
3188                );
3189                for (a, b) in memoized.det1.iter().zip(unmemoized.det1.iter()) {
3190                    assert_eq!(a.to_bits(), b.to_bits());
3191                }
3192                for (a, b) in memoized.det2.iter().zip(unmemoized.det2.iter()) {
3193                    assert_eq!(a.to_bits(), b.to_bits());
3194                }
3195                for (ma, ua) in memoized
3196                    .reparameterized_marginals
3197                    .iter()
3198                    .zip(unmemoized.reparameterized_marginals.iter())
3199                {
3200                    for (a, b) in ma.iter().zip(ua.iter()) {
3201                        assert_eq!(a.to_bits(), b.to_bits());
3202                    }
3203                }
3204                for (mq, uq) in memoized
3205                    .marginal_qs
3206                    .iter()
3207                    .zip(unmemoized.marginal_qs.iter())
3208                {
3209                    for (a, b) in mq.iter().zip(uq.iter()) {
3210                        assert_eq!(a.to_bits(), b.to_bits());
3211                    }
3212                }
3213            }
3214        }
3215    }
3216
3217    #[test]
3218    fn kronecker_double_penalty_shrinks_only_joint_null_space() {
3219        let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3220        let marginal_penalties = vec![
3221            array![[0.0, 0.0], [0.0, 2.0]],
3222            array![[0.0, 0.0], [0.0, 3.0]],
3223        ];
3224        let marginal_dims = vec![2usize, 2usize];
3225        let lambdas = vec![5.0, 7.0, 11.0];
3226
3227        let rep = super::kronecker_reparameterization_engine(
3228            &marginal_designs,
3229            &marginal_penalties,
3230            &marginal_dims,
3231            &lambdas,
3232            true,
3233            None,
3234        )
3235        .expect("kronecker reparameterization");
3236
3237        let s = rep.materialize_s_transformed(&lambdas);
3238        let expected = [11.0, 21.0, 10.0, 31.0];
3239        for (idx, expected_diag) in expected.iter().copied().enumerate() {
3240            assert!(
3241                (s[[idx, idx]] - expected_diag).abs() <= 1e-12,
3242                "diagonal {idx} got {}, expected {expected_diag}",
3243                s[[idx, idx]]
3244            );
3245        }
3246
3247        let expected_logdet: f64 = expected.iter().map(|v| f64::ln(*v)).sum();
3248        assert!((rep.log_det - expected_logdet).abs() <= 1e-12);
3249        assert!(
3250            (rep.det1[2] - 1.0).abs() <= 1e-12,
3251            "double-penalty derivative must come only from the joint null mode, got {}",
3252            rep.det1[2]
3253        );
3254        assert!(rep.det2[[2, 2]].abs() <= 1e-12);
3255
3256        let tensor_roots = vec![
3257            array![
3258                [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3259                [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3260            ],
3261            array![
3262                [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3263                [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3264            ],
3265        ];
3266        let dense = rep
3267            .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3268            .expect("dense artifact materialization");
3269        for (idx, expected_diag) in expected.iter().copied().enumerate() {
3270            assert!(
3271                (dense.s_transformed[[idx, idx]] - expected_diag).abs() <= 1e-12,
3272                "dense artifact diagonal {idx} got {}, expected {expected_diag}",
3273                dense.s_transformed[[idx, idx]]
3274            );
3275        }
3276    }
3277
3278    #[test]
3279    fn transformed_penalty_is_diagonal_in_transformed_frame() {
3280        let p = 3usize;
3281        let inv_sqrt2 = 2.0_f64.sqrt().recip();
3282        // Penalize a rotated direction in original space so Qs is non-trivial.
3283        let rs_list = vec![array![[inv_sqrt2, inv_sqrt2, 0.0]]];
3284        let canonical = canonical_from_roots(&rs_list, p);
3285        let lambdas = vec![4.0];
3286        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3287            .expect("precompute invariant");
3288        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3289            .expect("stable reparam");
3290
3291        assert_eq!(rep.e_transformed.nrows(), 1);
3292        assert!(rep.e_transformed[[0, 0]].abs() > 0.0);
3293        assert!(rep.e_transformed[[0, 1]].abs() <= 1e-12);
3294        assert!(rep.e_transformed[[0, 2]].abs() <= 1e-12);
3295        // Exact pseudo-logdet on the structural penalized block has no
3296        // delta-dependent nullspace normalization.
3297        let expected_det1 = 1.0_f64;
3298        assert!((rep.det1[0] - expected_det1).abs() <= 1e-12);
3299
3300        let s = rep.s_transformed;
3301        let mut max_offdiag = 0.0_f64;
3302        for i in 0..p {
3303            for j in 0..p {
3304                if i != j {
3305                    max_offdiag = max_offdiag.max(s[[i, j]].abs());
3306                }
3307            }
3308        }
3309        assert!(
3310            max_offdiag <= 1e-10,
3311            "transformed penalty should be diagonal, max offdiag={max_offdiag}"
3312        );
3313        assert!(s[[1, 1]].abs() <= 1e-10);
3314        assert!(s[[2, 2]].abs() <= 1e-10);
3315    }
3316
3317    #[test]
3318    fn det1_matches_rank_for_single_full_rank_penalty() {
3319        let p = 2usize;
3320        let inv_sqrt2 = 2.0_f64.sqrt().recip();
3321        // Q^T for a 45-degree rotation.
3322        let q_t = [[inv_sqrt2, inv_sqrt2], [-inv_sqrt2, inv_sqrt2]];
3323        // R = diag(3, 1) * Q^T gives S = Q * diag(9, 1) * Q^T.
3324        let rs = array![
3325            [3.0 * q_t[0][0], 3.0 * q_t[0][1]],
3326            [1.0 * q_t[1][0], 1.0 * q_t[1][1]]
3327        ];
3328        let rs_list = vec![rs];
3329        let canonical = canonical_from_roots(&rs_list, p);
3330        let lambdas = vec![5.0];
3331
3332        let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3333            .expect("precompute invariant");
3334        let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3335            .expect("stable reparam");
3336
3337        assert_eq!(rep.e_transformed.nrows(), p);
3338        let det1 = rep.det1[0];
3339        // Exact pseudo-logdet on the structural penalized block:
3340        //   det1 = lambda * sum_l d_l / (lambda*d_l)
3341        // where d_l are eigenvalues of S_k.
3342        let s_k_eigs = [9.0_f64, 1.0_f64];
3343        let lambda = 5.0_f64;
3344        let expected_det1: f64 = s_k_eigs.iter().map(|&d| lambda * d / (lambda * d)).sum();
3345        assert!(
3346            (det1 - expected_det1).abs() <= 1e-12,
3347            "expected det1={expected_det1}, got {det1}",
3348        );
3349
3350        let s = rep.s_transformed;
3351        assert!(s[[0, 1]].abs() <= 1e-10);
3352        assert!(s[[1, 0]].abs() <= 1e-10);
3353        assert!(s[[0, 0]] > 0.0);
3354        assert!(s[[1, 1]] > 0.0);
3355    }
3356
3357    #[test]
3358    fn kronecker_reparam_logdet_matches_dense() {
3359        // 2D tensor product: q1=3, q2=4.
3360        // Marginal penalties: second-order difference matrices.
3361        let q1 = 3;
3362        let q2 = 4;
3363        let s1 = {
3364            let mut s = Array2::<f64>::zeros((q1, q1));
3365            // D2' D2 for order 2 on 3 points: [[1,-2,1],[-2,4,-2],[1,-2,1]]... simplified
3366            s[[0, 0]] = 1.0;
3367            s[[0, 1]] = -1.0;
3368            s[[1, 0]] = -1.0;
3369            s[[1, 1]] = 2.0;
3370            s[[1, 2]] = -1.0;
3371            s[[2, 1]] = -1.0;
3372            s[[2, 2]] = 1.0;
3373            s
3374        };
3375        let s2 = {
3376            let mut s = Array2::<f64>::zeros((q2, q2));
3377            s[[0, 0]] = 1.0;
3378            s[[0, 1]] = -1.0;
3379            s[[1, 0]] = -1.0;
3380            s[[1, 1]] = 2.0;
3381            s[[1, 2]] = -1.0;
3382            s[[2, 1]] = -1.0;
3383            s[[2, 2]] = 2.0;
3384            s[[2, 3]] = -1.0;
3385            s[[3, 2]] = -1.0;
3386            s[[3, 3]] = 1.0;
3387            s
3388        };
3389
3390        let lambdas = [2.5, 1.3];
3391        // Build dense Kronecker penalty: λ1 (S1⊗I) + λ2 (I⊗S2).
3392        let p = q1 * q2;
3393        let i1 = Array2::<f64>::eye(q1);
3394        let i2 = Array2::<f64>::eye(q2);
3395        let pen0 = kronecker_product(&s1, &i2);
3396        let pen1 = kronecker_product(&i1, &s2);
3397        let mut s_dense = Array2::<f64>::zeros((p, p));
3398        s_dense.scaled_add(lambdas[0], &pen0);
3399        s_dense.scaled_add(lambdas[1], &pen1);
3400
3401        // Dense eigendecomposition for reference pseudo-logdet.
3402        let (evals_dense, _): (ndarray::Array1<f64>, ndarray::Array2<f64>) =
3403            s_dense.eigh(faer::Side::Lower).unwrap();
3404        let tol = 1e-12;
3405        let ref_logdet: f64 = evals_dense
3406            .iter()
3407            .filter(|&&v: &&f64| v > tol)
3408            .map(|&v: &f64| v.ln())
3409            .sum();
3410
3411        // Kronecker reparameterization engine.
3412        let marginal_designs = vec![
3413            Array2::<f64>::eye(q1), // dummy designs
3414            Array2::<f64>::eye(q2),
3415        ];
3416        let marginal_penalties = vec![s1, s2];
3417        let kron_result = super::kronecker_reparameterization_engine(
3418            &marginal_designs,
3419            &marginal_penalties,
3420            &[q1, q2],
3421            &lambdas,
3422            false,
3423            None,
3424        )
3425        .unwrap();
3426
3427        let diff = (kron_result.log_det - ref_logdet).abs();
3428        assert!(
3429            diff < 1e-8,
3430            "Kronecker logdet {:.10} vs dense {:.10}, diff={:.3e}",
3431            kron_result.log_det,
3432            ref_logdet,
3433            diff,
3434        );
3435
3436        // Check derivatives via central FD in rho-space (rho = log lambda).
3437        let rhos: Vec<f64> = lambdas.iter().map(|&l| l.ln()).collect();
3438        let eps = 1e-5;
3439        for k in 0..2 {
3440            let mut rho_plus = rhos.clone();
3441            rho_plus[k] += eps;
3442            let mut rho_minus = rhos.clone();
3443            rho_minus[k] -= eps;
3444            let lam_plus: Vec<f64> = rho_plus.iter().map(|&r| r.exp()).collect();
3445            let lam_minus: Vec<f64> = rho_minus.iter().map(|&r| r.exp()).collect();
3446            let result_plus = super::kronecker_reparameterization_engine(
3447                &marginal_designs,
3448                &marginal_penalties,
3449                &[q1, q2],
3450                &lam_plus,
3451                false,
3452                None,
3453            )
3454            .unwrap();
3455            let result_minus = super::kronecker_reparameterization_engine(
3456                &marginal_designs,
3457                &marginal_penalties,
3458                &[q1, q2],
3459                &lam_minus,
3460                false,
3461                None,
3462            )
3463            .unwrap();
3464            let fd_deriv = (result_plus.log_det - result_minus.log_det) / (2.0 * eps);
3465            let analytic_deriv = kron_result.det1[k];
3466            let rel_err = if analytic_deriv.abs() > 1e-10 {
3467                (fd_deriv - analytic_deriv).abs() / analytic_deriv.abs()
3468            } else {
3469                (fd_deriv - analytic_deriv).abs()
3470            };
3471            assert!(
3472                rel_err < 1e-4,
3473                "det1[{k}] mismatch: analytic={:.8}, fd={:.8}, rel_err={:.3e}",
3474                analytic_deriv,
3475                fd_deriv,
3476                rel_err,
3477            );
3478        }
3479    }
3480
3481    #[test]
3482    fn classify_strict_rejects_nan_eigenvalue() {
3483        let mut eigs = [1.0, f64::NAN, 0.5];
3484        match classify_eigenvalues_strict(&mut eigs, "test_nan") {
3485            Err(EstimationError::PenaltySpectrumNonFinite {
3486                context,
3487                index,
3488                value,
3489            }) => {
3490                assert_eq!(context, "test_nan");
3491                assert_eq!(index, 1);
3492                assert!(value.is_nan());
3493            }
3494            other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3495        }
3496    }
3497
3498    #[test]
3499    fn classify_strict_rejects_inf_eigenvalue() {
3500        let mut eigs = [1.0, 0.5, f64::INFINITY];
3501        match classify_eigenvalues_strict(&mut eigs, "test_inf") {
3502            Err(EstimationError::PenaltySpectrumNonFinite { index, value, .. }) => {
3503                assert_eq!(index, 2);
3504                assert!(value.is_infinite());
3505            }
3506            other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3507        }
3508    }
3509
3510    #[test]
3511    fn classify_strict_rejects_materially_indefinite() {
3512        // -1e-2 with scale ~1.0 is well above any reasonable roundoff tolerance.
3513        let mut eigs = [1.0, -1e-2, 0.5];
3514        match classify_eigenvalues_strict(&mut eigs, "test_indef") {
3515            Err(EstimationError::PenaltySpectrumIndefinite {
3516                context,
3517                index,
3518                value,
3519                ..
3520            }) => {
3521                assert_eq!(context, "test_indef");
3522                assert_eq!(index, 1);
3523                assert!((value + 1e-2).abs() <= 1e-15);
3524            }
3525            other => panic!("expected PenaltySpectrumIndefinite, got {:?}", other),
3526        }
3527    }
3528
3529    #[test]
3530    fn classify_strict_accepts_roundoff_negative() {
3531        // -1e-16 * scale is well within tol = 64 * eps * p * scale.
3532        let scale = 1.0_f64;
3533        let roundoff = -1e-16 * scale;
3534        let mut eigs = [scale, 0.5 * scale, roundoff, 0.25 * scale];
3535        classify_eigenvalues_strict(&mut eigs, "test_roundoff").expect("roundoff must classify");
3536        // The roundoff eigenvalue is snapped to exact zero.
3537        assert_eq!(eigs[2], 0.0);
3538        // Strictly positive entries must be preserved.
3539        assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3540    }
3541
3542    #[test]
3543    fn classify_strict_snaps_subtol_positive_to_zero() {
3544        // Positive eigenvalues below the tolerance are also snapped to exact 0
3545        // so downstream rank counts and pseudo-logdets are deterministic.
3546        let scale = 10.0_f64;
3547        let subtol = 1e-15 * scale;
3548        let mut eigs = [scale, subtol];
3549        classify_eigenvalues_strict(&mut eigs, "test_sub_pos").expect("sub-tol positive ok");
3550        assert_eq!(eigs[1], 0.0);
3551    }
3552
3553    /// Build a `CanonicalPenalty` directly from a symmetric `local` matrix.
3554    /// Bypasses root extraction — the redundancy diagnostic only reads `local`
3555    /// and `col_range`, so the rest is filler.
3556    fn canonical_from_local(
3557        local: Array2<f64>,
3558        col_range: std::ops::Range<usize>,
3559        total_dim: usize,
3560    ) -> CanonicalPenalty {
3561        let block_dim = local.nrows();
3562        // A trivially valid root: zero rank. The diagnostic doesn't read root.
3563        let root = Array2::<f64>::zeros((0, block_dim));
3564        CanonicalPenalty {
3565            root,
3566            col_range,
3567            total_dim,
3568            nullity: 0,
3569            local,
3570            prior_mean: Array1::zeros(block_dim),
3571            positive_eigenvalues: Vec::new(),
3572            op: None,
3573        }
3574    }
3575
3576    #[test]
3577    fn report_penalty_pair_redundancy_detects_identical_pair() {
3578        // Penalty 0: a "generic" SPD matrix on cols 0..3.
3579        let s0 = ndarray::array![[2.0, 0.5, 0.0], [0.5, 1.0, 0.25], [0.0, 0.25, 1.5],];
3580        // Penalties 1 and 2: identical block-local penalty on the SAME col_range.
3581        // This is the Z₂-symmetric saddle scenario.
3582        let s_shared = ndarray::array![[1.0, -0.5, 0.0], [-0.5, 2.0, -0.5], [0.0, -0.5, 1.0],];
3583
3584        let bundle = vec![
3585            canonical_from_local(s0, 0..3, 3),
3586            canonical_from_local(s_shared.clone(), 0..3, 3),
3587            canonical_from_local(s_shared, 0..3, 3),
3588        ];
3589
3590        let redundant = report_penalty_pair_redundancy(&bundle);
3591
3592        // Exactly one redundant pair: (1, 2). Pairs (0, 1) and (0, 2) involve
3593        // distinct matrices and must NOT be flagged.
3594        assert_eq!(
3595            redundant.len(),
3596            1,
3597            "expected exactly one redundant pair, got {:?}",
3598            redundant
3599        );
3600        let (i, j, cos) = redundant[0];
3601        assert_eq!((i, j), (1, 2));
3602        assert!(
3603            cos > 1.0 - 1e-12,
3604            "cosine for identical penalties should be ~1.0, got {cos}"
3605        );
3606    }
3607
3608    #[test]
3609    fn report_penalty_pair_redundancy_skips_different_col_ranges() {
3610        // Two identical local matrices but on disjoint col_ranges. The
3611        // function must NOT flag them — they live in different parameter
3612        // subspaces by construction.
3613        let s = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
3614        let bundle = vec![
3615            canonical_from_local(s.clone(), 0..2, 4),
3616            canonical_from_local(s, 2..4, 4),
3617        ];
3618        let redundant = report_penalty_pair_redundancy(&bundle);
3619        assert!(
3620            redundant.is_empty(),
3621            "different col_ranges must not be flagged"
3622        );
3623    }
3624}