1use crate::EstimationError;
2use crate::basis::analyze_penalty_block;
3use crate::smooth::PenaltyStructureHint;
4use faer::linalg::matmul::matmul;
5use faer::{Accum, Mat, MatRef, Par, Side};
6use gam_linalg::faer_ndarray::{FaerEigh, FaerLinalgError, FaerSvd};
7use gam_linalg::matrix::symmetrize_in_place;
8use gam_linalg::utils::KahanSum;
9use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut2, Axis, s};
10use rayon::iter::{
11 IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
12};
13use std::collections::{BTreeMap, HashSet};
14use std::ops::Range;
15use std::sync::Arc;
16
17const REL_PSD_FLOOR: f64 = 1.0e-8;
27
28#[derive(Clone)]
29pub enum PenaltyRepresentation {
30 Dense(Array2<f64>),
31 Banded {
32 bands: Vec<Array1<f64>>,
33 offsets: Vec<i32>,
34 },
35 Kronecker {
36 left: Array2<f64>,
42 right: Array2<f64>,
43 },
44}
45
46impl PenaltyRepresentation {
47 pub fn block_dimension(&self) -> usize {
49 match self {
50 PenaltyRepresentation::Dense(matrix) => matrix.nrows(),
51 PenaltyRepresentation::Banded { bands, offsets } => {
52 let mut dim = 0usize;
53 for (band, &offset) in bands.iter().zip(offsets.iter()) {
54 let len = band.len();
55 let extent = if offset >= 0 {
56 len + offset as usize
57 } else {
58 len + (-offset) as usize
59 };
60 dim = dim.max(extent);
61 }
62 dim
63 }
64 PenaltyRepresentation::Kronecker { left, right } => left.nrows() * right.nrows(),
65 }
66 }
67
68 pub fn to_block_dense(&self) -> Array2<f64> {
71 match self {
72 PenaltyRepresentation::Dense(matrix) => matrix.clone(),
73 PenaltyRepresentation::Banded { bands, offsets } => {
74 let dim = self.block_dimension();
75 let mut dense = Array2::zeros((dim, dim));
76 let positive_offsets: HashSet<usize> = offsets
77 .iter()
78 .filter_map(|&off| (off >= 0).then_some(off as usize))
79 .collect();
80 for (band, &offset) in bands.iter().zip(offsets.iter()) {
81 let off = offset.unsigned_abs() as usize;
82 if offset < 0 && positive_offsets.contains(&off) {
83 continue;
84 }
85 for (idx, &value) in band.iter().enumerate() {
86 let (i, j) = if offset >= 0 {
87 (idx, idx + off)
88 } else {
89 (idx + off, idx)
90 };
91 if i >= dim || j >= dim {
92 continue;
93 }
94 dense[[i, j]] = value;
95 dense[[j, i]] = value;
96 }
97 }
98 dense
99 }
100 PenaltyRepresentation::Kronecker { left, right } => {
101 let (lrows, l_cols) = left.dim();
102 let (rrows, r_cols) = right.dim();
103 let mut result = Array2::zeros((lrows * rrows, l_cols * r_cols));
104 for i in 0..lrows {
105 for j in 0..l_cols {
106 let scale = left[(i, j)];
107 if scale == 0.0 {
108 continue;
109 }
110 let mut block = result.slice_mut(s![
111 i * rrows..(i + 1) * rrows,
112 j * r_cols..(j + 1) * r_cols
113 ]);
114 block.assign(&(right * scale));
115 }
116 }
117 result
118 }
119 }
120 }
121}
122
123#[derive(Clone)]
124pub struct PenaltyMatrix {
125 pub col_range: Range<usize>,
126 pub representation: PenaltyRepresentation,
127}
128
129impl PenaltyMatrix {
130 fn accumulate_into(&self, mut dest: ArrayViewMut2<'_, f64>, weight: f64) {
131 if weight == 0.0 {
132 return;
133 }
134 match &self.representation {
135 PenaltyRepresentation::Dense(block) => {
136 dest.scaled_add(weight, block);
137 }
138 PenaltyRepresentation::Banded { bands, offsets } => {
139 let positive_offsets: HashSet<usize> = offsets
140 .iter()
141 .filter_map(|&off| (off >= 0).then_some(off as usize))
142 .collect();
143 for (band, &offset) in bands.iter().zip(offsets.iter()) {
144 let off = offset.unsigned_abs() as usize;
145 if offset < 0 && positive_offsets.contains(&off) {
146 continue;
147 }
148 for (idx, &value) in band.iter().enumerate() {
149 let (i, j) = if offset >= 0 {
150 (idx, idx + off)
151 } else {
152 (idx + off, idx)
153 };
154 let Some(entry_ij) = dest.get_mut((i, j)) else {
155 continue;
156 };
157 *entry_ij += weight * value;
158 if i != j
159 && let Some(entry_ji) = dest.get_mut((j, i))
160 {
161 *entry_ji += weight * value;
162 }
163 }
164 }
165 }
166 PenaltyRepresentation::Kronecker { left, right } => {
167 let (lrows, l_cols) = left.dim();
168 let (rrows, r_cols) = right.dim();
169 for i in 0..lrows {
170 for j in 0..l_cols {
171 let scale = left[(i, j)] * weight;
172 if scale == 0.0 {
173 continue;
174 }
175 let mut block = dest.slice_mut(s![
176 i * rrows..(i + 1) * rrows,
177 j * r_cols..(j + 1) * r_cols
178 ]);
179 block.scaled_add(scale, right);
180 }
181 }
182 }
183 }
184 }
185
186 pub fn to_dense(&self, total_dim: usize) -> Array2<f64> {
187 let mut dense = Array2::<f64>::zeros((total_dim, total_dim));
188 self.accumulate_into(
189 dense.slice_mut(s![self.col_range.clone(), self.col_range.clone()]),
190 1.0,
191 );
192 dense
193 }
194}
195
196pub(crate) fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
197 let (rows, cols) = array.dim();
198 Mat::from_fn(rows, cols, |i, j| array[[i, j]])
199}
200
201pub(crate) fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
202 let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
203 for i in 0..mat.nrows() {
204 for j in 0..mat.ncols() {
205 out[[i, j]] = mat[(i, j)];
206 }
207 }
208 out
209}
210
211fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
212 let (rows, cols) = matrix.shape();
213 let mut maxval = 0.0_f64;
214 for i in 0..rows {
215 for j in 0..cols {
216 let val = matrix[(i, j)];
217 if val.is_finite() {
218 maxval = maxval.max(val.abs());
219 }
220 }
221 }
222 maxval
223}
224
225fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
226 let (rows, cols) = matrix.as_ref().shape();
227 assert_eq!(rows, cols, "Matrix must be square for sanitization");
228
229 let mut sanitized = matrix.clone();
230
231 for i in 0..rows {
232 let diag = sanitized[(i, i)];
233 if !diag.is_finite() {
234 sanitized[(i, i)] = 0.0;
235 }
236 for j in (i + 1)..cols {
237 let mut upper = sanitized[(i, j)];
238 let mut lower = sanitized[(j, i)];
239 if !upper.is_finite() {
240 upper = 0.0;
241 }
242 if !lower.is_finite() {
243 lower = 0.0;
244 }
245 let avg = 0.5 * (upper + lower);
246 sanitized[(i, j)] = avg;
247 sanitized[(j, i)] = avg;
248 }
249 }
250
251 let scale = mat_max_abs_element(sanitized.as_ref());
252 let tiny = (scale * 1e-14).max(1e-30);
253 for i in 0..rows {
254 for j in 0..cols {
255 let val = sanitized[(i, j)];
256 if !val.is_finite() {
257 sanitized[(i, j)] = 0.0;
258 } else if val.abs() < tiny {
259 sanitized[(i, j)] = 0.0;
260 }
261 }
262 }
263
264 sanitized
265}
266
267fn penalty_from_root_faer(root: &Mat<f64>) -> Mat<f64> {
268 let cols = root.ncols();
269 let mut full = Mat::<f64>::zeros(cols, cols);
270 let root_ref = root.as_ref();
271 let root_t = root_ref.transpose();
272 matmul(
273 full.as_mut(),
274 Accum::Replace,
275 root_t,
276 root_ref,
277 1.0,
278 Par::Seq,
279 );
280 sanitize_symmetric_faer(&full)
281}
282
283fn symmetrize_faer_matrix_in_place(matrix: &mut Mat<f64>) {
284 let n = matrix.nrows().min(matrix.ncols());
285 for i in 0..n {
286 for j in 0..i {
287 let avg = 0.5 * (matrix[(i, j)] + matrix[(j, i)]);
288 matrix[(i, j)] = avg;
289 matrix[(j, i)] = avg;
290 }
291 }
292}
293
294fn orthogonal_similarity_transform_faer(
295 matrix: &Mat<f64>,
296 block_dim: usize,
297 orthogonal: &Mat<f64>,
298) -> Mat<f64> {
299 let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
300 let cols = orthogonal.ncols();
301 let mut temp = Mat::<f64>::zeros(block_dim, cols);
302 matmul(
303 temp.as_mut(),
304 Accum::Replace,
305 matrix_block,
306 orthogonal.as_ref(),
307 1.0,
308 Par::Seq,
309 );
310 let mut rotated = Mat::<f64>::zeros(cols, cols);
311 matmul(
312 rotated.as_mut(),
313 Accum::Replace,
314 orthogonal.transpose(),
315 temp.as_ref(),
316 1.0,
317 Par::Seq,
318 );
319 symmetrize_faer_matrix_in_place(&mut rotated);
320 rotated
321}
322
323fn trace_penalty_in_orthogonal_basis(
324 matrix: &Mat<f64>,
325 block_dim: usize,
326 orthogonal: &Mat<f64>,
327 rotated_eigenvalues: &[f64],
328 delta: f64,
329) -> f64 {
330 let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
331 let cols = orthogonal.ncols();
332 assert!(rotated_eigenvalues.len() >= cols);
333 let mut projected = Mat::<f64>::zeros(block_dim, cols);
334 matmul(
335 projected.as_mut(),
336 Accum::Replace,
337 matrix_block,
338 orthogonal.as_ref(),
339 1.0,
340 Par::Seq,
341 );
342 let mut trace = KahanSum::default();
343 for l in 0..cols {
344 let mut diag_ll = KahanSum::default();
345 for i in 0..block_dim {
346 diag_ll.add(orthogonal[(i, l)] * projected[(i, l)]);
347 }
348 trace.add(diag_ll.sum() / (rotated_eigenvalues[l] + delta));
349 }
350 trace.sum()
351}
352
353pub fn trace_reduced_penalty_covariance(
354 reduced_penalty: &Array2<f64>,
355 covariance_basis: &Array2<f64>,
356) -> f64 {
357 assert_eq!(
358 reduced_penalty.dim(),
359 covariance_basis.dim(),
360 "trace_reduced_penalty_covariance dimension mismatch"
361 );
362 let r = covariance_basis.nrows();
363 let mut trace = KahanSum::default();
364 for i in 0..r {
365 for j in 0..r {
366 trace.add(covariance_basis[[i, j]] * reduced_penalty[[j, i]]);
367 }
368 }
369 trace.sum()
370}
371
372pub fn trace_penalty_covariance_in_orthogonal_basis(
373 matrix: &Array2<f64>,
374 orthogonal: &Array2<f64>,
375 covariance_basis: &Array2<f64>,
376) -> f64 {
377 let reduced = gam_linalg::faer_ndarray::fast_ab(
378 &gam_linalg::faer_ndarray::fast_atb(orthogonal, matrix),
379 orthogonal,
380 );
381 trace_reduced_penalty_covariance(&reduced, covariance_basis)
382}
383
384fn classify_eigenvalues_strict(
404 eigenvalues: &mut [f64],
405 context: &str,
406) -> Result<(), EstimationError> {
407 const C_EPS_P_FACTOR: f64 = 64.0;
408 let p = eigenvalues.len();
413
414 let mut scale = 0.0_f64;
415 for (idx, &val) in eigenvalues.iter().enumerate() {
416 if !val.is_finite() {
417 return Err(EstimationError::PenaltySpectrumNonFinite {
418 context: context.to_string(),
419 index: idx,
420 value: val,
421 });
422 }
423 scale = scale.max(val.abs());
424 }
425
426 let machine_floor = C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale;
433 let tolerance = machine_floor
434 .max(REL_PSD_FLOOR * scale)
435 .max(f64::MIN_POSITIVE);
436
437 for (idx, val) in eigenvalues.iter_mut().enumerate() {
438 if val.abs() <= tolerance {
439 *val = 0.0;
440 } else if *val < 0.0 {
441 return Err(EstimationError::PenaltySpectrumIndefinite {
442 context: context.to_string(),
443 index: idx,
444 value: *val,
445 tolerance,
446 scale,
447 });
448 }
449 }
450 Ok(())
451}
452
453fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
454 matrix: &M,
455 context: &str,
456 validate_input: Validate,
457 sanitize: Sanitize,
458 mut eig_call: EigCall,
459 map_error: MapErr,
460) -> Result<(Vec<f64>, V), EstimationError>
461where
462 Validate: Fn(&M, &str) -> Result<(), EstimationError>,
463 Sanitize: Fn(&M) -> M,
464 EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
465 MapErr: Fn(E, &str) -> EstimationError,
466{
467 validate_input(matrix, context)?;
468
469 let candidate = sanitize(matrix);
475 match eig_call(&candidate) {
476 Ok((mut eigenvalues, eigenvectors)) => {
477 classify_eigenvalues_strict(&mut eigenvalues, context)?;
478 Ok((eigenvalues, eigenvectors))
479 }
480 Err(err) => Err(map_error(err, context)),
481 }
482}
483
484pub(crate) fn robust_eigh_faer(
485 matrix: &Mat<f64>,
486 side: Side,
487 context: &str,
488) -> Result<(Vec<f64>, Mat<f64>), EstimationError> {
489 robust_eighwith_policy(
490 matrix,
491 context,
492 |mat, ctx| {
493 let (rows, cols) = mat.as_ref().shape();
494 for i in 0..rows {
495 for j in 0..cols {
496 let val = mat[(i, j)];
497 if !val.is_finite() {
498 let max_abs = mat_max_abs_element(mat.as_ref());
499 crate::bail_invalid_estim!(
500 "{} contains non-finite entries (max finite magnitude {:.3e})",
501 ctx,
502 max_abs
503 );
504 }
505 }
506 }
507 Ok(())
508 },
509 sanitize_symmetric_faer,
510 |candidate| {
511 let eig = candidate.as_ref().self_adjoint_eigen(side)?;
512 let diag = eig.S();
513 let mut eigenvalues = Vec::with_capacity(diag.dim());
514 for idx in 0..diag.dim() {
515 eigenvalues.push(diag[idx]);
516 }
517
518 let vectors_ref = eig.U();
519 let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
520 for i in 0..vectors_ref.nrows() {
521 for j in 0..vectors_ref.ncols() {
522 eigenvectors[(i, j)] = vectors_ref[(i, j)];
523 }
524 }
525 Ok((eigenvalues, eigenvectors))
526 },
527 |err, _ctx| {
528 EstimationError::EigendecompositionFailed(FaerLinalgError::SelfAdjointEigen(err))
529 },
530 )
531}
532
533fn robust_eigh(
534 matrix: &Array2<f64>,
535 side: Side,
536 context: &str,
537) -> Result<(Array1<f64>, Array2<f64>), EstimationError> {
538 let matrix_faer = array_to_faer(matrix);
539 let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
540 Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
541}
542
543pub(crate) fn kronecker_marginal_eigensystems(
544 marginal_penalties: &[Array2<f64>],
545 context: &str,
546) -> Result<Vec<(Array1<f64>, Array2<f64>)>, EstimationError> {
547 let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
548 for (k, penalty) in marginal_penalties.iter().enumerate() {
549 eigensystems.push(robust_eigh(
550 penalty,
551 Side::Lower,
552 &format!("{context} marginal {k}"),
553 )?);
554 }
555 Ok(eigensystems)
556}
557
558#[derive(Debug, Clone, Copy)]
559struct SubspaceLeakageMetrics {
560 max_abs_sq: f64,
561 max_rel_sq: f64,
562 worst_penalty: usize,
563 max_cross_gram_abs: f64,
564}
565
566fn assess_subspace_leakage(
567 qs: &Mat<f64>,
568 rs_transformed: &[Mat<f64>],
569 structural_rank: usize,
570 p: usize,
571) -> SubspaceLeakageMetrics {
572 let mut max_abs_sq = 0.0_f64;
573 let mut max_rel_sq = 0.0_f64;
574 let mut worst_penalty = 0usize;
575
576 for (k, rs) in rs_transformed.iter().enumerate() {
577 let rows = rs.nrows();
578 let cols = rs.ncols().min(p);
579 let null_start = structural_rank.min(cols);
580 let mut abs_sq = 0.0_f64;
581 let mut total_sq = 0.0_f64;
582 for i in 0..rows {
583 for j in 0..cols {
584 let v = rs[(i, j)];
585 let vv = v * v;
586 total_sq += vv;
587 if j >= null_start {
588 abs_sq += vv;
589 }
590 }
591 }
592 let rel_sq = if total_sq > 0.0 {
593 abs_sq / total_sq
594 } else {
595 0.0
596 };
597 if rel_sq > max_rel_sq {
598 max_rel_sq = rel_sq;
599 worst_penalty = k;
600 }
601 max_abs_sq = max_abs_sq.max(abs_sq);
602 }
603
604 let mut max_cross_gram_abs = 0.0_f64;
605 let null_count = p.saturating_sub(structural_rank);
606 if structural_rank > 0 && null_count > 0 {
607 for i in 0..structural_rank {
608 for j in 0..null_count {
609 let qn_col = structural_rank + j;
610 let mut dot = 0.0_f64;
611 for r in 0..p {
612 dot += qs[(r, i)] * qs[(r, qn_col)];
613 }
614 max_cross_gram_abs = max_cross_gram_abs.max(dot.abs());
615 }
616 }
617 }
618
619 SubspaceLeakageMetrics {
620 max_abs_sq,
621 max_rel_sq,
622 worst_penalty,
623 max_cross_gram_abs,
624 }
625}
626
627fn subspace_split_is_consistent(leakage: &SubspaceLeakageMetrics, p: usize) -> bool {
654 let leakage_rel_tol = (p.max(1) as f64) * REL_PSD_FLOOR;
655 let leakage_abs_tol = 1e-12;
656 let orth_tol = 1e-10;
657 let root_leaks =
658 leakage.max_rel_sq > leakage_rel_tol && leakage.max_abs_sq > leakage_abs_tol;
659 let split_nonorthogonal = leakage.max_cross_gram_abs > orth_tol;
660 !(root_leaks || split_nonorthogonal)
661}
662
663fn compose_qs_from_split(q_pen: &Mat<f64>, q_null: &Mat<f64>, p: usize) -> Mat<f64> {
664 let rank = q_pen.ncols();
665 let null_count = q_null.ncols();
666 let mut qs = Mat::<f64>::zeros(p, p);
667 for i in 0..p {
668 for j in 0..rank {
669 qs[(i, j)] = q_pen[(i, j)];
670 }
671 for j in 0..null_count {
672 qs[(i, rank + j)] = q_null[(i, j)];
673 }
674 }
675 qs
676}
677
678pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
682 let (arows, a_cols) = a.dim();
683 let (brows, b_cols) = b.dim();
684 if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
685 return Array2::zeros((arows * brows, a_cols * b_cols));
686 }
687 let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
688
689 result
690 .axis_chunks_iter_mut(Axis(0), brows)
691 .into_par_iter()
692 .enumerate()
693 .for_each(|(i, mut row_block)| {
694 let arow = a.row(i);
695 let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
696 for (j, mut block) in col_chunks.into_iter().enumerate() {
697 let aval = arow[j];
698 if aval == 0.0 {
699 continue;
700 }
701 for (dest, &src) in block.iter_mut().zip(b.iter()) {
702 *dest = aval * src;
703 }
704 }
705 });
706
707 result
708}
709
710#[derive(Clone)]
712pub struct ReparamResult {
713 pub s_transformed: Array2<f64>,
717 pub log_det: f64,
719 pub det1: Array1<f64>,
721 pub qs: Array2<f64>,
723 pub canonical_transformed: Vec<CanonicalPenalty>,
728 pub e_transformed: Array2<f64>,
731 pub u_truncated: Array2<f64>,
741 pub penalty_shrinkage_ridge: f64,
744}
745
746struct KroneckerFactorDecomp {
752 root: Array2<f64>, positive_eigenvalues: Vec<f64>, rank: usize,
755 dim: usize,
756}
757
758fn decompose_kronecker_factors(
761 factors: &[Array2<f64>],
762 context: &str,
763) -> Result<Option<Vec<KroneckerFactorDecomp>>, EstimationError> {
764 let mut decomps = Vec::with_capacity(factors.len());
765 for (j, factor) in factors.iter().enumerate() {
766 let q_j = factor.nrows();
767 if q_j != factor.ncols() {
768 crate::bail_invalid_estim!(
769 "{context}: Kronecker factor {j} must be square, got {}x{}",
770 factor.nrows(),
771 factor.ncols()
772 );
773 }
774 let is_identity = {
775 let mut is_id = true;
776 'outer: for r in 0..q_j {
777 for c in 0..q_j {
778 let expected = if r == c { 1.0 } else { 0.0 };
779 if (factor[[r, c]] - expected).abs() > 1e-12 {
780 is_id = false;
781 break 'outer;
782 }
783 }
784 }
785 is_id
786 };
787 if is_identity {
788 decomps.push(KroneckerFactorDecomp {
789 root: Array2::eye(q_j),
790 positive_eigenvalues: vec![1.0; q_j],
791 rank: q_j,
792 dim: q_j,
793 });
794 continue;
795 }
796 let analysis = analyze_penalty_block(factor).map_err(|err| {
797 EstimationError::InvalidInput(format!(
798 "{context}: Kronecker factor {j} eigendecomp failed: {err}"
799 ))
800 })?;
801 if analysis.rank == 0 {
802 return Ok(None);
803 }
804 let factor_classes =
808 crate::basis::SpectralClassification::new(&analysis.eigenvalues, analysis.tol);
809 let mut root_j = Array2::zeros((analysis.rank, q_j));
810 let mut pos_eigs = Vec::with_capacity(analysis.rank);
811 for (row_idx, &i) in factor_classes.range_idx.iter().enumerate() {
812 let eigenval = analysis.eigenvalues[i];
813 let sqrt_ev = eigenval.sqrt();
814 let evec = analysis.eigenvectors.column(i);
815 for (col, &v) in evec.iter().enumerate() {
816 root_j[[row_idx, col]] = sqrt_ev * v;
817 }
818 pos_eigs.push(eigenval);
819 }
820 decomps.push(KroneckerFactorDecomp {
821 root: root_j,
822 positive_eigenvalues: pos_eigs,
823 rank: analysis.rank,
824 dim: q_j,
825 });
826 }
827 Ok(Some(decomps))
828}
829
830fn assemble_kronecker_root_local(decomps: &[KroneckerFactorDecomp]) -> Array2<f64> {
832 let mut kron_root = decomps[0].root.clone();
833 for fr in &decomps[1..] {
834 let (r1, c1) = kron_root.dim();
835 let (r2, c2) = (fr.rank, fr.dim);
836 let mut new_root = Array2::zeros((r1 * r2, c1 * c2));
837 for i1 in 0..r1 {
838 for i2 in 0..r2 {
839 for j1 in 0..c1 {
840 for j2 in 0..c2 {
841 new_root[[i1 * r2 + i2, j1 * c2 + j2]] =
842 kron_root[[i1, j1]] * fr.root[[i2, j2]];
843 }
844 }
845 }
846 }
847 kron_root = new_root;
848 }
849 kron_root
850}
851
852fn kronecker_eigenvalues(decomps: &[KroneckerFactorDecomp], block_dim: usize) -> (Vec<f64>, usize) {
854 let mut kron_eigs = decomps[0].positive_eigenvalues.clone();
855 for fd in &decomps[1..] {
856 let mut new_eigs = Vec::with_capacity(kron_eigs.len() * fd.positive_eigenvalues.len());
857 for &a in &kron_eigs {
858 for &b in &fd.positive_eigenvalues {
859 new_eigs.push(a * b);
860 }
861 }
862 kron_eigs = new_eigs;
863 }
864 let max_ev = kron_eigs.iter().copied().fold(0.0_f64, f64::max);
865 let tol = max_ev * 1e-10 * (block_dim as f64);
866 let positive: Vec<f64> = kron_eigs.into_iter().filter(|&ev| ev > tol).collect();
867 let nullity = block_dim - positive.len();
868 (positive, nullity)
869}
870
871#[derive(Clone)]
881pub struct CanonicalPenalty {
882 pub root: Array2<f64>,
885 pub col_range: std::ops::Range<usize>,
888 pub total_dim: usize,
890 pub nullity: usize,
892 pub local: Array2<f64>,
896 pub prior_mean: Array1<f64>,
898 pub positive_eigenvalues: Vec<f64>,
901 pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
905}
906
907impl std::fmt::Debug for CanonicalPenalty {
908 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
909 f.debug_struct("CanonicalPenalty")
910 .field(
911 "root",
912 &format_args!("{}×{}", self.root.nrows(), self.root.ncols()),
913 )
914 .field("col_range", &self.col_range)
915 .field("total_dim", &self.total_dim)
916 .field("nullity", &self.nullity)
917 .field(
918 "local",
919 &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
920 )
921 .field("prior_mean_len", &self.prior_mean.len())
922 .field("positive_eigenvalues", &self.positive_eigenvalues)
923 .field("op", &self.op.as_ref().map(|o| o.dim()))
924 .finish()
925 }
926}
927
928impl CanonicalPenalty {
929 pub fn from_dense_root(root: Array2<f64>, p: usize) -> Self {
933 Self::from_dense_root_with_mean(root, p, Array1::zeros(p))
934 }
935
936 pub fn from_dense_root_with_mean(root: Array2<f64>, p: usize, prior_mean: Array1<f64>) -> Self {
937 assert_eq!(prior_mean.len(), p);
938 let local = root.t().dot(&root);
939 let positive_eigenvalues = Vec::new(); Self {
941 root,
942 col_range: 0..p,
943 total_dim: p,
944 nullity: 0,
945 local,
946 prior_mean,
947 positive_eigenvalues,
948 op: None,
949 }
950 }
951
952 pub fn full_width_root(&self) -> Array2<f64> {
955 if self.col_range.start == 0 && self.col_range.end == self.total_dim {
956 return self.root.clone();
957 }
958 let rank = self.root.nrows();
959 let mut full = Array2::<f64>::zeros((rank, self.total_dim));
960 full.slice_mut(ndarray::s![.., self.col_range.clone()])
961 .assign(&self.root);
962 full
963 }
964
965 pub fn rank(&self) -> usize {
967 self.root.nrows()
968 }
969
970 pub fn block_dim(&self) -> usize {
972 self.col_range.len()
973 }
974
975 pub const fn is_block_local(&self) -> bool {
977 self.col_range.start != 0 || self.col_range.end != self.total_dim
978 }
979
980 pub fn local_ref(&self) -> &Array2<f64> {
983 &self.local
984 }
985
986 pub fn local_penalty(&self) -> Array2<f64> {
989 self.local.clone()
990 }
991
992 pub fn accumulate_weighted(&self, target: &mut Array2<f64>, lambda: f64) {
995 if lambda == 0.0 || self.rank() == 0 {
996 return;
997 }
998 let r = &self.col_range;
999 target
1000 .slice_mut(s![r.start..r.end, r.start..r.end])
1001 .scaled_add(lambda, &self.local);
1002 }
1003
1004 pub fn trace_product(&self, m: &Array2<f64>, scale: f64) -> f64 {
1007 if self.rank() == 0 || scale == 0.0 {
1008 return 0.0;
1009 }
1010 let r = &self.col_range;
1011 let m_block = m.slice(s![r.start..r.end, r.start..r.end]);
1012 let rm = self.root.dot(&m_block);
1013 scale
1014 * rm.iter()
1015 .zip(self.root.iter())
1016 .map(|(&a, &b)| a * b)
1017 .sum::<f64>()
1018 }
1019
1020 pub fn quadratic(&self, v: &Array1<f64>, scale: f64) -> f64 {
1023 if self.rank() == 0 || scale == 0.0 {
1024 return 0.0;
1025 }
1026 let v_block = v.slice(s![self.col_range.start..self.col_range.end]);
1027 let rv = self.root.dot(&v_block);
1028 scale * rv.dot(&rv)
1029 }
1030
1031 pub fn prior_linear_shift(&self, scale: f64) -> Array1<f64> {
1033 let mut out = Array1::<f64>::zeros(self.total_dim);
1034 if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
1035 return out;
1036 }
1037 let block = self.local.dot(&self.prior_mean) * scale;
1038 out.slice_mut(s![self.col_range.start..self.col_range.end])
1039 .assign(&block);
1040 out
1041 }
1042
1043 pub fn prior_constant_shift(&self, scale: f64) -> f64 {
1045 if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
1046 return 0.0;
1047 }
1048 scale * self.prior_mean.dot(&self.local.dot(&self.prior_mean))
1049 }
1050
1051 pub fn full_width_prior_mean(&self) -> Array1<f64> {
1053 if self.col_range.start == 0 && self.col_range.end == self.total_dim {
1054 return self.prior_mean.clone();
1055 }
1056 let mut out = Array1::<f64>::zeros(self.total_dim);
1057 out.slice_mut(s![self.col_range.start..self.col_range.end])
1058 .assign(&self.prior_mean);
1059 out
1060 }
1061
1062 pub fn to_penalty_coordinate(&self) -> gam_problem::PenaltyCoordinate {
1064 use gam_problem::PenaltyCoordinate;
1065 if self.is_block_local() {
1066 PenaltyCoordinate::from_block_root_with_mean(
1067 self.root.clone(),
1068 self.col_range.start,
1069 self.col_range.end,
1070 self.total_dim,
1071 self.prior_mean.clone(),
1072 )
1073 } else {
1074 PenaltyCoordinate::from_dense_root_with_mean(self.root.clone(), self.prior_mean.clone())
1075 }
1076 }
1077}
1078
1079pub fn report_penalty_pair_redundancy(canonical: &[CanonicalPenalty]) -> Vec<(usize, usize, f64)> {
1106 const REDUNDANCY_THRESHOLD: f64 = 1.0 - 1e-8;
1107 const SIMILARITY_THRESHOLD: f64 = 0.99;
1108 const LARGE_SCALE_K_THRESHOLD: usize = 64;
1109 const TOP_SIMILARITY_PAIRS: usize = 3;
1110
1111 let k = canonical.len();
1112 let mut redundant: Vec<(usize, usize, f64)> = Vec::new();
1113 let mut similar: Vec<(usize, usize, f64)> = Vec::new();
1114
1115 let trace_sq: Vec<f64> = canonical
1118 .iter()
1119 .map(|p| p.local.iter().map(|&v| v * v).sum::<f64>())
1120 .collect();
1121
1122 for i in 0..k {
1123 if trace_sq[i] == 0.0 {
1124 continue;
1125 }
1126 for j in (i + 1)..k {
1127 if trace_sq[j] == 0.0 {
1128 continue;
1129 }
1130 if canonical[i].col_range != canonical[j].col_range {
1134 continue;
1135 }
1136 assert_eq!(canonical[i].local.dim(), canonical[j].local.dim());
1139
1140 let inner: f64 = canonical[i]
1141 .local
1142 .iter()
1143 .zip(canonical[j].local.iter())
1144 .map(|(&a, &b)| a * b)
1145 .sum();
1146 let denom = (trace_sq[i] * trace_sq[j]).sqrt();
1147 if denom == 0.0 {
1148 continue;
1149 }
1150 let cos = inner / denom;
1151
1152 if cos > REDUNDANCY_THRESHOLD {
1153 redundant.push((i, j, cos));
1154 } else if cos > SIMILARITY_THRESHOLD {
1155 similar.push((i, j, cos));
1156 }
1157 }
1158 }
1159
1160 for &(i, j, cos) in &redundant {
1162 log::warn!(
1163 "[PENALTY-REDUNDANCY] penalties i={i} j={j} are structurally identical \
1164 (cos={cos:.6}) — model is over-parameterized along their antisymmetric \
1165 direction; expect a Z₂-symmetric saddle in the LAML cost. Consider \
1166 re-specifying (e.g. anisotropic→isotropic for spatial smoothers with \
1167 weak axis signal)."
1168 );
1169 }
1170
1171 if k > LARGE_SCALE_K_THRESHOLD && similar.len() > TOP_SIMILARITY_PAIRS {
1173 similar.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1174 similar.truncate(TOP_SIMILARITY_PAIRS);
1175 }
1176 for (i, j, cos) in similar {
1177 log::info!(
1178 "[PENALTY-SIMILARITY] penalties i={i} j={j} are near-identical \
1179 (cos={cos:.6}) — outer Hessian may be ill-conditioned along their \
1180 antisymmetric direction."
1181 );
1182 }
1183
1184 redundant
1185}
1186
1187pub fn canonicalize_penalty_spec(
1193 spec: &crate::PenaltySpec,
1194 p: usize,
1195 idx: usize,
1196 context: &str,
1197) -> Result<Option<CanonicalPenalty>, EstimationError> {
1198 use crate::PenaltySpec;
1199
1200 crate::validate_penalty_spec_shape(idx, spec, p, context)?;
1201
1202 let (local_matrix, col_range, prior_mean_spec, hint, op) = match spec {
1203 PenaltySpec::Block {
1204 local,
1205 col_range,
1206 prior_mean,
1207 structure_hint,
1208 op,
1209 } => (
1210 local.view(),
1211 col_range.clone(),
1212 prior_mean,
1213 structure_hint.as_ref(),
1214 op.clone(),
1215 ),
1216 PenaltySpec::Dense(m) => (
1217 m.view(),
1218 0..p,
1219 &gam_problem::CoefficientPriorMean::Zero,
1220 None,
1221 None,
1222 ),
1223 PenaltySpec::DenseWithMean { matrix, prior_mean } => {
1224 (matrix.view(), 0..p, prior_mean, None, None)
1225 }
1226 };
1227
1228 let block_dim = col_range.len();
1229 let prior_mean = prior_mean_spec
1230 .evaluate(block_dim, &format!("{context}: penalty {idx}"))
1231 .map_err(|e| EstimationError::InvalidInput(e.0))?;
1232
1233 if let Some(PenaltyStructureHint::Ridge(scale)) = hint {
1235 if *scale <= 0.0 {
1236 return Ok(None);
1237 }
1238 let sqrt_scale = scale.sqrt();
1239 let mut root = Array2::zeros((block_dim, block_dim));
1240 for i in 0..block_dim {
1241 root[[i, i]] = sqrt_scale;
1242 }
1243 let mut local_sym = local_matrix.to_owned();
1247 symmetrize_in_place(&mut local_sym);
1248 return Ok(Some(CanonicalPenalty {
1249 root,
1250 col_range,
1251 total_dim: p,
1252 nullity: 0,
1253 local: local_sym,
1254 prior_mean,
1255 positive_eigenvalues: vec![*scale; block_dim],
1256 op,
1257 }));
1258 }
1259
1260 if let Some(PenaltyStructureHint::Kronecker(factors)) = hint {
1262 let decomps =
1263 match decompose_kronecker_factors(factors, &format!("{context} penalty {idx}"))? {
1264 None => return Ok(None),
1265 Some(d) => d,
1266 };
1267 let (positive_eigenvalues, nullity) = kronecker_eigenvalues(&decomps, block_dim);
1268 if positive_eigenvalues.is_empty() {
1269 return Ok(None);
1270 }
1271 let root = assemble_kronecker_root_local(&decomps);
1272 let mut local_sym = local_matrix.to_owned();
1273 symmetrize_in_place(&mut local_sym);
1274 return Ok(Some(CanonicalPenalty {
1275 root,
1276 col_range,
1277 total_dim: p,
1278 nullity,
1279 local: local_sym,
1280 prior_mean,
1281 positive_eigenvalues,
1282 op,
1283 }));
1284 }
1285
1286 let local_owned = local_matrix.to_owned();
1288 let analysis = analyze_penalty_block(&local_owned).map_err(|err| {
1289 EstimationError::InvalidInput(format!(
1290 "{context}: penalty canonicalization failed at index {idx}: {err}"
1291 ))
1292 })?;
1293
1294 if analysis.rank == 0 {
1295 log::debug!(
1296 "Dropped inactive penalty block idx={idx} reason={}",
1297 if analysis.iszero {
1298 "ZeroMatrix"
1299 } else {
1300 "NumericalRankZero"
1301 }
1302 );
1303 return Ok(None);
1304 }
1305
1306 let tolerance = analysis.tol;
1312 let classes = crate::basis::SpectralClassification::new(&analysis.eigenvalues, tolerance);
1313 let rank_k = classes.rank();
1314 assert_eq!(
1315 rank_k, analysis.rank,
1316 "penalty-root rank disagreement: SpectralClassification rank={rank_k} vs analyze_penalty_block rank={} (#1425 canonical-classifier invariant)",
1317 analysis.rank
1318 );
1319
1320 let mut root = Array2::zeros((rank_k, block_dim));
1328 let mut positive_eigenvalues = Vec::with_capacity(rank_k);
1329 for (row_idx, &i) in classes.range_idx.iter().enumerate() {
1330 let eigenval = analysis.eigenvalues[i];
1331 let eigenvec = analysis.eigenvectors.column(i);
1332 root.row_mut(row_idx).assign(&(&eigenvec * eigenval.sqrt()));
1333 positive_eigenvalues.push(eigenval);
1334 }
1335
1336 if classes.is_indefinite() {
1342 log::debug!(
1343 "{context}: penalty block idx={idx} carries {} negative-curvature \
1344 eigendirection(s) below -tol={tolerance:e}; dropped from the canonical \
1345 root and NOT counted as null space (rank={rank_k}, nullity={})",
1346 classes.negative_dim(),
1347 classes.nullity()
1348 );
1349 }
1350
1351 let local = root.t().dot(&root);
1355 Ok(Some(CanonicalPenalty {
1356 root,
1357 col_range,
1358 total_dim: p,
1359 nullity: classes.nullity(),
1360 local,
1361 prior_mean,
1362 positive_eigenvalues,
1363 op,
1364 }))
1365}
1366
1367pub fn canonicalize_penalty_specs(
1370 specs: &[crate::PenaltySpec],
1371 nullspace_dims: &[usize],
1372 p: usize,
1373 context: &str,
1374) -> Result<(Vec<CanonicalPenalty>, Vec<usize>), EstimationError> {
1375 if specs.len() != nullspace_dims.len() {
1376 crate::bail_invalid_estim!(
1377 "{context}: nullspace_dims length mismatch: penalties={}, nullspace_dims={}",
1378 specs.len(),
1379 nullspace_dims.len()
1380 );
1381 }
1382
1383 let mut active = Vec::with_capacity(specs.len());
1384 let mut active_nullspace = Vec::with_capacity(specs.len());
1385 for (idx, spec) in specs.iter().enumerate() {
1386 if let Some(canonical) = canonicalize_penalty_spec(spec, p, idx, context)? {
1387 active_nullspace.push(nullspace_dims[idx]);
1388 active.push(canonical);
1389 }
1390 }
1391 Ok((active, active_nullspace))
1392}
1393
1394pub(crate) const OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P: usize = 4096;
1404
1405pub fn create_balanced_penalty_root_from_canonical(
1412 penalties: &[CanonicalPenalty],
1413 p: usize,
1414) -> Result<Array2<f64>, EstimationError> {
1415 if penalties.is_empty() {
1416 return Ok(Array2::zeros((0, p)));
1417 }
1418
1419 let mut block_groups: BTreeMap<(usize, usize), Vec<&CanonicalPenalty>> = BTreeMap::new();
1421 for cp in penalties {
1422 if cp.rank() == 0 {
1423 continue;
1424 }
1425 let key = (cp.col_range.start, cp.col_range.end);
1426 block_groups.entry(key).or_default().push(cp);
1427 }
1428
1429 if block_groups.is_empty() {
1430 return Ok(Array2::zeros((0, p)));
1431 }
1432
1433 let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1435 let mut overlapping = false;
1436 for i in 1..ranges.len() {
1437 if ranges[i].0 < ranges[i - 1].1 {
1438 overlapping = true;
1439 break;
1440 }
1441 }
1442
1443 if overlapping {
1444 if p > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1445 return Err(EstimationError::LayoutError(format!(
1446 "overlapping penalty root would require dense {}x{} eigendecomposition; \
1447 large-model dense fallback is disabled. Keep penalties structured or \
1448 extend the overlapping-penalty solver path",
1449 p, p
1450 )));
1451 }
1452 let mut s_balanced = Array2::zeros((p, p));
1454 for cp in penalties {
1455 if cp.rank() == 0 {
1456 continue;
1457 }
1458 let local = cp.local_ref();
1459 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1460 if frob_norm > 1e-12 {
1461 let r = &cp.col_range;
1462 s_balanced
1463 .slice_mut(s![r.start..r.end, r.start..r.end])
1464 .scaled_add(1.0 / frob_norm, local);
1465 }
1466 }
1467 let (eigenvalues, eigenvectors) =
1468 robust_eigh(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1469 let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1470 let tolerance = if max_eig > 0.0 {
1471 max_eig * 1e-12
1472 } else {
1473 1e-12
1474 };
1475 let penalty_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1476 if penalty_rank == 0 {
1477 return Ok(Array2::zeros((0, p)));
1478 }
1479 let mut eb = Array2::zeros((p, penalty_rank));
1480 let mut col_idx = 0;
1481 for (i, &eigenval) in eigenvalues.iter().enumerate() {
1482 if eigenval > tolerance {
1483 let sqrt_ev = eigenval.sqrt();
1484 let evec = eigenvectors.column(i);
1485 eb.column_mut(col_idx).assign(&(&evec * sqrt_ev));
1486 col_idx += 1;
1487 }
1488 }
1489 return Ok(eb.t().to_owned());
1490 }
1491
1492 struct BlockRoot {
1494 col_range: Range<usize>,
1495 root: Array2<f64>, }
1497 let ordered_blocks: Vec<((usize, usize), Vec<&CanonicalPenalty>)> =
1502 block_groups.into_iter().collect();
1503 let block_roots: Vec<BlockRoot> = ordered_blocks
1504 .into_par_iter()
1505 .map(
1506 |((start, end), cps)| -> Result<Option<BlockRoot>, EstimationError> {
1507 let block_dim = end - start;
1508 let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1509
1510 for cp in cps {
1511 let local = cp.local_ref();
1512 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1513 if frob_norm > 1e-12 {
1514 s_balanced_local.scaled_add(1.0 / frob_norm, local);
1515 }
1516 }
1517
1518 let (eigenvalues, eigenvectors) =
1519 robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1520 let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1521 let tolerance = if max_eig > 0.0 {
1522 max_eig * 1e-12
1523 } else {
1524 1e-12
1525 };
1526 let block_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1527
1528 if block_rank == 0 {
1529 return Ok(None);
1530 }
1531
1532 let mut root = Array2::zeros((block_rank, block_dim));
1533 let mut row_idx = 0;
1534 for (i, &eigenval) in eigenvalues.iter().enumerate() {
1535 if eigenval > tolerance {
1536 let sqrt_ev = eigenval.sqrt();
1537 let evec = eigenvectors.column(i);
1538 root.row_mut(row_idx).assign(&(&evec * sqrt_ev));
1539 row_idx += 1;
1540 }
1541 }
1542
1543 Ok(Some(BlockRoot {
1544 col_range: start..end,
1545 root,
1546 }))
1547 },
1548 )
1549 .collect::<Result<Vec<_>, _>>()?
1550 .into_iter()
1551 .flatten()
1552 .collect();
1553 let total_rank: usize = block_roots.iter().map(|br| br.root.nrows()).sum();
1554
1555 if total_rank == 0 {
1556 return Ok(Array2::zeros((0, p)));
1557 }
1558
1559 let mut eb = Array2::zeros((total_rank, p));
1561 let mut row_offset = 0;
1562 for br in &block_roots {
1563 let rank_b = br.root.nrows();
1564 eb.slice_mut(s![
1565 row_offset..(row_offset + rank_b),
1566 br.col_range.start..br.col_range.end
1567 ])
1568 .assign(&br.root);
1569 row_offset += rank_b;
1570 }
1571
1572 Ok(eb)
1573}
1574
1575#[derive(Clone)]
1577struct SubspaceSplit {
1578 q_pen: Array2<f64>,
1579 q_null: Array2<f64>,
1580}
1581
1582impl SubspaceSplit {
1583 fn identity(p: usize) -> Self {
1584 Self {
1585 q_pen: Array2::zeros((p, 0)),
1586 q_null: Array2::eye(p),
1587 }
1588 }
1589
1590 fn from_ordered_qs(
1591 qs: &Mat<f64>,
1592 penalized_rank: usize,
1593 p: usize,
1594 ) -> Result<Self, EstimationError> {
1595 if qs.nrows() != p || qs.ncols() != p {
1596 return Err(EstimationError::LayoutError(format!(
1597 "Invalid Q basis dimensions: expected {p}x{p}, got {}x{}",
1598 qs.nrows(),
1599 qs.ncols()
1600 )));
1601 }
1602 if penalized_rank > p {
1603 return Err(EstimationError::LayoutError(format!(
1604 "Invalid penalized rank {penalized_rank} for p={p}"
1605 )));
1606 }
1607
1608 let null_count = p - penalized_rank;
1609 let mut q_pen = Array2::<f64>::zeros((p, penalized_rank));
1610 let mut q_null = Array2::<f64>::zeros((p, null_count));
1611 for i in 0..p {
1612 for j in 0..penalized_rank {
1613 q_pen[(i, j)] = qs[(i, j)];
1614 }
1615 for j in 0..null_count {
1616 q_null[(i, j)] = qs[(i, penalized_rank + j)];
1617 }
1618 }
1619
1620 Ok(Self { q_pen, q_null })
1621 }
1622
1623 fn rank(&self) -> usize {
1624 self.q_pen.ncols()
1625 }
1626
1627 fn p(&self) -> usize {
1628 self.q_pen.nrows()
1629 }
1630
1631 fn compose_qs(&self) -> Array2<f64> {
1632 let p = self.p();
1633 let rank = self.rank();
1634 let null_count = self.q_null.ncols();
1635 let mut qs = Array2::<f64>::zeros((p, p));
1636 for i in 0..p {
1637 for j in 0..rank {
1638 qs[(i, j)] = self.q_pen[(i, j)];
1639 }
1640 for j in 0..null_count {
1641 qs[(i, rank + j)] = self.q_null[(i, j)];
1642 }
1643 }
1644 qs
1645 }
1646}
1647
1648#[derive(Clone)]
1650pub struct ReparamInvariant {
1651 split: SubspaceSplit,
1652 qs_base: Array2<f64>,
1656 has_nonzero: bool,
1657 max_balanced_eigenvalue: f64,
1660}
1661
1662impl ReparamInvariant {
1663 pub const fn max_balanced_eigenvalue(&self) -> f64 {
1666 self.max_balanced_eigenvalue
1667 }
1668}
1669
1670pub fn precompute_reparam_invariant_from_canonical(
1677 penalties: &[CanonicalPenalty],
1678 p_total: usize,
1679) -> Result<ReparamInvariant, EstimationError> {
1680 use std::cmp::Ordering;
1681
1682 let m = penalties.len();
1683
1684 if m == 0 {
1685 return Ok(ReparamInvariant {
1686 split: SubspaceSplit::identity(p_total),
1687 qs_base: Array2::eye(p_total),
1688 has_nonzero: false,
1689 max_balanced_eigenvalue: 0.0,
1690 });
1691 }
1692
1693 struct PenRef {
1695 penalty_index: usize,
1696 }
1697 let mut block_groups: BTreeMap<(usize, usize), Vec<PenRef>> = BTreeMap::new();
1698 let mut has_nonzero = false;
1699 for (i, cp) in penalties.iter().enumerate() {
1700 if cp.rank() == 0 {
1701 continue;
1702 }
1703 let local = cp.local_ref();
1704 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1705 if frob_norm > 1e-12 {
1706 has_nonzero = true;
1707 }
1708 let key = (cp.col_range.start, cp.col_range.end);
1709 block_groups
1710 .entry(key)
1711 .or_default()
1712 .push(PenRef { penalty_index: i });
1713 }
1714
1715 if !has_nonzero {
1716 return Ok(ReparamInvariant {
1717 split: SubspaceSplit::identity(p_total),
1718 qs_base: Array2::eye(p_total),
1719 has_nonzero: false,
1720 max_balanced_eigenvalue: 0.0,
1721 });
1722 }
1723
1724 let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1726 let mut overlapping = false;
1727 for i in 1..ranges.len() {
1728 if ranges[i].0 < ranges[i - 1].1 {
1729 overlapping = true;
1730 break;
1731 }
1732 }
1733
1734 if overlapping {
1735 if p_total > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1741 return Err(EstimationError::LayoutError(format!(
1742 "overlapping penalty reparameterization would require dense {}x{} eigendecomposition; \
1743 large-model dense fallback is disabled. Keep penalties structured or \
1744 extend the overlapping-penalty solver path",
1745 p_total, p_total
1746 )));
1747 }
1748 let mut s_balanced = Mat::<f64>::zeros(p_total, p_total);
1750 for cp in penalties {
1751 if cp.rank() == 0 {
1752 continue;
1753 }
1754 let local = cp.local_ref();
1755 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1756 if frob_norm > 1e-12 {
1757 let scale = 1.0 / frob_norm;
1758 let r = &cp.col_range;
1759 for i in 0..local.nrows() {
1760 for j in 0..local.ncols() {
1761 s_balanced[(r.start + i, r.start + j)] += scale * local[[i, j]];
1762 }
1763 }
1764 }
1765 }
1766
1767 let (bal_eigenvalues, bal_eigenvectors) =
1768 robust_eigh_faer(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1769
1770 let mut order: Vec<usize> = (0..p_total).collect();
1771 order.sort_by(|&i, &j| {
1772 bal_eigenvalues[j]
1773 .partial_cmp(&bal_eigenvalues[i])
1774 .unwrap_or(Ordering::Equal)
1775 .then(i.cmp(&j))
1776 });
1777
1778 let mut qs = Mat::<f64>::zeros(p_total, p_total);
1779 for (col_idx, &idx) in order.iter().enumerate() {
1780 for row in 0..p_total {
1781 qs[(row, col_idx)] = bal_eigenvectors[(row, idx)];
1782 }
1783 }
1784
1785 let max_bal = order
1786 .iter()
1787 .map(|&idx| bal_eigenvalues[idx].abs())
1788 .fold(0.0_f64, f64::max);
1789 let rank_tol = if max_bal > 0.0 {
1790 max_bal * 1e-12
1791 } else {
1792 1e-12
1793 };
1794 let penalized_rank = order
1795 .iter()
1796 .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1797 .count();
1798 let split = SubspaceSplit::from_ordered_qs(&qs, penalized_rank, p_total)?;
1799
1800 return Ok(ReparamInvariant {
1801 split,
1802 qs_base: mat_to_array(&qs),
1803 has_nonzero,
1804 max_balanced_eigenvalue: max_bal,
1805 });
1806 }
1807
1808 let mut covered = vec![false; p_total];
1816 for cp in penalties {
1817 for j in cp.col_range.clone() {
1818 covered[j] = true;
1819 }
1820 }
1821 let uncovered_cols: Vec<usize> = (0..p_total).filter(|j| !covered[*j]).collect();
1822
1823 struct BlockResult {
1824 col_range: Range<usize>,
1825 q_pen_local: Array2<f64>, q_null_local: Array2<f64>, max_balanced_eigenvalue: f64,
1829 pen_col_offset: usize,
1831 null_col_offset: usize,
1833 }
1834
1835 let block_specs: Vec<_> = block_groups.iter().collect();
1839 let mut block_results: Vec<BlockResult> = block_specs
1840 .into_par_iter()
1841 .map(
1842 |(&(start, end), refs)| -> Result<BlockResult, EstimationError> {
1843 let block_dim = end - start;
1844
1845 let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1847 let mut block_has_nonzero = false;
1848 for pref in refs {
1849 let cp = &penalties[pref.penalty_index];
1850 let local = cp.local_ref();
1851 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1852 if frob_norm > 1e-12 {
1853 s_balanced_local.scaled_add(1.0 / frob_norm, local);
1854 block_has_nonzero = true;
1855 }
1856 }
1857
1858 if !block_has_nonzero {
1859 return Ok(BlockResult {
1860 col_range: start..end,
1861 q_pen_local: Array2::zeros((block_dim, 0)),
1862 q_null_local: Array2::eye(block_dim),
1863 max_balanced_eigenvalue: 0.0,
1864 pen_col_offset: 0, null_col_offset: 0, });
1867 }
1868
1869 let (bal_eigenvalues, bal_eigenvectors) =
1871 robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1872
1873 let mut order: Vec<usize> = (0..block_dim).collect();
1874 order.sort_by(|&i, &j| {
1875 bal_eigenvalues[j]
1876 .partial_cmp(&bal_eigenvalues[i])
1877 .unwrap_or(Ordering::Equal)
1878 .then(i.cmp(&j))
1879 });
1880
1881 let max_bal = order
1882 .iter()
1883 .map(|&idx| bal_eigenvalues[idx].abs())
1884 .fold(0.0_f64, f64::max);
1885 let rank_tol = if max_bal > 0.0 {
1886 max_bal * 1e-12
1887 } else {
1888 1e-12
1889 };
1890 let penalized_rank = order
1891 .iter()
1892 .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1893 .count();
1894 let null_count = block_dim - penalized_rank;
1895
1896 let mut q_pen_local = Array2::zeros((block_dim, penalized_rank));
1897 let mut q_null_local = Array2::zeros((block_dim, null_count));
1898 for (col_idx, &idx) in order.iter().enumerate() {
1899 if col_idx < penalized_rank {
1900 for row in 0..block_dim {
1901 q_pen_local[[row, col_idx]] = bal_eigenvectors[[row, idx]];
1902 }
1903 } else {
1904 let null_col = col_idx - penalized_rank;
1905 for row in 0..block_dim {
1906 q_null_local[[row, null_col]] = bal_eigenvectors[[row, idx]];
1907 }
1908 }
1909 }
1910
1911 Ok(BlockResult {
1912 col_range: start..end,
1913 q_pen_local,
1914 q_null_local,
1915 max_balanced_eigenvalue: max_bal,
1916 pen_col_offset: 0, null_col_offset: 0, })
1919 },
1920 )
1921 .collect::<Result<_, _>>()?;
1922 let global_max_bal = block_results
1923 .iter()
1924 .map(|br| br.max_balanced_eigenvalue)
1925 .fold(0.0_f64, f64::max);
1926
1927 let total_pen_rank: usize = block_results.iter().map(|br| br.q_pen_local.ncols()).sum();
1929 let total_null: usize = block_results
1930 .iter()
1931 .map(|br| br.q_null_local.ncols())
1932 .sum::<usize>()
1933 + uncovered_cols.len();
1934 {
1935 let mut pen_off = 0usize;
1936 let mut null_off = 0usize;
1937 for br in &mut block_results {
1938 br.pen_col_offset = pen_off;
1939 br.null_col_offset = null_off;
1940 pen_off += br.q_pen_local.ncols();
1941 null_off += br.q_null_local.ncols();
1942 }
1943 }
1944
1945 let mut q_pen = Array2::zeros((p_total, total_pen_rank));
1946 let mut q_null = Array2::zeros((p_total, total_null));
1947
1948 for br in &block_results {
1949 let start = br.col_range.start;
1950 let bd = br.q_pen_local.nrows();
1951 let pen_r = br.q_pen_local.ncols();
1952 let null_r = br.q_null_local.ncols();
1953 if pen_r > 0 {
1954 q_pen
1955 .slice_mut(s![
1956 start..(start + bd),
1957 br.pen_col_offset..(br.pen_col_offset + pen_r)
1958 ])
1959 .assign(&br.q_pen_local);
1960 }
1961 if null_r > 0 {
1962 q_null
1963 .slice_mut(s![
1964 start..(start + bd),
1965 br.null_col_offset..(br.null_col_offset + null_r)
1966 ])
1967 .assign(&br.q_null_local);
1968 }
1969 }
1970 let mut null_col = block_results
1971 .iter()
1972 .map(|br| br.q_null_local.ncols())
1973 .sum::<usize>();
1974 for &j in &uncovered_cols {
1975 q_null[[j, null_col]] = 1.0;
1976 null_col += 1;
1977 }
1978
1979 let split = SubspaceSplit { q_pen, q_null };
1980
1981 let qs_global = split.compose_qs();
1985
1986 Ok(ReparamInvariant {
1987 split,
1988 qs_base: qs_global,
1989 has_nonzero,
1990 max_balanced_eigenvalue: global_max_bal,
1991 })
1992}
1993
1994fn structurally_penalized_columns(penalties: &[CanonicalPenalty], p: usize) -> Vec<bool> {
1995 let mut active = vec![false; p];
1996 for cp in penalties {
1997 let local = cp.local_ref();
1998 let scale = local.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
1999 if scale <= 0.0 {
2000 continue;
2001 }
2002 let tol = scale * 1e-12;
2003 for local_col in 0..cp.block_dim() {
2004 let mut column_active = false;
2005 for row in 0..cp.block_dim() {
2006 if local[[row, local_col]].abs() > tol || local[[local_col, row]].abs() > tol {
2007 column_active = true;
2008 break;
2009 }
2010 }
2011 if column_active {
2012 active[cp.col_range.start + local_col] = true;
2013 }
2014 }
2015 }
2016 active
2017}
2018
2019pub fn stable_reparameterizationwith_invariant(
2029 penalties: &[CanonicalPenalty],
2030 lambdas: &[f64],
2031 p: usize,
2032 invariant: &ReparamInvariant,
2033 penalty_shrinkage_floor: Option<f64>,
2034) -> Result<ReparamResult, EstimationError> {
2035 let m = penalties.len();
2036
2037 if lambdas.len() != m {
2038 return Err(EstimationError::ParameterConstraintViolation(format!(
2039 "Lambda count mismatch: expected {} lambdas for {} penalties, got {}",
2040 m,
2041 m,
2042 lambdas.len()
2043 )));
2044 }
2045
2046 if m == 0 {
2058 return Ok(ReparamResult {
2059 s_transformed: Array2::zeros((p, p)),
2060 log_det: 0.0,
2061 det1: Array1::zeros(0),
2062 qs: Array2::eye(p),
2063 canonical_transformed: vec![],
2064 e_transformed: Array2::zeros((0, p)),
2065 u_truncated: Array2::eye(p),
2067 penalty_shrinkage_ridge: 0.0,
2068 });
2069 }
2070
2071 if !invariant.has_nonzero {
2072 let qs = invariant.split.compose_qs();
2073 let u_truncated = qs.t().dot(&invariant.split.q_null);
2074 let canonical_transformed: Vec<CanonicalPenalty> = penalties.to_vec();
2076 return Ok(ReparamResult {
2077 s_transformed: Array2::zeros((p, p)),
2078 log_det: 0.0,
2079 det1: Array1::zeros(m),
2080 qs,
2081 canonical_transformed,
2082 e_transformed: Array2::zeros((0, p)),
2083 u_truncated,
2084 penalty_shrinkage_ridge: 0.0,
2085 });
2086 }
2087
2088 let q_pen = array_to_faer(&invariant.split.q_pen);
2089 let q_null = array_to_faer(&invariant.split.q_null);
2090 let qs_base = array_to_faer(&invariant.qs_base);
2091 let penalty_transforms: Vec<(Mat<f64>, Mat<f64>)> = penalties
2096 .par_iter()
2097 .map(|cp| {
2098 let r = &cp.col_range;
2099 let root_faer = array_to_faer(&cp.root);
2100 let q_block = qs_base.submatrix(r.start, 0, cp.block_dim(), p);
2101 let mut product = Mat::<f64>::zeros(cp.rank(), p);
2102 matmul(
2103 product.as_mut(),
2104 Accum::Replace,
2105 root_faer.as_ref(),
2106 q_block,
2107 1.0,
2108 Par::Seq,
2109 );
2110 let s_k = penalty_from_root_faer(&product);
2111 (product, s_k)
2112 })
2113 .collect();
2114 let (rs_transformed, s_k_penalized_cache): (Vec<Mat<f64>>, Vec<Mat<f64>>) =
2115 penalty_transforms.into_iter().unzip();
2116
2117 let penalized_rank = invariant.split.rank();
2118
2119 let mut range_eigenvalues_sorted: Vec<f64> = Vec::new();
2120 let mut range_rotation = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2121 if penalized_rank > 0 {
2122 let mut range_block = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2123 for (lambda, s_k) in lambdas.iter().zip(s_k_penalized_cache.iter()) {
2127 for i in 0..penalized_rank {
2128 for j in 0..penalized_rank {
2129 range_block[(i, j)] += *lambda * s_k[(i, j)];
2130 }
2131 }
2132 }
2133 let (range_eigenvalues, range_eigenvectors) =
2134 robust_eigh_faer(&range_block, Side::Lower, "range penalty block")?;
2135
2136 let mut range_order: Vec<usize> = (0..penalized_rank).collect();
2137 range_order.sort_by(|&i, &j| {
2138 range_eigenvalues[j]
2139 .partial_cmp(&range_eigenvalues[i])
2140 .unwrap_or(std::cmp::Ordering::Equal)
2141 .then(i.cmp(&j))
2142 });
2143 range_eigenvalues_sorted = range_order
2144 .iter()
2145 .map(|&idx| range_eigenvalues[idx])
2146 .collect();
2147
2148 for (col_idx, &idx) in range_order.iter().enumerate() {
2155 for row in 0..penalized_rank {
2156 range_rotation[(row, col_idx)] = range_eigenvectors[(row, idx)];
2157 }
2158 }
2159 }
2163
2164 let structural_rank = penalized_rank;
2169 let mut range_eigs_sorted: Vec<f64> = range_eigenvalues_sorted;
2170 let structurally_penalized_cols = structurally_penalized_columns(penalties, p);
2171
2172 let shrinkage_ridge = penalty_shrinkage_floor
2189 .filter(|&eps| eps > 0.0)
2190 .map(|eps| eps * invariant.max_balanced_eigenvalue)
2191 .unwrap_or(0.0);
2192 if shrinkage_ridge > 0.0 {
2193 let min_eig_before = range_eigs_sorted
2194 .iter()
2195 .copied()
2196 .fold(f64::INFINITY, f64::min);
2197 let mut shrinkage_floor_applied = 0usize;
2198 for eig_idx in 0..range_eigs_sorted.len() {
2199 let mut penalized_energy = 0.0;
2200 for original_col in 0..p {
2201 if structurally_penalized_cols[original_col] {
2202 let mut coordinate = 0.0;
2203 for pen_col in 0..penalized_rank {
2204 coordinate +=
2205 q_pen[(original_col, pen_col)] * range_rotation[(pen_col, eig_idx)];
2206 }
2207 penalized_energy += coordinate * coordinate;
2208 }
2209 }
2210 if penalized_energy > 1e-8 {
2211 range_eigs_sorted[eig_idx] += shrinkage_ridge;
2212 shrinkage_floor_applied += 1;
2213 }
2214 }
2215 if min_eig_before > 0.0 && shrinkage_ridge / min_eig_before > 0.01 {
2217 log::debug!(
2218 "Penalty shrinkage floor active: ridge={:.3e} (min_eig_before={:.3e}, ratio={:.1e}, max_bal_eig={:.3e}, applied_dirs={})",
2219 shrinkage_ridge,
2220 min_eig_before,
2221 shrinkage_ridge / min_eig_before,
2222 invariant.max_balanced_eigenvalue,
2223 shrinkage_floor_applied,
2224 );
2225 }
2226 }
2227
2228 let eigenvalue_floor = invariant.max_balanced_eigenvalue.max(1.0) * 1e-12;
2229 let qs = compose_qs_from_split(&q_pen, &q_null, p);
2230
2231 let leakage = assess_subspace_leakage(&qs, &rs_transformed, structural_rank, p);
2234 if !subspace_split_is_consistent(&leakage, p) {
2235 return Err(EstimationError::LayoutError(format!(
2236 "Reparameterization subspace split is inconsistent: max null leakage {:.3e} (rel {:.3e}, worst penalty {}), max |Qp'Qn| {:.3e}",
2237 leakage.max_abs_sq.sqrt(),
2238 leakage.max_rel_sq.sqrt(),
2239 leakage.worst_penalty,
2240 leakage.max_cross_gram_abs,
2241 )));
2242 }
2243
2244 let mut u_truncated_mat = Mat::<f64>::zeros(p, q_null.ncols());
2247 matmul(
2248 u_truncated_mat.as_mut(),
2249 Accum::Replace,
2250 qs.transpose(),
2251 q_null.as_ref(),
2252 1.0,
2253 Par::Seq,
2254 );
2255
2256 let mut e_transformed_mat = Mat::<f64>::zeros(structural_rank, p);
2262 for row_idx in 0..structural_rank {
2263 let safe_eigenval = range_eigs_sorted[row_idx].max(eigenvalue_floor);
2264 let sqrt_eigenval = safe_eigenval.sqrt();
2265 for j in 0..penalized_rank {
2267 e_transformed_mat[(row_idx, j)] = sqrt_eigenval * range_rotation[(j, row_idx)];
2268 }
2269 }
2270
2271 let mut floored_eigs: Vec<f64> = Vec::with_capacity(range_eigs_sorted.len());
2287 let mut log_det_sum = KahanSum::default();
2288 for (idx, &ev) in range_eigs_sorted.iter().enumerate() {
2289 if !ev.is_finite() || ev < -eigenvalue_floor {
2290 return Err(EstimationError::LayoutError(format!(
2291 "Penalty pseudo-logdet has a non-finite or large-negative structural eigenvalue at index {idx}: {ev:.3e}"
2292 )));
2293 }
2294 let safe_ev = ev.max(eigenvalue_floor);
2295 floored_eigs.push(safe_ev);
2296 if idx < penalized_rank {
2297 log_det_sum.add(safe_ev.ln());
2298 }
2299 }
2300 let log_det = log_det_sum.sum();
2301 let delta = 0.0;
2302
2303 let det1vec: Vec<f64> = (0..lambdas.len())
2306 .into_par_iter()
2307 .map(|k| {
2308 let s_k = &s_k_penalized_cache[k];
2309 let trace = trace_penalty_in_orthogonal_basis(
2313 s_k,
2314 penalized_rank,
2315 &range_rotation,
2316 &floored_eigs,
2317 delta,
2318 );
2319 lambdas[k] * trace
2320 })
2321 .collect();
2322
2323 {
2324 let mut maxdet1_mismatch = 0.0_f64;
2328 let mut det1_scale = 0.0_f64;
2329 for (k, lambda) in lambdas.iter().enumerate() {
2330 let s_k_penalized = &s_k_penalized_cache[k];
2331 let s_k_eigenbasis = orthogonal_similarity_transform_faer(
2332 s_k_penalized,
2333 penalized_rank,
2334 &range_rotation,
2335 );
2336 let mut trace = KahanSum::default();
2337 for l in 0..penalized_rank {
2338 trace.add(s_k_eigenbasis[(l, l)] / (floored_eigs[l] + delta));
2339 }
2340 let reference = *lambda * trace.sum();
2341 maxdet1_mismatch = maxdet1_mismatch.max((reference - det1vec[k]).abs());
2342 det1_scale = det1_scale.max(reference.abs()).max(det1vec[k].abs());
2343 }
2344 let det1_tolerance = 1e-7 * det1_scale.max(1.0);
2345 assert!(
2346 maxdet1_mismatch <= det1_tolerance,
2347 "det1 mismatch between optimized and reference formulas: max_abs={maxdet1_mismatch:.3e}, tol={det1_tolerance:.3e}"
2348 );
2349 }
2350
2351 let mut s_truncated = Mat::<f64>::zeros(p, p);
2362 matmul(
2363 s_truncated.as_mut(),
2364 Accum::Replace,
2365 e_transformed_mat.transpose(),
2366 e_transformed_mat.as_ref(),
2367 1.0,
2368 Par::Seq,
2369 );
2370
2371 {
2372 let mut max_null_diag = 0.0_f64;
2374 let mut max_null_offdiag = 0.0_f64;
2375 for i in structural_rank..p {
2376 max_null_diag = max_null_diag.max(s_truncated[(i, i)].abs());
2377 for j in 0..p {
2378 if i != j {
2379 max_null_offdiag = max_null_offdiag.max(s_truncated[(i, j)].abs());
2380 }
2381 }
2382 }
2383 assert!(
2384 max_null_diag <= 1e-10 && max_null_offdiag <= 1e-10,
2385 "null-space leakage in transformed penalty: max_null_diag={max_null_diag:.3e}, max_null_offdiag={max_null_offdiag:.3e}"
2386 );
2387 }
2388
2389 let qs_array = mat_to_array(&qs);
2390 let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2391 .par_iter()
2392 .zip(penalties.par_iter())
2393 .map(|(r, cp)| {
2394 let mean_transformed = qs_array.t().dot(&cp.full_width_prior_mean());
2395 CanonicalPenalty::from_dense_root_with_mean(mat_to_array(r), p, mean_transformed)
2396 })
2397 .collect();
2398 Ok(ReparamResult {
2399 s_transformed: mat_to_array(&s_truncated),
2400 log_det,
2401 det1: Array1::from(det1vec),
2402 qs: qs_array,
2403 canonical_transformed,
2404 e_transformed: mat_to_array(&e_transformed_mat),
2405 u_truncated: mat_to_array(&u_truncated_mat),
2406 penalty_shrinkage_ridge: shrinkage_ridge,
2407 })
2408}
2409
2410#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2412pub struct EngineDims {
2413 pub p: usize,
2414 pub k: usize,
2415}
2416
2417impl EngineDims {
2418 pub fn new(p: usize, k: usize) -> Self {
2419 Self { p, k }
2420 }
2421}
2422
2423pub fn stable_reparameterization_engine_canonical(
2432 penalties: &[CanonicalPenalty],
2433 lambdas: &[f64],
2434 dims: EngineDims,
2435 cached_invariant: Option<&ReparamInvariant>,
2436 penalty_shrinkage_floor: Option<f64>,
2437) -> Result<ReparamResult, EstimationError> {
2438 let owned;
2439 let invariant = match cached_invariant {
2440 Some(inv) => inv,
2441 None => {
2442 owned = precompute_reparam_invariant_from_canonical(penalties, dims.p)?;
2443 &owned
2444 }
2445 };
2446 stable_reparameterizationwith_invariant(
2447 penalties,
2448 lambdas,
2449 dims.p,
2450 invariant,
2451 penalty_shrinkage_floor,
2452 )
2453}
2454
2455#[derive(Clone)]
2465pub struct KroneckerReparamResult {
2466 pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
2472 pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
2474 pub marginal_qs: Arc<Vec<Array2<f64>>>,
2476 pub log_det: f64,
2478 pub det1: Array1<f64>,
2480 pub det2: Array2<f64>,
2482 pub penalty_shrinkage_ridge: f64,
2484 pub has_double_penalty: bool,
2486 pub marginal_dims: Vec<usize>,
2488}
2489
2490impl KroneckerReparamResult {
2491 pub fn materialize_qs(&self) -> Array2<f64> {
2494 let mut qs = Array2::<f64>::eye(1);
2495 for u_k in self.marginal_qs.iter() {
2496 qs = kronecker_product(&qs, u_k);
2497 }
2498 qs
2499 }
2500
2501 pub fn materialize_s_transformed(&self, lambdas: &[f64]) -> Array2<f64> {
2504 let d = self.marginal_dims.len();
2505 let p: usize = self.marginal_dims.iter().copied().product();
2506 let mut s = Array2::<f64>::zeros((p, p));
2507
2508 let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2512 self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2513 let has_double = self.has_double_penalty && lambdas.len() > d;
2514 let mut multi_idx = vec![0usize; d];
2515 let mut flat = 0usize;
2516 loop {
2517 let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2518 &eigenvalue_views,
2519 &multi_idx,
2520 lambdas,
2521 d,
2522 has_double,
2523 self.penalty_shrinkage_ridge,
2524 );
2525 s[[flat, flat]] = sigma;
2526 flat += 1;
2527
2528 if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2529 break;
2530 }
2531 }
2532 s
2533 }
2534
2535 pub fn materialize_dense_artifact_result(
2538 &self,
2539 rs_list: &[Array2<f64>],
2540 lambdas: &[f64],
2541 p: usize,
2542 ) -> Result<ReparamResult, EstimationError> {
2543 const KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P: usize = 4096;
2544 if p > KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P {
2545 return Err(EstimationError::LayoutError(format!(
2546 "Kronecker reparameterization would materialize dense {}x{} compatibility tensors; \
2547 large-model dense fallback is disabled. Wire the downstream solver to consume \
2548 the factored Kronecker result directly",
2549 p, p
2550 )));
2551 }
2552 let qs = self.materialize_qs();
2553 let s_transformed = self.materialize_s_transformed(lambdas);
2554
2555 let rs_transformed: Vec<Array2<f64>> = if rs_list.len() >= 2 {
2557 use rayon::prelude::*;
2558 rs_list
2559 .par_iter()
2560 .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2561 .collect()
2562 } else {
2563 rs_list
2564 .iter()
2565 .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2566 .collect()
2567 };
2568 let d = self.marginal_dims.len();
2574 let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2581 self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2582 let has_double = self.has_double_penalty && lambdas.len() > d;
2583 let diag_vals: Vec<f64> = {
2584 let mut vals = Vec::with_capacity(p);
2585 let mut multi_idx = vec![0usize; d];
2586 loop {
2587 let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2588 &eigenvalue_views,
2589 &multi_idx,
2590 lambdas,
2591 d,
2592 has_double,
2593 self.penalty_shrinkage_ridge,
2594 );
2595 vals.push(if sigma > 0.0 { sigma.sqrt() } else { 0.0 });
2596
2597 if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2598 break;
2599 }
2600 }
2601 vals
2602 };
2603 let rank = diag_vals.iter().filter(|&&v| v > 1e-12).count();
2604 let mut e_transformed = Array2::<f64>::zeros((rank, p));
2605 let mut row = 0;
2606 for (j, &v) in diag_vals.iter().enumerate() {
2607 if v > 1e-12 {
2608 e_transformed[[row, j]] = v;
2609 row += 1;
2610 }
2611 }
2612
2613 let null_count = p - rank;
2615 let mut u_truncated = Array2::<f64>::zeros((p, null_count));
2616 let mut col = 0;
2617 for (j, &v) in diag_vals.iter().enumerate() {
2618 if v <= 1e-12 {
2619 u_truncated[[j, col]] = 1.0; col += 1;
2621 }
2622 }
2623
2624 let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2625 .iter()
2626 .map(|r| CanonicalPenalty::from_dense_root(r.clone(), p))
2627 .collect();
2628 Ok(ReparamResult {
2629 s_transformed,
2630 log_det: self.log_det,
2631 det1: self.det1.clone(),
2632 qs,
2633 canonical_transformed,
2634 e_transformed,
2635 u_truncated,
2636 penalty_shrinkage_ridge: self.penalty_shrinkage_ridge,
2637 })
2638 }
2639}
2640
2641const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
2648
2649#[inline]
2663fn kronecker_cell_sigma(
2664 marginal_eigenvalues: &[ArrayView1<'_, f64>],
2665 multi_idx: &[usize],
2666 lambdas: &[f64],
2667 d: usize,
2668 has_double_penalty: bool,
2669 ridge: f64,
2670) -> (f64, f64, bool) {
2671 let mut sigma = 0.0;
2672 let mut structural_sigma = 0.0;
2673 for k in 0..d {
2674 let marginal_eigenvalue = marginal_eigenvalues[k][multi_idx[k]];
2675 structural_sigma += marginal_eigenvalue;
2676 sigma += lambdas[k] * marginal_eigenvalue;
2677 }
2678 let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
2679 if has_double_penalty && joint_null {
2680 sigma += lambdas[d];
2681 }
2682 if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
2683 sigma += ridge;
2684 }
2685 (sigma, structural_sigma, joint_null)
2686}
2687
2688#[inline]
2691fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
2692 let mut carry = true;
2693 for dim in (0..dims.len()).rev() {
2694 if carry {
2695 multi_idx[dim] += 1;
2696 if multi_idx[dim] < dims[dim] {
2697 carry = false;
2698 } else {
2699 multi_idx[dim] = 0;
2700 }
2701 }
2702 }
2703 carry
2704}
2705
2706pub fn kronecker_logdet_and_derivatives(
2707 marginal_eigenvalues: &[ArrayView1<'_, f64>],
2708 marginal_dims: &[usize],
2709 lambdas: &[f64],
2710 has_double_penalty: bool,
2711 ridge: f64,
2712) -> (f64, Array1<f64>, Array2<f64>) {
2713 let d = marginal_dims.len();
2714 let n_pen = d + if has_double_penalty { 1 } else { 0 };
2715
2716 let mut logdet = 0.0;
2717 let mut grad = Array1::<f64>::zeros(n_pen);
2718 let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2719 let tol = 1e-12;
2720
2721 let mut multi_idx = vec![0usize; d];
2722 loop {
2723 let (sigma, _structural_sigma, joint_null) = kronecker_cell_sigma(
2724 marginal_eigenvalues,
2725 &multi_idx,
2726 lambdas,
2727 d,
2728 has_double_penalty,
2729 ridge,
2730 );
2731
2732 if sigma > tol {
2733 logdet += sigma.ln();
2734 let inv_sigma = 1.0 / sigma;
2735 let inv_sigma2 = inv_sigma * inv_sigma;
2736
2737 for k in 0..d {
2738 let ck = lambdas[k] * marginal_eigenvalues[k][multi_idx[k]];
2739 grad[k] += ck * inv_sigma;
2740 }
2741 if has_double_penalty && joint_null {
2742 grad[d] += lambdas[d] * inv_sigma;
2743 }
2744
2745 for k in 0..n_pen {
2746 let ck = if k < d {
2747 lambdas[k] * marginal_eigenvalues[k][multi_idx[k]]
2748 } else if joint_null {
2749 lambdas[d]
2750 } else {
2751 0.0
2752 };
2753 if ck == 0.0 {
2760 continue;
2761 }
2762 hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2763 for l in (k + 1)..n_pen {
2764 let cl = if l < d {
2765 lambdas[l] * marginal_eigenvalues[l][multi_idx[l]]
2766 } else if joint_null {
2767 lambdas[d]
2768 } else {
2769 0.0
2770 };
2771 let off = -ck * cl * inv_sigma2;
2772 hess[[k, l]] += off;
2773 hess[[l, k]] += off;
2774 }
2775 }
2776 }
2777
2778 if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
2779 break;
2780 }
2781 }
2782
2783 (logdet, grad, hess)
2784}
2785
2786use crate::kronecker::KroneckerInvariantStructure;
2790
2791pub fn kronecker_reparameterization_engine(
2797 marginal_designs: &[Array2<f64>],
2798 marginal_penalties: &[Array2<f64>],
2799 marginal_dims: &[usize],
2800 lambdas: &[f64],
2801 has_double_penalty: bool,
2802 penalty_shrinkage_floor: Option<f64>,
2803) -> Result<KroneckerReparamResult, EstimationError> {
2804 let d = marginal_dims.len();
2805 if marginal_designs.len() != d || marginal_penalties.len() != d {
2806 return Err(EstimationError::LayoutError(format!(
2807 "kronecker_reparameterization_engine: dimension mismatch: designs={}, penalties={}, dims={}",
2808 marginal_designs.len(),
2809 marginal_penalties.len(),
2810 d
2811 )));
2812 }
2813
2814 let invariant =
2815 KroneckerInvariantStructure::compute(marginal_designs, marginal_penalties, marginal_dims)?;
2816 kronecker_reparameterization_engine_with_invariant(
2817 &invariant,
2818 marginal_dims,
2819 lambdas,
2820 has_double_penalty,
2821 penalty_shrinkage_floor,
2822 )
2823}
2824
2825pub fn kronecker_reparameterization_engine_with_invariant(
2833 invariant: &KroneckerInvariantStructure,
2834 marginal_dims: &[usize],
2835 lambdas: &[f64],
2836 has_double_penalty: bool,
2837 penalty_shrinkage_floor: Option<f64>,
2838) -> Result<KroneckerReparamResult, EstimationError> {
2839 let marginal_eigenvalues = Arc::clone(&invariant.marginal_eigenvalues);
2842 let marginal_qs = Arc::clone(&invariant.marginal_qs);
2843 let reparameterized_marginals = Arc::clone(&invariant.reparameterized_marginals);
2844
2845 let penalty_shrinkage_ridge = if let Some(floor) = penalty_shrinkage_floor {
2847 floor * invariant.max_balanced_eigenvalue
2848 } else {
2849 0.0
2850 };
2851
2852 let marginal_eigenvalue_views: Vec<_> = marginal_eigenvalues
2853 .iter()
2854 .map(|evals| evals.view())
2855 .collect();
2856 let (log_det, det1, det2) = kronecker_logdet_and_derivatives(
2857 &marginal_eigenvalue_views,
2858 marginal_dims,
2859 lambdas,
2860 has_double_penalty,
2861 penalty_shrinkage_ridge,
2862 );
2863
2864 Ok(KroneckerReparamResult {
2865 reparameterized_marginals,
2866 marginal_eigenvalues,
2867 marginal_qs,
2868 log_det,
2869 det1,
2870 det2,
2871 penalty_shrinkage_ridge,
2872 has_double_penalty,
2873 marginal_dims: marginal_dims.to_vec(),
2874 })
2875}
2876
2877pub fn calculate_condition_number(matrix: &Array2<f64>) -> Result<f64, FaerLinalgError> {
2897 let (rows, cols) = matrix.dim();
2898 if rows == 0 || cols == 0 {
2899 return Ok(1.0);
2900 }
2901
2902 if rows == cols {
2904 let mut max_abs = 0.0_f64;
2905 let mut max_asym = 0.0_f64;
2906 for i in 0..rows {
2907 for j in 0..cols {
2908 max_abs = max_abs.max(matrix[[i, j]].abs());
2909 }
2910 for j in 0..i {
2911 let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2912 if diff > max_asym {
2913 max_asym = diff;
2914 }
2915 }
2916 }
2917 let sym_tol = max_abs.max(1.0) * 1e-12;
2918 if max_asym <= sym_tol {
2919 let (evals, _) = matrix.eigh(Side::Lower)?;
2920 let mut max_abs_eval = 0.0_f64;
2921 let mut min_abs_eval = f64::INFINITY;
2922 for &lam in evals.iter() {
2923 let s = lam.abs();
2924 max_abs_eval = max_abs_eval.max(s);
2925 min_abs_eval = min_abs_eval.min(s);
2926 }
2927 if min_abs_eval < 1e-12 {
2928 return Ok(f64::INFINITY);
2929 }
2930 return Ok(max_abs_eval / min_abs_eval);
2931 }
2932 }
2933
2934 let (_, s, _) = matrix.svd(false, false)?;
2936 let max_sv = s.iter().fold(0.0_f64, |max, &val| max.max(val));
2937 let min_sv = s.iter().fold(f64::INFINITY, |min, &val| min.min(val));
2938 if min_sv < 1e-12 {
2939 return Ok(f64::INFINITY);
2940 }
2941 Ok(max_sv / min_sv)
2942}
2943
2944#[cfg(test)]
2945mod tests {
2946 use super::{
2947 CanonicalPenalty, REL_PSD_FLOOR, SubspaceLeakageMetrics, assess_subspace_leakage,
2948 classify_eigenvalues_strict, precompute_reparam_invariant_from_canonical,
2949 report_penalty_pair_redundancy, stable_reparameterizationwith_invariant,
2950 subspace_split_is_consistent,
2951 };
2952 use crate::EstimationError;
2953 use crate::construction::kronecker_product;
2954 use faer::Mat;
2955 use gam_linalg::faer_ndarray::FaerEigh;
2956 use gam_linalg::utils::inf_norm;
2957 use ndarray::{Array1, Array2, array};
2958
2959 fn canonical_from_roots(rs_list: &[Array2<f64>], p: usize) -> Vec<CanonicalPenalty> {
2961 rs_list
2962 .iter()
2963 .map(|r| {
2964 let local = r.t().dot(r);
2965 CanonicalPenalty {
2966 root: r.clone(),
2967 col_range: 0..p,
2968 total_dim: p,
2969 nullity: 0,
2970 local,
2971 prior_mean: Array1::zeros(p),
2972 positive_eigenvalues: Vec::new(),
2973 op: None,
2974 }
2975 })
2976 .collect()
2977 }
2978
2979 fn metrics_for(
2980 qs: &Mat<f64>,
2981 rs: &[Mat<f64>],
2982 structural_rank: usize,
2983 p: usize,
2984 ) -> SubspaceLeakageMetrics {
2985 assess_subspace_leakage(qs, rs, structural_rank, p)
2986 }
2987
2988 #[test]
2989 fn subspace_leakage_iszero_for_clean_split() {
2990 let p = 4usize;
2991 let structural_rank = 2usize;
2992 let qs = Mat::<f64>::identity(p, p);
2993 let mut r0 = Mat::<f64>::zeros(2, p);
2994 r0[(0, 0)] = 1.0;
2995 r0[(1, 1)] = 2.0;
2996
2997 let m = metrics_for(&qs, &[r0], structural_rank, p);
2998 assert!(m.max_abs_sq <= 1e-16);
2999 assert!(m.max_rel_sq <= 1e-16);
3000 assert!(m.max_cross_gram_abs <= 1e-16);
3001 }
3002
3003 #[test]
3004 fn subspace_leakage_detects_null_column_energy() {
3005 let p = 4usize;
3006 let structural_rank = 2usize;
3007 let qs = Mat::<f64>::identity(p, p);
3008 let mut r0 = Mat::<f64>::zeros(1, p);
3009 r0[(0, 2)] = 3.0;
3010
3011 let m = metrics_for(&qs, &[r0], structural_rank, p);
3012 assert!(m.max_abs_sq > 0.0);
3013 assert!(m.max_rel_sq > 0.99);
3014 }
3015
3016 #[test]
3017 fn subspace_leakage_detects_qp_qn_nonorthogonality() {
3018 let p = 3usize;
3019 let structural_rank = 1usize;
3020 let mut qs = Mat::<f64>::identity(p, p);
3021 qs[(0, 1)] = 0.2;
3022 let r0 = Mat::<f64>::zeros(1, p);
3023
3024 let m = metrics_for(&qs, &[r0], structural_rank, p);
3025 assert!(m.max_cross_gram_abs > 1e-3);
3026 }
3027
3028 #[test]
3029 fn subspace_split_admits_near_threshold_manifold_leakage_1802() {
3030 let p = 40usize;
3044 let structural_rank = p - 1;
3045 let null_amp = (1.06e-8_f64).sqrt();
3048 let mut rs = Mat::<f64>::zeros(1, p);
3049 rs[(0, 0)] = 1.0;
3050 rs[(0, p - 1)] = null_amp;
3051 let qs = Mat::<f64>::identity(p, p);
3052 let leakage = metrics_for(&qs, &[rs], structural_rank, p);
3053 assert!(
3056 leakage.max_rel_sq > 1e-10 && leakage.max_rel_sq < 1e-6,
3057 "reproduced leakage should sit in the near-REL_PSD_FLOOR band, got {:.3e}",
3058 leakage.max_rel_sq
3059 );
3060 assert!(leakage.max_cross_gram_abs <= 1e-12);
3061 assert!(
3062 subspace_split_is_consistent(&leakage, p),
3063 "near-REL_PSD_FLOOR null leakage on a manifold basis must be admitted \
3064 (rel_sq={:.3e}, tol={:.3e})",
3065 leakage.max_rel_sq,
3066 (p as f64) * REL_PSD_FLOOR,
3067 );
3068 }
3069
3070 #[test]
3071 fn subspace_split_still_rejects_genuine_inconsistency_1802() {
3072 let p = 40usize;
3077 let structural_rank = p - 1;
3078 let mut rs = Mat::<f64>::zeros(1, p);
3079 rs[(0, p - 1)] = 1.0; let qs = Mat::<f64>::identity(p, p);
3081 let leakage = metrics_for(&qs, &[rs], structural_rank, p);
3082 assert!(leakage.max_rel_sq > 0.99);
3083 assert!(
3084 !subspace_split_is_consistent(&leakage, p),
3085 "an O(1) null-block leakage is a real inconsistency and must be rejected"
3086 );
3087
3088 let mut qs_bad = Mat::<f64>::identity(3, 3);
3090 qs_bad[(0, 1)] = 0.2;
3091 let clean = Mat::<f64>::zeros(1, 3);
3092 let leakage2 = metrics_for(&qs_bad, &[clean], 1, 3);
3093 assert!(leakage2.max_cross_gram_abs > 1e-3);
3094 assert!(
3095 !subspace_split_is_consistent(&leakage2, 3),
3096 "a non-orthogonal Qp/Qn split must be rejected"
3097 );
3098 }
3099
3100 #[test]
3101 fn u_truncated_is_transformed_frame_in_nonzero_case() {
3102 let p = 3usize;
3103 let rs_list = vec![array![[1.0, 0.0, 0.0]]];
3104 let canonical = canonical_from_roots(&rs_list, p);
3105 let lambdas = vec![2.0];
3106 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3107 .expect("precompute invariant");
3108 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3109 .expect("stable reparam");
3110
3111 let expected = rep.qs.t().dot(&inv.split.q_null);
3112 let diff = &rep.u_truncated - &expected;
3113 let max_abs = inf_norm(diff.iter().copied());
3114 assert!(
3115 max_abs <= 1e-10,
3116 "u_truncated frame mismatch: max_abs={max_abs}"
3117 );
3118 }
3119
3120 #[test]
3121 fn infinite_lambda_keeps_range_penalty_block_finite_1379() {
3122 let p = 3usize;
3139 let rs_list = vec![array![[1.0, 0.0, 0.0]], array![[0.0, 1.0, 0.0]]];
3140 let canonical = canonical_from_roots(&rs_list, p);
3141 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3142 .expect("precompute invariant");
3143
3144 let lambdas_inf = vec![f64::INFINITY, 3.0];
3145 let inf_result =
3146 stable_reparameterizationwith_invariant(&canonical, &lambdas_inf, p, &inv, None);
3147 assert!(
3148 inf_result.is_err(),
3149 "an infinite lambda must surface as an error, not be silently clamped (#1074)"
3150 );
3151
3152 let lambdas_big = vec![1e300_f64, 3.0];
3156 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas_big, p, &inv, None)
3157 .expect("stable reparam at large-but-finite lambda");
3158 assert!(
3159 rep.s_transformed.iter().all(|v| v.is_finite()),
3160 "transformed penalty must be finite at large-but-finite lambda"
3161 );
3162 assert!(
3163 rep.qs.iter().all(|v| v.is_finite()),
3164 "reparam rotation must be finite at large-but-finite lambda"
3165 );
3166 assert!(
3167 rep.log_det.is_finite(),
3168 "penalty log-det must be finite at large-but-finite lambda"
3169 );
3170 assert!(
3171 rep.det1.iter().all(|v| v.is_finite()),
3172 "penalty log-det derivatives must be finite at large-but-finite lambda"
3173 );
3174 }
3175
3176 #[test]
3177 fn u_truncated_is_identitywhen_no_penalties() {
3178 let p = 4usize;
3179 let canonical: Vec<CanonicalPenalty> = Vec::new();
3180 let lambdas: Vec<f64> = Vec::new();
3181 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3182 .expect("precompute invariant");
3183 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3184 .expect("stable reparam");
3185 assert_eq!(rep.u_truncated, Array2::<f64>::eye(p));
3186 }
3187
3188 #[test]
3189 fn dense_shrinkage_floor_skips_structurally_unpenalized_range_columns() {
3190 let p = 3usize;
3191 let canonical = canonical_from_roots(&[array![[1.0, 0.0, 0.0]]], p);
3192 let invariant = super::ReparamInvariant {
3193 split: super::SubspaceSplit {
3194 q_pen: array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
3195 q_null: array![[0.0], [0.0], [1.0]],
3196 },
3197 qs_base: Array2::eye(p),
3198 has_nonzero: true,
3199 max_balanced_eigenvalue: 1.0,
3200 };
3201
3202 let rep =
3203 stable_reparameterizationwith_invariant(&canonical, &[2.0], p, &invariant, Some(1e-6))
3204 .expect("stable reparameterization");
3205 assert!(rep.s_transformed[[0, 0]] > 2.0);
3206 assert!(
3207 rep.s_transformed[[1, 1]] <= 1e-11,
3208 "structurally unpenalized range coordinate received shrinkage ridge: {}",
3209 rep.s_transformed[[1, 1]]
3210 );
3211 }
3212
3213 #[test]
3214 fn kronecker_shrinkage_floor_preserves_joint_null_space() {
3215 let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3216 let marginal_penalties = vec![
3217 array![[0.0, 0.0], [0.0, 2.0]],
3218 array![[0.0, 0.0], [0.0, 3.0]],
3219 ];
3220 let marginal_dims = vec![2usize, 2usize];
3221 let lambdas = vec![5.0, 7.0];
3222
3223 let rep = super::kronecker_reparameterization_engine(
3224 &marginal_designs,
3225 &marginal_penalties,
3226 &marginal_dims,
3227 &lambdas,
3228 false,
3229 Some(1e-6),
3230 )
3231 .expect("kronecker reparameterization");
3232 assert!(rep.penalty_shrinkage_ridge > 0.0);
3233
3234 let s = rep.materialize_s_transformed(&lambdas);
3235 assert!(
3236 s[[0, 0]].abs() <= 1e-14,
3237 "joint tensor null direction must remain unpenalized, got {}",
3238 s[[0, 0]]
3239 );
3240 assert!(s[[1, 1]] > lambdas[1] * 3.0);
3241 assert!(s[[2, 2]] > lambdas[0] * 2.0);
3242 assert!(s[[3, 3]] > lambdas[0] * 2.0 + lambdas[1] * 3.0);
3243
3244 let tensor_roots = vec![
3245 array![
3246 [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3247 [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3248 ],
3249 array![
3250 [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3251 [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3252 ],
3253 ];
3254 let dense = rep
3255 .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3256 .expect("dense artifact materialization");
3257 assert_eq!(dense.e_transformed.nrows(), 3);
3258 assert_eq!(dense.u_truncated.ncols(), 1);
3259 }
3260
3261 #[test]
3262 fn kronecker_memoized_invariant_is_bit_identical_to_unmemoized_engine() {
3263 let marginal_designs = vec![
3270 array![[1.0, 0.3, -0.2], [0.4, 1.0, 0.1], [-0.1, 0.2, 1.0]],
3271 array![[1.0, -0.5], [0.2, 1.0], [0.7, 0.3]],
3272 ];
3273 let marginal_penalties = vec![
3274 array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]],
3275 array![[3.0, -1.5], [-1.5, 3.0]],
3276 ];
3277 let marginal_dims = vec![3usize, 2usize];
3278
3279 let invariant = super::KroneckerInvariantStructure::compute(
3280 &marginal_designs,
3281 &marginal_penalties,
3282 &marginal_dims,
3283 )
3284 .expect("invariant structure");
3285
3286 for lambdas in [
3287 vec![5.0, 7.0],
3288 vec![0.0, 7.0],
3289 vec![5.0, 0.0],
3290 vec![1e-3, 1e3],
3291 ] {
3292 for floor in [None, Some(1e-6)] {
3293 let unmemoized = super::kronecker_reparameterization_engine(
3294 &marginal_designs,
3295 &marginal_penalties,
3296 &marginal_dims,
3297 &lambdas,
3298 true,
3299 floor,
3300 )
3301 .expect("unmemoized engine");
3302 let memoized = super::kronecker_reparameterization_engine_with_invariant(
3303 &invariant,
3304 &marginal_dims,
3305 &lambdas,
3306 true,
3307 floor,
3308 )
3309 .expect("memoized engine");
3310
3311 assert_eq!(memoized.log_det.to_bits(), unmemoized.log_det.to_bits());
3312 assert_eq!(
3313 memoized.penalty_shrinkage_ridge.to_bits(),
3314 unmemoized.penalty_shrinkage_ridge.to_bits()
3315 );
3316 for (a, b) in memoized.det1.iter().zip(unmemoized.det1.iter()) {
3317 assert_eq!(a.to_bits(), b.to_bits());
3318 }
3319 for (a, b) in memoized.det2.iter().zip(unmemoized.det2.iter()) {
3320 assert_eq!(a.to_bits(), b.to_bits());
3321 }
3322 for (ma, ua) in memoized
3323 .reparameterized_marginals
3324 .iter()
3325 .zip(unmemoized.reparameterized_marginals.iter())
3326 {
3327 for (a, b) in ma.iter().zip(ua.iter()) {
3328 assert_eq!(a.to_bits(), b.to_bits());
3329 }
3330 }
3331 for (mq, uq) in memoized
3332 .marginal_qs
3333 .iter()
3334 .zip(unmemoized.marginal_qs.iter())
3335 {
3336 for (a, b) in mq.iter().zip(uq.iter()) {
3337 assert_eq!(a.to_bits(), b.to_bits());
3338 }
3339 }
3340 }
3341 }
3342 }
3343
3344 #[test]
3345 fn kronecker_double_penalty_shrinks_only_joint_null_space() {
3346 let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3347 let marginal_penalties = vec![
3348 array![[0.0, 0.0], [0.0, 2.0]],
3349 array![[0.0, 0.0], [0.0, 3.0]],
3350 ];
3351 let marginal_dims = vec![2usize, 2usize];
3352 let lambdas = vec![5.0, 7.0, 11.0];
3353
3354 let rep = super::kronecker_reparameterization_engine(
3355 &marginal_designs,
3356 &marginal_penalties,
3357 &marginal_dims,
3358 &lambdas,
3359 true,
3360 None,
3361 )
3362 .expect("kronecker reparameterization");
3363
3364 let s = rep.materialize_s_transformed(&lambdas);
3365 let expected = [11.0, 21.0, 10.0, 31.0];
3366 for (idx, expected_diag) in expected.iter().copied().enumerate() {
3367 assert!(
3368 (s[[idx, idx]] - expected_diag).abs() <= 1e-12,
3369 "diagonal {idx} got {}, expected {expected_diag}",
3370 s[[idx, idx]]
3371 );
3372 }
3373
3374 let expected_logdet: f64 = expected.iter().map(|v| f64::ln(*v)).sum();
3375 assert!((rep.log_det - expected_logdet).abs() <= 1e-12);
3376 assert!(
3377 (rep.det1[2] - 1.0).abs() <= 1e-12,
3378 "double-penalty derivative must come only from the joint null mode, got {}",
3379 rep.det1[2]
3380 );
3381 assert!(rep.det2[[2, 2]].abs() <= 1e-12);
3382
3383 let tensor_roots = vec![
3384 array![
3385 [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3386 [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3387 ],
3388 array![
3389 [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3390 [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3391 ],
3392 ];
3393 let dense = rep
3394 .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3395 .expect("dense artifact materialization");
3396 for (idx, expected_diag) in expected.iter().copied().enumerate() {
3397 assert!(
3398 (dense.s_transformed[[idx, idx]] - expected_diag).abs() <= 1e-12,
3399 "dense artifact diagonal {idx} got {}, expected {expected_diag}",
3400 dense.s_transformed[[idx, idx]]
3401 );
3402 }
3403 }
3404
3405 #[test]
3406 fn transformed_penalty_is_diagonal_in_transformed_frame() {
3407 let p = 3usize;
3408 let inv_sqrt2 = 2.0_f64.sqrt().recip();
3409 let rs_list = vec![array![[inv_sqrt2, inv_sqrt2, 0.0]]];
3411 let canonical = canonical_from_roots(&rs_list, p);
3412 let lambdas = vec![4.0];
3413 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3414 .expect("precompute invariant");
3415 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3416 .expect("stable reparam");
3417
3418 assert_eq!(rep.e_transformed.nrows(), 1);
3419 assert!(rep.e_transformed[[0, 0]].abs() > 0.0);
3420 assert!(rep.e_transformed[[0, 1]].abs() <= 1e-12);
3421 assert!(rep.e_transformed[[0, 2]].abs() <= 1e-12);
3422 let expected_det1 = 1.0_f64;
3425 assert!((rep.det1[0] - expected_det1).abs() <= 1e-12);
3426
3427 let s = rep.s_transformed;
3428 let mut max_offdiag = 0.0_f64;
3429 for i in 0..p {
3430 for j in 0..p {
3431 if i != j {
3432 max_offdiag = max_offdiag.max(s[[i, j]].abs());
3433 }
3434 }
3435 }
3436 assert!(
3437 max_offdiag <= 1e-10,
3438 "transformed penalty should be diagonal, max offdiag={max_offdiag}"
3439 );
3440 assert!(s[[1, 1]].abs() <= 1e-10);
3441 assert!(s[[2, 2]].abs() <= 1e-10);
3442 }
3443
3444 #[test]
3445 fn det1_matches_rank_for_single_full_rank_penalty() {
3446 let p = 2usize;
3447 let inv_sqrt2 = 2.0_f64.sqrt().recip();
3448 let q_t = [[inv_sqrt2, inv_sqrt2], [-inv_sqrt2, inv_sqrt2]];
3450 let rs = array![
3452 [3.0 * q_t[0][0], 3.0 * q_t[0][1]],
3453 [1.0 * q_t[1][0], 1.0 * q_t[1][1]]
3454 ];
3455 let rs_list = vec![rs];
3456 let canonical = canonical_from_roots(&rs_list, p);
3457 let lambdas = vec![5.0];
3458
3459 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3460 .expect("precompute invariant");
3461 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3462 .expect("stable reparam");
3463
3464 assert_eq!(rep.e_transformed.nrows(), p);
3465 let det1 = rep.det1[0];
3466 let s_k_eigs = [9.0_f64, 1.0_f64];
3470 let lambda = 5.0_f64;
3471 let expected_det1: f64 = s_k_eigs.iter().map(|&d| lambda * d / (lambda * d)).sum();
3472 assert!(
3473 (det1 - expected_det1).abs() <= 1e-12,
3474 "expected det1={expected_det1}, got {det1}",
3475 );
3476
3477 let s = rep.s_transformed;
3478 assert!(s[[0, 1]].abs() <= 1e-10);
3479 assert!(s[[1, 0]].abs() <= 1e-10);
3480 assert!(s[[0, 0]] > 0.0);
3481 assert!(s[[1, 1]] > 0.0);
3482 }
3483
3484 #[test]
3485 fn kronecker_reparam_logdet_matches_dense() {
3486 let q1 = 3;
3489 let q2 = 4;
3490 let s1 = {
3491 let mut s = Array2::<f64>::zeros((q1, q1));
3492 s[[0, 0]] = 1.0;
3494 s[[0, 1]] = -1.0;
3495 s[[1, 0]] = -1.0;
3496 s[[1, 1]] = 2.0;
3497 s[[1, 2]] = -1.0;
3498 s[[2, 1]] = -1.0;
3499 s[[2, 2]] = 1.0;
3500 s
3501 };
3502 let s2 = {
3503 let mut s = Array2::<f64>::zeros((q2, q2));
3504 s[[0, 0]] = 1.0;
3505 s[[0, 1]] = -1.0;
3506 s[[1, 0]] = -1.0;
3507 s[[1, 1]] = 2.0;
3508 s[[1, 2]] = -1.0;
3509 s[[2, 1]] = -1.0;
3510 s[[2, 2]] = 2.0;
3511 s[[2, 3]] = -1.0;
3512 s[[3, 2]] = -1.0;
3513 s[[3, 3]] = 1.0;
3514 s
3515 };
3516
3517 let lambdas = [2.5, 1.3];
3518 let p = q1 * q2;
3520 let i1 = Array2::<f64>::eye(q1);
3521 let i2 = Array2::<f64>::eye(q2);
3522 let pen0 = kronecker_product(&s1, &i2);
3523 let pen1 = kronecker_product(&i1, &s2);
3524 let mut s_dense = Array2::<f64>::zeros((p, p));
3525 s_dense.scaled_add(lambdas[0], &pen0);
3526 s_dense.scaled_add(lambdas[1], &pen1);
3527
3528 let (evals_dense, _): (ndarray::Array1<f64>, ndarray::Array2<f64>) =
3530 s_dense.eigh(faer::Side::Lower).unwrap();
3531 let tol = 1e-12;
3532 let ref_logdet: f64 = evals_dense
3533 .iter()
3534 .filter(|&&v: &&f64| v > tol)
3535 .map(|&v: &f64| v.ln())
3536 .sum();
3537
3538 let marginal_designs = vec![
3540 Array2::<f64>::eye(q1), Array2::<f64>::eye(q2),
3542 ];
3543 let marginal_penalties = vec![s1, s2];
3544 let kron_result = super::kronecker_reparameterization_engine(
3545 &marginal_designs,
3546 &marginal_penalties,
3547 &[q1, q2],
3548 &lambdas,
3549 false,
3550 None,
3551 )
3552 .unwrap();
3553
3554 let diff = (kron_result.log_det - ref_logdet).abs();
3555 assert!(
3556 diff < 1e-8,
3557 "Kronecker logdet {:.10} vs dense {:.10}, diff={:.3e}",
3558 kron_result.log_det,
3559 ref_logdet,
3560 diff,
3561 );
3562
3563 let rhos: Vec<f64> = lambdas.iter().map(|&l| l.ln()).collect();
3565 let eps = 1e-5;
3566 for k in 0..2 {
3567 let mut rho_plus = rhos.clone();
3568 rho_plus[k] += eps;
3569 let mut rho_minus = rhos.clone();
3570 rho_minus[k] -= eps;
3571 let lam_plus: Vec<f64> = rho_plus.iter().map(|&r| r.exp()).collect();
3572 let lam_minus: Vec<f64> = rho_minus.iter().map(|&r| r.exp()).collect();
3573 let result_plus = super::kronecker_reparameterization_engine(
3574 &marginal_designs,
3575 &marginal_penalties,
3576 &[q1, q2],
3577 &lam_plus,
3578 false,
3579 None,
3580 )
3581 .unwrap();
3582 let result_minus = super::kronecker_reparameterization_engine(
3583 &marginal_designs,
3584 &marginal_penalties,
3585 &[q1, q2],
3586 &lam_minus,
3587 false,
3588 None,
3589 )
3590 .unwrap();
3591 let fd_deriv = (result_plus.log_det - result_minus.log_det) / (2.0 * eps);
3592 let analytic_deriv = kron_result.det1[k];
3593 let rel_err = if analytic_deriv.abs() > 1e-10 {
3594 (fd_deriv - analytic_deriv).abs() / analytic_deriv.abs()
3595 } else {
3596 (fd_deriv - analytic_deriv).abs()
3597 };
3598 assert!(
3599 rel_err < 1e-4,
3600 "det1[{k}] mismatch: analytic={:.8}, fd={:.8}, rel_err={:.3e}",
3601 analytic_deriv,
3602 fd_deriv,
3603 rel_err,
3604 );
3605 }
3606 }
3607
3608 #[test]
3609 fn classify_strict_rejects_nan_eigenvalue() {
3610 let mut eigs = [1.0, f64::NAN, 0.5];
3611 match classify_eigenvalues_strict(&mut eigs, "test_nan") {
3612 Err(EstimationError::PenaltySpectrumNonFinite {
3613 context,
3614 index,
3615 value,
3616 }) => {
3617 assert_eq!(context, "test_nan");
3618 assert_eq!(index, 1);
3619 assert!(value.is_nan());
3620 }
3621 other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3622 }
3623 }
3624
3625 #[test]
3626 fn classify_strict_rejects_inf_eigenvalue() {
3627 let mut eigs = [1.0, 0.5, f64::INFINITY];
3628 match classify_eigenvalues_strict(&mut eigs, "test_inf") {
3629 Err(EstimationError::PenaltySpectrumNonFinite { index, value, .. }) => {
3630 assert_eq!(index, 2);
3631 assert!(value.is_infinite());
3632 }
3633 other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3634 }
3635 }
3636
3637 #[test]
3638 fn classify_strict_rejects_materially_indefinite() {
3639 let mut eigs = [1.0, -1e-2, 0.5];
3641 match classify_eigenvalues_strict(&mut eigs, "test_indef") {
3642 Err(EstimationError::PenaltySpectrumIndefinite {
3643 context,
3644 index,
3645 value,
3646 ..
3647 }) => {
3648 assert_eq!(context, "test_indef");
3649 assert_eq!(index, 1);
3650 assert!((value + 1e-2).abs() <= 1e-15);
3651 }
3652 other => panic!("expected PenaltySpectrumIndefinite, got {:?}", other),
3653 }
3654 }
3655
3656 #[test]
3657 fn classify_strict_accepts_roundoff_negative() {
3658 let scale = 1.0_f64;
3660 let roundoff = -1e-16 * scale;
3661 let mut eigs = [scale, 0.5 * scale, roundoff, 0.25 * scale];
3662 classify_eigenvalues_strict(&mut eigs, "test_roundoff").expect("roundoff must classify");
3663 assert_eq!(eigs[2], 0.0);
3665 assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3667 }
3668
3669 #[test]
3670 fn classify_strict_accepts_extreme_lambda_assembly_noise_1619() {
3671 let scale = 8.509e12_f64;
3678 let noise = -6.546e2_f64;
3680 assert!(
3681 (noise.abs() / scale) < 1.0e-10,
3682 "fixture must reproduce the ~1e-11-relative noise from #1619"
3683 );
3684 let mut eigs = vec![scale, 0.5 * scale, noise, 0.1 * scale];
3685 classify_eigenvalues_strict(&mut eigs, "range penalty block")
3686 .expect("a ~1e-11-relative roundoff-negative eigenvalue must be accepted (#1619)");
3687 assert_eq!(eigs[2], 0.0);
3689 assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3691 }
3692
3693 #[test]
3694 fn classify_strict_snaps_subtol_positive_to_zero() {
3695 let scale = 10.0_f64;
3698 let subtol = 1e-15 * scale;
3699 let mut eigs = [scale, subtol];
3700 classify_eigenvalues_strict(&mut eigs, "test_sub_pos").expect("sub-tol positive ok");
3701 assert_eq!(eigs[1], 0.0);
3702 }
3703
3704 fn canonical_from_local(
3708 local: Array2<f64>,
3709 col_range: std::ops::Range<usize>,
3710 total_dim: usize,
3711 ) -> CanonicalPenalty {
3712 let block_dim = local.nrows();
3713 let root = Array2::<f64>::zeros((0, block_dim));
3715 CanonicalPenalty {
3716 root,
3717 col_range,
3718 total_dim,
3719 nullity: 0,
3720 local,
3721 prior_mean: Array1::zeros(block_dim),
3722 positive_eigenvalues: Vec::new(),
3723 op: None,
3724 }
3725 }
3726
3727 #[test]
3728 fn report_penalty_pair_redundancy_detects_identical_pair() {
3729 let s0 = ndarray::array![[2.0, 0.5, 0.0], [0.5, 1.0, 0.25], [0.0, 0.25, 1.5],];
3731 let s_shared = ndarray::array![[1.0, -0.5, 0.0], [-0.5, 2.0, -0.5], [0.0, -0.5, 1.0],];
3734
3735 let bundle = vec![
3736 canonical_from_local(s0, 0..3, 3),
3737 canonical_from_local(s_shared.clone(), 0..3, 3),
3738 canonical_from_local(s_shared, 0..3, 3),
3739 ];
3740
3741 let redundant = report_penalty_pair_redundancy(&bundle);
3742
3743 assert_eq!(
3746 redundant.len(),
3747 1,
3748 "expected exactly one redundant pair, got {:?}",
3749 redundant
3750 );
3751 let (i, j, cos) = redundant[0];
3752 assert_eq!((i, j), (1, 2));
3753 assert!(
3754 cos > 1.0 - 1e-12,
3755 "cosine for identical penalties should be ~1.0, got {cos}"
3756 );
3757 }
3758
3759 #[test]
3760 fn report_penalty_pair_redundancy_skips_different_col_ranges() {
3761 let s = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
3765 let bundle = vec![
3766 canonical_from_local(s.clone(), 0..2, 4),
3767 canonical_from_local(s, 2..4, 4),
3768 ];
3769 let redundant = report_penalty_pair_redundancy(&bundle);
3770 assert!(
3771 redundant.is_empty(),
3772 "different col_ranges must not be flagged"
3773 );
3774 }
3775}