1use crate::basis::analyze_penalty_block;
2use crate::EstimationError;
3use crate::smooth::PenaltyStructureHint;
4use faer::linalg::matmul::matmul;
5use faer::{Accum, Mat, MatRef, Par, Side};
6use gam_linalg::faer_ndarray::{FaerEigh, FaerLinalgError, FaerSvd};
7use gam_linalg::matrix::symmetrize_in_place;
8use gam_linalg::utils::KahanSum;
9use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut2, Axis, s};
10use rayon::iter::{
11 IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
12};
13use std::collections::{BTreeMap, HashSet};
14use std::ops::Range;
15use std::sync::Arc;
16
17#[derive(Clone)]
18pub enum PenaltyRepresentation {
19 Dense(Array2<f64>),
20 Banded {
21 bands: Vec<Array1<f64>>,
22 offsets: Vec<i32>,
23 },
24 Kronecker {
25 left: Array2<f64>,
31 right: Array2<f64>,
32 },
33}
34
35impl PenaltyRepresentation {
36 pub fn block_dimension(&self) -> usize {
38 match self {
39 PenaltyRepresentation::Dense(matrix) => matrix.nrows(),
40 PenaltyRepresentation::Banded { bands, offsets } => {
41 let mut dim = 0usize;
42 for (band, &offset) in bands.iter().zip(offsets.iter()) {
43 let len = band.len();
44 let extent = if offset >= 0 {
45 len + offset as usize
46 } else {
47 len + (-offset) as usize
48 };
49 dim = dim.max(extent);
50 }
51 dim
52 }
53 PenaltyRepresentation::Kronecker { left, right } => left.nrows() * right.nrows(),
54 }
55 }
56
57 pub fn to_block_dense(&self) -> Array2<f64> {
60 match self {
61 PenaltyRepresentation::Dense(matrix) => matrix.clone(),
62 PenaltyRepresentation::Banded { bands, offsets } => {
63 let dim = self.block_dimension();
64 let mut dense = Array2::zeros((dim, dim));
65 let positive_offsets: HashSet<usize> = offsets
66 .iter()
67 .filter_map(|&off| (off >= 0).then_some(off as usize))
68 .collect();
69 for (band, &offset) in bands.iter().zip(offsets.iter()) {
70 let off = offset.unsigned_abs() as usize;
71 if offset < 0 && positive_offsets.contains(&off) {
72 continue;
73 }
74 for (idx, &value) in band.iter().enumerate() {
75 let (i, j) = if offset >= 0 {
76 (idx, idx + off)
77 } else {
78 (idx + off, idx)
79 };
80 if i >= dim || j >= dim {
81 continue;
82 }
83 dense[[i, j]] = value;
84 dense[[j, i]] = value;
85 }
86 }
87 dense
88 }
89 PenaltyRepresentation::Kronecker { left, right } => {
90 let (lrows, l_cols) = left.dim();
91 let (rrows, r_cols) = right.dim();
92 let mut result = Array2::zeros((lrows * rrows, l_cols * r_cols));
93 for i in 0..lrows {
94 for j in 0..l_cols {
95 let scale = left[(i, j)];
96 if scale == 0.0 {
97 continue;
98 }
99 let mut block = result.slice_mut(s![
100 i * rrows..(i + 1) * rrows,
101 j * r_cols..(j + 1) * r_cols
102 ]);
103 block.assign(&(right * scale));
104 }
105 }
106 result
107 }
108 }
109 }
110}
111
112#[derive(Clone)]
113pub struct PenaltyMatrix {
114 pub col_range: Range<usize>,
115 pub representation: PenaltyRepresentation,
116}
117
118impl PenaltyMatrix {
119 fn accumulate_into(&self, mut dest: ArrayViewMut2<'_, f64>, weight: f64) {
120 if weight == 0.0 {
121 return;
122 }
123 match &self.representation {
124 PenaltyRepresentation::Dense(block) => {
125 dest.scaled_add(weight, block);
126 }
127 PenaltyRepresentation::Banded { bands, offsets } => {
128 let positive_offsets: HashSet<usize> = offsets
129 .iter()
130 .filter_map(|&off| (off >= 0).then_some(off as usize))
131 .collect();
132 for (band, &offset) in bands.iter().zip(offsets.iter()) {
133 let off = offset.unsigned_abs() as usize;
134 if offset < 0 && positive_offsets.contains(&off) {
135 continue;
136 }
137 for (idx, &value) in band.iter().enumerate() {
138 let (i, j) = if offset >= 0 {
139 (idx, idx + off)
140 } else {
141 (idx + off, idx)
142 };
143 let Some(entry_ij) = dest.get_mut((i, j)) else {
144 continue;
145 };
146 *entry_ij += weight * value;
147 if i != j
148 && let Some(entry_ji) = dest.get_mut((j, i))
149 {
150 *entry_ji += weight * value;
151 }
152 }
153 }
154 }
155 PenaltyRepresentation::Kronecker { left, right } => {
156 let (lrows, l_cols) = left.dim();
157 let (rrows, r_cols) = right.dim();
158 for i in 0..lrows {
159 for j in 0..l_cols {
160 let scale = left[(i, j)] * weight;
161 if scale == 0.0 {
162 continue;
163 }
164 let mut block = dest.slice_mut(s![
165 i * rrows..(i + 1) * rrows,
166 j * r_cols..(j + 1) * r_cols
167 ]);
168 block.scaled_add(scale, right);
169 }
170 }
171 }
172 }
173 }
174
175 pub fn to_dense(&self, total_dim: usize) -> Array2<f64> {
176 let mut dense = Array2::<f64>::zeros((total_dim, total_dim));
177 self.accumulate_into(
178 dense.slice_mut(s![self.col_range.clone(), self.col_range.clone()]),
179 1.0,
180 );
181 dense
182 }
183}
184
185pub(crate) fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
186 let (rows, cols) = array.dim();
187 Mat::from_fn(rows, cols, |i, j| array[[i, j]])
188}
189
190pub(crate) fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
191 let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
192 for i in 0..mat.nrows() {
193 for j in 0..mat.ncols() {
194 out[[i, j]] = mat[(i, j)];
195 }
196 }
197 out
198}
199
200fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
201 let (rows, cols) = matrix.shape();
202 let mut maxval = 0.0_f64;
203 for i in 0..rows {
204 for j in 0..cols {
205 let val = matrix[(i, j)];
206 if val.is_finite() {
207 maxval = maxval.max(val.abs());
208 }
209 }
210 }
211 maxval
212}
213
214fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
215 let (rows, cols) = matrix.as_ref().shape();
216 assert_eq!(rows, cols, "Matrix must be square for sanitization");
217
218 let mut sanitized = matrix.clone();
219
220 for i in 0..rows {
221 let diag = sanitized[(i, i)];
222 if !diag.is_finite() {
223 sanitized[(i, i)] = 0.0;
224 }
225 for j in (i + 1)..cols {
226 let mut upper = sanitized[(i, j)];
227 let mut lower = sanitized[(j, i)];
228 if !upper.is_finite() {
229 upper = 0.0;
230 }
231 if !lower.is_finite() {
232 lower = 0.0;
233 }
234 let avg = 0.5 * (upper + lower);
235 sanitized[(i, j)] = avg;
236 sanitized[(j, i)] = avg;
237 }
238 }
239
240 let scale = mat_max_abs_element(sanitized.as_ref());
241 let tiny = (scale * 1e-14).max(1e-30);
242 for i in 0..rows {
243 for j in 0..cols {
244 let val = sanitized[(i, j)];
245 if !val.is_finite() {
246 sanitized[(i, j)] = 0.0;
247 } else if val.abs() < tiny {
248 sanitized[(i, j)] = 0.0;
249 }
250 }
251 }
252
253 sanitized
254}
255
256fn penalty_from_root_faer(root: &Mat<f64>) -> Mat<f64> {
257 let cols = root.ncols();
258 let mut full = Mat::<f64>::zeros(cols, cols);
259 let root_ref = root.as_ref();
260 let root_t = root_ref.transpose();
261 matmul(
262 full.as_mut(),
263 Accum::Replace,
264 root_t,
265 root_ref,
266 1.0,
267 Par::Seq,
268 );
269 sanitize_symmetric_faer(&full)
270}
271
272fn symmetrize_faer_matrix_in_place(matrix: &mut Mat<f64>) {
273 let n = matrix.nrows().min(matrix.ncols());
274 for i in 0..n {
275 for j in 0..i {
276 let avg = 0.5 * (matrix[(i, j)] + matrix[(j, i)]);
277 matrix[(i, j)] = avg;
278 matrix[(j, i)] = avg;
279 }
280 }
281}
282
283fn orthogonal_similarity_transform_faer(
284 matrix: &Mat<f64>,
285 block_dim: usize,
286 orthogonal: &Mat<f64>,
287) -> Mat<f64> {
288 let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
289 let cols = orthogonal.ncols();
290 let mut temp = Mat::<f64>::zeros(block_dim, cols);
291 matmul(
292 temp.as_mut(),
293 Accum::Replace,
294 matrix_block,
295 orthogonal.as_ref(),
296 1.0,
297 Par::Seq,
298 );
299 let mut rotated = Mat::<f64>::zeros(cols, cols);
300 matmul(
301 rotated.as_mut(),
302 Accum::Replace,
303 orthogonal.transpose(),
304 temp.as_ref(),
305 1.0,
306 Par::Seq,
307 );
308 symmetrize_faer_matrix_in_place(&mut rotated);
309 rotated
310}
311
312fn trace_penalty_in_orthogonal_basis(
313 matrix: &Mat<f64>,
314 block_dim: usize,
315 orthogonal: &Mat<f64>,
316 rotated_eigenvalues: &[f64],
317 delta: f64,
318) -> f64 {
319 let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
320 let cols = orthogonal.ncols();
321 assert!(rotated_eigenvalues.len() >= cols);
322 let mut projected = Mat::<f64>::zeros(block_dim, cols);
323 matmul(
324 projected.as_mut(),
325 Accum::Replace,
326 matrix_block,
327 orthogonal.as_ref(),
328 1.0,
329 Par::Seq,
330 );
331 let mut trace = KahanSum::default();
332 for l in 0..cols {
333 let mut diag_ll = KahanSum::default();
334 for i in 0..block_dim {
335 diag_ll.add(orthogonal[(i, l)] * projected[(i, l)]);
336 }
337 trace.add(diag_ll.sum() / (rotated_eigenvalues[l] + delta));
338 }
339 trace.sum()
340}
341
342pub fn trace_reduced_penalty_covariance(
343 reduced_penalty: &Array2<f64>,
344 covariance_basis: &Array2<f64>,
345) -> f64 {
346 assert_eq!(
347 reduced_penalty.dim(),
348 covariance_basis.dim(),
349 "trace_reduced_penalty_covariance dimension mismatch"
350 );
351 let r = covariance_basis.nrows();
352 let mut trace = KahanSum::default();
353 for i in 0..r {
354 for j in 0..r {
355 trace.add(covariance_basis[[i, j]] * reduced_penalty[[j, i]]);
356 }
357 }
358 trace.sum()
359}
360
361pub fn trace_penalty_covariance_in_orthogonal_basis(
362 matrix: &Array2<f64>,
363 orthogonal: &Array2<f64>,
364 covariance_basis: &Array2<f64>,
365) -> f64 {
366 let reduced = gam_linalg::faer_ndarray::fast_ab(
367 &gam_linalg::faer_ndarray::fast_atb(orthogonal, matrix),
368 orthogonal,
369 );
370 trace_reduced_penalty_covariance(&reduced, covariance_basis)
371}
372
373fn classify_eigenvalues_strict(
393 eigenvalues: &mut [f64],
394 context: &str,
395) -> Result<(), EstimationError> {
396 const C_EPS_P_FACTOR: f64 = 64.0;
397 const REL_PSD_FLOOR: f64 = 1.0e-8;
401 let p = eigenvalues.len();
402
403 let mut scale = 0.0_f64;
404 for (idx, &val) in eigenvalues.iter().enumerate() {
405 if !val.is_finite() {
406 return Err(EstimationError::PenaltySpectrumNonFinite {
407 context: context.to_string(),
408 index: idx,
409 value: val,
410 });
411 }
412 scale = scale.max(val.abs());
413 }
414
415 let machine_floor = C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale;
422 let tolerance = machine_floor
423 .max(REL_PSD_FLOOR * scale)
424 .max(f64::MIN_POSITIVE);
425
426 for (idx, val) in eigenvalues.iter_mut().enumerate() {
427 if val.abs() <= tolerance {
428 *val = 0.0;
429 } else if *val < 0.0 {
430 return Err(EstimationError::PenaltySpectrumIndefinite {
431 context: context.to_string(),
432 index: idx,
433 value: *val,
434 tolerance,
435 scale,
436 });
437 }
438 }
439 Ok(())
440}
441
442fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
443 matrix: &M,
444 context: &str,
445 validate_input: Validate,
446 sanitize: Sanitize,
447 mut eig_call: EigCall,
448 map_error: MapErr,
449) -> Result<(Vec<f64>, V), EstimationError>
450where
451 Validate: Fn(&M, &str) -> Result<(), EstimationError>,
452 Sanitize: Fn(&M) -> M,
453 EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
454 MapErr: Fn(E, &str) -> EstimationError,
455{
456 validate_input(matrix, context)?;
457
458 let candidate = sanitize(matrix);
464 match eig_call(&candidate) {
465 Ok((mut eigenvalues, eigenvectors)) => {
466 classify_eigenvalues_strict(&mut eigenvalues, context)?;
467 Ok((eigenvalues, eigenvectors))
468 }
469 Err(err) => Err(map_error(err, context)),
470 }
471}
472
473pub(crate) fn robust_eigh_faer(
474 matrix: &Mat<f64>,
475 side: Side,
476 context: &str,
477) -> Result<(Vec<f64>, Mat<f64>), EstimationError> {
478 robust_eighwith_policy(
479 matrix,
480 context,
481 |mat, ctx| {
482 let (rows, cols) = mat.as_ref().shape();
483 for i in 0..rows {
484 for j in 0..cols {
485 let val = mat[(i, j)];
486 if !val.is_finite() {
487 let max_abs = mat_max_abs_element(mat.as_ref());
488 crate::bail_invalid_estim!(
489 "{} contains non-finite entries (max finite magnitude {:.3e})",
490 ctx,
491 max_abs
492 );
493 }
494 }
495 }
496 Ok(())
497 },
498 sanitize_symmetric_faer,
499 |candidate| {
500 let eig = candidate.as_ref().self_adjoint_eigen(side)?;
501 let diag = eig.S();
502 let mut eigenvalues = Vec::with_capacity(diag.dim());
503 for idx in 0..diag.dim() {
504 eigenvalues.push(diag[idx]);
505 }
506
507 let vectors_ref = eig.U();
508 let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
509 for i in 0..vectors_ref.nrows() {
510 for j in 0..vectors_ref.ncols() {
511 eigenvectors[(i, j)] = vectors_ref[(i, j)];
512 }
513 }
514 Ok((eigenvalues, eigenvectors))
515 },
516 |err, _ctx| {
517 EstimationError::EigendecompositionFailed(FaerLinalgError::SelfAdjointEigen(err))
518 },
519 )
520}
521
522fn robust_eigh(
523 matrix: &Array2<f64>,
524 side: Side,
525 context: &str,
526) -> Result<(Array1<f64>, Array2<f64>), EstimationError> {
527 let matrix_faer = array_to_faer(matrix);
528 let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
529 Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
530}
531
532pub(crate) fn kronecker_marginal_eigensystems(
533 marginal_penalties: &[Array2<f64>],
534 context: &str,
535) -> Result<Vec<(Array1<f64>, Array2<f64>)>, EstimationError> {
536 let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
537 for (k, penalty) in marginal_penalties.iter().enumerate() {
538 eigensystems.push(robust_eigh(
539 penalty,
540 Side::Lower,
541 &format!("{context} marginal {k}"),
542 )?);
543 }
544 Ok(eigensystems)
545}
546
547#[derive(Debug, Clone, Copy)]
548struct SubspaceLeakageMetrics {
549 max_abs_sq: f64,
550 max_rel_sq: f64,
551 worst_penalty: usize,
552 max_cross_gram_abs: f64,
553}
554
555fn assess_subspace_leakage(
556 qs: &Mat<f64>,
557 rs_transformed: &[Mat<f64>],
558 structural_rank: usize,
559 p: usize,
560) -> SubspaceLeakageMetrics {
561 let mut max_abs_sq = 0.0_f64;
562 let mut max_rel_sq = 0.0_f64;
563 let mut worst_penalty = 0usize;
564
565 for (k, rs) in rs_transformed.iter().enumerate() {
566 let rows = rs.nrows();
567 let cols = rs.ncols().min(p);
568 let null_start = structural_rank.min(cols);
569 let mut abs_sq = 0.0_f64;
570 let mut total_sq = 0.0_f64;
571 for i in 0..rows {
572 for j in 0..cols {
573 let v = rs[(i, j)];
574 let vv = v * v;
575 total_sq += vv;
576 if j >= null_start {
577 abs_sq += vv;
578 }
579 }
580 }
581 let rel_sq = if total_sq > 0.0 {
582 abs_sq / total_sq
583 } else {
584 0.0
585 };
586 if rel_sq > max_rel_sq {
587 max_rel_sq = rel_sq;
588 worst_penalty = k;
589 }
590 max_abs_sq = max_abs_sq.max(abs_sq);
591 }
592
593 let mut max_cross_gram_abs = 0.0_f64;
594 let null_count = p.saturating_sub(structural_rank);
595 if structural_rank > 0 && null_count > 0 {
596 for i in 0..structural_rank {
597 for j in 0..null_count {
598 let qn_col = structural_rank + j;
599 let mut dot = 0.0_f64;
600 for r in 0..p {
601 dot += qs[(r, i)] * qs[(r, qn_col)];
602 }
603 max_cross_gram_abs = max_cross_gram_abs.max(dot.abs());
604 }
605 }
606 }
607
608 SubspaceLeakageMetrics {
609 max_abs_sq,
610 max_rel_sq,
611 worst_penalty,
612 max_cross_gram_abs,
613 }
614}
615
616fn compose_qs_from_split(q_pen: &Mat<f64>, q_null: &Mat<f64>, p: usize) -> Mat<f64> {
617 let rank = q_pen.ncols();
618 let null_count = q_null.ncols();
619 let mut qs = Mat::<f64>::zeros(p, p);
620 for i in 0..p {
621 for j in 0..rank {
622 qs[(i, j)] = q_pen[(i, j)];
623 }
624 for j in 0..null_count {
625 qs[(i, rank + j)] = q_null[(i, j)];
626 }
627 }
628 qs
629}
630
631pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
635 let (arows, a_cols) = a.dim();
636 let (brows, b_cols) = b.dim();
637 if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
638 return Array2::zeros((arows * brows, a_cols * b_cols));
639 }
640 let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
641
642 result
643 .axis_chunks_iter_mut(Axis(0), brows)
644 .into_par_iter()
645 .enumerate()
646 .for_each(|(i, mut row_block)| {
647 let arow = a.row(i);
648 let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
649 for (j, mut block) in col_chunks.into_iter().enumerate() {
650 let aval = arow[j];
651 if aval == 0.0 {
652 continue;
653 }
654 for (dest, &src) in block.iter_mut().zip(b.iter()) {
655 *dest = aval * src;
656 }
657 }
658 });
659
660 result
661}
662
663#[derive(Clone)]
665pub struct ReparamResult {
666 pub s_transformed: Array2<f64>,
670 pub log_det: f64,
672 pub det1: Array1<f64>,
674 pub qs: Array2<f64>,
676 pub canonical_transformed: Vec<CanonicalPenalty>,
681 pub e_transformed: Array2<f64>,
684 pub u_truncated: Array2<f64>,
694 pub penalty_shrinkage_ridge: f64,
697}
698
699struct KroneckerFactorDecomp {
705 root: Array2<f64>, positive_eigenvalues: Vec<f64>, rank: usize,
708 dim: usize,
709}
710
711fn decompose_kronecker_factors(
714 factors: &[Array2<f64>],
715 context: &str,
716) -> Result<Option<Vec<KroneckerFactorDecomp>>, EstimationError> {
717 let mut decomps = Vec::with_capacity(factors.len());
718 for (j, factor) in factors.iter().enumerate() {
719 let q_j = factor.nrows();
720 if q_j != factor.ncols() {
721 crate::bail_invalid_estim!(
722 "{context}: Kronecker factor {j} must be square, got {}x{}",
723 factor.nrows(),
724 factor.ncols()
725 );
726 }
727 let is_identity = {
728 let mut is_id = true;
729 'outer: for r in 0..q_j {
730 for c in 0..q_j {
731 let expected = if r == c { 1.0 } else { 0.0 };
732 if (factor[[r, c]] - expected).abs() > 1e-12 {
733 is_id = false;
734 break 'outer;
735 }
736 }
737 }
738 is_id
739 };
740 if is_identity {
741 decomps.push(KroneckerFactorDecomp {
742 root: Array2::eye(q_j),
743 positive_eigenvalues: vec![1.0; q_j],
744 rank: q_j,
745 dim: q_j,
746 });
747 continue;
748 }
749 let analysis = analyze_penalty_block(factor).map_err(|err| {
750 EstimationError::InvalidInput(format!(
751 "{context}: Kronecker factor {j} eigendecomp failed: {err}"
752 ))
753 })?;
754 if analysis.rank == 0 {
755 return Ok(None);
756 }
757 let factor_classes =
761 crate::basis::SpectralClassification::new(&analysis.eigenvalues, analysis.tol);
762 let mut root_j = Array2::zeros((analysis.rank, q_j));
763 let mut pos_eigs = Vec::with_capacity(analysis.rank);
764 for (row_idx, &i) in factor_classes.range_idx.iter().enumerate() {
765 let eigenval = analysis.eigenvalues[i];
766 let sqrt_ev = eigenval.sqrt();
767 let evec = analysis.eigenvectors.column(i);
768 for (col, &v) in evec.iter().enumerate() {
769 root_j[[row_idx, col]] = sqrt_ev * v;
770 }
771 pos_eigs.push(eigenval);
772 }
773 decomps.push(KroneckerFactorDecomp {
774 root: root_j,
775 positive_eigenvalues: pos_eigs,
776 rank: analysis.rank,
777 dim: q_j,
778 });
779 }
780 Ok(Some(decomps))
781}
782
783fn assemble_kronecker_root_local(decomps: &[KroneckerFactorDecomp]) -> Array2<f64> {
785 let mut kron_root = decomps[0].root.clone();
786 for fr in &decomps[1..] {
787 let (r1, c1) = kron_root.dim();
788 let (r2, c2) = (fr.rank, fr.dim);
789 let mut new_root = Array2::zeros((r1 * r2, c1 * c2));
790 for i1 in 0..r1 {
791 for i2 in 0..r2 {
792 for j1 in 0..c1 {
793 for j2 in 0..c2 {
794 new_root[[i1 * r2 + i2, j1 * c2 + j2]] =
795 kron_root[[i1, j1]] * fr.root[[i2, j2]];
796 }
797 }
798 }
799 }
800 kron_root = new_root;
801 }
802 kron_root
803}
804
805fn kronecker_eigenvalues(decomps: &[KroneckerFactorDecomp], block_dim: usize) -> (Vec<f64>, usize) {
807 let mut kron_eigs = decomps[0].positive_eigenvalues.clone();
808 for fd in &decomps[1..] {
809 let mut new_eigs = Vec::with_capacity(kron_eigs.len() * fd.positive_eigenvalues.len());
810 for &a in &kron_eigs {
811 for &b in &fd.positive_eigenvalues {
812 new_eigs.push(a * b);
813 }
814 }
815 kron_eigs = new_eigs;
816 }
817 let max_ev = kron_eigs.iter().copied().fold(0.0_f64, f64::max);
818 let tol = max_ev * 1e-10 * (block_dim as f64);
819 let positive: Vec<f64> = kron_eigs.into_iter().filter(|&ev| ev > tol).collect();
820 let nullity = block_dim - positive.len();
821 (positive, nullity)
822}
823
824#[derive(Clone)]
834pub struct CanonicalPenalty {
835 pub root: Array2<f64>,
838 pub col_range: std::ops::Range<usize>,
841 pub total_dim: usize,
843 pub nullity: usize,
845 pub local: Array2<f64>,
849 pub prior_mean: Array1<f64>,
851 pub positive_eigenvalues: Vec<f64>,
854 pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
858}
859
860impl std::fmt::Debug for CanonicalPenalty {
861 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
862 f.debug_struct("CanonicalPenalty")
863 .field(
864 "root",
865 &format_args!("{}×{}", self.root.nrows(), self.root.ncols()),
866 )
867 .field("col_range", &self.col_range)
868 .field("total_dim", &self.total_dim)
869 .field("nullity", &self.nullity)
870 .field(
871 "local",
872 &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
873 )
874 .field("prior_mean_len", &self.prior_mean.len())
875 .field("positive_eigenvalues", &self.positive_eigenvalues)
876 .field("op", &self.op.as_ref().map(|o| o.dim()))
877 .finish()
878 }
879}
880
881impl CanonicalPenalty {
882 pub fn from_dense_root(root: Array2<f64>, p: usize) -> Self {
886 Self::from_dense_root_with_mean(root, p, Array1::zeros(p))
887 }
888
889 pub fn from_dense_root_with_mean(root: Array2<f64>, p: usize, prior_mean: Array1<f64>) -> Self {
890 assert_eq!(prior_mean.len(), p);
891 let local = root.t().dot(&root);
892 let positive_eigenvalues = Vec::new(); Self {
894 root,
895 col_range: 0..p,
896 total_dim: p,
897 nullity: 0,
898 local,
899 prior_mean,
900 positive_eigenvalues,
901 op: None,
902 }
903 }
904
905 pub fn full_width_root(&self) -> Array2<f64> {
908 if self.col_range.start == 0 && self.col_range.end == self.total_dim {
909 return self.root.clone();
910 }
911 let rank = self.root.nrows();
912 let mut full = Array2::<f64>::zeros((rank, self.total_dim));
913 full.slice_mut(ndarray::s![.., self.col_range.clone()])
914 .assign(&self.root);
915 full
916 }
917
918 pub fn rank(&self) -> usize {
920 self.root.nrows()
921 }
922
923 pub fn block_dim(&self) -> usize {
925 self.col_range.len()
926 }
927
928 pub const fn is_block_local(&self) -> bool {
930 self.col_range.start != 0 || self.col_range.end != self.total_dim
931 }
932
933 pub fn local_ref(&self) -> &Array2<f64> {
936 &self.local
937 }
938
939 pub fn local_penalty(&self) -> Array2<f64> {
942 self.local.clone()
943 }
944
945 pub fn accumulate_weighted(&self, target: &mut Array2<f64>, lambda: f64) {
948 if lambda == 0.0 || self.rank() == 0 {
949 return;
950 }
951 let r = &self.col_range;
952 target
953 .slice_mut(s![r.start..r.end, r.start..r.end])
954 .scaled_add(lambda, &self.local);
955 }
956
957 pub fn trace_product(&self, m: &Array2<f64>, scale: f64) -> f64 {
960 if self.rank() == 0 || scale == 0.0 {
961 return 0.0;
962 }
963 let r = &self.col_range;
964 let m_block = m.slice(s![r.start..r.end, r.start..r.end]);
965 let rm = self.root.dot(&m_block);
966 scale
967 * rm.iter()
968 .zip(self.root.iter())
969 .map(|(&a, &b)| a * b)
970 .sum::<f64>()
971 }
972
973 pub fn quadratic(&self, v: &Array1<f64>, scale: f64) -> f64 {
976 if self.rank() == 0 || scale == 0.0 {
977 return 0.0;
978 }
979 let v_block = v.slice(s![self.col_range.start..self.col_range.end]);
980 let rv = self.root.dot(&v_block);
981 scale * rv.dot(&rv)
982 }
983
984 pub fn prior_linear_shift(&self, scale: f64) -> Array1<f64> {
986 let mut out = Array1::<f64>::zeros(self.total_dim);
987 if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
988 return out;
989 }
990 let block = self.local.dot(&self.prior_mean) * scale;
991 out.slice_mut(s![self.col_range.start..self.col_range.end])
992 .assign(&block);
993 out
994 }
995
996 pub fn prior_constant_shift(&self, scale: f64) -> f64 {
998 if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
999 return 0.0;
1000 }
1001 scale * self.prior_mean.dot(&self.local.dot(&self.prior_mean))
1002 }
1003
1004 pub fn full_width_prior_mean(&self) -> Array1<f64> {
1006 if self.col_range.start == 0 && self.col_range.end == self.total_dim {
1007 return self.prior_mean.clone();
1008 }
1009 let mut out = Array1::<f64>::zeros(self.total_dim);
1010 out.slice_mut(s![self.col_range.start..self.col_range.end])
1011 .assign(&self.prior_mean);
1012 out
1013 }
1014
1015 pub fn to_penalty_coordinate(
1017 &self,
1018 ) -> gam_problem::PenaltyCoordinate {
1019 use gam_problem::PenaltyCoordinate;
1020 if self.is_block_local() {
1021 PenaltyCoordinate::from_block_root_with_mean(
1022 self.root.clone(),
1023 self.col_range.start,
1024 self.col_range.end,
1025 self.total_dim,
1026 self.prior_mean.clone(),
1027 )
1028 } else {
1029 PenaltyCoordinate::from_dense_root_with_mean(self.root.clone(), self.prior_mean.clone())
1030 }
1031 }
1032}
1033
1034pub fn report_penalty_pair_redundancy(canonical: &[CanonicalPenalty]) -> Vec<(usize, usize, f64)> {
1061 const REDUNDANCY_THRESHOLD: f64 = 1.0 - 1e-8;
1062 const SIMILARITY_THRESHOLD: f64 = 0.99;
1063 const LARGE_SCALE_K_THRESHOLD: usize = 64;
1064 const TOP_SIMILARITY_PAIRS: usize = 3;
1065
1066 let k = canonical.len();
1067 let mut redundant: Vec<(usize, usize, f64)> = Vec::new();
1068 let mut similar: Vec<(usize, usize, f64)> = Vec::new();
1069
1070 let trace_sq: Vec<f64> = canonical
1073 .iter()
1074 .map(|p| p.local.iter().map(|&v| v * v).sum::<f64>())
1075 .collect();
1076
1077 for i in 0..k {
1078 if trace_sq[i] == 0.0 {
1079 continue;
1080 }
1081 for j in (i + 1)..k {
1082 if trace_sq[j] == 0.0 {
1083 continue;
1084 }
1085 if canonical[i].col_range != canonical[j].col_range {
1089 continue;
1090 }
1091 assert_eq!(canonical[i].local.dim(), canonical[j].local.dim());
1094
1095 let inner: f64 = canonical[i]
1096 .local
1097 .iter()
1098 .zip(canonical[j].local.iter())
1099 .map(|(&a, &b)| a * b)
1100 .sum();
1101 let denom = (trace_sq[i] * trace_sq[j]).sqrt();
1102 if denom == 0.0 {
1103 continue;
1104 }
1105 let cos = inner / denom;
1106
1107 if cos > REDUNDANCY_THRESHOLD {
1108 redundant.push((i, j, cos));
1109 } else if cos > SIMILARITY_THRESHOLD {
1110 similar.push((i, j, cos));
1111 }
1112 }
1113 }
1114
1115 for &(i, j, cos) in &redundant {
1117 log::warn!(
1118 "[PENALTY-REDUNDANCY] penalties i={i} j={j} are structurally identical \
1119 (cos={cos:.6}) — model is over-parameterized along their antisymmetric \
1120 direction; expect a Z₂-symmetric saddle in the LAML cost. Consider \
1121 re-specifying (e.g. anisotropic→isotropic for spatial smoothers with \
1122 weak axis signal)."
1123 );
1124 }
1125
1126 if k > LARGE_SCALE_K_THRESHOLD && similar.len() > TOP_SIMILARITY_PAIRS {
1128 similar.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1129 similar.truncate(TOP_SIMILARITY_PAIRS);
1130 }
1131 for (i, j, cos) in similar {
1132 log::info!(
1133 "[PENALTY-SIMILARITY] penalties i={i} j={j} are near-identical \
1134 (cos={cos:.6}) — outer Hessian may be ill-conditioned along their \
1135 antisymmetric direction."
1136 );
1137 }
1138
1139 redundant
1140}
1141
1142pub fn canonicalize_penalty_spec(
1148 spec: &crate::PenaltySpec,
1149 p: usize,
1150 idx: usize,
1151 context: &str,
1152) -> Result<Option<CanonicalPenalty>, EstimationError> {
1153 use crate::PenaltySpec;
1154
1155 crate::validate_penalty_spec_shape(idx, spec, p, context)?;
1156
1157 let (local_matrix, col_range, prior_mean_spec, hint, op) = match spec {
1158 PenaltySpec::Block {
1159 local,
1160 col_range,
1161 prior_mean,
1162 structure_hint,
1163 op,
1164 } => (
1165 local.view(),
1166 col_range.clone(),
1167 prior_mean,
1168 structure_hint.as_ref(),
1169 op.clone(),
1170 ),
1171 PenaltySpec::Dense(m) => (
1172 m.view(),
1173 0..p,
1174 &gam_problem::CoefficientPriorMean::Zero,
1175 None,
1176 None,
1177 ),
1178 PenaltySpec::DenseWithMean { matrix, prior_mean } => {
1179 (matrix.view(), 0..p, prior_mean, None, None)
1180 }
1181 };
1182
1183 let block_dim = col_range.len();
1184 let prior_mean = prior_mean_spec
1185 .evaluate(block_dim, &format!("{context}: penalty {idx}"))
1186 .map_err(|e| EstimationError::InvalidInput(e.0))?;
1187
1188 if let Some(PenaltyStructureHint::Ridge(scale)) = hint {
1190 if *scale <= 0.0 {
1191 return Ok(None);
1192 }
1193 let sqrt_scale = scale.sqrt();
1194 let mut root = Array2::zeros((block_dim, block_dim));
1195 for i in 0..block_dim {
1196 root[[i, i]] = sqrt_scale;
1197 }
1198 let mut local_sym = local_matrix.to_owned();
1202 symmetrize_in_place(&mut local_sym);
1203 return Ok(Some(CanonicalPenalty {
1204 root,
1205 col_range,
1206 total_dim: p,
1207 nullity: 0,
1208 local: local_sym,
1209 prior_mean,
1210 positive_eigenvalues: vec![*scale; block_dim],
1211 op,
1212 }));
1213 }
1214
1215 if let Some(PenaltyStructureHint::Kronecker(factors)) = hint {
1217 let decomps =
1218 match decompose_kronecker_factors(factors, &format!("{context} penalty {idx}"))? {
1219 None => return Ok(None),
1220 Some(d) => d,
1221 };
1222 let (positive_eigenvalues, nullity) = kronecker_eigenvalues(&decomps, block_dim);
1223 if positive_eigenvalues.is_empty() {
1224 return Ok(None);
1225 }
1226 let root = assemble_kronecker_root_local(&decomps);
1227 let mut local_sym = local_matrix.to_owned();
1228 symmetrize_in_place(&mut local_sym);
1229 return Ok(Some(CanonicalPenalty {
1230 root,
1231 col_range,
1232 total_dim: p,
1233 nullity,
1234 local: local_sym,
1235 prior_mean,
1236 positive_eigenvalues,
1237 op,
1238 }));
1239 }
1240
1241 let local_owned = local_matrix.to_owned();
1243 let analysis = analyze_penalty_block(&local_owned).map_err(|err| {
1244 EstimationError::InvalidInput(format!(
1245 "{context}: penalty canonicalization failed at index {idx}: {err}"
1246 ))
1247 })?;
1248
1249 if analysis.rank == 0 {
1250 log::debug!(
1251 "Dropped inactive penalty block idx={idx} reason={}",
1252 if analysis.iszero {
1253 "ZeroMatrix"
1254 } else {
1255 "NumericalRankZero"
1256 }
1257 );
1258 return Ok(None);
1259 }
1260
1261 let tolerance = analysis.tol;
1267 let classes = crate::basis::SpectralClassification::new(&analysis.eigenvalues, tolerance);
1268 let rank_k = classes.rank();
1269 assert_eq!(
1270 rank_k, analysis.rank,
1271 "penalty-root rank disagreement: SpectralClassification rank={rank_k} vs analyze_penalty_block rank={} (#1425 canonical-classifier invariant)",
1272 analysis.rank
1273 );
1274
1275 let mut root = Array2::zeros((rank_k, block_dim));
1283 let mut positive_eigenvalues = Vec::with_capacity(rank_k);
1284 for (row_idx, &i) in classes.range_idx.iter().enumerate() {
1285 let eigenval = analysis.eigenvalues[i];
1286 let eigenvec = analysis.eigenvectors.column(i);
1287 root.row_mut(row_idx).assign(&(&eigenvec * eigenval.sqrt()));
1288 positive_eigenvalues.push(eigenval);
1289 }
1290
1291 if classes.is_indefinite() {
1297 log::debug!(
1298 "{context}: penalty block idx={idx} carries {} negative-curvature \
1299 eigendirection(s) below -tol={tolerance:e}; dropped from the canonical \
1300 root and NOT counted as null space (rank={rank_k}, nullity={})",
1301 classes.negative_dim(),
1302 classes.nullity()
1303 );
1304 }
1305
1306 let local = root.t().dot(&root);
1310 Ok(Some(CanonicalPenalty {
1311 root,
1312 col_range,
1313 total_dim: p,
1314 nullity: classes.nullity(),
1315 local,
1316 prior_mean,
1317 positive_eigenvalues,
1318 op,
1319 }))
1320}
1321
1322pub fn canonicalize_penalty_specs(
1325 specs: &[crate::PenaltySpec],
1326 nullspace_dims: &[usize],
1327 p: usize,
1328 context: &str,
1329) -> Result<(Vec<CanonicalPenalty>, Vec<usize>), EstimationError> {
1330 if specs.len() != nullspace_dims.len() {
1331 crate::bail_invalid_estim!(
1332 "{context}: nullspace_dims length mismatch: penalties={}, nullspace_dims={}",
1333 specs.len(),
1334 nullspace_dims.len()
1335 );
1336 }
1337
1338 let mut active = Vec::with_capacity(specs.len());
1339 let mut active_nullspace = Vec::with_capacity(specs.len());
1340 for (idx, spec) in specs.iter().enumerate() {
1341 if let Some(canonical) = canonicalize_penalty_spec(spec, p, idx, context)? {
1342 active_nullspace.push(nullspace_dims[idx]);
1343 active.push(canonical);
1344 }
1345 }
1346 Ok((active, active_nullspace))
1347}
1348
1349pub(crate) const OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P: usize = 4096;
1359
1360pub fn create_balanced_penalty_root_from_canonical(
1367 penalties: &[CanonicalPenalty],
1368 p: usize,
1369) -> Result<Array2<f64>, EstimationError> {
1370 if penalties.is_empty() {
1371 return Ok(Array2::zeros((0, p)));
1372 }
1373
1374 let mut block_groups: BTreeMap<(usize, usize), Vec<&CanonicalPenalty>> = BTreeMap::new();
1376 for cp in penalties {
1377 if cp.rank() == 0 {
1378 continue;
1379 }
1380 let key = (cp.col_range.start, cp.col_range.end);
1381 block_groups.entry(key).or_default().push(cp);
1382 }
1383
1384 if block_groups.is_empty() {
1385 return Ok(Array2::zeros((0, p)));
1386 }
1387
1388 let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1390 let mut overlapping = false;
1391 for i in 1..ranges.len() {
1392 if ranges[i].0 < ranges[i - 1].1 {
1393 overlapping = true;
1394 break;
1395 }
1396 }
1397
1398 if overlapping {
1399 if p > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1400 return Err(EstimationError::LayoutError(format!(
1401 "overlapping penalty root would require dense {}x{} eigendecomposition; \
1402 large-model dense fallback is disabled. Keep penalties structured or \
1403 extend the overlapping-penalty solver path",
1404 p, p
1405 )));
1406 }
1407 let mut s_balanced = Array2::zeros((p, p));
1409 for cp in penalties {
1410 if cp.rank() == 0 {
1411 continue;
1412 }
1413 let local = cp.local_ref();
1414 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1415 if frob_norm > 1e-12 {
1416 let r = &cp.col_range;
1417 s_balanced
1418 .slice_mut(s![r.start..r.end, r.start..r.end])
1419 .scaled_add(1.0 / frob_norm, local);
1420 }
1421 }
1422 let (eigenvalues, eigenvectors) =
1423 robust_eigh(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1424 let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1425 let tolerance = if max_eig > 0.0 {
1426 max_eig * 1e-12
1427 } else {
1428 1e-12
1429 };
1430 let penalty_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1431 if penalty_rank == 0 {
1432 return Ok(Array2::zeros((0, p)));
1433 }
1434 let mut eb = Array2::zeros((p, penalty_rank));
1435 let mut col_idx = 0;
1436 for (i, &eigenval) in eigenvalues.iter().enumerate() {
1437 if eigenval > tolerance {
1438 let sqrt_ev = eigenval.sqrt();
1439 let evec = eigenvectors.column(i);
1440 eb.column_mut(col_idx).assign(&(&evec * sqrt_ev));
1441 col_idx += 1;
1442 }
1443 }
1444 return Ok(eb.t().to_owned());
1445 }
1446
1447 struct BlockRoot {
1449 col_range: Range<usize>,
1450 root: Array2<f64>, }
1452 let ordered_blocks: Vec<((usize, usize), Vec<&CanonicalPenalty>)> =
1457 block_groups.into_iter().collect();
1458 let block_roots: Vec<BlockRoot> = ordered_blocks
1459 .into_par_iter()
1460 .map(
1461 |((start, end), cps)| -> Result<Option<BlockRoot>, EstimationError> {
1462 let block_dim = end - start;
1463 let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1464
1465 for cp in cps {
1466 let local = cp.local_ref();
1467 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1468 if frob_norm > 1e-12 {
1469 s_balanced_local.scaled_add(1.0 / frob_norm, local);
1470 }
1471 }
1472
1473 let (eigenvalues, eigenvectors) =
1474 robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1475 let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1476 let tolerance = if max_eig > 0.0 {
1477 max_eig * 1e-12
1478 } else {
1479 1e-12
1480 };
1481 let block_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1482
1483 if block_rank == 0 {
1484 return Ok(None);
1485 }
1486
1487 let mut root = Array2::zeros((block_rank, block_dim));
1488 let mut row_idx = 0;
1489 for (i, &eigenval) in eigenvalues.iter().enumerate() {
1490 if eigenval > tolerance {
1491 let sqrt_ev = eigenval.sqrt();
1492 let evec = eigenvectors.column(i);
1493 root.row_mut(row_idx).assign(&(&evec * sqrt_ev));
1494 row_idx += 1;
1495 }
1496 }
1497
1498 Ok(Some(BlockRoot {
1499 col_range: start..end,
1500 root,
1501 }))
1502 },
1503 )
1504 .collect::<Result<Vec<_>, _>>()?
1505 .into_iter()
1506 .flatten()
1507 .collect();
1508 let total_rank: usize = block_roots.iter().map(|br| br.root.nrows()).sum();
1509
1510 if total_rank == 0 {
1511 return Ok(Array2::zeros((0, p)));
1512 }
1513
1514 let mut eb = Array2::zeros((total_rank, p));
1516 let mut row_offset = 0;
1517 for br in &block_roots {
1518 let rank_b = br.root.nrows();
1519 eb.slice_mut(s![
1520 row_offset..(row_offset + rank_b),
1521 br.col_range.start..br.col_range.end
1522 ])
1523 .assign(&br.root);
1524 row_offset += rank_b;
1525 }
1526
1527 Ok(eb)
1528}
1529
1530#[derive(Clone)]
1532struct SubspaceSplit {
1533 q_pen: Array2<f64>,
1534 q_null: Array2<f64>,
1535}
1536
1537impl SubspaceSplit {
1538 fn identity(p: usize) -> Self {
1539 Self {
1540 q_pen: Array2::zeros((p, 0)),
1541 q_null: Array2::eye(p),
1542 }
1543 }
1544
1545 fn from_ordered_qs(
1546 qs: &Mat<f64>,
1547 penalized_rank: usize,
1548 p: usize,
1549 ) -> Result<Self, EstimationError> {
1550 if qs.nrows() != p || qs.ncols() != p {
1551 return Err(EstimationError::LayoutError(format!(
1552 "Invalid Q basis dimensions: expected {p}x{p}, got {}x{}",
1553 qs.nrows(),
1554 qs.ncols()
1555 )));
1556 }
1557 if penalized_rank > p {
1558 return Err(EstimationError::LayoutError(format!(
1559 "Invalid penalized rank {penalized_rank} for p={p}"
1560 )));
1561 }
1562
1563 let null_count = p - penalized_rank;
1564 let mut q_pen = Array2::<f64>::zeros((p, penalized_rank));
1565 let mut q_null = Array2::<f64>::zeros((p, null_count));
1566 for i in 0..p {
1567 for j in 0..penalized_rank {
1568 q_pen[(i, j)] = qs[(i, j)];
1569 }
1570 for j in 0..null_count {
1571 q_null[(i, j)] = qs[(i, penalized_rank + j)];
1572 }
1573 }
1574
1575 Ok(Self { q_pen, q_null })
1576 }
1577
1578 fn rank(&self) -> usize {
1579 self.q_pen.ncols()
1580 }
1581
1582 fn p(&self) -> usize {
1583 self.q_pen.nrows()
1584 }
1585
1586 fn compose_qs(&self) -> Array2<f64> {
1587 let p = self.p();
1588 let rank = self.rank();
1589 let null_count = self.q_null.ncols();
1590 let mut qs = Array2::<f64>::zeros((p, p));
1591 for i in 0..p {
1592 for j in 0..rank {
1593 qs[(i, j)] = self.q_pen[(i, j)];
1594 }
1595 for j in 0..null_count {
1596 qs[(i, rank + j)] = self.q_null[(i, j)];
1597 }
1598 }
1599 qs
1600 }
1601}
1602
1603#[derive(Clone)]
1605pub struct ReparamInvariant {
1606 split: SubspaceSplit,
1607 qs_base: Array2<f64>,
1611 has_nonzero: bool,
1612 max_balanced_eigenvalue: f64,
1615}
1616
1617impl ReparamInvariant {
1618 pub const fn max_balanced_eigenvalue(&self) -> f64 {
1621 self.max_balanced_eigenvalue
1622 }
1623}
1624
1625pub fn precompute_reparam_invariant_from_canonical(
1632 penalties: &[CanonicalPenalty],
1633 p_total: usize,
1634) -> Result<ReparamInvariant, EstimationError> {
1635 use std::cmp::Ordering;
1636
1637 let m = penalties.len();
1638
1639 if m == 0 {
1640 return Ok(ReparamInvariant {
1641 split: SubspaceSplit::identity(p_total),
1642 qs_base: Array2::eye(p_total),
1643 has_nonzero: false,
1644 max_balanced_eigenvalue: 0.0,
1645 });
1646 }
1647
1648 struct PenRef {
1650 penalty_index: usize,
1651 }
1652 let mut block_groups: BTreeMap<(usize, usize), Vec<PenRef>> = BTreeMap::new();
1653 let mut has_nonzero = false;
1654 for (i, cp) in penalties.iter().enumerate() {
1655 if cp.rank() == 0 {
1656 continue;
1657 }
1658 let local = cp.local_ref();
1659 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1660 if frob_norm > 1e-12 {
1661 has_nonzero = true;
1662 }
1663 let key = (cp.col_range.start, cp.col_range.end);
1664 block_groups
1665 .entry(key)
1666 .or_default()
1667 .push(PenRef { penalty_index: i });
1668 }
1669
1670 if !has_nonzero {
1671 return Ok(ReparamInvariant {
1672 split: SubspaceSplit::identity(p_total),
1673 qs_base: Array2::eye(p_total),
1674 has_nonzero: false,
1675 max_balanced_eigenvalue: 0.0,
1676 });
1677 }
1678
1679 let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1681 let mut overlapping = false;
1682 for i in 1..ranges.len() {
1683 if ranges[i].0 < ranges[i - 1].1 {
1684 overlapping = true;
1685 break;
1686 }
1687 }
1688
1689 if overlapping {
1690 if p_total > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1696 return Err(EstimationError::LayoutError(format!(
1697 "overlapping penalty reparameterization would require dense {}x{} eigendecomposition; \
1698 large-model dense fallback is disabled. Keep penalties structured or \
1699 extend the overlapping-penalty solver path",
1700 p_total, p_total
1701 )));
1702 }
1703 let mut s_balanced = Mat::<f64>::zeros(p_total, p_total);
1705 for cp in penalties {
1706 if cp.rank() == 0 {
1707 continue;
1708 }
1709 let local = cp.local_ref();
1710 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1711 if frob_norm > 1e-12 {
1712 let scale = 1.0 / frob_norm;
1713 let r = &cp.col_range;
1714 for i in 0..local.nrows() {
1715 for j in 0..local.ncols() {
1716 s_balanced[(r.start + i, r.start + j)] += scale * local[[i, j]];
1717 }
1718 }
1719 }
1720 }
1721
1722 let (bal_eigenvalues, bal_eigenvectors) =
1723 robust_eigh_faer(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1724
1725 let mut order: Vec<usize> = (0..p_total).collect();
1726 order.sort_by(|&i, &j| {
1727 bal_eigenvalues[j]
1728 .partial_cmp(&bal_eigenvalues[i])
1729 .unwrap_or(Ordering::Equal)
1730 .then(i.cmp(&j))
1731 });
1732
1733 let mut qs = Mat::<f64>::zeros(p_total, p_total);
1734 for (col_idx, &idx) in order.iter().enumerate() {
1735 for row in 0..p_total {
1736 qs[(row, col_idx)] = bal_eigenvectors[(row, idx)];
1737 }
1738 }
1739
1740 let max_bal = order
1741 .iter()
1742 .map(|&idx| bal_eigenvalues[idx].abs())
1743 .fold(0.0_f64, f64::max);
1744 let rank_tol = if max_bal > 0.0 {
1745 max_bal * 1e-12
1746 } else {
1747 1e-12
1748 };
1749 let penalized_rank = order
1750 .iter()
1751 .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1752 .count();
1753 let split = SubspaceSplit::from_ordered_qs(&qs, penalized_rank, p_total)?;
1754
1755 return Ok(ReparamInvariant {
1756 split,
1757 qs_base: mat_to_array(&qs),
1758 has_nonzero,
1759 max_balanced_eigenvalue: max_bal,
1760 });
1761 }
1762
1763 let mut covered = vec![false; p_total];
1771 for cp in penalties {
1772 for j in cp.col_range.clone() {
1773 covered[j] = true;
1774 }
1775 }
1776 let uncovered_cols: Vec<usize> = (0..p_total).filter(|j| !covered[*j]).collect();
1777
1778 struct BlockResult {
1779 col_range: Range<usize>,
1780 q_pen_local: Array2<f64>, q_null_local: Array2<f64>, max_balanced_eigenvalue: f64,
1784 pen_col_offset: usize,
1786 null_col_offset: usize,
1788 }
1789
1790 let block_specs: Vec<_> = block_groups.iter().collect();
1794 let mut block_results: Vec<BlockResult> = block_specs
1795 .into_par_iter()
1796 .map(
1797 |(&(start, end), refs)| -> Result<BlockResult, EstimationError> {
1798 let block_dim = end - start;
1799
1800 let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1802 let mut block_has_nonzero = false;
1803 for pref in refs {
1804 let cp = &penalties[pref.penalty_index];
1805 let local = cp.local_ref();
1806 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1807 if frob_norm > 1e-12 {
1808 s_balanced_local.scaled_add(1.0 / frob_norm, local);
1809 block_has_nonzero = true;
1810 }
1811 }
1812
1813 if !block_has_nonzero {
1814 return Ok(BlockResult {
1815 col_range: start..end,
1816 q_pen_local: Array2::zeros((block_dim, 0)),
1817 q_null_local: Array2::eye(block_dim),
1818 max_balanced_eigenvalue: 0.0,
1819 pen_col_offset: 0, null_col_offset: 0, });
1822 }
1823
1824 let (bal_eigenvalues, bal_eigenvectors) =
1826 robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1827
1828 let mut order: Vec<usize> = (0..block_dim).collect();
1829 order.sort_by(|&i, &j| {
1830 bal_eigenvalues[j]
1831 .partial_cmp(&bal_eigenvalues[i])
1832 .unwrap_or(Ordering::Equal)
1833 .then(i.cmp(&j))
1834 });
1835
1836 let max_bal = order
1837 .iter()
1838 .map(|&idx| bal_eigenvalues[idx].abs())
1839 .fold(0.0_f64, f64::max);
1840 let rank_tol = if max_bal > 0.0 {
1841 max_bal * 1e-12
1842 } else {
1843 1e-12
1844 };
1845 let penalized_rank = order
1846 .iter()
1847 .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1848 .count();
1849 let null_count = block_dim - penalized_rank;
1850
1851 let mut q_pen_local = Array2::zeros((block_dim, penalized_rank));
1852 let mut q_null_local = Array2::zeros((block_dim, null_count));
1853 for (col_idx, &idx) in order.iter().enumerate() {
1854 if col_idx < penalized_rank {
1855 for row in 0..block_dim {
1856 q_pen_local[[row, col_idx]] = bal_eigenvectors[[row, idx]];
1857 }
1858 } else {
1859 let null_col = col_idx - penalized_rank;
1860 for row in 0..block_dim {
1861 q_null_local[[row, null_col]] = bal_eigenvectors[[row, idx]];
1862 }
1863 }
1864 }
1865
1866 Ok(BlockResult {
1867 col_range: start..end,
1868 q_pen_local,
1869 q_null_local,
1870 max_balanced_eigenvalue: max_bal,
1871 pen_col_offset: 0, null_col_offset: 0, })
1874 },
1875 )
1876 .collect::<Result<_, _>>()?;
1877 let global_max_bal = block_results
1878 .iter()
1879 .map(|br| br.max_balanced_eigenvalue)
1880 .fold(0.0_f64, f64::max);
1881
1882 let total_pen_rank: usize = block_results.iter().map(|br| br.q_pen_local.ncols()).sum();
1884 let total_null: usize = block_results
1885 .iter()
1886 .map(|br| br.q_null_local.ncols())
1887 .sum::<usize>()
1888 + uncovered_cols.len();
1889 {
1890 let mut pen_off = 0usize;
1891 let mut null_off = 0usize;
1892 for br in &mut block_results {
1893 br.pen_col_offset = pen_off;
1894 br.null_col_offset = null_off;
1895 pen_off += br.q_pen_local.ncols();
1896 null_off += br.q_null_local.ncols();
1897 }
1898 }
1899
1900 let mut q_pen = Array2::zeros((p_total, total_pen_rank));
1901 let mut q_null = Array2::zeros((p_total, total_null));
1902
1903 for br in &block_results {
1904 let start = br.col_range.start;
1905 let bd = br.q_pen_local.nrows();
1906 let pen_r = br.q_pen_local.ncols();
1907 let null_r = br.q_null_local.ncols();
1908 if pen_r > 0 {
1909 q_pen
1910 .slice_mut(s![
1911 start..(start + bd),
1912 br.pen_col_offset..(br.pen_col_offset + pen_r)
1913 ])
1914 .assign(&br.q_pen_local);
1915 }
1916 if null_r > 0 {
1917 q_null
1918 .slice_mut(s![
1919 start..(start + bd),
1920 br.null_col_offset..(br.null_col_offset + null_r)
1921 ])
1922 .assign(&br.q_null_local);
1923 }
1924 }
1925 let mut null_col = block_results
1926 .iter()
1927 .map(|br| br.q_null_local.ncols())
1928 .sum::<usize>();
1929 for &j in &uncovered_cols {
1930 q_null[[j, null_col]] = 1.0;
1931 null_col += 1;
1932 }
1933
1934 let split = SubspaceSplit { q_pen, q_null };
1935
1936 let qs_global = split.compose_qs();
1940
1941 Ok(ReparamInvariant {
1942 split,
1943 qs_base: qs_global,
1944 has_nonzero,
1945 max_balanced_eigenvalue: global_max_bal,
1946 })
1947}
1948
1949fn structurally_penalized_columns(penalties: &[CanonicalPenalty], p: usize) -> Vec<bool> {
1950 let mut active = vec![false; p];
1951 for cp in penalties {
1952 let local = cp.local_ref();
1953 let scale = local.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
1954 if scale <= 0.0 {
1955 continue;
1956 }
1957 let tol = scale * 1e-12;
1958 for local_col in 0..cp.block_dim() {
1959 let mut column_active = false;
1960 for row in 0..cp.block_dim() {
1961 if local[[row, local_col]].abs() > tol || local[[local_col, row]].abs() > tol {
1962 column_active = true;
1963 break;
1964 }
1965 }
1966 if column_active {
1967 active[cp.col_range.start + local_col] = true;
1968 }
1969 }
1970 }
1971 active
1972}
1973
1974pub fn stable_reparameterizationwith_invariant(
1984 penalties: &[CanonicalPenalty],
1985 lambdas: &[f64],
1986 p: usize,
1987 invariant: &ReparamInvariant,
1988 penalty_shrinkage_floor: Option<f64>,
1989) -> Result<ReparamResult, EstimationError> {
1990 let m = penalties.len();
1991
1992 if lambdas.len() != m {
1993 return Err(EstimationError::ParameterConstraintViolation(format!(
1994 "Lambda count mismatch: expected {} lambdas for {} penalties, got {}",
1995 m,
1996 m,
1997 lambdas.len()
1998 )));
1999 }
2000
2001 if m == 0 {
2013 return Ok(ReparamResult {
2014 s_transformed: Array2::zeros((p, p)),
2015 log_det: 0.0,
2016 det1: Array1::zeros(0),
2017 qs: Array2::eye(p),
2018 canonical_transformed: vec![],
2019 e_transformed: Array2::zeros((0, p)),
2020 u_truncated: Array2::eye(p),
2022 penalty_shrinkage_ridge: 0.0,
2023 });
2024 }
2025
2026 if !invariant.has_nonzero {
2027 let qs = invariant.split.compose_qs();
2028 let u_truncated = qs.t().dot(&invariant.split.q_null);
2029 let canonical_transformed: Vec<CanonicalPenalty> = penalties.to_vec();
2031 return Ok(ReparamResult {
2032 s_transformed: Array2::zeros((p, p)),
2033 log_det: 0.0,
2034 det1: Array1::zeros(m),
2035 qs,
2036 canonical_transformed,
2037 e_transformed: Array2::zeros((0, p)),
2038 u_truncated,
2039 penalty_shrinkage_ridge: 0.0,
2040 });
2041 }
2042
2043 let q_pen = array_to_faer(&invariant.split.q_pen);
2044 let q_null = array_to_faer(&invariant.split.q_null);
2045 let qs_base = array_to_faer(&invariant.qs_base);
2046 let penalty_transforms: Vec<(Mat<f64>, Mat<f64>)> = penalties
2051 .par_iter()
2052 .map(|cp| {
2053 let r = &cp.col_range;
2054 let root_faer = array_to_faer(&cp.root);
2055 let q_block = qs_base.submatrix(r.start, 0, cp.block_dim(), p);
2056 let mut product = Mat::<f64>::zeros(cp.rank(), p);
2057 matmul(
2058 product.as_mut(),
2059 Accum::Replace,
2060 root_faer.as_ref(),
2061 q_block,
2062 1.0,
2063 Par::Seq,
2064 );
2065 let s_k = penalty_from_root_faer(&product);
2066 (product, s_k)
2067 })
2068 .collect();
2069 let (rs_transformed, s_k_penalized_cache): (Vec<Mat<f64>>, Vec<Mat<f64>>) =
2070 penalty_transforms.into_iter().unzip();
2071
2072 let penalized_rank = invariant.split.rank();
2073
2074 let mut range_eigenvalues_sorted: Vec<f64> = Vec::new();
2075 let mut range_rotation = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2076 if penalized_rank > 0 {
2077 let mut range_block = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2078 for (lambda, s_k) in lambdas.iter().zip(s_k_penalized_cache.iter()) {
2082 for i in 0..penalized_rank {
2083 for j in 0..penalized_rank {
2084 range_block[(i, j)] += *lambda * s_k[(i, j)];
2085 }
2086 }
2087 }
2088 let (range_eigenvalues, range_eigenvectors) =
2089 robust_eigh_faer(&range_block, Side::Lower, "range penalty block")?;
2090
2091 let mut range_order: Vec<usize> = (0..penalized_rank).collect();
2092 range_order.sort_by(|&i, &j| {
2093 range_eigenvalues[j]
2094 .partial_cmp(&range_eigenvalues[i])
2095 .unwrap_or(std::cmp::Ordering::Equal)
2096 .then(i.cmp(&j))
2097 });
2098 range_eigenvalues_sorted = range_order
2099 .iter()
2100 .map(|&idx| range_eigenvalues[idx])
2101 .collect();
2102
2103 for (col_idx, &idx) in range_order.iter().enumerate() {
2110 for row in 0..penalized_rank {
2111 range_rotation[(row, col_idx)] = range_eigenvectors[(row, idx)];
2112 }
2113 }
2114 }
2118
2119 let structural_rank = penalized_rank;
2124 let mut range_eigs_sorted: Vec<f64> = range_eigenvalues_sorted;
2125 let structurally_penalized_cols = structurally_penalized_columns(penalties, p);
2126
2127 let shrinkage_ridge = penalty_shrinkage_floor
2144 .filter(|&eps| eps > 0.0)
2145 .map(|eps| eps * invariant.max_balanced_eigenvalue)
2146 .unwrap_or(0.0);
2147 if shrinkage_ridge > 0.0 {
2148 let min_eig_before = range_eigs_sorted
2149 .iter()
2150 .copied()
2151 .fold(f64::INFINITY, f64::min);
2152 let mut shrinkage_floor_applied = 0usize;
2153 for eig_idx in 0..range_eigs_sorted.len() {
2154 let mut penalized_energy = 0.0;
2155 for original_col in 0..p {
2156 if structurally_penalized_cols[original_col] {
2157 let mut coordinate = 0.0;
2158 for pen_col in 0..penalized_rank {
2159 coordinate +=
2160 q_pen[(original_col, pen_col)] * range_rotation[(pen_col, eig_idx)];
2161 }
2162 penalized_energy += coordinate * coordinate;
2163 }
2164 }
2165 if penalized_energy > 1e-8 {
2166 range_eigs_sorted[eig_idx] += shrinkage_ridge;
2167 shrinkage_floor_applied += 1;
2168 }
2169 }
2170 if min_eig_before > 0.0 && shrinkage_ridge / min_eig_before > 0.01 {
2172 log::debug!(
2173 "Penalty shrinkage floor active: ridge={:.3e} (min_eig_before={:.3e}, ratio={:.1e}, max_bal_eig={:.3e}, applied_dirs={})",
2174 shrinkage_ridge,
2175 min_eig_before,
2176 shrinkage_ridge / min_eig_before,
2177 invariant.max_balanced_eigenvalue,
2178 shrinkage_floor_applied,
2179 );
2180 }
2181 }
2182
2183 let eigenvalue_floor = invariant.max_balanced_eigenvalue.max(1.0) * 1e-12;
2184 let qs = compose_qs_from_split(&q_pen, &q_null, p);
2185
2186 let leakage = assess_subspace_leakage(&qs, &rs_transformed, structural_rank, p);
2189 let leakage_rel_tol = 1e-10;
2190 let leakage_abs_tol = 1e-12;
2191 let orth_tol = 1e-10;
2192 if leakage.max_rel_sq > leakage_rel_tol && leakage.max_abs_sq > leakage_abs_tol
2193 || leakage.max_cross_gram_abs > orth_tol
2194 {
2195 return Err(EstimationError::LayoutError(format!(
2196 "Reparameterization subspace split is inconsistent: max null leakage {:.3e} (rel {:.3e}, worst penalty {}), max |Qp'Qn| {:.3e}",
2197 leakage.max_abs_sq.sqrt(),
2198 leakage.max_rel_sq.sqrt(),
2199 leakage.worst_penalty,
2200 leakage.max_cross_gram_abs,
2201 )));
2202 }
2203
2204 let mut u_truncated_mat = Mat::<f64>::zeros(p, q_null.ncols());
2207 matmul(
2208 u_truncated_mat.as_mut(),
2209 Accum::Replace,
2210 qs.transpose(),
2211 q_null.as_ref(),
2212 1.0,
2213 Par::Seq,
2214 );
2215
2216 let mut e_transformed_mat = Mat::<f64>::zeros(structural_rank, p);
2222 for row_idx in 0..structural_rank {
2223 let safe_eigenval = range_eigs_sorted[row_idx].max(eigenvalue_floor);
2224 let sqrt_eigenval = safe_eigenval.sqrt();
2225 for j in 0..penalized_rank {
2227 e_transformed_mat[(row_idx, j)] = sqrt_eigenval * range_rotation[(j, row_idx)];
2228 }
2229 }
2230
2231 let mut floored_eigs: Vec<f64> = Vec::with_capacity(range_eigs_sorted.len());
2247 let mut log_det_sum = KahanSum::default();
2248 for (idx, &ev) in range_eigs_sorted.iter().enumerate() {
2249 if !ev.is_finite() || ev < -eigenvalue_floor {
2250 return Err(EstimationError::LayoutError(format!(
2251 "Penalty pseudo-logdet has a non-finite or large-negative structural eigenvalue at index {idx}: {ev:.3e}"
2252 )));
2253 }
2254 let safe_ev = ev.max(eigenvalue_floor);
2255 floored_eigs.push(safe_ev);
2256 if idx < penalized_rank {
2257 log_det_sum.add(safe_ev.ln());
2258 }
2259 }
2260 let log_det = log_det_sum.sum();
2261 let delta = 0.0;
2262
2263 let det1vec: Vec<f64> = (0..lambdas.len())
2266 .into_par_iter()
2267 .map(|k| {
2268 let s_k = &s_k_penalized_cache[k];
2269 let trace = trace_penalty_in_orthogonal_basis(
2273 s_k,
2274 penalized_rank,
2275 &range_rotation,
2276 &floored_eigs,
2277 delta,
2278 );
2279 lambdas[k] * trace
2280 })
2281 .collect();
2282
2283 {
2284 let mut maxdet1_mismatch = 0.0_f64;
2288 let mut det1_scale = 0.0_f64;
2289 for (k, lambda) in lambdas.iter().enumerate() {
2290 let s_k_penalized = &s_k_penalized_cache[k];
2291 let s_k_eigenbasis = orthogonal_similarity_transform_faer(
2292 s_k_penalized,
2293 penalized_rank,
2294 &range_rotation,
2295 );
2296 let mut trace = KahanSum::default();
2297 for l in 0..penalized_rank {
2298 trace.add(s_k_eigenbasis[(l, l)] / (floored_eigs[l] + delta));
2299 }
2300 let reference = *lambda * trace.sum();
2301 maxdet1_mismatch = maxdet1_mismatch.max((reference - det1vec[k]).abs());
2302 det1_scale = det1_scale.max(reference.abs()).max(det1vec[k].abs());
2303 }
2304 let det1_tolerance = 1e-7 * det1_scale.max(1.0);
2305 assert!(
2306 maxdet1_mismatch <= det1_tolerance,
2307 "det1 mismatch between optimized and reference formulas: max_abs={maxdet1_mismatch:.3e}, tol={det1_tolerance:.3e}"
2308 );
2309 }
2310
2311 let mut s_truncated = Mat::<f64>::zeros(p, p);
2322 matmul(
2323 s_truncated.as_mut(),
2324 Accum::Replace,
2325 e_transformed_mat.transpose(),
2326 e_transformed_mat.as_ref(),
2327 1.0,
2328 Par::Seq,
2329 );
2330
2331 {
2332 let mut max_null_diag = 0.0_f64;
2334 let mut max_null_offdiag = 0.0_f64;
2335 for i in structural_rank..p {
2336 max_null_diag = max_null_diag.max(s_truncated[(i, i)].abs());
2337 for j in 0..p {
2338 if i != j {
2339 max_null_offdiag = max_null_offdiag.max(s_truncated[(i, j)].abs());
2340 }
2341 }
2342 }
2343 assert!(
2344 max_null_diag <= 1e-10 && max_null_offdiag <= 1e-10,
2345 "null-space leakage in transformed penalty: max_null_diag={max_null_diag:.3e}, max_null_offdiag={max_null_offdiag:.3e}"
2346 );
2347 }
2348
2349 let qs_array = mat_to_array(&qs);
2350 let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2351 .par_iter()
2352 .zip(penalties.par_iter())
2353 .map(|(r, cp)| {
2354 let mean_transformed = qs_array.t().dot(&cp.full_width_prior_mean());
2355 CanonicalPenalty::from_dense_root_with_mean(mat_to_array(r), p, mean_transformed)
2356 })
2357 .collect();
2358 Ok(ReparamResult {
2359 s_transformed: mat_to_array(&s_truncated),
2360 log_det,
2361 det1: Array1::from(det1vec),
2362 qs: qs_array,
2363 canonical_transformed,
2364 e_transformed: mat_to_array(&e_transformed_mat),
2365 u_truncated: mat_to_array(&u_truncated_mat),
2366 penalty_shrinkage_ridge: shrinkage_ridge,
2367 })
2368}
2369
2370#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2372pub struct EngineDims {
2373 pub p: usize,
2374 pub k: usize,
2375}
2376
2377impl EngineDims {
2378 pub fn new(p: usize, k: usize) -> Self {
2379 Self { p, k }
2380 }
2381}
2382
2383pub fn stable_reparameterization_engine_canonical(
2392 penalties: &[CanonicalPenalty],
2393 lambdas: &[f64],
2394 dims: EngineDims,
2395 cached_invariant: Option<&ReparamInvariant>,
2396 penalty_shrinkage_floor: Option<f64>,
2397) -> Result<ReparamResult, EstimationError> {
2398 let owned;
2399 let invariant = match cached_invariant {
2400 Some(inv) => inv,
2401 None => {
2402 owned = precompute_reparam_invariant_from_canonical(penalties, dims.p)?;
2403 &owned
2404 }
2405 };
2406 stable_reparameterizationwith_invariant(
2407 penalties,
2408 lambdas,
2409 dims.p,
2410 invariant,
2411 penalty_shrinkage_floor,
2412 )
2413}
2414
2415#[derive(Clone)]
2425pub struct KroneckerReparamResult {
2426 pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
2432 pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
2434 pub marginal_qs: Arc<Vec<Array2<f64>>>,
2436 pub log_det: f64,
2438 pub det1: Array1<f64>,
2440 pub det2: Array2<f64>,
2442 pub penalty_shrinkage_ridge: f64,
2444 pub has_double_penalty: bool,
2446 pub marginal_dims: Vec<usize>,
2448}
2449
2450impl KroneckerReparamResult {
2451 pub fn materialize_qs(&self) -> Array2<f64> {
2454 let mut qs = Array2::<f64>::eye(1);
2455 for u_k in self.marginal_qs.iter() {
2456 qs = kronecker_product(&qs, u_k);
2457 }
2458 qs
2459 }
2460
2461 pub fn materialize_s_transformed(&self, lambdas: &[f64]) -> Array2<f64> {
2464 let d = self.marginal_dims.len();
2465 let p: usize = self.marginal_dims.iter().copied().product();
2466 let mut s = Array2::<f64>::zeros((p, p));
2467
2468 let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2472 self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2473 let has_double = self.has_double_penalty && lambdas.len() > d;
2474 let mut multi_idx = vec![0usize; d];
2475 let mut flat = 0usize;
2476 loop {
2477 let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2478 &eigenvalue_views,
2479 &multi_idx,
2480 lambdas,
2481 d,
2482 has_double,
2483 self.penalty_shrinkage_ridge,
2484 );
2485 s[[flat, flat]] = sigma;
2486 flat += 1;
2487
2488 if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2489 break;
2490 }
2491 }
2492 s
2493 }
2494
2495 pub fn materialize_dense_artifact_result(
2498 &self,
2499 rs_list: &[Array2<f64>],
2500 lambdas: &[f64],
2501 p: usize,
2502 ) -> Result<ReparamResult, EstimationError> {
2503 const KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P: usize = 4096;
2504 if p > KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P {
2505 return Err(EstimationError::LayoutError(format!(
2506 "Kronecker reparameterization would materialize dense {}x{} compatibility tensors; \
2507 large-model dense fallback is disabled. Wire the downstream solver to consume \
2508 the factored Kronecker result directly",
2509 p, p
2510 )));
2511 }
2512 let qs = self.materialize_qs();
2513 let s_transformed = self.materialize_s_transformed(lambdas);
2514
2515 let rs_transformed: Vec<Array2<f64>> = if rs_list.len() >= 2 {
2517 use rayon::prelude::*;
2518 rs_list
2519 .par_iter()
2520 .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2521 .collect()
2522 } else {
2523 rs_list
2524 .iter()
2525 .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2526 .collect()
2527 };
2528 let d = self.marginal_dims.len();
2534 let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2541 self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2542 let has_double = self.has_double_penalty && lambdas.len() > d;
2543 let diag_vals: Vec<f64> = {
2544 let mut vals = Vec::with_capacity(p);
2545 let mut multi_idx = vec![0usize; d];
2546 loop {
2547 let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2548 &eigenvalue_views,
2549 &multi_idx,
2550 lambdas,
2551 d,
2552 has_double,
2553 self.penalty_shrinkage_ridge,
2554 );
2555 vals.push(if sigma > 0.0 { sigma.sqrt() } else { 0.0 });
2556
2557 if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2558 break;
2559 }
2560 }
2561 vals
2562 };
2563 let rank = diag_vals.iter().filter(|&&v| v > 1e-12).count();
2564 let mut e_transformed = Array2::<f64>::zeros((rank, p));
2565 let mut row = 0;
2566 for (j, &v) in diag_vals.iter().enumerate() {
2567 if v > 1e-12 {
2568 e_transformed[[row, j]] = v;
2569 row += 1;
2570 }
2571 }
2572
2573 let null_count = p - rank;
2575 let mut u_truncated = Array2::<f64>::zeros((p, null_count));
2576 let mut col = 0;
2577 for (j, &v) in diag_vals.iter().enumerate() {
2578 if v <= 1e-12 {
2579 u_truncated[[j, col]] = 1.0; col += 1;
2581 }
2582 }
2583
2584 let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2585 .iter()
2586 .map(|r| CanonicalPenalty::from_dense_root(r.clone(), p))
2587 .collect();
2588 Ok(ReparamResult {
2589 s_transformed,
2590 log_det: self.log_det,
2591 det1: self.det1.clone(),
2592 qs,
2593 canonical_transformed,
2594 e_transformed,
2595 u_truncated,
2596 penalty_shrinkage_ridge: self.penalty_shrinkage_ridge,
2597 })
2598 }
2599}
2600
2601const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
2608
2609#[inline]
2623fn kronecker_cell_sigma(
2624 marginal_eigenvalues: &[ArrayView1<'_, f64>],
2625 multi_idx: &[usize],
2626 lambdas: &[f64],
2627 d: usize,
2628 has_double_penalty: bool,
2629 ridge: f64,
2630) -> (f64, f64, bool) {
2631 let mut sigma = 0.0;
2632 let mut structural_sigma = 0.0;
2633 for k in 0..d {
2634 let marginal_eigenvalue = marginal_eigenvalues[k][multi_idx[k]];
2635 structural_sigma += marginal_eigenvalue;
2636 sigma += lambdas[k] * marginal_eigenvalue;
2637 }
2638 let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
2639 if has_double_penalty && joint_null {
2640 sigma += lambdas[d];
2641 }
2642 if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
2643 sigma += ridge;
2644 }
2645 (sigma, structural_sigma, joint_null)
2646}
2647
2648#[inline]
2651fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
2652 let mut carry = true;
2653 for dim in (0..dims.len()).rev() {
2654 if carry {
2655 multi_idx[dim] += 1;
2656 if multi_idx[dim] < dims[dim] {
2657 carry = false;
2658 } else {
2659 multi_idx[dim] = 0;
2660 }
2661 }
2662 }
2663 carry
2664}
2665
2666pub fn kronecker_logdet_and_derivatives(
2667 marginal_eigenvalues: &[ArrayView1<'_, f64>],
2668 marginal_dims: &[usize],
2669 lambdas: &[f64],
2670 has_double_penalty: bool,
2671 ridge: f64,
2672) -> (f64, Array1<f64>, Array2<f64>) {
2673 let d = marginal_dims.len();
2674 let n_pen = d + if has_double_penalty { 1 } else { 0 };
2675
2676 let mut logdet = 0.0;
2677 let mut grad = Array1::<f64>::zeros(n_pen);
2678 let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2679 let tol = 1e-12;
2680
2681 let mut multi_idx = vec![0usize; d];
2682 loop {
2683 let (sigma, _structural_sigma, joint_null) = kronecker_cell_sigma(
2684 marginal_eigenvalues,
2685 &multi_idx,
2686 lambdas,
2687 d,
2688 has_double_penalty,
2689 ridge,
2690 );
2691
2692 if sigma > tol {
2693 logdet += sigma.ln();
2694 let inv_sigma = 1.0 / sigma;
2695 let inv_sigma2 = inv_sigma * inv_sigma;
2696
2697 for k in 0..d {
2698 let ck = lambdas[k] * marginal_eigenvalues[k][multi_idx[k]];
2699 grad[k] += ck * inv_sigma;
2700 }
2701 if has_double_penalty && joint_null {
2702 grad[d] += lambdas[d] * inv_sigma;
2703 }
2704
2705 for k in 0..n_pen {
2706 let ck = if k < d {
2707 lambdas[k] * marginal_eigenvalues[k][multi_idx[k]]
2708 } else if joint_null {
2709 lambdas[d]
2710 } else {
2711 0.0
2712 };
2713 if ck == 0.0 {
2720 continue;
2721 }
2722 hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2723 for l in (k + 1)..n_pen {
2724 let cl = if l < d {
2725 lambdas[l] * marginal_eigenvalues[l][multi_idx[l]]
2726 } else if joint_null {
2727 lambdas[d]
2728 } else {
2729 0.0
2730 };
2731 let off = -ck * cl * inv_sigma2;
2732 hess[[k, l]] += off;
2733 hess[[l, k]] += off;
2734 }
2735 }
2736 }
2737
2738 if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
2739 break;
2740 }
2741 }
2742
2743 (logdet, grad, hess)
2744}
2745
2746use crate::kronecker::KroneckerInvariantStructure;
2750
2751pub fn kronecker_reparameterization_engine(
2757 marginal_designs: &[Array2<f64>],
2758 marginal_penalties: &[Array2<f64>],
2759 marginal_dims: &[usize],
2760 lambdas: &[f64],
2761 has_double_penalty: bool,
2762 penalty_shrinkage_floor: Option<f64>,
2763) -> Result<KroneckerReparamResult, EstimationError> {
2764 let d = marginal_dims.len();
2765 if marginal_designs.len() != d || marginal_penalties.len() != d {
2766 return Err(EstimationError::LayoutError(format!(
2767 "kronecker_reparameterization_engine: dimension mismatch: designs={}, penalties={}, dims={}",
2768 marginal_designs.len(),
2769 marginal_penalties.len(),
2770 d
2771 )));
2772 }
2773
2774 let invariant =
2775 KroneckerInvariantStructure::compute(marginal_designs, marginal_penalties, marginal_dims)?;
2776 kronecker_reparameterization_engine_with_invariant(
2777 &invariant,
2778 marginal_dims,
2779 lambdas,
2780 has_double_penalty,
2781 penalty_shrinkage_floor,
2782 )
2783}
2784
2785pub fn kronecker_reparameterization_engine_with_invariant(
2793 invariant: &KroneckerInvariantStructure,
2794 marginal_dims: &[usize],
2795 lambdas: &[f64],
2796 has_double_penalty: bool,
2797 penalty_shrinkage_floor: Option<f64>,
2798) -> Result<KroneckerReparamResult, EstimationError> {
2799 let marginal_eigenvalues = Arc::clone(&invariant.marginal_eigenvalues);
2802 let marginal_qs = Arc::clone(&invariant.marginal_qs);
2803 let reparameterized_marginals = Arc::clone(&invariant.reparameterized_marginals);
2804
2805 let penalty_shrinkage_ridge = if let Some(floor) = penalty_shrinkage_floor {
2807 floor * invariant.max_balanced_eigenvalue
2808 } else {
2809 0.0
2810 };
2811
2812 let marginal_eigenvalue_views: Vec<_> = marginal_eigenvalues
2813 .iter()
2814 .map(|evals| evals.view())
2815 .collect();
2816 let (log_det, det1, det2) = kronecker_logdet_and_derivatives(
2817 &marginal_eigenvalue_views,
2818 marginal_dims,
2819 lambdas,
2820 has_double_penalty,
2821 penalty_shrinkage_ridge,
2822 );
2823
2824 Ok(KroneckerReparamResult {
2825 reparameterized_marginals,
2826 marginal_eigenvalues,
2827 marginal_qs,
2828 log_det,
2829 det1,
2830 det2,
2831 penalty_shrinkage_ridge,
2832 has_double_penalty,
2833 marginal_dims: marginal_dims.to_vec(),
2834 })
2835}
2836
2837pub fn calculate_condition_number(matrix: &Array2<f64>) -> Result<f64, FaerLinalgError> {
2857 let (rows, cols) = matrix.dim();
2858 if rows == 0 || cols == 0 {
2859 return Ok(1.0);
2860 }
2861
2862 if rows == cols {
2864 let mut max_abs = 0.0_f64;
2865 let mut max_asym = 0.0_f64;
2866 for i in 0..rows {
2867 for j in 0..cols {
2868 max_abs = max_abs.max(matrix[[i, j]].abs());
2869 }
2870 for j in 0..i {
2871 let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2872 if diff > max_asym {
2873 max_asym = diff;
2874 }
2875 }
2876 }
2877 let sym_tol = max_abs.max(1.0) * 1e-12;
2878 if max_asym <= sym_tol {
2879 let (evals, _) = matrix.eigh(Side::Lower)?;
2880 let mut max_abs_eval = 0.0_f64;
2881 let mut min_abs_eval = f64::INFINITY;
2882 for &lam in evals.iter() {
2883 let s = lam.abs();
2884 max_abs_eval = max_abs_eval.max(s);
2885 min_abs_eval = min_abs_eval.min(s);
2886 }
2887 if min_abs_eval < 1e-12 {
2888 return Ok(f64::INFINITY);
2889 }
2890 return Ok(max_abs_eval / min_abs_eval);
2891 }
2892 }
2893
2894 let (_, s, _) = matrix.svd(false, false)?;
2896 let max_sv = s.iter().fold(0.0_f64, |max, &val| max.max(val));
2897 let min_sv = s.iter().fold(f64::INFINITY, |min, &val| min.min(val));
2898 if min_sv < 1e-12 {
2899 return Ok(f64::INFINITY);
2900 }
2901 Ok(max_sv / min_sv)
2902}
2903
2904#[cfg(test)]
2905mod tests {
2906 use super::{
2907 CanonicalPenalty, SubspaceLeakageMetrics, assess_subspace_leakage,
2908 classify_eigenvalues_strict, precompute_reparam_invariant_from_canonical,
2909 report_penalty_pair_redundancy, stable_reparameterizationwith_invariant,
2910 };
2911 use crate::construction::kronecker_product;
2912 use crate::EstimationError;
2913 use faer::Mat;
2914 use gam_linalg::faer_ndarray::FaerEigh;
2915 use gam_linalg::utils::inf_norm;
2916 use ndarray::{Array1, Array2, array};
2917
2918 fn canonical_from_roots(rs_list: &[Array2<f64>], p: usize) -> Vec<CanonicalPenalty> {
2920 rs_list
2921 .iter()
2922 .map(|r| {
2923 let local = r.t().dot(r);
2924 CanonicalPenalty {
2925 root: r.clone(),
2926 col_range: 0..p,
2927 total_dim: p,
2928 nullity: 0,
2929 local,
2930 prior_mean: Array1::zeros(p),
2931 positive_eigenvalues: Vec::new(),
2932 op: None,
2933 }
2934 })
2935 .collect()
2936 }
2937
2938 fn metrics_for(
2939 qs: &Mat<f64>,
2940 rs: &[Mat<f64>],
2941 structural_rank: usize,
2942 p: usize,
2943 ) -> SubspaceLeakageMetrics {
2944 assess_subspace_leakage(qs, rs, structural_rank, p)
2945 }
2946
2947 #[test]
2948 fn subspace_leakage_iszero_for_clean_split() {
2949 let p = 4usize;
2950 let structural_rank = 2usize;
2951 let qs = Mat::<f64>::identity(p, p);
2952 let mut r0 = Mat::<f64>::zeros(2, p);
2953 r0[(0, 0)] = 1.0;
2954 r0[(1, 1)] = 2.0;
2955
2956 let m = metrics_for(&qs, &[r0], structural_rank, p);
2957 assert!(m.max_abs_sq <= 1e-16);
2958 assert!(m.max_rel_sq <= 1e-16);
2959 assert!(m.max_cross_gram_abs <= 1e-16);
2960 }
2961
2962 #[test]
2963 fn subspace_leakage_detects_null_column_energy() {
2964 let p = 4usize;
2965 let structural_rank = 2usize;
2966 let qs = Mat::<f64>::identity(p, p);
2967 let mut r0 = Mat::<f64>::zeros(1, p);
2968 r0[(0, 2)] = 3.0;
2969
2970 let m = metrics_for(&qs, &[r0], structural_rank, p);
2971 assert!(m.max_abs_sq > 0.0);
2972 assert!(m.max_rel_sq > 0.99);
2973 }
2974
2975 #[test]
2976 fn subspace_leakage_detects_qp_qn_nonorthogonality() {
2977 let p = 3usize;
2978 let structural_rank = 1usize;
2979 let mut qs = Mat::<f64>::identity(p, p);
2980 qs[(0, 1)] = 0.2;
2981 let r0 = Mat::<f64>::zeros(1, p);
2982
2983 let m = metrics_for(&qs, &[r0], structural_rank, p);
2984 assert!(m.max_cross_gram_abs > 1e-3);
2985 }
2986
2987 #[test]
2988 fn u_truncated_is_transformed_frame_in_nonzero_case() {
2989 let p = 3usize;
2990 let rs_list = vec![array![[1.0, 0.0, 0.0]]];
2991 let canonical = canonical_from_roots(&rs_list, p);
2992 let lambdas = vec![2.0];
2993 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
2994 .expect("precompute invariant");
2995 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
2996 .expect("stable reparam");
2997
2998 let expected = rep.qs.t().dot(&inv.split.q_null);
2999 let diff = &rep.u_truncated - &expected;
3000 let max_abs = inf_norm(diff.iter().copied());
3001 assert!(
3002 max_abs <= 1e-10,
3003 "u_truncated frame mismatch: max_abs={max_abs}"
3004 );
3005 }
3006
3007 #[test]
3008 fn infinite_lambda_keeps_range_penalty_block_finite_1379() {
3009 let p = 3usize;
3026 let rs_list = vec![array![[1.0, 0.0, 0.0]], array![[0.0, 1.0, 0.0]]];
3027 let canonical = canonical_from_roots(&rs_list, p);
3028 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3029 .expect("precompute invariant");
3030
3031 let lambdas_inf = vec![f64::INFINITY, 3.0];
3032 let inf_result =
3033 stable_reparameterizationwith_invariant(&canonical, &lambdas_inf, p, &inv, None);
3034 assert!(
3035 inf_result.is_err(),
3036 "an infinite lambda must surface as an error, not be silently clamped (#1074)"
3037 );
3038
3039 let lambdas_big = vec![1e300_f64, 3.0];
3043 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas_big, p, &inv, None)
3044 .expect("stable reparam at large-but-finite lambda");
3045 assert!(
3046 rep.s_transformed.iter().all(|v| v.is_finite()),
3047 "transformed penalty must be finite at large-but-finite lambda"
3048 );
3049 assert!(
3050 rep.qs.iter().all(|v| v.is_finite()),
3051 "reparam rotation must be finite at large-but-finite lambda"
3052 );
3053 assert!(
3054 rep.log_det.is_finite(),
3055 "penalty log-det must be finite at large-but-finite lambda"
3056 );
3057 assert!(
3058 rep.det1.iter().all(|v| v.is_finite()),
3059 "penalty log-det derivatives must be finite at large-but-finite lambda"
3060 );
3061 }
3062
3063 #[test]
3064 fn u_truncated_is_identitywhen_no_penalties() {
3065 let p = 4usize;
3066 let canonical: Vec<CanonicalPenalty> = Vec::new();
3067 let lambdas: Vec<f64> = Vec::new();
3068 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3069 .expect("precompute invariant");
3070 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3071 .expect("stable reparam");
3072 assert_eq!(rep.u_truncated, Array2::<f64>::eye(p));
3073 }
3074
3075 #[test]
3076 fn dense_shrinkage_floor_skips_structurally_unpenalized_range_columns() {
3077 let p = 3usize;
3078 let canonical = canonical_from_roots(&[array![[1.0, 0.0, 0.0]]], p);
3079 let invariant = super::ReparamInvariant {
3080 split: super::SubspaceSplit {
3081 q_pen: array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
3082 q_null: array![[0.0], [0.0], [1.0]],
3083 },
3084 qs_base: Array2::eye(p),
3085 has_nonzero: true,
3086 max_balanced_eigenvalue: 1.0,
3087 };
3088
3089 let rep =
3090 stable_reparameterizationwith_invariant(&canonical, &[2.0], p, &invariant, Some(1e-6))
3091 .expect("stable reparameterization");
3092 assert!(rep.s_transformed[[0, 0]] > 2.0);
3093 assert!(
3094 rep.s_transformed[[1, 1]] <= 1e-11,
3095 "structurally unpenalized range coordinate received shrinkage ridge: {}",
3096 rep.s_transformed[[1, 1]]
3097 );
3098 }
3099
3100 #[test]
3101 fn kronecker_shrinkage_floor_preserves_joint_null_space() {
3102 let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3103 let marginal_penalties = vec![
3104 array![[0.0, 0.0], [0.0, 2.0]],
3105 array![[0.0, 0.0], [0.0, 3.0]],
3106 ];
3107 let marginal_dims = vec![2usize, 2usize];
3108 let lambdas = vec![5.0, 7.0];
3109
3110 let rep = super::kronecker_reparameterization_engine(
3111 &marginal_designs,
3112 &marginal_penalties,
3113 &marginal_dims,
3114 &lambdas,
3115 false,
3116 Some(1e-6),
3117 )
3118 .expect("kronecker reparameterization");
3119 assert!(rep.penalty_shrinkage_ridge > 0.0);
3120
3121 let s = rep.materialize_s_transformed(&lambdas);
3122 assert!(
3123 s[[0, 0]].abs() <= 1e-14,
3124 "joint tensor null direction must remain unpenalized, got {}",
3125 s[[0, 0]]
3126 );
3127 assert!(s[[1, 1]] > lambdas[1] * 3.0);
3128 assert!(s[[2, 2]] > lambdas[0] * 2.0);
3129 assert!(s[[3, 3]] > lambdas[0] * 2.0 + lambdas[1] * 3.0);
3130
3131 let tensor_roots = vec![
3132 array![
3133 [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3134 [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3135 ],
3136 array![
3137 [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3138 [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3139 ],
3140 ];
3141 let dense = rep
3142 .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3143 .expect("dense artifact materialization");
3144 assert_eq!(dense.e_transformed.nrows(), 3);
3145 assert_eq!(dense.u_truncated.ncols(), 1);
3146 }
3147
3148 #[test]
3149 fn kronecker_memoized_invariant_is_bit_identical_to_unmemoized_engine() {
3150 let marginal_designs = vec![
3157 array![[1.0, 0.3, -0.2], [0.4, 1.0, 0.1], [-0.1, 0.2, 1.0]],
3158 array![[1.0, -0.5], [0.2, 1.0], [0.7, 0.3]],
3159 ];
3160 let marginal_penalties = vec![
3161 array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]],
3162 array![[3.0, -1.5], [-1.5, 3.0]],
3163 ];
3164 let marginal_dims = vec![3usize, 2usize];
3165
3166 let invariant = super::KroneckerInvariantStructure::compute(
3167 &marginal_designs,
3168 &marginal_penalties,
3169 &marginal_dims,
3170 )
3171 .expect("invariant structure");
3172
3173 for lambdas in [
3174 vec![5.0, 7.0],
3175 vec![0.0, 7.0],
3176 vec![5.0, 0.0],
3177 vec![1e-3, 1e3],
3178 ] {
3179 for floor in [None, Some(1e-6)] {
3180 let unmemoized = super::kronecker_reparameterization_engine(
3181 &marginal_designs,
3182 &marginal_penalties,
3183 &marginal_dims,
3184 &lambdas,
3185 true,
3186 floor,
3187 )
3188 .expect("unmemoized engine");
3189 let memoized = super::kronecker_reparameterization_engine_with_invariant(
3190 &invariant,
3191 &marginal_dims,
3192 &lambdas,
3193 true,
3194 floor,
3195 )
3196 .expect("memoized engine");
3197
3198 assert_eq!(memoized.log_det.to_bits(), unmemoized.log_det.to_bits());
3199 assert_eq!(
3200 memoized.penalty_shrinkage_ridge.to_bits(),
3201 unmemoized.penalty_shrinkage_ridge.to_bits()
3202 );
3203 for (a, b) in memoized.det1.iter().zip(unmemoized.det1.iter()) {
3204 assert_eq!(a.to_bits(), b.to_bits());
3205 }
3206 for (a, b) in memoized.det2.iter().zip(unmemoized.det2.iter()) {
3207 assert_eq!(a.to_bits(), b.to_bits());
3208 }
3209 for (ma, ua) in memoized
3210 .reparameterized_marginals
3211 .iter()
3212 .zip(unmemoized.reparameterized_marginals.iter())
3213 {
3214 for (a, b) in ma.iter().zip(ua.iter()) {
3215 assert_eq!(a.to_bits(), b.to_bits());
3216 }
3217 }
3218 for (mq, uq) in memoized
3219 .marginal_qs
3220 .iter()
3221 .zip(unmemoized.marginal_qs.iter())
3222 {
3223 for (a, b) in mq.iter().zip(uq.iter()) {
3224 assert_eq!(a.to_bits(), b.to_bits());
3225 }
3226 }
3227 }
3228 }
3229 }
3230
3231 #[test]
3232 fn kronecker_double_penalty_shrinks_only_joint_null_space() {
3233 let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3234 let marginal_penalties = vec![
3235 array![[0.0, 0.0], [0.0, 2.0]],
3236 array![[0.0, 0.0], [0.0, 3.0]],
3237 ];
3238 let marginal_dims = vec![2usize, 2usize];
3239 let lambdas = vec![5.0, 7.0, 11.0];
3240
3241 let rep = super::kronecker_reparameterization_engine(
3242 &marginal_designs,
3243 &marginal_penalties,
3244 &marginal_dims,
3245 &lambdas,
3246 true,
3247 None,
3248 )
3249 .expect("kronecker reparameterization");
3250
3251 let s = rep.materialize_s_transformed(&lambdas);
3252 let expected = [11.0, 21.0, 10.0, 31.0];
3253 for (idx, expected_diag) in expected.iter().copied().enumerate() {
3254 assert!(
3255 (s[[idx, idx]] - expected_diag).abs() <= 1e-12,
3256 "diagonal {idx} got {}, expected {expected_diag}",
3257 s[[idx, idx]]
3258 );
3259 }
3260
3261 let expected_logdet: f64 = expected.iter().map(|v| f64::ln(*v)).sum();
3262 assert!((rep.log_det - expected_logdet).abs() <= 1e-12);
3263 assert!(
3264 (rep.det1[2] - 1.0).abs() <= 1e-12,
3265 "double-penalty derivative must come only from the joint null mode, got {}",
3266 rep.det1[2]
3267 );
3268 assert!(rep.det2[[2, 2]].abs() <= 1e-12);
3269
3270 let tensor_roots = vec![
3271 array![
3272 [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3273 [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3274 ],
3275 array![
3276 [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3277 [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3278 ],
3279 ];
3280 let dense = rep
3281 .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3282 .expect("dense artifact materialization");
3283 for (idx, expected_diag) in expected.iter().copied().enumerate() {
3284 assert!(
3285 (dense.s_transformed[[idx, idx]] - expected_diag).abs() <= 1e-12,
3286 "dense artifact diagonal {idx} got {}, expected {expected_diag}",
3287 dense.s_transformed[[idx, idx]]
3288 );
3289 }
3290 }
3291
3292 #[test]
3293 fn transformed_penalty_is_diagonal_in_transformed_frame() {
3294 let p = 3usize;
3295 let inv_sqrt2 = 2.0_f64.sqrt().recip();
3296 let rs_list = vec![array![[inv_sqrt2, inv_sqrt2, 0.0]]];
3298 let canonical = canonical_from_roots(&rs_list, p);
3299 let lambdas = vec![4.0];
3300 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3301 .expect("precompute invariant");
3302 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3303 .expect("stable reparam");
3304
3305 assert_eq!(rep.e_transformed.nrows(), 1);
3306 assert!(rep.e_transformed[[0, 0]].abs() > 0.0);
3307 assert!(rep.e_transformed[[0, 1]].abs() <= 1e-12);
3308 assert!(rep.e_transformed[[0, 2]].abs() <= 1e-12);
3309 let expected_det1 = 1.0_f64;
3312 assert!((rep.det1[0] - expected_det1).abs() <= 1e-12);
3313
3314 let s = rep.s_transformed;
3315 let mut max_offdiag = 0.0_f64;
3316 for i in 0..p {
3317 for j in 0..p {
3318 if i != j {
3319 max_offdiag = max_offdiag.max(s[[i, j]].abs());
3320 }
3321 }
3322 }
3323 assert!(
3324 max_offdiag <= 1e-10,
3325 "transformed penalty should be diagonal, max offdiag={max_offdiag}"
3326 );
3327 assert!(s[[1, 1]].abs() <= 1e-10);
3328 assert!(s[[2, 2]].abs() <= 1e-10);
3329 }
3330
3331 #[test]
3332 fn det1_matches_rank_for_single_full_rank_penalty() {
3333 let p = 2usize;
3334 let inv_sqrt2 = 2.0_f64.sqrt().recip();
3335 let q_t = [[inv_sqrt2, inv_sqrt2], [-inv_sqrt2, inv_sqrt2]];
3337 let rs = array![
3339 [3.0 * q_t[0][0], 3.0 * q_t[0][1]],
3340 [1.0 * q_t[1][0], 1.0 * q_t[1][1]]
3341 ];
3342 let rs_list = vec![rs];
3343 let canonical = canonical_from_roots(&rs_list, p);
3344 let lambdas = vec![5.0];
3345
3346 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3347 .expect("precompute invariant");
3348 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3349 .expect("stable reparam");
3350
3351 assert_eq!(rep.e_transformed.nrows(), p);
3352 let det1 = rep.det1[0];
3353 let s_k_eigs = [9.0_f64, 1.0_f64];
3357 let lambda = 5.0_f64;
3358 let expected_det1: f64 = s_k_eigs.iter().map(|&d| lambda * d / (lambda * d)).sum();
3359 assert!(
3360 (det1 - expected_det1).abs() <= 1e-12,
3361 "expected det1={expected_det1}, got {det1}",
3362 );
3363
3364 let s = rep.s_transformed;
3365 assert!(s[[0, 1]].abs() <= 1e-10);
3366 assert!(s[[1, 0]].abs() <= 1e-10);
3367 assert!(s[[0, 0]] > 0.0);
3368 assert!(s[[1, 1]] > 0.0);
3369 }
3370
3371 #[test]
3372 fn kronecker_reparam_logdet_matches_dense() {
3373 let q1 = 3;
3376 let q2 = 4;
3377 let s1 = {
3378 let mut s = Array2::<f64>::zeros((q1, q1));
3379 s[[0, 0]] = 1.0;
3381 s[[0, 1]] = -1.0;
3382 s[[1, 0]] = -1.0;
3383 s[[1, 1]] = 2.0;
3384 s[[1, 2]] = -1.0;
3385 s[[2, 1]] = -1.0;
3386 s[[2, 2]] = 1.0;
3387 s
3388 };
3389 let s2 = {
3390 let mut s = Array2::<f64>::zeros((q2, q2));
3391 s[[0, 0]] = 1.0;
3392 s[[0, 1]] = -1.0;
3393 s[[1, 0]] = -1.0;
3394 s[[1, 1]] = 2.0;
3395 s[[1, 2]] = -1.0;
3396 s[[2, 1]] = -1.0;
3397 s[[2, 2]] = 2.0;
3398 s[[2, 3]] = -1.0;
3399 s[[3, 2]] = -1.0;
3400 s[[3, 3]] = 1.0;
3401 s
3402 };
3403
3404 let lambdas = [2.5, 1.3];
3405 let p = q1 * q2;
3407 let i1 = Array2::<f64>::eye(q1);
3408 let i2 = Array2::<f64>::eye(q2);
3409 let pen0 = kronecker_product(&s1, &i2);
3410 let pen1 = kronecker_product(&i1, &s2);
3411 let mut s_dense = Array2::<f64>::zeros((p, p));
3412 s_dense.scaled_add(lambdas[0], &pen0);
3413 s_dense.scaled_add(lambdas[1], &pen1);
3414
3415 let (evals_dense, _): (ndarray::Array1<f64>, ndarray::Array2<f64>) =
3417 s_dense.eigh(faer::Side::Lower).unwrap();
3418 let tol = 1e-12;
3419 let ref_logdet: f64 = evals_dense
3420 .iter()
3421 .filter(|&&v: &&f64| v > tol)
3422 .map(|&v: &f64| v.ln())
3423 .sum();
3424
3425 let marginal_designs = vec![
3427 Array2::<f64>::eye(q1), Array2::<f64>::eye(q2),
3429 ];
3430 let marginal_penalties = vec![s1, s2];
3431 let kron_result = super::kronecker_reparameterization_engine(
3432 &marginal_designs,
3433 &marginal_penalties,
3434 &[q1, q2],
3435 &lambdas,
3436 false,
3437 None,
3438 )
3439 .unwrap();
3440
3441 let diff = (kron_result.log_det - ref_logdet).abs();
3442 assert!(
3443 diff < 1e-8,
3444 "Kronecker logdet {:.10} vs dense {:.10}, diff={:.3e}",
3445 kron_result.log_det,
3446 ref_logdet,
3447 diff,
3448 );
3449
3450 let rhos: Vec<f64> = lambdas.iter().map(|&l| l.ln()).collect();
3452 let eps = 1e-5;
3453 for k in 0..2 {
3454 let mut rho_plus = rhos.clone();
3455 rho_plus[k] += eps;
3456 let mut rho_minus = rhos.clone();
3457 rho_minus[k] -= eps;
3458 let lam_plus: Vec<f64> = rho_plus.iter().map(|&r| r.exp()).collect();
3459 let lam_minus: Vec<f64> = rho_minus.iter().map(|&r| r.exp()).collect();
3460 let result_plus = super::kronecker_reparameterization_engine(
3461 &marginal_designs,
3462 &marginal_penalties,
3463 &[q1, q2],
3464 &lam_plus,
3465 false,
3466 None,
3467 )
3468 .unwrap();
3469 let result_minus = super::kronecker_reparameterization_engine(
3470 &marginal_designs,
3471 &marginal_penalties,
3472 &[q1, q2],
3473 &lam_minus,
3474 false,
3475 None,
3476 )
3477 .unwrap();
3478 let fd_deriv = (result_plus.log_det - result_minus.log_det) / (2.0 * eps);
3479 let analytic_deriv = kron_result.det1[k];
3480 let rel_err = if analytic_deriv.abs() > 1e-10 {
3481 (fd_deriv - analytic_deriv).abs() / analytic_deriv.abs()
3482 } else {
3483 (fd_deriv - analytic_deriv).abs()
3484 };
3485 assert!(
3486 rel_err < 1e-4,
3487 "det1[{k}] mismatch: analytic={:.8}, fd={:.8}, rel_err={:.3e}",
3488 analytic_deriv,
3489 fd_deriv,
3490 rel_err,
3491 );
3492 }
3493 }
3494
3495 #[test]
3496 fn classify_strict_rejects_nan_eigenvalue() {
3497 let mut eigs = [1.0, f64::NAN, 0.5];
3498 match classify_eigenvalues_strict(&mut eigs, "test_nan") {
3499 Err(EstimationError::PenaltySpectrumNonFinite {
3500 context,
3501 index,
3502 value,
3503 }) => {
3504 assert_eq!(context, "test_nan");
3505 assert_eq!(index, 1);
3506 assert!(value.is_nan());
3507 }
3508 other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3509 }
3510 }
3511
3512 #[test]
3513 fn classify_strict_rejects_inf_eigenvalue() {
3514 let mut eigs = [1.0, 0.5, f64::INFINITY];
3515 match classify_eigenvalues_strict(&mut eigs, "test_inf") {
3516 Err(EstimationError::PenaltySpectrumNonFinite { index, value, .. }) => {
3517 assert_eq!(index, 2);
3518 assert!(value.is_infinite());
3519 }
3520 other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3521 }
3522 }
3523
3524 #[test]
3525 fn classify_strict_rejects_materially_indefinite() {
3526 let mut eigs = [1.0, -1e-2, 0.5];
3528 match classify_eigenvalues_strict(&mut eigs, "test_indef") {
3529 Err(EstimationError::PenaltySpectrumIndefinite {
3530 context,
3531 index,
3532 value,
3533 ..
3534 }) => {
3535 assert_eq!(context, "test_indef");
3536 assert_eq!(index, 1);
3537 assert!((value + 1e-2).abs() <= 1e-15);
3538 }
3539 other => panic!("expected PenaltySpectrumIndefinite, got {:?}", other),
3540 }
3541 }
3542
3543 #[test]
3544 fn classify_strict_accepts_roundoff_negative() {
3545 let scale = 1.0_f64;
3547 let roundoff = -1e-16 * scale;
3548 let mut eigs = [scale, 0.5 * scale, roundoff, 0.25 * scale];
3549 classify_eigenvalues_strict(&mut eigs, "test_roundoff").expect("roundoff must classify");
3550 assert_eq!(eigs[2], 0.0);
3552 assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3554 }
3555
3556 #[test]
3557 fn classify_strict_accepts_extreme_lambda_assembly_noise_1619() {
3558 let scale = 8.509e12_f64;
3565 let noise = -6.546e2_f64;
3567 assert!(
3568 (noise.abs() / scale) < 1.0e-10,
3569 "fixture must reproduce the ~1e-11-relative noise from #1619"
3570 );
3571 let mut eigs = vec![scale, 0.5 * scale, noise, 0.1 * scale];
3572 classify_eigenvalues_strict(&mut eigs, "range penalty block")
3573 .expect("a ~1e-11-relative roundoff-negative eigenvalue must be accepted (#1619)");
3574 assert_eq!(eigs[2], 0.0);
3576 assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3578 }
3579
3580 #[test]
3581 fn classify_strict_snaps_subtol_positive_to_zero() {
3582 let scale = 10.0_f64;
3585 let subtol = 1e-15 * scale;
3586 let mut eigs = [scale, subtol];
3587 classify_eigenvalues_strict(&mut eigs, "test_sub_pos").expect("sub-tol positive ok");
3588 assert_eq!(eigs[1], 0.0);
3589 }
3590
3591 fn canonical_from_local(
3595 local: Array2<f64>,
3596 col_range: std::ops::Range<usize>,
3597 total_dim: usize,
3598 ) -> CanonicalPenalty {
3599 let block_dim = local.nrows();
3600 let root = Array2::<f64>::zeros((0, block_dim));
3602 CanonicalPenalty {
3603 root,
3604 col_range,
3605 total_dim,
3606 nullity: 0,
3607 local,
3608 prior_mean: Array1::zeros(block_dim),
3609 positive_eigenvalues: Vec::new(),
3610 op: None,
3611 }
3612 }
3613
3614 #[test]
3615 fn report_penalty_pair_redundancy_detects_identical_pair() {
3616 let s0 = ndarray::array![[2.0, 0.5, 0.0], [0.5, 1.0, 0.25], [0.0, 0.25, 1.5],];
3618 let s_shared = ndarray::array![[1.0, -0.5, 0.0], [-0.5, 2.0, -0.5], [0.0, -0.5, 1.0],];
3621
3622 let bundle = vec![
3623 canonical_from_local(s0, 0..3, 3),
3624 canonical_from_local(s_shared.clone(), 0..3, 3),
3625 canonical_from_local(s_shared, 0..3, 3),
3626 ];
3627
3628 let redundant = report_penalty_pair_redundancy(&bundle);
3629
3630 assert_eq!(
3633 redundant.len(),
3634 1,
3635 "expected exactly one redundant pair, got {:?}",
3636 redundant
3637 );
3638 let (i, j, cos) = redundant[0];
3639 assert_eq!((i, j), (1, 2));
3640 assert!(
3641 cos > 1.0 - 1e-12,
3642 "cosine for identical penalties should be ~1.0, got {cos}"
3643 );
3644 }
3645
3646 #[test]
3647 fn report_penalty_pair_redundancy_skips_different_col_ranges() {
3648 let s = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
3652 let bundle = vec![
3653 canonical_from_local(s.clone(), 0..2, 4),
3654 canonical_from_local(s, 2..4, 4),
3655 ];
3656 let redundant = report_penalty_pair_redundancy(&bundle);
3657 assert!(
3658 redundant.is_empty(),
3659 "different col_ranges must not be flagged"
3660 );
3661 }
3662}