1use crate::basis::analyze_penalty_block;
2use crate::EstimationError;
3use crate::smooth::PenaltyStructureHint;
4use faer::linalg::matmul::matmul;
5use faer::{Accum, Mat, MatRef, Par, Side};
6use gam_linalg::faer_ndarray::{FaerEigh, FaerLinalgError, FaerSvd};
7use gam_linalg::matrix::symmetrize_in_place;
8use gam_linalg::utils::KahanSum;
9use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut2, Axis, s};
10use rayon::iter::{
11 IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
12};
13use std::collections::{BTreeMap, HashSet};
14use std::ops::Range;
15use std::sync::Arc;
16
17#[derive(Clone)]
18pub enum PenaltyRepresentation {
19 Dense(Array2<f64>),
20 Banded {
21 bands: Vec<Array1<f64>>,
22 offsets: Vec<i32>,
23 },
24 Kronecker {
25 left: Array2<f64>,
31 right: Array2<f64>,
32 },
33}
34
35impl PenaltyRepresentation {
36 pub fn block_dimension(&self) -> usize {
38 match self {
39 PenaltyRepresentation::Dense(matrix) => matrix.nrows(),
40 PenaltyRepresentation::Banded { bands, offsets } => {
41 let mut dim = 0usize;
42 for (band, &offset) in bands.iter().zip(offsets.iter()) {
43 let len = band.len();
44 let extent = if offset >= 0 {
45 len + offset as usize
46 } else {
47 len + (-offset) as usize
48 };
49 dim = dim.max(extent);
50 }
51 dim
52 }
53 PenaltyRepresentation::Kronecker { left, right } => left.nrows() * right.nrows(),
54 }
55 }
56
57 pub fn to_block_dense(&self) -> Array2<f64> {
60 match self {
61 PenaltyRepresentation::Dense(matrix) => matrix.clone(),
62 PenaltyRepresentation::Banded { bands, offsets } => {
63 let dim = self.block_dimension();
64 let mut dense = Array2::zeros((dim, dim));
65 let positive_offsets: HashSet<usize> = offsets
66 .iter()
67 .filter_map(|&off| (off >= 0).then_some(off as usize))
68 .collect();
69 for (band, &offset) in bands.iter().zip(offsets.iter()) {
70 let off = offset.unsigned_abs() as usize;
71 if offset < 0 && positive_offsets.contains(&off) {
72 continue;
73 }
74 for (idx, &value) in band.iter().enumerate() {
75 let (i, j) = if offset >= 0 {
76 (idx, idx + off)
77 } else {
78 (idx + off, idx)
79 };
80 if i >= dim || j >= dim {
81 continue;
82 }
83 dense[[i, j]] = value;
84 dense[[j, i]] = value;
85 }
86 }
87 dense
88 }
89 PenaltyRepresentation::Kronecker { left, right } => {
90 let (lrows, l_cols) = left.dim();
91 let (rrows, r_cols) = right.dim();
92 let mut result = Array2::zeros((lrows * rrows, l_cols * r_cols));
93 for i in 0..lrows {
94 for j in 0..l_cols {
95 let scale = left[(i, j)];
96 if scale == 0.0 {
97 continue;
98 }
99 let mut block = result.slice_mut(s![
100 i * rrows..(i + 1) * rrows,
101 j * r_cols..(j + 1) * r_cols
102 ]);
103 block.assign(&(right * scale));
104 }
105 }
106 result
107 }
108 }
109 }
110}
111
112#[derive(Clone)]
113pub struct PenaltyMatrix {
114 pub col_range: Range<usize>,
115 pub representation: PenaltyRepresentation,
116}
117
118impl PenaltyMatrix {
119 fn accumulate_into(&self, mut dest: ArrayViewMut2<'_, f64>, weight: f64) {
120 if weight == 0.0 {
121 return;
122 }
123 match &self.representation {
124 PenaltyRepresentation::Dense(block) => {
125 dest.scaled_add(weight, block);
126 }
127 PenaltyRepresentation::Banded { bands, offsets } => {
128 let positive_offsets: HashSet<usize> = offsets
129 .iter()
130 .filter_map(|&off| (off >= 0).then_some(off as usize))
131 .collect();
132 for (band, &offset) in bands.iter().zip(offsets.iter()) {
133 let off = offset.unsigned_abs() as usize;
134 if offset < 0 && positive_offsets.contains(&off) {
135 continue;
136 }
137 for (idx, &value) in band.iter().enumerate() {
138 let (i, j) = if offset >= 0 {
139 (idx, idx + off)
140 } else {
141 (idx + off, idx)
142 };
143 let Some(entry_ij) = dest.get_mut((i, j)) else {
144 continue;
145 };
146 *entry_ij += weight * value;
147 if i != j
148 && let Some(entry_ji) = dest.get_mut((j, i))
149 {
150 *entry_ji += weight * value;
151 }
152 }
153 }
154 }
155 PenaltyRepresentation::Kronecker { left, right } => {
156 let (lrows, l_cols) = left.dim();
157 let (rrows, r_cols) = right.dim();
158 for i in 0..lrows {
159 for j in 0..l_cols {
160 let scale = left[(i, j)] * weight;
161 if scale == 0.0 {
162 continue;
163 }
164 let mut block = dest.slice_mut(s![
165 i * rrows..(i + 1) * rrows,
166 j * r_cols..(j + 1) * r_cols
167 ]);
168 block.scaled_add(scale, right);
169 }
170 }
171 }
172 }
173 }
174
175 pub fn to_dense(&self, total_dim: usize) -> Array2<f64> {
176 let mut dense = Array2::<f64>::zeros((total_dim, total_dim));
177 self.accumulate_into(
178 dense.slice_mut(s![self.col_range.clone(), self.col_range.clone()]),
179 1.0,
180 );
181 dense
182 }
183}
184
185pub(crate) fn array_to_faer(array: &Array2<f64>) -> Mat<f64> {
186 let (rows, cols) = array.dim();
187 Mat::from_fn(rows, cols, |i, j| array[[i, j]])
188}
189
190pub(crate) fn mat_to_array(mat: &Mat<f64>) -> Array2<f64> {
191 let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
192 for i in 0..mat.nrows() {
193 for j in 0..mat.ncols() {
194 out[[i, j]] = mat[(i, j)];
195 }
196 }
197 out
198}
199
200fn mat_max_abs_element(matrix: MatRef<'_, f64>) -> f64 {
201 let (rows, cols) = matrix.shape();
202 let mut maxval = 0.0_f64;
203 for i in 0..rows {
204 for j in 0..cols {
205 let val = matrix[(i, j)];
206 if val.is_finite() {
207 maxval = maxval.max(val.abs());
208 }
209 }
210 }
211 maxval
212}
213
214fn sanitize_symmetric_faer(matrix: &Mat<f64>) -> Mat<f64> {
215 let (rows, cols) = matrix.as_ref().shape();
216 assert_eq!(rows, cols, "Matrix must be square for sanitization");
217
218 let mut sanitized = matrix.clone();
219
220 for i in 0..rows {
221 let diag = sanitized[(i, i)];
222 if !diag.is_finite() {
223 sanitized[(i, i)] = 0.0;
224 }
225 for j in (i + 1)..cols {
226 let mut upper = sanitized[(i, j)];
227 let mut lower = sanitized[(j, i)];
228 if !upper.is_finite() {
229 upper = 0.0;
230 }
231 if !lower.is_finite() {
232 lower = 0.0;
233 }
234 let avg = 0.5 * (upper + lower);
235 sanitized[(i, j)] = avg;
236 sanitized[(j, i)] = avg;
237 }
238 }
239
240 let scale = mat_max_abs_element(sanitized.as_ref());
241 let tiny = (scale * 1e-14).max(1e-30);
242 for i in 0..rows {
243 for j in 0..cols {
244 let val = sanitized[(i, j)];
245 if !val.is_finite() {
246 sanitized[(i, j)] = 0.0;
247 } else if val.abs() < tiny {
248 sanitized[(i, j)] = 0.0;
249 }
250 }
251 }
252
253 sanitized
254}
255
256fn penalty_from_root_faer(root: &Mat<f64>) -> Mat<f64> {
257 let cols = root.ncols();
258 let mut full = Mat::<f64>::zeros(cols, cols);
259 let root_ref = root.as_ref();
260 let root_t = root_ref.transpose();
261 matmul(
262 full.as_mut(),
263 Accum::Replace,
264 root_t,
265 root_ref,
266 1.0,
267 Par::Seq,
268 );
269 sanitize_symmetric_faer(&full)
270}
271
272fn symmetrize_faer_matrix_in_place(matrix: &mut Mat<f64>) {
273 let n = matrix.nrows().min(matrix.ncols());
274 for i in 0..n {
275 for j in 0..i {
276 let avg = 0.5 * (matrix[(i, j)] + matrix[(j, i)]);
277 matrix[(i, j)] = avg;
278 matrix[(j, i)] = avg;
279 }
280 }
281}
282
283fn orthogonal_similarity_transform_faer(
284 matrix: &Mat<f64>,
285 block_dim: usize,
286 orthogonal: &Mat<f64>,
287) -> Mat<f64> {
288 let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
289 let cols = orthogonal.ncols();
290 let mut temp = Mat::<f64>::zeros(block_dim, cols);
291 matmul(
292 temp.as_mut(),
293 Accum::Replace,
294 matrix_block,
295 orthogonal.as_ref(),
296 1.0,
297 Par::Seq,
298 );
299 let mut rotated = Mat::<f64>::zeros(cols, cols);
300 matmul(
301 rotated.as_mut(),
302 Accum::Replace,
303 orthogonal.transpose(),
304 temp.as_ref(),
305 1.0,
306 Par::Seq,
307 );
308 symmetrize_faer_matrix_in_place(&mut rotated);
309 rotated
310}
311
312fn trace_penalty_in_orthogonal_basis(
313 matrix: &Mat<f64>,
314 block_dim: usize,
315 orthogonal: &Mat<f64>,
316 rotated_eigenvalues: &[f64],
317 delta: f64,
318) -> f64 {
319 let matrix_block = matrix.as_ref().submatrix(0, 0, block_dim, block_dim);
320 let cols = orthogonal.ncols();
321 assert!(rotated_eigenvalues.len() >= cols);
322 let mut projected = Mat::<f64>::zeros(block_dim, cols);
323 matmul(
324 projected.as_mut(),
325 Accum::Replace,
326 matrix_block,
327 orthogonal.as_ref(),
328 1.0,
329 Par::Seq,
330 );
331 let mut trace = KahanSum::default();
332 for l in 0..cols {
333 let mut diag_ll = KahanSum::default();
334 for i in 0..block_dim {
335 diag_ll.add(orthogonal[(i, l)] * projected[(i, l)]);
336 }
337 trace.add(diag_ll.sum() / (rotated_eigenvalues[l] + delta));
338 }
339 trace.sum()
340}
341
342pub fn trace_reduced_penalty_covariance(
343 reduced_penalty: &Array2<f64>,
344 covariance_basis: &Array2<f64>,
345) -> f64 {
346 assert_eq!(
347 reduced_penalty.dim(),
348 covariance_basis.dim(),
349 "trace_reduced_penalty_covariance dimension mismatch"
350 );
351 let r = covariance_basis.nrows();
352 let mut trace = KahanSum::default();
353 for i in 0..r {
354 for j in 0..r {
355 trace.add(covariance_basis[[i, j]] * reduced_penalty[[j, i]]);
356 }
357 }
358 trace.sum()
359}
360
361pub fn trace_penalty_covariance_in_orthogonal_basis(
362 matrix: &Array2<f64>,
363 orthogonal: &Array2<f64>,
364 covariance_basis: &Array2<f64>,
365) -> f64 {
366 let reduced = gam_linalg::faer_ndarray::fast_ab(
367 &gam_linalg::faer_ndarray::fast_atb(orthogonal, matrix),
368 orthogonal,
369 );
370 trace_reduced_penalty_covariance(&reduced, covariance_basis)
371}
372
373fn classify_eigenvalues_strict(
387 eigenvalues: &mut [f64],
388 context: &str,
389) -> Result<(), EstimationError> {
390 const C_EPS_P_FACTOR: f64 = 64.0;
391 let p = eigenvalues.len();
392
393 let mut scale = 0.0_f64;
394 for (idx, &val) in eigenvalues.iter().enumerate() {
395 if !val.is_finite() {
396 return Err(EstimationError::PenaltySpectrumNonFinite {
397 context: context.to_string(),
398 index: idx,
399 value: val,
400 });
401 }
402 scale = scale.max(val.abs());
403 }
404
405 let tolerance =
410 (C_EPS_P_FACTOR * f64::EPSILON * (p.max(1) as f64) * scale).max(f64::MIN_POSITIVE);
411
412 for (idx, val) in eigenvalues.iter_mut().enumerate() {
413 if val.abs() <= tolerance {
414 *val = 0.0;
415 } else if *val < 0.0 {
416 return Err(EstimationError::PenaltySpectrumIndefinite {
417 context: context.to_string(),
418 index: idx,
419 value: *val,
420 tolerance,
421 scale,
422 });
423 }
424 }
425 Ok(())
426}
427
428fn robust_eighwith_policy<M, V, E, Validate, Sanitize, EigCall, MapErr>(
429 matrix: &M,
430 context: &str,
431 validate_input: Validate,
432 sanitize: Sanitize,
433 mut eig_call: EigCall,
434 map_error: MapErr,
435) -> Result<(Vec<f64>, V), EstimationError>
436where
437 Validate: Fn(&M, &str) -> Result<(), EstimationError>,
438 Sanitize: Fn(&M) -> M,
439 EigCall: FnMut(&M) -> Result<(Vec<f64>, V), E>,
440 MapErr: Fn(E, &str) -> EstimationError,
441{
442 validate_input(matrix, context)?;
443
444 let candidate = sanitize(matrix);
450 match eig_call(&candidate) {
451 Ok((mut eigenvalues, eigenvectors)) => {
452 classify_eigenvalues_strict(&mut eigenvalues, context)?;
453 Ok((eigenvalues, eigenvectors))
454 }
455 Err(err) => Err(map_error(err, context)),
456 }
457}
458
459pub(crate) fn robust_eigh_faer(
460 matrix: &Mat<f64>,
461 side: Side,
462 context: &str,
463) -> Result<(Vec<f64>, Mat<f64>), EstimationError> {
464 robust_eighwith_policy(
465 matrix,
466 context,
467 |mat, ctx| {
468 let (rows, cols) = mat.as_ref().shape();
469 for i in 0..rows {
470 for j in 0..cols {
471 let val = mat[(i, j)];
472 if !val.is_finite() {
473 let max_abs = mat_max_abs_element(mat.as_ref());
474 crate::bail_invalid_estim!(
475 "{} contains non-finite entries (max finite magnitude {:.3e})",
476 ctx,
477 max_abs
478 );
479 }
480 }
481 }
482 Ok(())
483 },
484 sanitize_symmetric_faer,
485 |candidate| {
486 let eig = candidate.as_ref().self_adjoint_eigen(side)?;
487 let diag = eig.S();
488 let mut eigenvalues = Vec::with_capacity(diag.dim());
489 for idx in 0..diag.dim() {
490 eigenvalues.push(diag[idx]);
491 }
492
493 let vectors_ref = eig.U();
494 let mut eigenvectors = Mat::<f64>::zeros(vectors_ref.nrows(), vectors_ref.ncols());
495 for i in 0..vectors_ref.nrows() {
496 for j in 0..vectors_ref.ncols() {
497 eigenvectors[(i, j)] = vectors_ref[(i, j)];
498 }
499 }
500 Ok((eigenvalues, eigenvectors))
501 },
502 |err, _ctx| {
503 EstimationError::EigendecompositionFailed(FaerLinalgError::SelfAdjointEigen(err))
504 },
505 )
506}
507
508fn robust_eigh(
509 matrix: &Array2<f64>,
510 side: Side,
511 context: &str,
512) -> Result<(Array1<f64>, Array2<f64>), EstimationError> {
513 let matrix_faer = array_to_faer(matrix);
514 let (eigenvalues, eigenvectors) = robust_eigh_faer(&matrix_faer, side, context)?;
515 Ok((Array1::from_vec(eigenvalues), mat_to_array(&eigenvectors)))
516}
517
518pub(crate) fn kronecker_marginal_eigensystems(
519 marginal_penalties: &[Array2<f64>],
520 context: &str,
521) -> Result<Vec<(Array1<f64>, Array2<f64>)>, EstimationError> {
522 let mut eigensystems = Vec::with_capacity(marginal_penalties.len());
523 for (k, penalty) in marginal_penalties.iter().enumerate() {
524 eigensystems.push(robust_eigh(
525 penalty,
526 Side::Lower,
527 &format!("{context} marginal {k}"),
528 )?);
529 }
530 Ok(eigensystems)
531}
532
533#[derive(Debug, Clone, Copy)]
534struct SubspaceLeakageMetrics {
535 max_abs_sq: f64,
536 max_rel_sq: f64,
537 worst_penalty: usize,
538 max_cross_gram_abs: f64,
539}
540
541fn assess_subspace_leakage(
542 qs: &Mat<f64>,
543 rs_transformed: &[Mat<f64>],
544 structural_rank: usize,
545 p: usize,
546) -> SubspaceLeakageMetrics {
547 let mut max_abs_sq = 0.0_f64;
548 let mut max_rel_sq = 0.0_f64;
549 let mut worst_penalty = 0usize;
550
551 for (k, rs) in rs_transformed.iter().enumerate() {
552 let rows = rs.nrows();
553 let cols = rs.ncols().min(p);
554 let null_start = structural_rank.min(cols);
555 let mut abs_sq = 0.0_f64;
556 let mut total_sq = 0.0_f64;
557 for i in 0..rows {
558 for j in 0..cols {
559 let v = rs[(i, j)];
560 let vv = v * v;
561 total_sq += vv;
562 if j >= null_start {
563 abs_sq += vv;
564 }
565 }
566 }
567 let rel_sq = if total_sq > 0.0 {
568 abs_sq / total_sq
569 } else {
570 0.0
571 };
572 if rel_sq > max_rel_sq {
573 max_rel_sq = rel_sq;
574 worst_penalty = k;
575 }
576 max_abs_sq = max_abs_sq.max(abs_sq);
577 }
578
579 let mut max_cross_gram_abs = 0.0_f64;
580 let null_count = p.saturating_sub(structural_rank);
581 if structural_rank > 0 && null_count > 0 {
582 for i in 0..structural_rank {
583 for j in 0..null_count {
584 let qn_col = structural_rank + j;
585 let mut dot = 0.0_f64;
586 for r in 0..p {
587 dot += qs[(r, i)] * qs[(r, qn_col)];
588 }
589 max_cross_gram_abs = max_cross_gram_abs.max(dot.abs());
590 }
591 }
592 }
593
594 SubspaceLeakageMetrics {
595 max_abs_sq,
596 max_rel_sq,
597 worst_penalty,
598 max_cross_gram_abs,
599 }
600}
601
602fn compose_qs_from_split(q_pen: &Mat<f64>, q_null: &Mat<f64>, p: usize) -> Mat<f64> {
603 let rank = q_pen.ncols();
604 let null_count = q_null.ncols();
605 let mut qs = Mat::<f64>::zeros(p, p);
606 for i in 0..p {
607 for j in 0..rank {
608 qs[(i, j)] = q_pen[(i, j)];
609 }
610 for j in 0..null_count {
611 qs[(i, rank + j)] = q_null[(i, j)];
612 }
613 }
614 qs
615}
616
617pub fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
621 let (arows, a_cols) = a.dim();
622 let (brows, b_cols) = b.dim();
623 if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
624 return Array2::zeros((arows * brows, a_cols * b_cols));
625 }
626 let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
627
628 result
629 .axis_chunks_iter_mut(Axis(0), brows)
630 .into_par_iter()
631 .enumerate()
632 .for_each(|(i, mut row_block)| {
633 let arow = a.row(i);
634 let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
635 for (j, mut block) in col_chunks.into_iter().enumerate() {
636 let aval = arow[j];
637 if aval == 0.0 {
638 continue;
639 }
640 for (dest, &src) in block.iter_mut().zip(b.iter()) {
641 *dest = aval * src;
642 }
643 }
644 });
645
646 result
647}
648
649#[derive(Clone)]
651pub struct ReparamResult {
652 pub s_transformed: Array2<f64>,
656 pub log_det: f64,
658 pub det1: Array1<f64>,
660 pub qs: Array2<f64>,
662 pub canonical_transformed: Vec<CanonicalPenalty>,
667 pub e_transformed: Array2<f64>,
670 pub u_truncated: Array2<f64>,
680 pub penalty_shrinkage_ridge: f64,
683}
684
685struct KroneckerFactorDecomp {
691 root: Array2<f64>, positive_eigenvalues: Vec<f64>, rank: usize,
694 dim: usize,
695}
696
697fn decompose_kronecker_factors(
700 factors: &[Array2<f64>],
701 context: &str,
702) -> Result<Option<Vec<KroneckerFactorDecomp>>, EstimationError> {
703 let mut decomps = Vec::with_capacity(factors.len());
704 for (j, factor) in factors.iter().enumerate() {
705 let q_j = factor.nrows();
706 if q_j != factor.ncols() {
707 crate::bail_invalid_estim!(
708 "{context}: Kronecker factor {j} must be square, got {}x{}",
709 factor.nrows(),
710 factor.ncols()
711 );
712 }
713 let is_identity = {
714 let mut is_id = true;
715 'outer: for r in 0..q_j {
716 for c in 0..q_j {
717 let expected = if r == c { 1.0 } else { 0.0 };
718 if (factor[[r, c]] - expected).abs() > 1e-12 {
719 is_id = false;
720 break 'outer;
721 }
722 }
723 }
724 is_id
725 };
726 if is_identity {
727 decomps.push(KroneckerFactorDecomp {
728 root: Array2::eye(q_j),
729 positive_eigenvalues: vec![1.0; q_j],
730 rank: q_j,
731 dim: q_j,
732 });
733 continue;
734 }
735 let analysis = analyze_penalty_block(factor).map_err(|err| {
736 EstimationError::InvalidInput(format!(
737 "{context}: Kronecker factor {j} eigendecomp failed: {err}"
738 ))
739 })?;
740 if analysis.rank == 0 {
741 return Ok(None);
742 }
743 let factor_classes =
747 crate::basis::SpectralClassification::new(&analysis.eigenvalues, analysis.tol);
748 let mut root_j = Array2::zeros((analysis.rank, q_j));
749 let mut pos_eigs = Vec::with_capacity(analysis.rank);
750 for (row_idx, &i) in factor_classes.range_idx.iter().enumerate() {
751 let eigenval = analysis.eigenvalues[i];
752 let sqrt_ev = eigenval.sqrt();
753 let evec = analysis.eigenvectors.column(i);
754 for (col, &v) in evec.iter().enumerate() {
755 root_j[[row_idx, col]] = sqrt_ev * v;
756 }
757 pos_eigs.push(eigenval);
758 }
759 decomps.push(KroneckerFactorDecomp {
760 root: root_j,
761 positive_eigenvalues: pos_eigs,
762 rank: analysis.rank,
763 dim: q_j,
764 });
765 }
766 Ok(Some(decomps))
767}
768
769fn assemble_kronecker_root_local(decomps: &[KroneckerFactorDecomp]) -> Array2<f64> {
771 let mut kron_root = decomps[0].root.clone();
772 for fr in &decomps[1..] {
773 let (r1, c1) = kron_root.dim();
774 let (r2, c2) = (fr.rank, fr.dim);
775 let mut new_root = Array2::zeros((r1 * r2, c1 * c2));
776 for i1 in 0..r1 {
777 for i2 in 0..r2 {
778 for j1 in 0..c1 {
779 for j2 in 0..c2 {
780 new_root[[i1 * r2 + i2, j1 * c2 + j2]] =
781 kron_root[[i1, j1]] * fr.root[[i2, j2]];
782 }
783 }
784 }
785 }
786 kron_root = new_root;
787 }
788 kron_root
789}
790
791fn kronecker_eigenvalues(decomps: &[KroneckerFactorDecomp], block_dim: usize) -> (Vec<f64>, usize) {
793 let mut kron_eigs = decomps[0].positive_eigenvalues.clone();
794 for fd in &decomps[1..] {
795 let mut new_eigs = Vec::with_capacity(kron_eigs.len() * fd.positive_eigenvalues.len());
796 for &a in &kron_eigs {
797 for &b in &fd.positive_eigenvalues {
798 new_eigs.push(a * b);
799 }
800 }
801 kron_eigs = new_eigs;
802 }
803 let max_ev = kron_eigs.iter().copied().fold(0.0_f64, f64::max);
804 let tol = max_ev * 1e-10 * (block_dim as f64);
805 let positive: Vec<f64> = kron_eigs.into_iter().filter(|&ev| ev > tol).collect();
806 let nullity = block_dim - positive.len();
807 (positive, nullity)
808}
809
810#[derive(Clone)]
820pub struct CanonicalPenalty {
821 pub root: Array2<f64>,
824 pub col_range: std::ops::Range<usize>,
827 pub total_dim: usize,
829 pub nullity: usize,
831 pub local: Array2<f64>,
835 pub prior_mean: Array1<f64>,
837 pub positive_eigenvalues: Vec<f64>,
840 pub op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
844}
845
846impl std::fmt::Debug for CanonicalPenalty {
847 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848 f.debug_struct("CanonicalPenalty")
849 .field(
850 "root",
851 &format_args!("{}×{}", self.root.nrows(), self.root.ncols()),
852 )
853 .field("col_range", &self.col_range)
854 .field("total_dim", &self.total_dim)
855 .field("nullity", &self.nullity)
856 .field(
857 "local",
858 &format_args!("{}×{}", self.local.nrows(), self.local.ncols()),
859 )
860 .field("prior_mean_len", &self.prior_mean.len())
861 .field("positive_eigenvalues", &self.positive_eigenvalues)
862 .field("op", &self.op.as_ref().map(|o| o.dim()))
863 .finish()
864 }
865}
866
867impl CanonicalPenalty {
868 pub fn from_dense_root(root: Array2<f64>, p: usize) -> Self {
872 Self::from_dense_root_with_mean(root, p, Array1::zeros(p))
873 }
874
875 pub fn from_dense_root_with_mean(root: Array2<f64>, p: usize, prior_mean: Array1<f64>) -> Self {
876 assert_eq!(prior_mean.len(), p);
877 let local = root.t().dot(&root);
878 let positive_eigenvalues = Vec::new(); Self {
880 root,
881 col_range: 0..p,
882 total_dim: p,
883 nullity: 0,
884 local,
885 prior_mean,
886 positive_eigenvalues,
887 op: None,
888 }
889 }
890
891 pub fn full_width_root(&self) -> Array2<f64> {
894 if self.col_range.start == 0 && self.col_range.end == self.total_dim {
895 return self.root.clone();
896 }
897 let rank = self.root.nrows();
898 let mut full = Array2::<f64>::zeros((rank, self.total_dim));
899 full.slice_mut(ndarray::s![.., self.col_range.clone()])
900 .assign(&self.root);
901 full
902 }
903
904 pub fn rank(&self) -> usize {
906 self.root.nrows()
907 }
908
909 pub fn block_dim(&self) -> usize {
911 self.col_range.len()
912 }
913
914 pub const fn is_block_local(&self) -> bool {
916 self.col_range.start != 0 || self.col_range.end != self.total_dim
917 }
918
919 pub fn local_ref(&self) -> &Array2<f64> {
922 &self.local
923 }
924
925 pub fn local_penalty(&self) -> Array2<f64> {
928 self.local.clone()
929 }
930
931 pub fn accumulate_weighted(&self, target: &mut Array2<f64>, lambda: f64) {
934 if lambda == 0.0 || self.rank() == 0 {
935 return;
936 }
937 let r = &self.col_range;
938 target
939 .slice_mut(s![r.start..r.end, r.start..r.end])
940 .scaled_add(lambda, &self.local);
941 }
942
943 pub fn trace_product(&self, m: &Array2<f64>, scale: f64) -> f64 {
946 if self.rank() == 0 || scale == 0.0 {
947 return 0.0;
948 }
949 let r = &self.col_range;
950 let m_block = m.slice(s![r.start..r.end, r.start..r.end]);
951 let rm = self.root.dot(&m_block);
952 scale
953 * rm.iter()
954 .zip(self.root.iter())
955 .map(|(&a, &b)| a * b)
956 .sum::<f64>()
957 }
958
959 pub fn quadratic(&self, v: &Array1<f64>, scale: f64) -> f64 {
962 if self.rank() == 0 || scale == 0.0 {
963 return 0.0;
964 }
965 let v_block = v.slice(s![self.col_range.start..self.col_range.end]);
966 let rv = self.root.dot(&v_block);
967 scale * rv.dot(&rv)
968 }
969
970 pub fn prior_linear_shift(&self, scale: f64) -> Array1<f64> {
972 let mut out = Array1::<f64>::zeros(self.total_dim);
973 if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
974 return out;
975 }
976 let block = self.local.dot(&self.prior_mean) * scale;
977 out.slice_mut(s![self.col_range.start..self.col_range.end])
978 .assign(&block);
979 out
980 }
981
982 pub fn prior_constant_shift(&self, scale: f64) -> f64 {
984 if self.rank() == 0 || scale == 0.0 || self.prior_mean.iter().all(|&v| v == 0.0) {
985 return 0.0;
986 }
987 scale * self.prior_mean.dot(&self.local.dot(&self.prior_mean))
988 }
989
990 pub fn full_width_prior_mean(&self) -> Array1<f64> {
992 if self.col_range.start == 0 && self.col_range.end == self.total_dim {
993 return self.prior_mean.clone();
994 }
995 let mut out = Array1::<f64>::zeros(self.total_dim);
996 out.slice_mut(s![self.col_range.start..self.col_range.end])
997 .assign(&self.prior_mean);
998 out
999 }
1000
1001 pub fn to_penalty_coordinate(
1003 &self,
1004 ) -> gam_problem::PenaltyCoordinate {
1005 use gam_problem::PenaltyCoordinate;
1006 if self.is_block_local() {
1007 PenaltyCoordinate::from_block_root_with_mean(
1008 self.root.clone(),
1009 self.col_range.start,
1010 self.col_range.end,
1011 self.total_dim,
1012 self.prior_mean.clone(),
1013 )
1014 } else {
1015 PenaltyCoordinate::from_dense_root_with_mean(self.root.clone(), self.prior_mean.clone())
1016 }
1017 }
1018}
1019
1020pub fn report_penalty_pair_redundancy(canonical: &[CanonicalPenalty]) -> Vec<(usize, usize, f64)> {
1047 const REDUNDANCY_THRESHOLD: f64 = 1.0 - 1e-8;
1048 const SIMILARITY_THRESHOLD: f64 = 0.99;
1049 const LARGE_SCALE_K_THRESHOLD: usize = 64;
1050 const TOP_SIMILARITY_PAIRS: usize = 3;
1051
1052 let k = canonical.len();
1053 let mut redundant: Vec<(usize, usize, f64)> = Vec::new();
1054 let mut similar: Vec<(usize, usize, f64)> = Vec::new();
1055
1056 let trace_sq: Vec<f64> = canonical
1059 .iter()
1060 .map(|p| p.local.iter().map(|&v| v * v).sum::<f64>())
1061 .collect();
1062
1063 for i in 0..k {
1064 if trace_sq[i] == 0.0 {
1065 continue;
1066 }
1067 for j in (i + 1)..k {
1068 if trace_sq[j] == 0.0 {
1069 continue;
1070 }
1071 if canonical[i].col_range != canonical[j].col_range {
1075 continue;
1076 }
1077 assert_eq!(canonical[i].local.dim(), canonical[j].local.dim());
1080
1081 let inner: f64 = canonical[i]
1082 .local
1083 .iter()
1084 .zip(canonical[j].local.iter())
1085 .map(|(&a, &b)| a * b)
1086 .sum();
1087 let denom = (trace_sq[i] * trace_sq[j]).sqrt();
1088 if denom == 0.0 {
1089 continue;
1090 }
1091 let cos = inner / denom;
1092
1093 if cos > REDUNDANCY_THRESHOLD {
1094 redundant.push((i, j, cos));
1095 } else if cos > SIMILARITY_THRESHOLD {
1096 similar.push((i, j, cos));
1097 }
1098 }
1099 }
1100
1101 for &(i, j, cos) in &redundant {
1103 log::warn!(
1104 "[PENALTY-REDUNDANCY] penalties i={i} j={j} are structurally identical \
1105 (cos={cos:.6}) — model is over-parameterized along their antisymmetric \
1106 direction; expect a Z₂-symmetric saddle in the LAML cost. Consider \
1107 re-specifying (e.g. anisotropic→isotropic for spatial smoothers with \
1108 weak axis signal)."
1109 );
1110 }
1111
1112 if k > LARGE_SCALE_K_THRESHOLD && similar.len() > TOP_SIMILARITY_PAIRS {
1114 similar.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1115 similar.truncate(TOP_SIMILARITY_PAIRS);
1116 }
1117 for (i, j, cos) in similar {
1118 log::info!(
1119 "[PENALTY-SIMILARITY] penalties i={i} j={j} are near-identical \
1120 (cos={cos:.6}) — outer Hessian may be ill-conditioned along their \
1121 antisymmetric direction."
1122 );
1123 }
1124
1125 redundant
1126}
1127
1128pub fn canonicalize_penalty_spec(
1134 spec: &crate::PenaltySpec,
1135 p: usize,
1136 idx: usize,
1137 context: &str,
1138) -> Result<Option<CanonicalPenalty>, EstimationError> {
1139 use crate::PenaltySpec;
1140
1141 crate::validate_penalty_spec_shape(idx, spec, p, context)?;
1142
1143 let (local_matrix, col_range, prior_mean_spec, hint, op) = match spec {
1144 PenaltySpec::Block {
1145 local,
1146 col_range,
1147 prior_mean,
1148 structure_hint,
1149 op,
1150 } => (
1151 local.view(),
1152 col_range.clone(),
1153 prior_mean,
1154 structure_hint.as_ref(),
1155 op.clone(),
1156 ),
1157 PenaltySpec::Dense(m) => (
1158 m.view(),
1159 0..p,
1160 &gam_problem::CoefficientPriorMean::Zero,
1161 None,
1162 None,
1163 ),
1164 PenaltySpec::DenseWithMean { matrix, prior_mean } => {
1165 (matrix.view(), 0..p, prior_mean, None, None)
1166 }
1167 };
1168
1169 let block_dim = col_range.len();
1170 let prior_mean = prior_mean_spec
1171 .evaluate(block_dim, &format!("{context}: penalty {idx}"))
1172 .map_err(|e| EstimationError::InvalidInput(e.0))?;
1173
1174 if let Some(PenaltyStructureHint::Ridge(scale)) = hint {
1176 if *scale <= 0.0 {
1177 return Ok(None);
1178 }
1179 let sqrt_scale = scale.sqrt();
1180 let mut root = Array2::zeros((block_dim, block_dim));
1181 for i in 0..block_dim {
1182 root[[i, i]] = sqrt_scale;
1183 }
1184 let mut local_sym = local_matrix.to_owned();
1188 symmetrize_in_place(&mut local_sym);
1189 return Ok(Some(CanonicalPenalty {
1190 root,
1191 col_range,
1192 total_dim: p,
1193 nullity: 0,
1194 local: local_sym,
1195 prior_mean,
1196 positive_eigenvalues: vec![*scale; block_dim],
1197 op,
1198 }));
1199 }
1200
1201 if let Some(PenaltyStructureHint::Kronecker(factors)) = hint {
1203 let decomps =
1204 match decompose_kronecker_factors(factors, &format!("{context} penalty {idx}"))? {
1205 None => return Ok(None),
1206 Some(d) => d,
1207 };
1208 let (positive_eigenvalues, nullity) = kronecker_eigenvalues(&decomps, block_dim);
1209 if positive_eigenvalues.is_empty() {
1210 return Ok(None);
1211 }
1212 let root = assemble_kronecker_root_local(&decomps);
1213 let mut local_sym = local_matrix.to_owned();
1214 symmetrize_in_place(&mut local_sym);
1215 return Ok(Some(CanonicalPenalty {
1216 root,
1217 col_range,
1218 total_dim: p,
1219 nullity,
1220 local: local_sym,
1221 prior_mean,
1222 positive_eigenvalues,
1223 op,
1224 }));
1225 }
1226
1227 let local_owned = local_matrix.to_owned();
1229 let analysis = analyze_penalty_block(&local_owned).map_err(|err| {
1230 EstimationError::InvalidInput(format!(
1231 "{context}: penalty canonicalization failed at index {idx}: {err}"
1232 ))
1233 })?;
1234
1235 if analysis.rank == 0 {
1236 log::debug!(
1237 "Dropped inactive penalty block idx={idx} reason={}",
1238 if analysis.iszero {
1239 "ZeroMatrix"
1240 } else {
1241 "NumericalRankZero"
1242 }
1243 );
1244 return Ok(None);
1245 }
1246
1247 let tolerance = analysis.tol;
1253 let classes = crate::basis::SpectralClassification::new(&analysis.eigenvalues, tolerance);
1254 let rank_k = classes.rank();
1255 assert_eq!(
1256 rank_k, analysis.rank,
1257 "penalty-root rank disagreement: SpectralClassification rank={rank_k} vs analyze_penalty_block rank={} (#1425 canonical-classifier invariant)",
1258 analysis.rank
1259 );
1260
1261 let mut root = Array2::zeros((rank_k, block_dim));
1269 let mut positive_eigenvalues = Vec::with_capacity(rank_k);
1270 for (row_idx, &i) in classes.range_idx.iter().enumerate() {
1271 let eigenval = analysis.eigenvalues[i];
1272 let eigenvec = analysis.eigenvectors.column(i);
1273 root.row_mut(row_idx).assign(&(&eigenvec * eigenval.sqrt()));
1274 positive_eigenvalues.push(eigenval);
1275 }
1276
1277 if classes.is_indefinite() {
1283 log::debug!(
1284 "{context}: penalty block idx={idx} carries {} negative-curvature \
1285 eigendirection(s) below -tol={tolerance:e}; dropped from the canonical \
1286 root and NOT counted as null space (rank={rank_k}, nullity={})",
1287 classes.negative_dim(),
1288 classes.nullity()
1289 );
1290 }
1291
1292 let local = root.t().dot(&root);
1296 Ok(Some(CanonicalPenalty {
1297 root,
1298 col_range,
1299 total_dim: p,
1300 nullity: classes.nullity(),
1301 local,
1302 prior_mean,
1303 positive_eigenvalues,
1304 op,
1305 }))
1306}
1307
1308pub fn canonicalize_penalty_specs(
1311 specs: &[crate::PenaltySpec],
1312 nullspace_dims: &[usize],
1313 p: usize,
1314 context: &str,
1315) -> Result<(Vec<CanonicalPenalty>, Vec<usize>), EstimationError> {
1316 if specs.len() != nullspace_dims.len() {
1317 crate::bail_invalid_estim!(
1318 "{context}: nullspace_dims length mismatch: penalties={}, nullspace_dims={}",
1319 specs.len(),
1320 nullspace_dims.len()
1321 );
1322 }
1323
1324 let mut active = Vec::with_capacity(specs.len());
1325 let mut active_nullspace = Vec::with_capacity(specs.len());
1326 for (idx, spec) in specs.iter().enumerate() {
1327 if let Some(canonical) = canonicalize_penalty_spec(spec, p, idx, context)? {
1328 active_nullspace.push(nullspace_dims[idx]);
1329 active.push(canonical);
1330 }
1331 }
1332 Ok((active, active_nullspace))
1333}
1334
1335pub(crate) const OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P: usize = 4096;
1345
1346pub fn create_balanced_penalty_root_from_canonical(
1353 penalties: &[CanonicalPenalty],
1354 p: usize,
1355) -> Result<Array2<f64>, EstimationError> {
1356 if penalties.is_empty() {
1357 return Ok(Array2::zeros((0, p)));
1358 }
1359
1360 let mut block_groups: BTreeMap<(usize, usize), Vec<&CanonicalPenalty>> = BTreeMap::new();
1362 for cp in penalties {
1363 if cp.rank() == 0 {
1364 continue;
1365 }
1366 let key = (cp.col_range.start, cp.col_range.end);
1367 block_groups.entry(key).or_default().push(cp);
1368 }
1369
1370 if block_groups.is_empty() {
1371 return Ok(Array2::zeros((0, p)));
1372 }
1373
1374 let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1376 let mut overlapping = false;
1377 for i in 1..ranges.len() {
1378 if ranges[i].0 < ranges[i - 1].1 {
1379 overlapping = true;
1380 break;
1381 }
1382 }
1383
1384 if overlapping {
1385 if p > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1386 return Err(EstimationError::LayoutError(format!(
1387 "overlapping penalty root would require dense {}x{} eigendecomposition; \
1388 large-model dense fallback is disabled. Keep penalties structured or \
1389 extend the overlapping-penalty solver path",
1390 p, p
1391 )));
1392 }
1393 let mut s_balanced = Array2::zeros((p, p));
1395 for cp in penalties {
1396 if cp.rank() == 0 {
1397 continue;
1398 }
1399 let local = cp.local_ref();
1400 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1401 if frob_norm > 1e-12 {
1402 let r = &cp.col_range;
1403 s_balanced
1404 .slice_mut(s![r.start..r.end, r.start..r.end])
1405 .scaled_add(1.0 / frob_norm, local);
1406 }
1407 }
1408 let (eigenvalues, eigenvectors) =
1409 robust_eigh(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1410 let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1411 let tolerance = if max_eig > 0.0 {
1412 max_eig * 1e-12
1413 } else {
1414 1e-12
1415 };
1416 let penalty_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1417 if penalty_rank == 0 {
1418 return Ok(Array2::zeros((0, p)));
1419 }
1420 let mut eb = Array2::zeros((p, penalty_rank));
1421 let mut col_idx = 0;
1422 for (i, &eigenval) in eigenvalues.iter().enumerate() {
1423 if eigenval > tolerance {
1424 let sqrt_ev = eigenval.sqrt();
1425 let evec = eigenvectors.column(i);
1426 eb.column_mut(col_idx).assign(&(&evec * sqrt_ev));
1427 col_idx += 1;
1428 }
1429 }
1430 return Ok(eb.t().to_owned());
1431 }
1432
1433 struct BlockRoot {
1435 col_range: Range<usize>,
1436 root: Array2<f64>, }
1438 let ordered_blocks: Vec<((usize, usize), Vec<&CanonicalPenalty>)> =
1443 block_groups.into_iter().collect();
1444 let block_roots: Vec<BlockRoot> = ordered_blocks
1445 .into_par_iter()
1446 .map(
1447 |((start, end), cps)| -> Result<Option<BlockRoot>, EstimationError> {
1448 let block_dim = end - start;
1449 let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1450
1451 for cp in cps {
1452 let local = cp.local_ref();
1453 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1454 if frob_norm > 1e-12 {
1455 s_balanced_local.scaled_add(1.0 / frob_norm, local);
1456 }
1457 }
1458
1459 let (eigenvalues, eigenvectors) =
1460 robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1461 let max_eig = eigenvalues.iter().fold(0.0f64, |max, &val| max.max(val));
1462 let tolerance = if max_eig > 0.0 {
1463 max_eig * 1e-12
1464 } else {
1465 1e-12
1466 };
1467 let block_rank = eigenvalues.iter().filter(|&&ev| ev > tolerance).count();
1468
1469 if block_rank == 0 {
1470 return Ok(None);
1471 }
1472
1473 let mut root = Array2::zeros((block_rank, block_dim));
1474 let mut row_idx = 0;
1475 for (i, &eigenval) in eigenvalues.iter().enumerate() {
1476 if eigenval > tolerance {
1477 let sqrt_ev = eigenval.sqrt();
1478 let evec = eigenvectors.column(i);
1479 root.row_mut(row_idx).assign(&(&evec * sqrt_ev));
1480 row_idx += 1;
1481 }
1482 }
1483
1484 Ok(Some(BlockRoot {
1485 col_range: start..end,
1486 root,
1487 }))
1488 },
1489 )
1490 .collect::<Result<Vec<_>, _>>()?
1491 .into_iter()
1492 .flatten()
1493 .collect();
1494 let total_rank: usize = block_roots.iter().map(|br| br.root.nrows()).sum();
1495
1496 if total_rank == 0 {
1497 return Ok(Array2::zeros((0, p)));
1498 }
1499
1500 let mut eb = Array2::zeros((total_rank, p));
1502 let mut row_offset = 0;
1503 for br in &block_roots {
1504 let rank_b = br.root.nrows();
1505 eb.slice_mut(s![
1506 row_offset..(row_offset + rank_b),
1507 br.col_range.start..br.col_range.end
1508 ])
1509 .assign(&br.root);
1510 row_offset += rank_b;
1511 }
1512
1513 Ok(eb)
1514}
1515
1516#[derive(Clone)]
1518struct SubspaceSplit {
1519 q_pen: Array2<f64>,
1520 q_null: Array2<f64>,
1521}
1522
1523impl SubspaceSplit {
1524 fn identity(p: usize) -> Self {
1525 Self {
1526 q_pen: Array2::zeros((p, 0)),
1527 q_null: Array2::eye(p),
1528 }
1529 }
1530
1531 fn from_ordered_qs(
1532 qs: &Mat<f64>,
1533 penalized_rank: usize,
1534 p: usize,
1535 ) -> Result<Self, EstimationError> {
1536 if qs.nrows() != p || qs.ncols() != p {
1537 return Err(EstimationError::LayoutError(format!(
1538 "Invalid Q basis dimensions: expected {p}x{p}, got {}x{}",
1539 qs.nrows(),
1540 qs.ncols()
1541 )));
1542 }
1543 if penalized_rank > p {
1544 return Err(EstimationError::LayoutError(format!(
1545 "Invalid penalized rank {penalized_rank} for p={p}"
1546 )));
1547 }
1548
1549 let null_count = p - penalized_rank;
1550 let mut q_pen = Array2::<f64>::zeros((p, penalized_rank));
1551 let mut q_null = Array2::<f64>::zeros((p, null_count));
1552 for i in 0..p {
1553 for j in 0..penalized_rank {
1554 q_pen[(i, j)] = qs[(i, j)];
1555 }
1556 for j in 0..null_count {
1557 q_null[(i, j)] = qs[(i, penalized_rank + j)];
1558 }
1559 }
1560
1561 Ok(Self { q_pen, q_null })
1562 }
1563
1564 fn rank(&self) -> usize {
1565 self.q_pen.ncols()
1566 }
1567
1568 fn p(&self) -> usize {
1569 self.q_pen.nrows()
1570 }
1571
1572 fn compose_qs(&self) -> Array2<f64> {
1573 let p = self.p();
1574 let rank = self.rank();
1575 let null_count = self.q_null.ncols();
1576 let mut qs = Array2::<f64>::zeros((p, p));
1577 for i in 0..p {
1578 for j in 0..rank {
1579 qs[(i, j)] = self.q_pen[(i, j)];
1580 }
1581 for j in 0..null_count {
1582 qs[(i, rank + j)] = self.q_null[(i, j)];
1583 }
1584 }
1585 qs
1586 }
1587}
1588
1589#[derive(Clone)]
1591pub struct ReparamInvariant {
1592 split: SubspaceSplit,
1593 qs_base: Array2<f64>,
1597 has_nonzero: bool,
1598 max_balanced_eigenvalue: f64,
1601}
1602
1603impl ReparamInvariant {
1604 pub const fn max_balanced_eigenvalue(&self) -> f64 {
1607 self.max_balanced_eigenvalue
1608 }
1609}
1610
1611pub fn precompute_reparam_invariant_from_canonical(
1618 penalties: &[CanonicalPenalty],
1619 p_total: usize,
1620) -> Result<ReparamInvariant, EstimationError> {
1621 use std::cmp::Ordering;
1622
1623 let m = penalties.len();
1624
1625 if m == 0 {
1626 return Ok(ReparamInvariant {
1627 split: SubspaceSplit::identity(p_total),
1628 qs_base: Array2::eye(p_total),
1629 has_nonzero: false,
1630 max_balanced_eigenvalue: 0.0,
1631 });
1632 }
1633
1634 struct PenRef {
1636 penalty_index: usize,
1637 }
1638 let mut block_groups: BTreeMap<(usize, usize), Vec<PenRef>> = BTreeMap::new();
1639 let mut has_nonzero = false;
1640 for (i, cp) in penalties.iter().enumerate() {
1641 if cp.rank() == 0 {
1642 continue;
1643 }
1644 let local = cp.local_ref();
1645 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1646 if frob_norm > 1e-12 {
1647 has_nonzero = true;
1648 }
1649 let key = (cp.col_range.start, cp.col_range.end);
1650 block_groups
1651 .entry(key)
1652 .or_default()
1653 .push(PenRef { penalty_index: i });
1654 }
1655
1656 if !has_nonzero {
1657 return Ok(ReparamInvariant {
1658 split: SubspaceSplit::identity(p_total),
1659 qs_base: Array2::eye(p_total),
1660 has_nonzero: false,
1661 max_balanced_eigenvalue: 0.0,
1662 });
1663 }
1664
1665 let ranges: Vec<(usize, usize)> = block_groups.keys().copied().collect();
1667 let mut overlapping = false;
1668 for i in 1..ranges.len() {
1669 if ranges[i].0 < ranges[i - 1].1 {
1670 overlapping = true;
1671 break;
1672 }
1673 }
1674
1675 if overlapping {
1676 if p_total > OVERLAPPING_PENALTY_DENSE_FALLBACK_MAX_P {
1682 return Err(EstimationError::LayoutError(format!(
1683 "overlapping penalty reparameterization would require dense {}x{} eigendecomposition; \
1684 large-model dense fallback is disabled. Keep penalties structured or \
1685 extend the overlapping-penalty solver path",
1686 p_total, p_total
1687 )));
1688 }
1689 let mut s_balanced = Mat::<f64>::zeros(p_total, p_total);
1691 for cp in penalties {
1692 if cp.rank() == 0 {
1693 continue;
1694 }
1695 let local = cp.local_ref();
1696 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1697 if frob_norm > 1e-12 {
1698 let scale = 1.0 / frob_norm;
1699 let r = &cp.col_range;
1700 for i in 0..local.nrows() {
1701 for j in 0..local.ncols() {
1702 s_balanced[(r.start + i, r.start + j)] += scale * local[[i, j]];
1703 }
1704 }
1705 }
1706 }
1707
1708 let (bal_eigenvalues, bal_eigenvectors) =
1709 robust_eigh_faer(&s_balanced, Side::Lower, "balanced penalty matrix")?;
1710
1711 let mut order: Vec<usize> = (0..p_total).collect();
1712 order.sort_by(|&i, &j| {
1713 bal_eigenvalues[j]
1714 .partial_cmp(&bal_eigenvalues[i])
1715 .unwrap_or(Ordering::Equal)
1716 .then(i.cmp(&j))
1717 });
1718
1719 let mut qs = Mat::<f64>::zeros(p_total, p_total);
1720 for (col_idx, &idx) in order.iter().enumerate() {
1721 for row in 0..p_total {
1722 qs[(row, col_idx)] = bal_eigenvectors[(row, idx)];
1723 }
1724 }
1725
1726 let max_bal = order
1727 .iter()
1728 .map(|&idx| bal_eigenvalues[idx].abs())
1729 .fold(0.0_f64, f64::max);
1730 let rank_tol = if max_bal > 0.0 {
1731 max_bal * 1e-12
1732 } else {
1733 1e-12
1734 };
1735 let penalized_rank = order
1736 .iter()
1737 .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1738 .count();
1739 let split = SubspaceSplit::from_ordered_qs(&qs, penalized_rank, p_total)?;
1740
1741 return Ok(ReparamInvariant {
1742 split,
1743 qs_base: mat_to_array(&qs),
1744 has_nonzero,
1745 max_balanced_eigenvalue: max_bal,
1746 });
1747 }
1748
1749 let mut covered = vec![false; p_total];
1757 for cp in penalties {
1758 for j in cp.col_range.clone() {
1759 covered[j] = true;
1760 }
1761 }
1762 let uncovered_cols: Vec<usize> = (0..p_total).filter(|j| !covered[*j]).collect();
1763
1764 struct BlockResult {
1765 col_range: Range<usize>,
1766 q_pen_local: Array2<f64>, q_null_local: Array2<f64>, max_balanced_eigenvalue: f64,
1770 pen_col_offset: usize,
1772 null_col_offset: usize,
1774 }
1775
1776 let block_specs: Vec<_> = block_groups.iter().collect();
1780 let mut block_results: Vec<BlockResult> = block_specs
1781 .into_par_iter()
1782 .map(
1783 |(&(start, end), refs)| -> Result<BlockResult, EstimationError> {
1784 let block_dim = end - start;
1785
1786 let mut s_balanced_local = Array2::zeros((block_dim, block_dim));
1788 let mut block_has_nonzero = false;
1789 for pref in refs {
1790 let cp = &penalties[pref.penalty_index];
1791 let local = cp.local_ref();
1792 let frob_norm = local.iter().map(|&x| x * x).sum::<f64>().sqrt();
1793 if frob_norm > 1e-12 {
1794 s_balanced_local.scaled_add(1.0 / frob_norm, local);
1795 block_has_nonzero = true;
1796 }
1797 }
1798
1799 if !block_has_nonzero {
1800 return Ok(BlockResult {
1801 col_range: start..end,
1802 q_pen_local: Array2::zeros((block_dim, 0)),
1803 q_null_local: Array2::eye(block_dim),
1804 max_balanced_eigenvalue: 0.0,
1805 pen_col_offset: 0, null_col_offset: 0, });
1808 }
1809
1810 let (bal_eigenvalues, bal_eigenvectors) =
1812 robust_eigh(&s_balanced_local, Side::Lower, "balanced penalty block")?;
1813
1814 let mut order: Vec<usize> = (0..block_dim).collect();
1815 order.sort_by(|&i, &j| {
1816 bal_eigenvalues[j]
1817 .partial_cmp(&bal_eigenvalues[i])
1818 .unwrap_or(Ordering::Equal)
1819 .then(i.cmp(&j))
1820 });
1821
1822 let max_bal = order
1823 .iter()
1824 .map(|&idx| bal_eigenvalues[idx].abs())
1825 .fold(0.0_f64, f64::max);
1826 let rank_tol = if max_bal > 0.0 {
1827 max_bal * 1e-12
1828 } else {
1829 1e-12
1830 };
1831 let penalized_rank = order
1832 .iter()
1833 .take_while(|&&idx| bal_eigenvalues[idx] > rank_tol)
1834 .count();
1835 let null_count = block_dim - penalized_rank;
1836
1837 let mut q_pen_local = Array2::zeros((block_dim, penalized_rank));
1838 let mut q_null_local = Array2::zeros((block_dim, null_count));
1839 for (col_idx, &idx) in order.iter().enumerate() {
1840 if col_idx < penalized_rank {
1841 for row in 0..block_dim {
1842 q_pen_local[[row, col_idx]] = bal_eigenvectors[[row, idx]];
1843 }
1844 } else {
1845 let null_col = col_idx - penalized_rank;
1846 for row in 0..block_dim {
1847 q_null_local[[row, null_col]] = bal_eigenvectors[[row, idx]];
1848 }
1849 }
1850 }
1851
1852 Ok(BlockResult {
1853 col_range: start..end,
1854 q_pen_local,
1855 q_null_local,
1856 max_balanced_eigenvalue: max_bal,
1857 pen_col_offset: 0, null_col_offset: 0, })
1860 },
1861 )
1862 .collect::<Result<_, _>>()?;
1863 let global_max_bal = block_results
1864 .iter()
1865 .map(|br| br.max_balanced_eigenvalue)
1866 .fold(0.0_f64, f64::max);
1867
1868 let total_pen_rank: usize = block_results.iter().map(|br| br.q_pen_local.ncols()).sum();
1870 let total_null: usize = block_results
1871 .iter()
1872 .map(|br| br.q_null_local.ncols())
1873 .sum::<usize>()
1874 + uncovered_cols.len();
1875 {
1876 let mut pen_off = 0usize;
1877 let mut null_off = 0usize;
1878 for br in &mut block_results {
1879 br.pen_col_offset = pen_off;
1880 br.null_col_offset = null_off;
1881 pen_off += br.q_pen_local.ncols();
1882 null_off += br.q_null_local.ncols();
1883 }
1884 }
1885
1886 let mut q_pen = Array2::zeros((p_total, total_pen_rank));
1887 let mut q_null = Array2::zeros((p_total, total_null));
1888
1889 for br in &block_results {
1890 let start = br.col_range.start;
1891 let bd = br.q_pen_local.nrows();
1892 let pen_r = br.q_pen_local.ncols();
1893 let null_r = br.q_null_local.ncols();
1894 if pen_r > 0 {
1895 q_pen
1896 .slice_mut(s![
1897 start..(start + bd),
1898 br.pen_col_offset..(br.pen_col_offset + pen_r)
1899 ])
1900 .assign(&br.q_pen_local);
1901 }
1902 if null_r > 0 {
1903 q_null
1904 .slice_mut(s![
1905 start..(start + bd),
1906 br.null_col_offset..(br.null_col_offset + null_r)
1907 ])
1908 .assign(&br.q_null_local);
1909 }
1910 }
1911 let mut null_col = block_results
1912 .iter()
1913 .map(|br| br.q_null_local.ncols())
1914 .sum::<usize>();
1915 for &j in &uncovered_cols {
1916 q_null[[j, null_col]] = 1.0;
1917 null_col += 1;
1918 }
1919
1920 let split = SubspaceSplit { q_pen, q_null };
1921
1922 let qs_global = split.compose_qs();
1926
1927 Ok(ReparamInvariant {
1928 split,
1929 qs_base: qs_global,
1930 has_nonzero,
1931 max_balanced_eigenvalue: global_max_bal,
1932 })
1933}
1934
1935fn structurally_penalized_columns(penalties: &[CanonicalPenalty], p: usize) -> Vec<bool> {
1936 let mut active = vec![false; p];
1937 for cp in penalties {
1938 let local = cp.local_ref();
1939 let scale = local.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max);
1940 if scale <= 0.0 {
1941 continue;
1942 }
1943 let tol = scale * 1e-12;
1944 for local_col in 0..cp.block_dim() {
1945 let mut column_active = false;
1946 for row in 0..cp.block_dim() {
1947 if local[[row, local_col]].abs() > tol || local[[local_col, row]].abs() > tol {
1948 column_active = true;
1949 break;
1950 }
1951 }
1952 if column_active {
1953 active[cp.col_range.start + local_col] = true;
1954 }
1955 }
1956 }
1957 active
1958}
1959
1960pub fn stable_reparameterizationwith_invariant(
1970 penalties: &[CanonicalPenalty],
1971 lambdas: &[f64],
1972 p: usize,
1973 invariant: &ReparamInvariant,
1974 penalty_shrinkage_floor: Option<f64>,
1975) -> Result<ReparamResult, EstimationError> {
1976 let m = penalties.len();
1977
1978 if lambdas.len() != m {
1979 return Err(EstimationError::ParameterConstraintViolation(format!(
1980 "Lambda count mismatch: expected {} lambdas for {} penalties, got {}",
1981 m,
1982 m,
1983 lambdas.len()
1984 )));
1985 }
1986
1987 if m == 0 {
1999 return Ok(ReparamResult {
2000 s_transformed: Array2::zeros((p, p)),
2001 log_det: 0.0,
2002 det1: Array1::zeros(0),
2003 qs: Array2::eye(p),
2004 canonical_transformed: vec![],
2005 e_transformed: Array2::zeros((0, p)),
2006 u_truncated: Array2::eye(p),
2008 penalty_shrinkage_ridge: 0.0,
2009 });
2010 }
2011
2012 if !invariant.has_nonzero {
2013 let qs = invariant.split.compose_qs();
2014 let u_truncated = qs.t().dot(&invariant.split.q_null);
2015 let canonical_transformed: Vec<CanonicalPenalty> = penalties.to_vec();
2017 return Ok(ReparamResult {
2018 s_transformed: Array2::zeros((p, p)),
2019 log_det: 0.0,
2020 det1: Array1::zeros(m),
2021 qs,
2022 canonical_transformed,
2023 e_transformed: Array2::zeros((0, p)),
2024 u_truncated,
2025 penalty_shrinkage_ridge: 0.0,
2026 });
2027 }
2028
2029 let q_pen = array_to_faer(&invariant.split.q_pen);
2030 let q_null = array_to_faer(&invariant.split.q_null);
2031 let qs_base = array_to_faer(&invariant.qs_base);
2032 let penalty_transforms: Vec<(Mat<f64>, Mat<f64>)> = penalties
2037 .par_iter()
2038 .map(|cp| {
2039 let r = &cp.col_range;
2040 let root_faer = array_to_faer(&cp.root);
2041 let q_block = qs_base.submatrix(r.start, 0, cp.block_dim(), p);
2042 let mut product = Mat::<f64>::zeros(cp.rank(), p);
2043 matmul(
2044 product.as_mut(),
2045 Accum::Replace,
2046 root_faer.as_ref(),
2047 q_block,
2048 1.0,
2049 Par::Seq,
2050 );
2051 let s_k = penalty_from_root_faer(&product);
2052 (product, s_k)
2053 })
2054 .collect();
2055 let (rs_transformed, s_k_penalized_cache): (Vec<Mat<f64>>, Vec<Mat<f64>>) =
2056 penalty_transforms.into_iter().unzip();
2057
2058 let penalized_rank = invariant.split.rank();
2059
2060 let mut range_eigenvalues_sorted: Vec<f64> = Vec::new();
2061 let mut range_rotation = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2062 if penalized_rank > 0 {
2063 let mut range_block = Mat::<f64>::zeros(penalized_rank, penalized_rank);
2064 for (lambda, s_k) in lambdas.iter().zip(s_k_penalized_cache.iter()) {
2068 for i in 0..penalized_rank {
2069 for j in 0..penalized_rank {
2070 range_block[(i, j)] += *lambda * s_k[(i, j)];
2071 }
2072 }
2073 }
2074 let (range_eigenvalues, range_eigenvectors) =
2075 robust_eigh_faer(&range_block, Side::Lower, "range penalty block")?;
2076
2077 let mut range_order: Vec<usize> = (0..penalized_rank).collect();
2078 range_order.sort_by(|&i, &j| {
2079 range_eigenvalues[j]
2080 .partial_cmp(&range_eigenvalues[i])
2081 .unwrap_or(std::cmp::Ordering::Equal)
2082 .then(i.cmp(&j))
2083 });
2084 range_eigenvalues_sorted = range_order
2085 .iter()
2086 .map(|&idx| range_eigenvalues[idx])
2087 .collect();
2088
2089 for (col_idx, &idx) in range_order.iter().enumerate() {
2096 for row in 0..penalized_rank {
2097 range_rotation[(row, col_idx)] = range_eigenvectors[(row, idx)];
2098 }
2099 }
2100 }
2104
2105 let structural_rank = penalized_rank;
2110 let mut range_eigs_sorted: Vec<f64> = range_eigenvalues_sorted;
2111 let structurally_penalized_cols = structurally_penalized_columns(penalties, p);
2112
2113 let shrinkage_ridge = penalty_shrinkage_floor
2130 .filter(|&eps| eps > 0.0)
2131 .map(|eps| eps * invariant.max_balanced_eigenvalue)
2132 .unwrap_or(0.0);
2133 if shrinkage_ridge > 0.0 {
2134 let min_eig_before = range_eigs_sorted
2135 .iter()
2136 .copied()
2137 .fold(f64::INFINITY, f64::min);
2138 let mut shrinkage_floor_applied = 0usize;
2139 for eig_idx in 0..range_eigs_sorted.len() {
2140 let mut penalized_energy = 0.0;
2141 for original_col in 0..p {
2142 if structurally_penalized_cols[original_col] {
2143 let mut coordinate = 0.0;
2144 for pen_col in 0..penalized_rank {
2145 coordinate +=
2146 q_pen[(original_col, pen_col)] * range_rotation[(pen_col, eig_idx)];
2147 }
2148 penalized_energy += coordinate * coordinate;
2149 }
2150 }
2151 if penalized_energy > 1e-8 {
2152 range_eigs_sorted[eig_idx] += shrinkage_ridge;
2153 shrinkage_floor_applied += 1;
2154 }
2155 }
2156 if min_eig_before > 0.0 && shrinkage_ridge / min_eig_before > 0.01 {
2158 log::debug!(
2159 "Penalty shrinkage floor active: ridge={:.3e} (min_eig_before={:.3e}, ratio={:.1e}, max_bal_eig={:.3e}, applied_dirs={})",
2160 shrinkage_ridge,
2161 min_eig_before,
2162 shrinkage_ridge / min_eig_before,
2163 invariant.max_balanced_eigenvalue,
2164 shrinkage_floor_applied,
2165 );
2166 }
2167 }
2168
2169 let eigenvalue_floor = invariant.max_balanced_eigenvalue.max(1.0) * 1e-12;
2170 let qs = compose_qs_from_split(&q_pen, &q_null, p);
2171
2172 let leakage = assess_subspace_leakage(&qs, &rs_transformed, structural_rank, p);
2175 let leakage_rel_tol = 1e-10;
2176 let leakage_abs_tol = 1e-12;
2177 let orth_tol = 1e-10;
2178 if leakage.max_rel_sq > leakage_rel_tol && leakage.max_abs_sq > leakage_abs_tol
2179 || leakage.max_cross_gram_abs > orth_tol
2180 {
2181 return Err(EstimationError::LayoutError(format!(
2182 "Reparameterization subspace split is inconsistent: max null leakage {:.3e} (rel {:.3e}, worst penalty {}), max |Qp'Qn| {:.3e}",
2183 leakage.max_abs_sq.sqrt(),
2184 leakage.max_rel_sq.sqrt(),
2185 leakage.worst_penalty,
2186 leakage.max_cross_gram_abs,
2187 )));
2188 }
2189
2190 let mut u_truncated_mat = Mat::<f64>::zeros(p, q_null.ncols());
2193 matmul(
2194 u_truncated_mat.as_mut(),
2195 Accum::Replace,
2196 qs.transpose(),
2197 q_null.as_ref(),
2198 1.0,
2199 Par::Seq,
2200 );
2201
2202 let mut e_transformed_mat = Mat::<f64>::zeros(structural_rank, p);
2208 for row_idx in 0..structural_rank {
2209 let safe_eigenval = range_eigs_sorted[row_idx].max(eigenvalue_floor);
2210 let sqrt_eigenval = safe_eigenval.sqrt();
2211 for j in 0..penalized_rank {
2213 e_transformed_mat[(row_idx, j)] = sqrt_eigenval * range_rotation[(j, row_idx)];
2214 }
2215 }
2216
2217 let mut floored_eigs: Vec<f64> = Vec::with_capacity(range_eigs_sorted.len());
2233 let mut log_det_sum = KahanSum::default();
2234 for (idx, &ev) in range_eigs_sorted.iter().enumerate() {
2235 if !ev.is_finite() || ev < -eigenvalue_floor {
2236 return Err(EstimationError::LayoutError(format!(
2237 "Penalty pseudo-logdet has a non-finite or large-negative structural eigenvalue at index {idx}: {ev:.3e}"
2238 )));
2239 }
2240 let safe_ev = ev.max(eigenvalue_floor);
2241 floored_eigs.push(safe_ev);
2242 if idx < penalized_rank {
2243 log_det_sum.add(safe_ev.ln());
2244 }
2245 }
2246 let log_det = log_det_sum.sum();
2247 let delta = 0.0;
2248
2249 let det1vec: Vec<f64> = (0..lambdas.len())
2252 .into_par_iter()
2253 .map(|k| {
2254 let s_k = &s_k_penalized_cache[k];
2255 let trace = trace_penalty_in_orthogonal_basis(
2259 s_k,
2260 penalized_rank,
2261 &range_rotation,
2262 &floored_eigs,
2263 delta,
2264 );
2265 lambdas[k] * trace
2266 })
2267 .collect();
2268
2269 {
2270 let mut maxdet1_mismatch = 0.0_f64;
2274 let mut det1_scale = 0.0_f64;
2275 for (k, lambda) in lambdas.iter().enumerate() {
2276 let s_k_penalized = &s_k_penalized_cache[k];
2277 let s_k_eigenbasis = orthogonal_similarity_transform_faer(
2278 s_k_penalized,
2279 penalized_rank,
2280 &range_rotation,
2281 );
2282 let mut trace = KahanSum::default();
2283 for l in 0..penalized_rank {
2284 trace.add(s_k_eigenbasis[(l, l)] / (floored_eigs[l] + delta));
2285 }
2286 let reference = *lambda * trace.sum();
2287 maxdet1_mismatch = maxdet1_mismatch.max((reference - det1vec[k]).abs());
2288 det1_scale = det1_scale.max(reference.abs()).max(det1vec[k].abs());
2289 }
2290 let det1_tolerance = 1e-7 * det1_scale.max(1.0);
2291 assert!(
2292 maxdet1_mismatch <= det1_tolerance,
2293 "det1 mismatch between optimized and reference formulas: max_abs={maxdet1_mismatch:.3e}, tol={det1_tolerance:.3e}"
2294 );
2295 }
2296
2297 let mut s_truncated = Mat::<f64>::zeros(p, p);
2308 matmul(
2309 s_truncated.as_mut(),
2310 Accum::Replace,
2311 e_transformed_mat.transpose(),
2312 e_transformed_mat.as_ref(),
2313 1.0,
2314 Par::Seq,
2315 );
2316
2317 {
2318 let mut max_null_diag = 0.0_f64;
2320 let mut max_null_offdiag = 0.0_f64;
2321 for i in structural_rank..p {
2322 max_null_diag = max_null_diag.max(s_truncated[(i, i)].abs());
2323 for j in 0..p {
2324 if i != j {
2325 max_null_offdiag = max_null_offdiag.max(s_truncated[(i, j)].abs());
2326 }
2327 }
2328 }
2329 assert!(
2330 max_null_diag <= 1e-10 && max_null_offdiag <= 1e-10,
2331 "null-space leakage in transformed penalty: max_null_diag={max_null_diag:.3e}, max_null_offdiag={max_null_offdiag:.3e}"
2332 );
2333 }
2334
2335 let qs_array = mat_to_array(&qs);
2336 let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2337 .par_iter()
2338 .zip(penalties.par_iter())
2339 .map(|(r, cp)| {
2340 let mean_transformed = qs_array.t().dot(&cp.full_width_prior_mean());
2341 CanonicalPenalty::from_dense_root_with_mean(mat_to_array(r), p, mean_transformed)
2342 })
2343 .collect();
2344 Ok(ReparamResult {
2345 s_transformed: mat_to_array(&s_truncated),
2346 log_det,
2347 det1: Array1::from(det1vec),
2348 qs: qs_array,
2349 canonical_transformed,
2350 e_transformed: mat_to_array(&e_transformed_mat),
2351 u_truncated: mat_to_array(&u_truncated_mat),
2352 penalty_shrinkage_ridge: shrinkage_ridge,
2353 })
2354}
2355
2356#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2358pub struct EngineDims {
2359 pub p: usize,
2360 pub k: usize,
2361}
2362
2363impl EngineDims {
2364 pub fn new(p: usize, k: usize) -> Self {
2365 Self { p, k }
2366 }
2367}
2368
2369pub fn stable_reparameterization_engine_canonical(
2378 penalties: &[CanonicalPenalty],
2379 lambdas: &[f64],
2380 dims: EngineDims,
2381 cached_invariant: Option<&ReparamInvariant>,
2382 penalty_shrinkage_floor: Option<f64>,
2383) -> Result<ReparamResult, EstimationError> {
2384 let owned;
2385 let invariant = match cached_invariant {
2386 Some(inv) => inv,
2387 None => {
2388 owned = precompute_reparam_invariant_from_canonical(penalties, dims.p)?;
2389 &owned
2390 }
2391 };
2392 stable_reparameterizationwith_invariant(
2393 penalties,
2394 lambdas,
2395 dims.p,
2396 invariant,
2397 penalty_shrinkage_floor,
2398 )
2399}
2400
2401#[derive(Clone)]
2411pub struct KroneckerReparamResult {
2412 pub reparameterized_marginals: Arc<Vec<Array2<f64>>>,
2418 pub marginal_eigenvalues: Arc<Vec<Array1<f64>>>,
2420 pub marginal_qs: Arc<Vec<Array2<f64>>>,
2422 pub log_det: f64,
2424 pub det1: Array1<f64>,
2426 pub det2: Array2<f64>,
2428 pub penalty_shrinkage_ridge: f64,
2430 pub has_double_penalty: bool,
2432 pub marginal_dims: Vec<usize>,
2434}
2435
2436impl KroneckerReparamResult {
2437 pub fn materialize_qs(&self) -> Array2<f64> {
2440 let mut qs = Array2::<f64>::eye(1);
2441 for u_k in self.marginal_qs.iter() {
2442 qs = kronecker_product(&qs, u_k);
2443 }
2444 qs
2445 }
2446
2447 pub fn materialize_s_transformed(&self, lambdas: &[f64]) -> Array2<f64> {
2450 let d = self.marginal_dims.len();
2451 let p: usize = self.marginal_dims.iter().copied().product();
2452 let mut s = Array2::<f64>::zeros((p, p));
2453
2454 let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2458 self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2459 let has_double = self.has_double_penalty && lambdas.len() > d;
2460 let mut multi_idx = vec![0usize; d];
2461 let mut flat = 0usize;
2462 loop {
2463 let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2464 &eigenvalue_views,
2465 &multi_idx,
2466 lambdas,
2467 d,
2468 has_double,
2469 self.penalty_shrinkage_ridge,
2470 );
2471 s[[flat, flat]] = sigma;
2472 flat += 1;
2473
2474 if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2475 break;
2476 }
2477 }
2478 s
2479 }
2480
2481 pub fn materialize_dense_artifact_result(
2484 &self,
2485 rs_list: &[Array2<f64>],
2486 lambdas: &[f64],
2487 p: usize,
2488 ) -> Result<ReparamResult, EstimationError> {
2489 const KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P: usize = 4096;
2490 if p > KRONECKER_DENSE_COMPAT_FALLBACK_MAX_P {
2491 return Err(EstimationError::LayoutError(format!(
2492 "Kronecker reparameterization would materialize dense {}x{} compatibility tensors; \
2493 large-model dense fallback is disabled. Wire the downstream solver to consume \
2494 the factored Kronecker result directly",
2495 p, p
2496 )));
2497 }
2498 let qs = self.materialize_qs();
2499 let s_transformed = self.materialize_s_transformed(lambdas);
2500
2501 let rs_transformed: Vec<Array2<f64>> = if rs_list.len() >= 2 {
2503 use rayon::prelude::*;
2504 rs_list
2505 .par_iter()
2506 .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2507 .collect()
2508 } else {
2509 rs_list
2510 .iter()
2511 .map(|r| gam_linalg::faer_ndarray::fast_ab(r, &qs))
2512 .collect()
2513 };
2514 let d = self.marginal_dims.len();
2520 let eigenvalue_views: Vec<ArrayView1<'_, f64>> =
2527 self.marginal_eigenvalues.iter().map(|m| m.view()).collect();
2528 let has_double = self.has_double_penalty && lambdas.len() > d;
2529 let diag_vals: Vec<f64> = {
2530 let mut vals = Vec::with_capacity(p);
2531 let mut multi_idx = vec![0usize; d];
2532 loop {
2533 let (sigma, _structural_sigma, _joint_null) = kronecker_cell_sigma(
2534 &eigenvalue_views,
2535 &multi_idx,
2536 lambdas,
2537 d,
2538 has_double,
2539 self.penalty_shrinkage_ridge,
2540 );
2541 vals.push(if sigma > 0.0 { sigma.sqrt() } else { 0.0 });
2542
2543 if kronecker_multi_index_advance(&mut multi_idx, &self.marginal_dims) {
2544 break;
2545 }
2546 }
2547 vals
2548 };
2549 let rank = diag_vals.iter().filter(|&&v| v > 1e-12).count();
2550 let mut e_transformed = Array2::<f64>::zeros((rank, p));
2551 let mut row = 0;
2552 for (j, &v) in diag_vals.iter().enumerate() {
2553 if v > 1e-12 {
2554 e_transformed[[row, j]] = v;
2555 row += 1;
2556 }
2557 }
2558
2559 let null_count = p - rank;
2561 let mut u_truncated = Array2::<f64>::zeros((p, null_count));
2562 let mut col = 0;
2563 for (j, &v) in diag_vals.iter().enumerate() {
2564 if v <= 1e-12 {
2565 u_truncated[[j, col]] = 1.0; col += 1;
2567 }
2568 }
2569
2570 let canonical_transformed: Vec<CanonicalPenalty> = rs_transformed
2571 .iter()
2572 .map(|r| CanonicalPenalty::from_dense_root(r.clone(), p))
2573 .collect();
2574 Ok(ReparamResult {
2575 s_transformed,
2576 log_det: self.log_det,
2577 det1: self.det1.clone(),
2578 qs,
2579 canonical_transformed,
2580 e_transformed,
2581 u_truncated,
2582 penalty_shrinkage_ridge: self.penalty_shrinkage_ridge,
2583 })
2584 }
2585}
2586
2587const KRONECKER_STRUCTURAL_ZERO_TOL: f64 = 1e-12;
2594
2595#[inline]
2609fn kronecker_cell_sigma(
2610 marginal_eigenvalues: &[ArrayView1<'_, f64>],
2611 multi_idx: &[usize],
2612 lambdas: &[f64],
2613 d: usize,
2614 has_double_penalty: bool,
2615 ridge: f64,
2616) -> (f64, f64, bool) {
2617 let mut sigma = 0.0;
2618 let mut structural_sigma = 0.0;
2619 for k in 0..d {
2620 let marginal_eigenvalue = marginal_eigenvalues[k][multi_idx[k]];
2621 structural_sigma += marginal_eigenvalue;
2622 sigma += lambdas[k] * marginal_eigenvalue;
2623 }
2624 let joint_null = structural_sigma <= KRONECKER_STRUCTURAL_ZERO_TOL;
2625 if has_double_penalty && joint_null {
2626 sigma += lambdas[d];
2627 }
2628 if structural_sigma > KRONECKER_STRUCTURAL_ZERO_TOL {
2629 sigma += ridge;
2630 }
2631 (sigma, structural_sigma, joint_null)
2632}
2633
2634#[inline]
2637fn kronecker_multi_index_advance(multi_idx: &mut [usize], dims: &[usize]) -> bool {
2638 let mut carry = true;
2639 for dim in (0..dims.len()).rev() {
2640 if carry {
2641 multi_idx[dim] += 1;
2642 if multi_idx[dim] < dims[dim] {
2643 carry = false;
2644 } else {
2645 multi_idx[dim] = 0;
2646 }
2647 }
2648 }
2649 carry
2650}
2651
2652pub fn kronecker_logdet_and_derivatives(
2653 marginal_eigenvalues: &[ArrayView1<'_, f64>],
2654 marginal_dims: &[usize],
2655 lambdas: &[f64],
2656 has_double_penalty: bool,
2657 ridge: f64,
2658) -> (f64, Array1<f64>, Array2<f64>) {
2659 let d = marginal_dims.len();
2660 let n_pen = d + if has_double_penalty { 1 } else { 0 };
2661
2662 let mut logdet = 0.0;
2663 let mut grad = Array1::<f64>::zeros(n_pen);
2664 let mut hess = Array2::<f64>::zeros((n_pen, n_pen));
2665 let tol = 1e-12;
2666
2667 let mut multi_idx = vec![0usize; d];
2668 loop {
2669 let (sigma, _structural_sigma, joint_null) = kronecker_cell_sigma(
2670 marginal_eigenvalues,
2671 &multi_idx,
2672 lambdas,
2673 d,
2674 has_double_penalty,
2675 ridge,
2676 );
2677
2678 if sigma > tol {
2679 logdet += sigma.ln();
2680 let inv_sigma = 1.0 / sigma;
2681 let inv_sigma2 = inv_sigma * inv_sigma;
2682
2683 for k in 0..d {
2684 let ck = lambdas[k] * marginal_eigenvalues[k][multi_idx[k]];
2685 grad[k] += ck * inv_sigma;
2686 }
2687 if has_double_penalty && joint_null {
2688 grad[d] += lambdas[d] * inv_sigma;
2689 }
2690
2691 for k in 0..n_pen {
2692 let ck = if k < d {
2693 lambdas[k] * marginal_eigenvalues[k][multi_idx[k]]
2694 } else if joint_null {
2695 lambdas[d]
2696 } else {
2697 0.0
2698 };
2699 if ck == 0.0 {
2706 continue;
2707 }
2708 hess[[k, k]] += ck * inv_sigma - ck * ck * inv_sigma2;
2709 for l in (k + 1)..n_pen {
2710 let cl = if l < d {
2711 lambdas[l] * marginal_eigenvalues[l][multi_idx[l]]
2712 } else if joint_null {
2713 lambdas[d]
2714 } else {
2715 0.0
2716 };
2717 let off = -ck * cl * inv_sigma2;
2718 hess[[k, l]] += off;
2719 hess[[l, k]] += off;
2720 }
2721 }
2722 }
2723
2724 if kronecker_multi_index_advance(&mut multi_idx, marginal_dims) {
2725 break;
2726 }
2727 }
2728
2729 (logdet, grad, hess)
2730}
2731
2732use crate::kronecker::KroneckerInvariantStructure;
2736
2737pub fn kronecker_reparameterization_engine(
2743 marginal_designs: &[Array2<f64>],
2744 marginal_penalties: &[Array2<f64>],
2745 marginal_dims: &[usize],
2746 lambdas: &[f64],
2747 has_double_penalty: bool,
2748 penalty_shrinkage_floor: Option<f64>,
2749) -> Result<KroneckerReparamResult, EstimationError> {
2750 let d = marginal_dims.len();
2751 if marginal_designs.len() != d || marginal_penalties.len() != d {
2752 return Err(EstimationError::LayoutError(format!(
2753 "kronecker_reparameterization_engine: dimension mismatch: designs={}, penalties={}, dims={}",
2754 marginal_designs.len(),
2755 marginal_penalties.len(),
2756 d
2757 )));
2758 }
2759
2760 let invariant =
2761 KroneckerInvariantStructure::compute(marginal_designs, marginal_penalties, marginal_dims)?;
2762 kronecker_reparameterization_engine_with_invariant(
2763 &invariant,
2764 marginal_dims,
2765 lambdas,
2766 has_double_penalty,
2767 penalty_shrinkage_floor,
2768 )
2769}
2770
2771pub fn kronecker_reparameterization_engine_with_invariant(
2779 invariant: &KroneckerInvariantStructure,
2780 marginal_dims: &[usize],
2781 lambdas: &[f64],
2782 has_double_penalty: bool,
2783 penalty_shrinkage_floor: Option<f64>,
2784) -> Result<KroneckerReparamResult, EstimationError> {
2785 let marginal_eigenvalues = Arc::clone(&invariant.marginal_eigenvalues);
2788 let marginal_qs = Arc::clone(&invariant.marginal_qs);
2789 let reparameterized_marginals = Arc::clone(&invariant.reparameterized_marginals);
2790
2791 let penalty_shrinkage_ridge = if let Some(floor) = penalty_shrinkage_floor {
2793 floor * invariant.max_balanced_eigenvalue
2794 } else {
2795 0.0
2796 };
2797
2798 let marginal_eigenvalue_views: Vec<_> = marginal_eigenvalues
2799 .iter()
2800 .map(|evals| evals.view())
2801 .collect();
2802 let (log_det, det1, det2) = kronecker_logdet_and_derivatives(
2803 &marginal_eigenvalue_views,
2804 marginal_dims,
2805 lambdas,
2806 has_double_penalty,
2807 penalty_shrinkage_ridge,
2808 );
2809
2810 Ok(KroneckerReparamResult {
2811 reparameterized_marginals,
2812 marginal_eigenvalues,
2813 marginal_qs,
2814 log_det,
2815 det1,
2816 det2,
2817 penalty_shrinkage_ridge,
2818 has_double_penalty,
2819 marginal_dims: marginal_dims.to_vec(),
2820 })
2821}
2822
2823pub fn calculate_condition_number(matrix: &Array2<f64>) -> Result<f64, FaerLinalgError> {
2843 let (rows, cols) = matrix.dim();
2844 if rows == 0 || cols == 0 {
2845 return Ok(1.0);
2846 }
2847
2848 if rows == cols {
2850 let mut max_abs = 0.0_f64;
2851 let mut max_asym = 0.0_f64;
2852 for i in 0..rows {
2853 for j in 0..cols {
2854 max_abs = max_abs.max(matrix[[i, j]].abs());
2855 }
2856 for j in 0..i {
2857 let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
2858 if diff > max_asym {
2859 max_asym = diff;
2860 }
2861 }
2862 }
2863 let sym_tol = max_abs.max(1.0) * 1e-12;
2864 if max_asym <= sym_tol {
2865 let (evals, _) = matrix.eigh(Side::Lower)?;
2866 let mut max_abs_eval = 0.0_f64;
2867 let mut min_abs_eval = f64::INFINITY;
2868 for &lam in evals.iter() {
2869 let s = lam.abs();
2870 max_abs_eval = max_abs_eval.max(s);
2871 min_abs_eval = min_abs_eval.min(s);
2872 }
2873 if min_abs_eval < 1e-12 {
2874 return Ok(f64::INFINITY);
2875 }
2876 return Ok(max_abs_eval / min_abs_eval);
2877 }
2878 }
2879
2880 let (_, s, _) = matrix.svd(false, false)?;
2882 let max_sv = s.iter().fold(0.0_f64, |max, &val| max.max(val));
2883 let min_sv = s.iter().fold(f64::INFINITY, |min, &val| min.min(val));
2884 if min_sv < 1e-12 {
2885 return Ok(f64::INFINITY);
2886 }
2887 Ok(max_sv / min_sv)
2888}
2889
2890#[cfg(test)]
2891mod tests {
2892 use super::{
2893 CanonicalPenalty, SubspaceLeakageMetrics, assess_subspace_leakage,
2894 classify_eigenvalues_strict, precompute_reparam_invariant_from_canonical,
2895 report_penalty_pair_redundancy, stable_reparameterizationwith_invariant,
2896 };
2897 use crate::construction::kronecker_product;
2898 use crate::EstimationError;
2899 use faer::Mat;
2900 use gam_linalg::faer_ndarray::FaerEigh;
2901 use gam_linalg::utils::inf_norm;
2902 use ndarray::{Array1, Array2, array};
2903
2904 fn canonical_from_roots(rs_list: &[Array2<f64>], p: usize) -> Vec<CanonicalPenalty> {
2906 rs_list
2907 .iter()
2908 .map(|r| {
2909 let local = r.t().dot(r);
2910 CanonicalPenalty {
2911 root: r.clone(),
2912 col_range: 0..p,
2913 total_dim: p,
2914 nullity: 0,
2915 local,
2916 prior_mean: Array1::zeros(p),
2917 positive_eigenvalues: Vec::new(),
2918 op: None,
2919 }
2920 })
2921 .collect()
2922 }
2923
2924 fn metrics_for(
2925 qs: &Mat<f64>,
2926 rs: &[Mat<f64>],
2927 structural_rank: usize,
2928 p: usize,
2929 ) -> SubspaceLeakageMetrics {
2930 assess_subspace_leakage(qs, rs, structural_rank, p)
2931 }
2932
2933 #[test]
2934 fn subspace_leakage_iszero_for_clean_split() {
2935 let p = 4usize;
2936 let structural_rank = 2usize;
2937 let qs = Mat::<f64>::identity(p, p);
2938 let mut r0 = Mat::<f64>::zeros(2, p);
2939 r0[(0, 0)] = 1.0;
2940 r0[(1, 1)] = 2.0;
2941
2942 let m = metrics_for(&qs, &[r0], structural_rank, p);
2943 assert!(m.max_abs_sq <= 1e-16);
2944 assert!(m.max_rel_sq <= 1e-16);
2945 assert!(m.max_cross_gram_abs <= 1e-16);
2946 }
2947
2948 #[test]
2949 fn subspace_leakage_detects_null_column_energy() {
2950 let p = 4usize;
2951 let structural_rank = 2usize;
2952 let qs = Mat::<f64>::identity(p, p);
2953 let mut r0 = Mat::<f64>::zeros(1, p);
2954 r0[(0, 2)] = 3.0;
2955
2956 let m = metrics_for(&qs, &[r0], structural_rank, p);
2957 assert!(m.max_abs_sq > 0.0);
2958 assert!(m.max_rel_sq > 0.99);
2959 }
2960
2961 #[test]
2962 fn subspace_leakage_detects_qp_qn_nonorthogonality() {
2963 let p = 3usize;
2964 let structural_rank = 1usize;
2965 let mut qs = Mat::<f64>::identity(p, p);
2966 qs[(0, 1)] = 0.2;
2967 let r0 = Mat::<f64>::zeros(1, p);
2968
2969 let m = metrics_for(&qs, &[r0], structural_rank, p);
2970 assert!(m.max_cross_gram_abs > 1e-3);
2971 }
2972
2973 #[test]
2974 fn u_truncated_is_transformed_frame_in_nonzero_case() {
2975 let p = 3usize;
2976 let rs_list = vec![array![[1.0, 0.0, 0.0]]];
2977 let canonical = canonical_from_roots(&rs_list, p);
2978 let lambdas = vec![2.0];
2979 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
2980 .expect("precompute invariant");
2981 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
2982 .expect("stable reparam");
2983
2984 let expected = rep.qs.t().dot(&inv.split.q_null);
2985 let diff = &rep.u_truncated - &expected;
2986 let max_abs = inf_norm(diff.iter().copied());
2987 assert!(
2988 max_abs <= 1e-10,
2989 "u_truncated frame mismatch: max_abs={max_abs}"
2990 );
2991 }
2992
2993 #[test]
2994 fn infinite_lambda_keeps_range_penalty_block_finite_1379() {
2995 let p = 3usize;
3008 let rs_list = vec![array![[1.0, 0.0, 0.0]], array![[0.0, 1.0, 0.0]]];
3009 let canonical = canonical_from_roots(&rs_list, p);
3010 let lambdas = vec![f64::INFINITY, 3.0];
3011 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3012 .expect("precompute invariant");
3013 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3014 .expect("stable reparam must not abort on an infinite lambda (gam#1379)");
3015
3016 assert!(
3017 rep.s_transformed.iter().all(|v| v.is_finite()),
3018 "transformed penalty must be finite with an infinite lambda"
3019 );
3020 assert!(
3021 rep.qs.iter().all(|v| v.is_finite()),
3022 "reparam rotation must be finite with an infinite lambda"
3023 );
3024 assert!(
3025 rep.log_det.is_finite(),
3026 "penalty log-det must be finite with an infinite lambda"
3027 );
3028 assert!(
3029 rep.det1.iter().all(|v| v.is_finite()),
3030 "penalty log-det derivatives must be finite with an infinite lambda"
3031 );
3032 }
3033
3034 #[test]
3035 fn u_truncated_is_identitywhen_no_penalties() {
3036 let p = 4usize;
3037 let canonical: Vec<CanonicalPenalty> = Vec::new();
3038 let lambdas: Vec<f64> = Vec::new();
3039 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3040 .expect("precompute invariant");
3041 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3042 .expect("stable reparam");
3043 assert_eq!(rep.u_truncated, Array2::<f64>::eye(p));
3044 }
3045
3046 #[test]
3047 fn dense_shrinkage_floor_skips_structurally_unpenalized_range_columns() {
3048 let p = 3usize;
3049 let canonical = canonical_from_roots(&[array![[1.0, 0.0, 0.0]]], p);
3050 let invariant = super::ReparamInvariant {
3051 split: super::SubspaceSplit {
3052 q_pen: array![[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
3053 q_null: array![[0.0], [0.0], [1.0]],
3054 },
3055 qs_base: Array2::eye(p),
3056 has_nonzero: true,
3057 max_balanced_eigenvalue: 1.0,
3058 };
3059
3060 let rep =
3061 stable_reparameterizationwith_invariant(&canonical, &[2.0], p, &invariant, Some(1e-6))
3062 .expect("stable reparameterization");
3063 assert!(rep.s_transformed[[0, 0]] > 2.0);
3064 assert!(
3065 rep.s_transformed[[1, 1]] <= 1e-11,
3066 "structurally unpenalized range coordinate received shrinkage ridge: {}",
3067 rep.s_transformed[[1, 1]]
3068 );
3069 }
3070
3071 #[test]
3072 fn kronecker_shrinkage_floor_preserves_joint_null_space() {
3073 let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3074 let marginal_penalties = vec![
3075 array![[0.0, 0.0], [0.0, 2.0]],
3076 array![[0.0, 0.0], [0.0, 3.0]],
3077 ];
3078 let marginal_dims = vec![2usize, 2usize];
3079 let lambdas = vec![5.0, 7.0];
3080
3081 let rep = super::kronecker_reparameterization_engine(
3082 &marginal_designs,
3083 &marginal_penalties,
3084 &marginal_dims,
3085 &lambdas,
3086 false,
3087 Some(1e-6),
3088 )
3089 .expect("kronecker reparameterization");
3090 assert!(rep.penalty_shrinkage_ridge > 0.0);
3091
3092 let s = rep.materialize_s_transformed(&lambdas);
3093 assert!(
3094 s[[0, 0]].abs() <= 1e-14,
3095 "joint tensor null direction must remain unpenalized, got {}",
3096 s[[0, 0]]
3097 );
3098 assert!(s[[1, 1]] > lambdas[1] * 3.0);
3099 assert!(s[[2, 2]] > lambdas[0] * 2.0);
3100 assert!(s[[3, 3]] > lambdas[0] * 2.0 + lambdas[1] * 3.0);
3101
3102 let tensor_roots = vec![
3103 array![
3104 [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3105 [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3106 ],
3107 array![
3108 [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3109 [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3110 ],
3111 ];
3112 let dense = rep
3113 .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3114 .expect("dense artifact materialization");
3115 assert_eq!(dense.e_transformed.nrows(), 3);
3116 assert_eq!(dense.u_truncated.ncols(), 1);
3117 }
3118
3119 #[test]
3120 fn kronecker_memoized_invariant_is_bit_identical_to_unmemoized_engine() {
3121 let marginal_designs = vec![
3128 array![[1.0, 0.3, -0.2], [0.4, 1.0, 0.1], [-0.1, 0.2, 1.0]],
3129 array![[1.0, -0.5], [0.2, 1.0], [0.7, 0.3]],
3130 ];
3131 let marginal_penalties = vec![
3132 array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]],
3133 array![[3.0, -1.5], [-1.5, 3.0]],
3134 ];
3135 let marginal_dims = vec![3usize, 2usize];
3136
3137 let invariant = super::KroneckerInvariantStructure::compute(
3138 &marginal_designs,
3139 &marginal_penalties,
3140 &marginal_dims,
3141 )
3142 .expect("invariant structure");
3143
3144 for lambdas in [
3145 vec![5.0, 7.0],
3146 vec![0.0, 7.0],
3147 vec![5.0, 0.0],
3148 vec![1e-3, 1e3],
3149 ] {
3150 for floor in [None, Some(1e-6)] {
3151 let unmemoized = super::kronecker_reparameterization_engine(
3152 &marginal_designs,
3153 &marginal_penalties,
3154 &marginal_dims,
3155 &lambdas,
3156 true,
3157 floor,
3158 )
3159 .expect("unmemoized engine");
3160 let memoized = super::kronecker_reparameterization_engine_with_invariant(
3161 &invariant,
3162 &marginal_dims,
3163 &lambdas,
3164 true,
3165 floor,
3166 )
3167 .expect("memoized engine");
3168
3169 assert_eq!(memoized.log_det.to_bits(), unmemoized.log_det.to_bits());
3170 assert_eq!(
3171 memoized.penalty_shrinkage_ridge.to_bits(),
3172 unmemoized.penalty_shrinkage_ridge.to_bits()
3173 );
3174 for (a, b) in memoized.det1.iter().zip(unmemoized.det1.iter()) {
3175 assert_eq!(a.to_bits(), b.to_bits());
3176 }
3177 for (a, b) in memoized.det2.iter().zip(unmemoized.det2.iter()) {
3178 assert_eq!(a.to_bits(), b.to_bits());
3179 }
3180 for (ma, ua) in memoized
3181 .reparameterized_marginals
3182 .iter()
3183 .zip(unmemoized.reparameterized_marginals.iter())
3184 {
3185 for (a, b) in ma.iter().zip(ua.iter()) {
3186 assert_eq!(a.to_bits(), b.to_bits());
3187 }
3188 }
3189 for (mq, uq) in memoized
3190 .marginal_qs
3191 .iter()
3192 .zip(unmemoized.marginal_qs.iter())
3193 {
3194 for (a, b) in mq.iter().zip(uq.iter()) {
3195 assert_eq!(a.to_bits(), b.to_bits());
3196 }
3197 }
3198 }
3199 }
3200 }
3201
3202 #[test]
3203 fn kronecker_double_penalty_shrinks_only_joint_null_space() {
3204 let marginal_designs = vec![Array2::<f64>::eye(2), Array2::<f64>::eye(2)];
3205 let marginal_penalties = vec![
3206 array![[0.0, 0.0], [0.0, 2.0]],
3207 array![[0.0, 0.0], [0.0, 3.0]],
3208 ];
3209 let marginal_dims = vec![2usize, 2usize];
3210 let lambdas = vec![5.0, 7.0, 11.0];
3211
3212 let rep = super::kronecker_reparameterization_engine(
3213 &marginal_designs,
3214 &marginal_penalties,
3215 &marginal_dims,
3216 &lambdas,
3217 true,
3218 None,
3219 )
3220 .expect("kronecker reparameterization");
3221
3222 let s = rep.materialize_s_transformed(&lambdas);
3223 let expected = [11.0, 21.0, 10.0, 31.0];
3224 for (idx, expected_diag) in expected.iter().copied().enumerate() {
3225 assert!(
3226 (s[[idx, idx]] - expected_diag).abs() <= 1e-12,
3227 "diagonal {idx} got {}, expected {expected_diag}",
3228 s[[idx, idx]]
3229 );
3230 }
3231
3232 let expected_logdet: f64 = expected.iter().map(|v| f64::ln(*v)).sum();
3233 assert!((rep.log_det - expected_logdet).abs() <= 1e-12);
3234 assert!(
3235 (rep.det1[2] - 1.0).abs() <= 1e-12,
3236 "double-penalty derivative must come only from the joint null mode, got {}",
3237 rep.det1[2]
3238 );
3239 assert!(rep.det2[[2, 2]].abs() <= 1e-12);
3240
3241 let tensor_roots = vec![
3242 array![
3243 [0.0, 0.0, 2.0_f64.sqrt(), 0.0],
3244 [0.0, 0.0, 0.0, 2.0_f64.sqrt()]
3245 ],
3246 array![
3247 [0.0, 3.0_f64.sqrt(), 0.0, 0.0],
3248 [0.0, 0.0, 0.0, 3.0_f64.sqrt()]
3249 ],
3250 ];
3251 let dense = rep
3252 .materialize_dense_artifact_result(&tensor_roots, &lambdas, 4)
3253 .expect("dense artifact materialization");
3254 for (idx, expected_diag) in expected.iter().copied().enumerate() {
3255 assert!(
3256 (dense.s_transformed[[idx, idx]] - expected_diag).abs() <= 1e-12,
3257 "dense artifact diagonal {idx} got {}, expected {expected_diag}",
3258 dense.s_transformed[[idx, idx]]
3259 );
3260 }
3261 }
3262
3263 #[test]
3264 fn transformed_penalty_is_diagonal_in_transformed_frame() {
3265 let p = 3usize;
3266 let inv_sqrt2 = 2.0_f64.sqrt().recip();
3267 let rs_list = vec![array![[inv_sqrt2, inv_sqrt2, 0.0]]];
3269 let canonical = canonical_from_roots(&rs_list, p);
3270 let lambdas = vec![4.0];
3271 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3272 .expect("precompute invariant");
3273 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3274 .expect("stable reparam");
3275
3276 assert_eq!(rep.e_transformed.nrows(), 1);
3277 assert!(rep.e_transformed[[0, 0]].abs() > 0.0);
3278 assert!(rep.e_transformed[[0, 1]].abs() <= 1e-12);
3279 assert!(rep.e_transformed[[0, 2]].abs() <= 1e-12);
3280 let expected_det1 = 1.0_f64;
3283 assert!((rep.det1[0] - expected_det1).abs() <= 1e-12);
3284
3285 let s = rep.s_transformed;
3286 let mut max_offdiag = 0.0_f64;
3287 for i in 0..p {
3288 for j in 0..p {
3289 if i != j {
3290 max_offdiag = max_offdiag.max(s[[i, j]].abs());
3291 }
3292 }
3293 }
3294 assert!(
3295 max_offdiag <= 1e-10,
3296 "transformed penalty should be diagonal, max offdiag={max_offdiag}"
3297 );
3298 assert!(s[[1, 1]].abs() <= 1e-10);
3299 assert!(s[[2, 2]].abs() <= 1e-10);
3300 }
3301
3302 #[test]
3303 fn det1_matches_rank_for_single_full_rank_penalty() {
3304 let p = 2usize;
3305 let inv_sqrt2 = 2.0_f64.sqrt().recip();
3306 let q_t = [[inv_sqrt2, inv_sqrt2], [-inv_sqrt2, inv_sqrt2]];
3308 let rs = array![
3310 [3.0 * q_t[0][0], 3.0 * q_t[0][1]],
3311 [1.0 * q_t[1][0], 1.0 * q_t[1][1]]
3312 ];
3313 let rs_list = vec![rs];
3314 let canonical = canonical_from_roots(&rs_list, p);
3315 let lambdas = vec![5.0];
3316
3317 let inv = precompute_reparam_invariant_from_canonical(&canonical, p)
3318 .expect("precompute invariant");
3319 let rep = stable_reparameterizationwith_invariant(&canonical, &lambdas, p, &inv, None)
3320 .expect("stable reparam");
3321
3322 assert_eq!(rep.e_transformed.nrows(), p);
3323 let det1 = rep.det1[0];
3324 let s_k_eigs = [9.0_f64, 1.0_f64];
3328 let lambda = 5.0_f64;
3329 let expected_det1: f64 = s_k_eigs.iter().map(|&d| lambda * d / (lambda * d)).sum();
3330 assert!(
3331 (det1 - expected_det1).abs() <= 1e-12,
3332 "expected det1={expected_det1}, got {det1}",
3333 );
3334
3335 let s = rep.s_transformed;
3336 assert!(s[[0, 1]].abs() <= 1e-10);
3337 assert!(s[[1, 0]].abs() <= 1e-10);
3338 assert!(s[[0, 0]] > 0.0);
3339 assert!(s[[1, 1]] > 0.0);
3340 }
3341
3342 #[test]
3343 fn kronecker_reparam_logdet_matches_dense() {
3344 let q1 = 3;
3347 let q2 = 4;
3348 let s1 = {
3349 let mut s = Array2::<f64>::zeros((q1, q1));
3350 s[[0, 0]] = 1.0;
3352 s[[0, 1]] = -1.0;
3353 s[[1, 0]] = -1.0;
3354 s[[1, 1]] = 2.0;
3355 s[[1, 2]] = -1.0;
3356 s[[2, 1]] = -1.0;
3357 s[[2, 2]] = 1.0;
3358 s
3359 };
3360 let s2 = {
3361 let mut s = Array2::<f64>::zeros((q2, q2));
3362 s[[0, 0]] = 1.0;
3363 s[[0, 1]] = -1.0;
3364 s[[1, 0]] = -1.0;
3365 s[[1, 1]] = 2.0;
3366 s[[1, 2]] = -1.0;
3367 s[[2, 1]] = -1.0;
3368 s[[2, 2]] = 2.0;
3369 s[[2, 3]] = -1.0;
3370 s[[3, 2]] = -1.0;
3371 s[[3, 3]] = 1.0;
3372 s
3373 };
3374
3375 let lambdas = [2.5, 1.3];
3376 let p = q1 * q2;
3378 let i1 = Array2::<f64>::eye(q1);
3379 let i2 = Array2::<f64>::eye(q2);
3380 let pen0 = kronecker_product(&s1, &i2);
3381 let pen1 = kronecker_product(&i1, &s2);
3382 let mut s_dense = Array2::<f64>::zeros((p, p));
3383 s_dense.scaled_add(lambdas[0], &pen0);
3384 s_dense.scaled_add(lambdas[1], &pen1);
3385
3386 let (evals_dense, _): (ndarray::Array1<f64>, ndarray::Array2<f64>) =
3388 s_dense.eigh(faer::Side::Lower).unwrap();
3389 let tol = 1e-12;
3390 let ref_logdet: f64 = evals_dense
3391 .iter()
3392 .filter(|&&v: &&f64| v > tol)
3393 .map(|&v: &f64| v.ln())
3394 .sum();
3395
3396 let marginal_designs = vec![
3398 Array2::<f64>::eye(q1), Array2::<f64>::eye(q2),
3400 ];
3401 let marginal_penalties = vec![s1, s2];
3402 let kron_result = super::kronecker_reparameterization_engine(
3403 &marginal_designs,
3404 &marginal_penalties,
3405 &[q1, q2],
3406 &lambdas,
3407 false,
3408 None,
3409 )
3410 .unwrap();
3411
3412 let diff = (kron_result.log_det - ref_logdet).abs();
3413 assert!(
3414 diff < 1e-8,
3415 "Kronecker logdet {:.10} vs dense {:.10}, diff={:.3e}",
3416 kron_result.log_det,
3417 ref_logdet,
3418 diff,
3419 );
3420
3421 let rhos: Vec<f64> = lambdas.iter().map(|&l| l.ln()).collect();
3423 let eps = 1e-5;
3424 for k in 0..2 {
3425 let mut rho_plus = rhos.clone();
3426 rho_plus[k] += eps;
3427 let mut rho_minus = rhos.clone();
3428 rho_minus[k] -= eps;
3429 let lam_plus: Vec<f64> = rho_plus.iter().map(|&r| r.exp()).collect();
3430 let lam_minus: Vec<f64> = rho_minus.iter().map(|&r| r.exp()).collect();
3431 let result_plus = super::kronecker_reparameterization_engine(
3432 &marginal_designs,
3433 &marginal_penalties,
3434 &[q1, q2],
3435 &lam_plus,
3436 false,
3437 None,
3438 )
3439 .unwrap();
3440 let result_minus = super::kronecker_reparameterization_engine(
3441 &marginal_designs,
3442 &marginal_penalties,
3443 &[q1, q2],
3444 &lam_minus,
3445 false,
3446 None,
3447 )
3448 .unwrap();
3449 let fd_deriv = (result_plus.log_det - result_minus.log_det) / (2.0 * eps);
3450 let analytic_deriv = kron_result.det1[k];
3451 let rel_err = if analytic_deriv.abs() > 1e-10 {
3452 (fd_deriv - analytic_deriv).abs() / analytic_deriv.abs()
3453 } else {
3454 (fd_deriv - analytic_deriv).abs()
3455 };
3456 assert!(
3457 rel_err < 1e-4,
3458 "det1[{k}] mismatch: analytic={:.8}, fd={:.8}, rel_err={:.3e}",
3459 analytic_deriv,
3460 fd_deriv,
3461 rel_err,
3462 );
3463 }
3464 }
3465
3466 #[test]
3467 fn classify_strict_rejects_nan_eigenvalue() {
3468 let mut eigs = [1.0, f64::NAN, 0.5];
3469 match classify_eigenvalues_strict(&mut eigs, "test_nan") {
3470 Err(EstimationError::PenaltySpectrumNonFinite {
3471 context,
3472 index,
3473 value,
3474 }) => {
3475 assert_eq!(context, "test_nan");
3476 assert_eq!(index, 1);
3477 assert!(value.is_nan());
3478 }
3479 other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3480 }
3481 }
3482
3483 #[test]
3484 fn classify_strict_rejects_inf_eigenvalue() {
3485 let mut eigs = [1.0, 0.5, f64::INFINITY];
3486 match classify_eigenvalues_strict(&mut eigs, "test_inf") {
3487 Err(EstimationError::PenaltySpectrumNonFinite { index, value, .. }) => {
3488 assert_eq!(index, 2);
3489 assert!(value.is_infinite());
3490 }
3491 other => panic!("expected PenaltySpectrumNonFinite, got {:?}", other),
3492 }
3493 }
3494
3495 #[test]
3496 fn classify_strict_rejects_materially_indefinite() {
3497 let mut eigs = [1.0, -1e-2, 0.5];
3499 match classify_eigenvalues_strict(&mut eigs, "test_indef") {
3500 Err(EstimationError::PenaltySpectrumIndefinite {
3501 context,
3502 index,
3503 value,
3504 ..
3505 }) => {
3506 assert_eq!(context, "test_indef");
3507 assert_eq!(index, 1);
3508 assert!((value + 1e-2).abs() <= 1e-15);
3509 }
3510 other => panic!("expected PenaltySpectrumIndefinite, got {:?}", other),
3511 }
3512 }
3513
3514 #[test]
3515 fn classify_strict_accepts_roundoff_negative() {
3516 let scale = 1.0_f64;
3518 let roundoff = -1e-16 * scale;
3519 let mut eigs = [scale, 0.5 * scale, roundoff, 0.25 * scale];
3520 classify_eigenvalues_strict(&mut eigs, "test_roundoff").expect("roundoff must classify");
3521 assert_eq!(eigs[2], 0.0);
3523 assert!(eigs[0] > 0.0 && eigs[1] > 0.0 && eigs[3] > 0.0);
3525 }
3526
3527 #[test]
3528 fn classify_strict_snaps_subtol_positive_to_zero() {
3529 let scale = 10.0_f64;
3532 let subtol = 1e-15 * scale;
3533 let mut eigs = [scale, subtol];
3534 classify_eigenvalues_strict(&mut eigs, "test_sub_pos").expect("sub-tol positive ok");
3535 assert_eq!(eigs[1], 0.0);
3536 }
3537
3538 fn canonical_from_local(
3542 local: Array2<f64>,
3543 col_range: std::ops::Range<usize>,
3544 total_dim: usize,
3545 ) -> CanonicalPenalty {
3546 let block_dim = local.nrows();
3547 let root = Array2::<f64>::zeros((0, block_dim));
3549 CanonicalPenalty {
3550 root,
3551 col_range,
3552 total_dim,
3553 nullity: 0,
3554 local,
3555 prior_mean: Array1::zeros(block_dim),
3556 positive_eigenvalues: Vec::new(),
3557 op: None,
3558 }
3559 }
3560
3561 #[test]
3562 fn report_penalty_pair_redundancy_detects_identical_pair() {
3563 let s0 = ndarray::array![[2.0, 0.5, 0.0], [0.5, 1.0, 0.25], [0.0, 0.25, 1.5],];
3565 let s_shared = ndarray::array![[1.0, -0.5, 0.0], [-0.5, 2.0, -0.5], [0.0, -0.5, 1.0],];
3568
3569 let bundle = vec![
3570 canonical_from_local(s0, 0..3, 3),
3571 canonical_from_local(s_shared.clone(), 0..3, 3),
3572 canonical_from_local(s_shared, 0..3, 3),
3573 ];
3574
3575 let redundant = report_penalty_pair_redundancy(&bundle);
3576
3577 assert_eq!(
3580 redundant.len(),
3581 1,
3582 "expected exactly one redundant pair, got {:?}",
3583 redundant
3584 );
3585 let (i, j, cos) = redundant[0];
3586 assert_eq!((i, j), (1, 2));
3587 assert!(
3588 cos > 1.0 - 1e-12,
3589 "cosine for identical penalties should be ~1.0, got {cos}"
3590 );
3591 }
3592
3593 #[test]
3594 fn report_penalty_pair_redundancy_skips_different_col_ranges() {
3595 let s = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
3599 let bundle = vec![
3600 canonical_from_local(s.clone(), 0..2, 4),
3601 canonical_from_local(s, 2..4, 4),
3602 ];
3603 let redundant = report_penalty_pair_redundancy(&bundle);
3604 assert!(
3605 redundant.is_empty(),
3606 "different col_ranges must not be flagged"
3607 );
3608 }
3609}