1use faer::Side;
45use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
46
47use gam_linalg::faer_ndarray::FaerEigh;
48use gam_linalg::lanczos::{
49 SymmetricLanczosOptions, symmetric_lanczos_eigenpairs, symmetric_lanczos_log_quadrature,
50};
51use gam_linalg::triangular::cholesky_solve_vector;
52use crate::arrow_schur::{ArrowFactorCache, ArrowSchurSystem};
53use crate::priority_selection::{PriorityCandidate, rank_priority_candidates};
54
55pub const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
56const EVIDENCE_LOGDET_SLQ_PROBES: usize = 16;
57const EVIDENCE_LOGDET_LANCZOS_STEPS: usize = 32;
58const EVIDENCE_HVP_SYMMETRY_REL_TOL: f64 = 1e-8;
59const EVIDENCE_HVP_SYMMETRY_PROBES: usize = 4;
60
61#[derive(Clone, Copy)]
65pub struct EvidenceHvpLogDet<'a> {
66 pub dim: usize,
67 pub apply: &'a dyn Fn(&[f64]) -> Vec<f64>,
68}
69
70#[derive(Clone, Copy)]
72pub enum EvidenceLogDetSource<'a> {
73 FactoredArrow {
76 cache: &'a ArrowFactorCache,
77 fallback_hvp: Option<EvidenceHvpLogDet<'a>>,
78 },
79 Hvp(EvidenceHvpLogDet<'a>),
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum TopologyKind {
97 Periodic,
99 Flat,
101 Sphere,
103 Torus,
105}
106
107impl TopologyKind {
108 pub fn complexity_rank(self) -> u8 {
111 match self {
112 TopologyKind::Flat => 0,
113 TopologyKind::Periodic => 1,
114 TopologyKind::Sphere => 2,
115 TopologyKind::Torus => 3,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
123pub struct TopologyCandidate {
124 pub kind: TopologyKind,
125 pub negative_log_evidence: f64,
128 pub effective_dim: f64,
131 pub n_obs: usize,
134 pub converged: bool,
138 pub exclusion_reason: Option<String>,
141}
142
143#[derive(Debug, Clone)]
145pub struct SelectedTopology {
146 pub winner: TopologyKind,
147 pub ranking: Vec<TopologyCandidate>,
150 pub tie: bool,
154}
155
156#[derive(Debug, Clone, Copy)]
158pub struct TopologySelectOptions {
159 pub tie_tolerance: f64,
163 pub score_scale: TopologyScoreScale,
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum TopologyScoreScale {
172 PerObservation,
174 PerEffectiveDim,
176}
177
178#[derive(Debug, Clone, Copy)]
180pub struct StackingConfig {
181 pub max_iter: usize,
182 pub weight_tol: f64,
183}
184
185impl Default for StackingConfig {
186 fn default() -> Self {
187 Self {
188 max_iter: 1000,
189 weight_tol: 1e-10,
190 }
191 }
192}
193
194#[derive(Debug, Clone)]
197pub struct StackingWeights {
198 pub weights: Array1<f64>,
199 pub mean_log_score: f64,
200 pub iterations: usize,
201}
202
203pub fn solve_stacking_weights(
210 log_density: ArrayView2<'_, f64>,
211 config: StackingConfig,
212) -> Result<StackingWeights, String> {
213 let n_obs = log_density.nrows();
214 let n_cand = log_density.ncols();
215 if n_cand == 0 {
216 return Err("stacking requires at least one candidate column".to_string());
217 }
218 if n_obs == 0 {
219 return Err("stacking requires at least one held-out observation row".to_string());
220 }
221
222 let kept_cols: Vec<usize> = (0..n_cand)
223 .filter(|&k| (0..n_obs).any(|i| log_density[[i, k]].is_finite()))
224 .collect();
225 if kept_cols.is_empty() {
226 return Err("stacking found no candidate with any finite held-out density".to_string());
227 }
228 let rows: Vec<usize> = (0..n_obs)
229 .filter(|&i| kept_cols.iter().any(|&k| log_density[[i, k]].is_finite()))
230 .collect();
231 if rows.is_empty() {
232 return Err("stacking found no held-out row with a finite density".to_string());
233 }
234
235 let kept = kept_cols.len();
236 let mut weights = Array1::<f64>::from_elem(kept, 1.0 / kept as f64);
237 let mut next = Array1::<f64>::zeros(kept);
238 let mut iterations = 0usize;
239 for _ in 0..config.max_iter {
240 iterations += 1;
241 next.fill(0.0);
242 let mut active_rows = 0usize;
243 for &row in &rows {
244 let mut row_max = f64::NEG_INFINITY;
245 for (local_col, &source_col) in kept_cols.iter().enumerate() {
246 let log_p = log_density[[row, source_col]];
247 if log_p.is_finite() && weights[local_col] > 0.0 {
248 row_max = row_max.max(weights[local_col].ln() + log_p);
249 }
250 }
251 if !row_max.is_finite() {
252 continue;
253 }
254 let mut denom = 0.0_f64;
255 for (local_col, &source_col) in kept_cols.iter().enumerate() {
256 let log_p = log_density[[row, source_col]];
257 if log_p.is_finite() && weights[local_col] > 0.0 {
258 denom += (weights[local_col].ln() + log_p - row_max).exp();
259 }
260 }
261 if denom <= 0.0 {
262 continue;
263 }
264 active_rows += 1;
265 let log_mix = row_max + denom.ln();
266 for (local_col, &source_col) in kept_cols.iter().enumerate() {
267 let log_p = log_density[[row, source_col]];
268 if log_p.is_finite() && weights[local_col] > 0.0 {
269 next[local_col] += (weights[local_col].ln() + log_p - log_mix).exp();
270 }
271 }
272 }
273 if active_rows == 0 {
274 break;
275 }
276 next.mapv_inplace(|value| value / active_rows as f64);
277 let total = next.sum();
278 if total > 0.0 {
279 next.mapv_inplace(|value| value / total);
280 }
281 let delta = next
282 .iter()
283 .zip(weights.iter())
284 .fold(0.0_f64, |acc, (a, b)| acc.max((a - b).abs()));
285 weights.assign(&next);
286 if delta <= config.weight_tol {
287 break;
288 }
289 }
290
291 let mean_log_score = stacking_mean_log_score(log_density, &rows, &kept_cols, weights.view());
292 let mut full = Array1::<f64>::zeros(n_cand);
293 for (local_col, &source_col) in kept_cols.iter().enumerate() {
294 full[source_col] = weights[local_col];
295 }
296 Ok(StackingWeights {
297 weights: full,
298 mean_log_score,
299 iterations,
300 })
301}
302
303fn stacking_mean_log_score(
304 log_density: ArrayView2<'_, f64>,
305 rows: &[usize],
306 kept_cols: &[usize],
307 weights: ArrayView1<'_, f64>,
308) -> f64 {
309 let mut score_sum = 0.0_f64;
310 let mut counted = 0usize;
311 for &row in rows {
312 let mut row_max = f64::NEG_INFINITY;
313 for (local_col, &source_col) in kept_cols.iter().enumerate() {
314 let log_p = log_density[[row, source_col]];
315 if log_p.is_finite() && weights[local_col] > 0.0 {
316 row_max = row_max.max(weights[local_col].ln() + log_p);
317 }
318 }
319 if !row_max.is_finite() {
320 continue;
321 }
322 let mut denom = 0.0_f64;
323 for (local_col, &source_col) in kept_cols.iter().enumerate() {
324 let log_p = log_density[[row, source_col]];
325 if log_p.is_finite() && weights[local_col] > 0.0 {
326 denom += (weights[local_col].ln() + log_p - row_max).exp();
327 }
328 }
329 if denom > 0.0 {
330 score_sum += row_max + denom.ln();
331 counted += 1;
332 }
333 }
334 if counted == 0 {
335 f64::NEG_INFINITY
336 } else {
337 score_sum / counted as f64
338 }
339}
340
341pub fn stacked_predictive_mean(
343 weights: &Array1<f64>,
344 candidate_means: &[Array1<f64>],
345) -> Result<Array1<f64>, String> {
346 if candidate_means.len() != weights.len() {
347 return Err(format!(
348 "stacked_predictive_mean: {} weights but {} candidate mean vectors",
349 weights.len(),
350 candidate_means.len()
351 ));
352 }
353 let Some(first) = candidate_means.first() else {
354 return Err("stacked_predictive_mean requires at least one candidate".to_string());
355 };
356 let n_rows = first.len();
357 if candidate_means.iter().any(|means| means.len() != n_rows) {
358 return Err(
359 "stacked_predictive_mean: candidate mean vectors disagree on row count".to_string(),
360 );
361 }
362 let mut out = Array1::<f64>::zeros(n_rows);
363 for (weight, means) in weights.iter().zip(candidate_means) {
364 if *weight != 0.0 {
365 out.scaled_add(*weight, means);
366 }
367 }
368 Ok(out)
369}
370
371#[derive(Debug, Clone, Copy)]
393pub struct GaussianMixtureConfig {
394 pub max_iter: usize,
396 pub loglik_tol: f64,
398 pub covariance_floor: f64,
401 pub kmeans_max_iter: usize,
403}
404
405impl Default for GaussianMixtureConfig {
406 fn default() -> Self {
407 Self {
408 max_iter: 200,
409 loglik_tol: 1e-7,
410 covariance_floor: 1e-6,
411 kmeans_max_iter: 25,
412 }
413 }
414}
415
416#[derive(Debug, Clone)]
418pub struct GaussianMixtureFit {
419 pub weights: Array1<f64>,
421 pub means: Array2<f64>,
423 pub covariances: Vec<Array2<f64>>,
425 pub k: usize,
427 pub d: usize,
429 pub n_obs: usize,
431 pub loglik: f64,
433 pub iterations: usize,
435}
436
437impl GaussianMixtureFit {
438 pub fn num_free_parameters(&self) -> usize {
443 let cov_per = self.d * (self.d + 1) / 2;
444 (self.k - 1) + self.k * self.d + self.k * cov_per
445 }
446
447 pub fn per_point_log_density(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
451 if data.ncols() != self.d {
452 return Err(format!(
453 "mixture log-density expects {} columns, got {}",
454 self.d,
455 data.ncols()
456 ));
457 }
458 let n = data.nrows();
459 let mut comp = vec![GaussianComponentEval::new(self.d); self.k];
460 for j in 0..self.k {
461 comp[j] = GaussianComponentEval::factor(self.means.row(j), &self.covariances[j])?;
462 }
463 let mut out = Array1::<f64>::zeros(n);
464 let log_w: Vec<f64> = self
465 .weights
466 .iter()
467 .map(|w| w.max(f64::MIN_POSITIVE).ln())
468 .collect();
469 for i in 0..n {
470 let row = data.row(i);
471 let mut log_terms = vec![f64::NEG_INFINITY; self.k];
472 let mut max_term = f64::NEG_INFINITY;
473 for j in 0..self.k {
474 let lt = log_w[j] + comp[j].log_density(row);
475 log_terms[j] = lt;
476 if lt > max_term {
477 max_term = lt;
478 }
479 }
480 out[i] = log_sum_exp(&log_terms, max_term);
481 }
482 Ok(out)
483 }
484
485 pub fn laplace_negative_log_evidence(&self, data: ArrayView2<'_, f64>) -> Result<f64, String> {
491 let p = self.num_free_parameters();
492 let information = self.empirical_fisher_information(data)?;
493 if information.nrows() != p {
494 return Err(format!(
495 "mixture empirical-Fisher information has dim {} but expected free-parameter count {p}",
496 information.nrows()
497 ));
498 }
499 let apply_info = |x: &[f64]| -> Vec<f64> {
500 let mut out = vec![0.0_f64; p];
501 for r in 0..p {
502 let mut acc = 0.0_f64;
503 for c in 0..p {
504 acc += information[[r, c]] * x[c];
505 }
506 out[r] = acc;
507 }
508 out
509 };
510 let hvp = EvidenceHvpLogDet {
511 dim: p,
512 apply: &apply_info,
513 };
514 let v = laplace_evidence(
515 EvidenceLogDetSource::Hvp(hvp),
516 0.0,
517 -self.loglik,
518 p as f64,
519 0.0,
520 );
521 if !v.is_finite() {
522 return Err("mixture Laplace evidence is not finite".to_string());
523 }
524 Ok(v)
525 }
526
527 fn empirical_fisher_information(
536 &self,
537 data: ArrayView2<'_, f64>,
538 ) -> Result<Array2<f64>, String> {
539 if data.ncols() != self.d {
540 return Err(format!(
541 "mixture information expects {} columns, got {}",
542 self.d,
543 data.ncols()
544 ));
545 }
546 let n = data.nrows();
547 let p = self.num_free_parameters();
548 let cov_per = self.d * (self.d + 1) / 2;
549 let mut comp = Vec::with_capacity(self.k);
551 for j in 0..self.k {
552 comp.push(GaussianComponentEval::factor(
553 self.means.row(j),
554 &self.covariances[j],
555 )?);
556 }
557 let log_w: Vec<f64> = self
558 .weights
559 .iter()
560 .map(|w| w.max(f64::MIN_POSITIVE).ln())
561 .collect();
562
563 let mean_base = self.k - 1;
564 let cov_base = mean_base + self.k * self.d;
565
566 let mut info = Array2::<f64>::zeros((p, p));
567 let mut score = vec![0.0_f64; p];
568 for i in 0..n {
569 let row = data.row(i);
570 let mut log_terms = vec![0.0_f64; self.k];
572 let mut max_term = f64::NEG_INFINITY;
573 for j in 0..self.k {
574 let lt = log_w[j] + comp[j].log_density(row);
575 log_terms[j] = lt;
576 if lt > max_term {
577 max_term = lt;
578 }
579 }
580 let log_mix = log_sum_exp(&log_terms, max_term);
581 let resp: Vec<f64> = log_terms.iter().map(|lt| (lt - log_mix).exp()).collect();
582
583 for s in score.iter_mut() {
584 *s = 0.0;
585 }
586 for j in 1..self.k {
589 score[j - 1] = resp[j] - self.weights[j];
590 }
591 for j in 0..self.k {
596 let prec_v = comp[j].precision_times_residual(row); let mbo = mean_base + j * self.d;
598 for c in 0..self.d {
599 score[mbo + c] = resp[j] * prec_v[c];
600 }
601 let cbo = cov_base + j * cov_per;
602 let mut idx = 0usize;
603 for a in 0..self.d {
604 for b in 0..=a {
605 let outer = prec_v[a] * prec_v[b];
606 let prec_ab = comp[j].precision[[a, b]];
607 let mut g = 0.5 * (outer - prec_ab);
608 if a != b {
609 g *= 2.0;
612 }
613 score[cbo + idx] = resp[j] * g;
614 idx += 1;
615 }
616 }
617 }
618 for r in 0..p {
620 let sr = score[r];
621 if sr == 0.0 {
622 continue;
623 }
624 for c in 0..p {
625 info[[r, c]] += sr * score[c];
626 }
627 }
628 }
629 for r in 0..p {
637 for c in (r + 1)..p {
638 let avg = 0.5 * (info[[r, c]] + info[[c, r]]);
639 info[[r, c]] = avg;
640 info[[c, r]] = avg;
641 }
642 info[[r, r]] += 1.0;
643 }
644 Ok(info)
645 }
646}
647
648#[derive(Debug, Clone)]
651struct GaussianComponentEval {
652 mean: Array1<f64>,
653 precision: Array2<f64>,
654 log_norm: f64,
655 d: usize,
656}
657
658impl GaussianComponentEval {
659 fn new(d: usize) -> Self {
660 Self {
661 mean: Array1::zeros(d),
662 precision: Array2::eye(d),
663 log_norm: 0.0,
664 d,
665 }
666 }
667
668 fn factor(mean: ArrayView1<'_, f64>, cov: &Array2<f64>) -> Result<Self, String> {
669 let d = mean.len();
670 if cov.nrows() != d || cov.ncols() != d {
671 return Err(format!(
672 "mixture component covariance must be {d}x{d}, got {}x{}",
673 cov.nrows(),
674 cov.ncols()
675 ));
676 }
677 let (evals, evecs) = cov
678 .eigh(Side::Lower)
679 .map_err(|e| format!("mixture component covariance eigendecomposition failed: {e}"))?;
680 let mut log_det = 0.0_f64;
681 let mut inv_evals = Array1::<f64>::zeros(d);
682 for (idx, &ev) in evals.iter().enumerate() {
683 if !ev.is_finite() || ev <= 0.0 {
684 return Err(format!(
685 "mixture component covariance is not SPD: eigenvalue {idx} is {ev:.3e}"
686 ));
687 }
688 log_det += ev.ln();
689 inv_evals[idx] = 1.0 / ev;
690 }
691 let mut precision = Array2::<f64>::zeros((d, d));
693 for a in 0..d {
694 for b in 0..d {
695 let mut acc = 0.0_f64;
696 for m in 0..d {
697 acc += evecs[[a, m]] * inv_evals[m] * evecs[[b, m]];
698 }
699 precision[[a, b]] = acc;
700 }
701 }
702 let log_norm = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det);
703 Ok(Self {
704 mean: mean.to_owned(),
705 precision,
706 log_norm,
707 d,
708 })
709 }
710
711 #[inline]
712 fn log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
713 let pv = self.precision_times_residual(y);
714 let mut quad = 0.0_f64;
715 for c in 0..self.d {
716 quad += (y[c] - self.mean[c]) * pv[c];
717 }
718 self.log_norm - 0.5 * quad
719 }
720
721 #[inline]
723 fn precision_times_residual(&self, y: ArrayView1<'_, f64>) -> Vec<f64> {
724 let mut out = vec![0.0_f64; self.d];
725 for a in 0..self.d {
726 let mut acc = 0.0_f64;
727 for b in 0..self.d {
728 acc += self.precision[[a, b]] * (y[b] - self.mean[b]);
729 }
730 out[a] = acc;
731 }
732 out
733 }
734}
735
736#[inline]
737fn log_sum_exp(terms: &[f64], max_term: f64) -> f64 {
738 if !max_term.is_finite() {
739 return f64::NEG_INFINITY;
740 }
741 let mut acc = 0.0_f64;
742 for &t in terms {
743 acc += (t - max_term).exp();
744 }
745 max_term + acc.ln()
746}
747
748pub fn fit_gaussian_mixture(
757 data: ArrayView2<'_, f64>,
758 k: usize,
759 config: GaussianMixtureConfig,
760) -> Result<GaussianMixtureFit, String> {
761 let n = data.nrows();
762 let d = data.ncols();
763 if k == 0 {
764 return Err("gaussian mixture requires k >= 1".to_string());
765 }
766 if d == 0 {
767 return Err("gaussian mixture requires at least one column".to_string());
768 }
769 if k > n {
770 return Err(format!(
771 "gaussian mixture requested {k} components but data has {n} rows"
772 ));
773 }
774 let centers = gam_terms::basis::select_centers_by_strategy(
777 data,
778 &gam_terms::basis::CenterStrategy::KMeans {
779 num_centers: k,
780 max_iter: config.kmeans_max_iter,
781 },
782 )
783 .map_err(|e| format!("gaussian mixture k-means seeding failed: {e}"))?;
784 if centers.nrows() != k || centers.ncols() != d {
785 return Err(format!(
786 "gaussian mixture seeding returned {}x{} centers, expected {k}x{d}",
787 centers.nrows(),
788 centers.ncols()
789 ));
790 }
791
792 let mut means = centers;
793 let global_cov = data_covariance(data, config.covariance_floor);
795 let mut covariances = vec![global_cov; k];
796 let mut weights = Array1::<f64>::from_elem(k, 1.0 / k as f64);
797
798 let mut resp = Array2::<f64>::zeros((n, k));
799 let mut prev_mean_ll = f64::NEG_INFINITY;
800 let mut total_loglik = f64::NEG_INFINITY;
801 let mut iterations = 0usize;
802
803 for iter in 0..config.max_iter.max(1) {
804 iterations = iter + 1;
805 let mut comp = Vec::with_capacity(k);
807 for j in 0..k {
808 comp.push(GaussianComponentEval::factor(
809 means.row(j),
810 &covariances[j],
811 )?);
812 }
813 let log_w: Vec<f64> = weights
814 .iter()
815 .map(|w| w.max(f64::MIN_POSITIVE).ln())
816 .collect();
817 total_loglik = 0.0;
818 for i in 0..n {
819 let yrow = data.row(i);
820 let mut log_terms = vec![0.0_f64; k];
821 let mut max_term = f64::NEG_INFINITY;
822 for j in 0..k {
823 let lt = log_w[j] + comp[j].log_density(yrow);
824 log_terms[j] = lt;
825 if lt > max_term {
826 max_term = lt;
827 }
828 }
829 let log_mix = log_sum_exp(&log_terms, max_term);
830 total_loglik += log_mix;
831 for j in 0..k {
832 resp[[i, j]] = (log_terms[j] - log_mix).exp();
833 }
834 }
835 let mean_ll = total_loglik / n as f64;
836 if iter > 0 {
837 let denom = prev_mean_ll.abs().max(1.0);
838 if (mean_ll - prev_mean_ll).abs() / denom <= config.loglik_tol {
839 break;
840 }
841 }
842 prev_mean_ll = mean_ll;
843
844 let mut nk = vec![0.0_f64; k];
846 for j in 0..k {
847 let mut sum = 0.0_f64;
848 for i in 0..n {
849 sum += resp[[i, j]];
850 }
851 nk[j] = sum.max(f64::MIN_POSITIVE);
852 }
853 for j in 0..k {
854 weights[j] = nk[j] / n as f64;
855 let mut mu = Array1::<f64>::zeros(d);
857 for i in 0..n {
858 let r = resp[[i, j]];
859 if r == 0.0 {
860 continue;
861 }
862 for c in 0..d {
863 mu[c] += r * data[[i, c]];
864 }
865 }
866 mu.mapv_inplace(|v| v / nk[j]);
867 for c in 0..d {
868 means[[j, c]] = mu[c];
869 }
870 let mut cov = Array2::<f64>::zeros((d, d));
872 for i in 0..n {
873 let r = resp[[i, j]];
874 if r == 0.0 {
875 continue;
876 }
877 for a in 0..d {
878 let da = data[[i, a]] - mu[a];
879 for b in 0..d {
880 cov[[a, b]] += r * da * (data[[i, b]] - mu[b]);
881 }
882 }
883 }
884 cov.mapv_inplace(|v| v / nk[j]);
885 for a in 0..d {
886 cov[[a, a]] += config.covariance_floor;
887 }
888 covariances[j] = cov;
889 }
890 }
891
892 Ok(GaussianMixtureFit {
893 weights,
894 means,
895 covariances,
896 k,
897 d,
898 n_obs: n,
899 loglik: total_loglik,
900 iterations,
901 })
902}
903
904fn data_covariance(data: ArrayView2<'_, f64>, floor: f64) -> Array2<f64> {
906 let n = data.nrows();
907 let d = data.ncols();
908 let mut mean = Array1::<f64>::zeros(d);
909 for i in 0..n {
910 for c in 0..d {
911 mean[c] += data[[i, c]];
912 }
913 }
914 mean.mapv_inplace(|v| v / n.max(1) as f64);
915 let mut cov = Array2::<f64>::zeros((d, d));
916 for i in 0..n {
917 for a in 0..d {
918 let da = data[[i, a]] - mean[a];
919 for b in 0..d {
920 cov[[a, b]] += da * (data[[i, b]] - mean[b]);
921 }
922 }
923 }
924 let inv = 1.0 / (n.max(1) as f64);
925 cov.mapv_inplace(|v| v * inv);
926 for a in 0..d {
927 cov[[a, a]] += floor;
928 }
929 cov
930}
931
932#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
961pub enum UnionStructure {
962 CircleCircle,
964 CirclePointCluster,
966 LineCluster,
968}
969
970pub const UNION_STRUCTURE_LADDER: &[UnionStructure] = &[
972 UnionStructure::CircleCircle,
973 UnionStructure::CirclePointCluster,
974 UnionStructure::LineCluster,
975];
976
977#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
983pub enum UnionComponentKind {
984 Circle,
985 Line,
986 PointCluster,
987}
988
989impl UnionStructure {
990 pub const fn as_str(self) -> &'static str {
992 match self {
993 UnionStructure::CircleCircle => "union_circle+circle",
994 UnionStructure::CirclePointCluster => "union_circle+cluster",
995 UnionStructure::LineCluster => "union_line+cluster",
996 }
997 }
998
999 pub const fn components(self) -> &'static [UnionComponentKind] {
1001 match self {
1002 UnionStructure::CircleCircle => {
1003 &[UnionComponentKind::Circle, UnionComponentKind::Circle]
1004 }
1005 UnionStructure::CirclePointCluster => {
1006 &[UnionComponentKind::Circle, UnionComponentKind::PointCluster]
1007 }
1008 UnionStructure::LineCluster => {
1009 &[UnionComponentKind::Line, UnionComponentKind::PointCluster]
1010 }
1011 }
1012 }
1013
1014 pub const fn num_components(self) -> usize {
1016 self.components().len()
1017 }
1018}
1019
1020#[derive(Debug, Clone)]
1024pub struct UnionComponentFit {
1025 pub kind: UnionComponentKind,
1026 pub row_count: usize,
1027 pub num_parameters: usize,
1028 pub negative_log_evidence: f64,
1029}
1030
1031#[derive(Debug, Clone)]
1035pub struct UnionStructureFit {
1036 pub structure: UnionStructure,
1037 pub components: Vec<UnionComponentFit>,
1038 pub negative_log_evidence: f64,
1040 pub total_parameters: usize,
1042}
1043
1044pub fn union_responsibility_split(
1049 data: ArrayView2<'_, f64>,
1050 m: usize,
1051 config: GaussianMixtureConfig,
1052) -> Result<Vec<Vec<usize>>, String> {
1053 let n = data.nrows();
1054 if m == 0 {
1055 return Err("union split requires at least one component".to_string());
1056 }
1057 if m > n {
1058 return Err(format!(
1059 "union split requested {m} groups but data has {n} rows"
1060 ));
1061 }
1062 if m == 1 {
1063 return Ok(vec![(0..n).collect()]);
1064 }
1065 let fit = fit_gaussian_mixture(data, m, config)?;
1066 let mut groups: Vec<Vec<usize>> = vec![Vec::new(); m];
1067 let mut comp = Vec::with_capacity(m);
1069 for j in 0..m {
1070 comp.push(GaussianComponentEval::factor(
1071 fit.means.row(j),
1072 &fit.covariances[j],
1073 )?);
1074 }
1075 let log_w: Vec<f64> = fit
1076 .weights
1077 .iter()
1078 .map(|w| w.max(f64::MIN_POSITIVE).ln())
1079 .collect();
1080 for i in 0..n {
1081 let row = data.row(i);
1082 let mut best_j = 0usize;
1083 let mut best_lt = f64::NEG_INFINITY;
1084 for j in 0..m {
1085 let lt = log_w[j] + comp[j].log_density(row);
1086 if lt > best_lt {
1087 best_lt = lt;
1088 best_j = j;
1089 }
1090 }
1091 groups[best_j].push(i);
1092 }
1093 Ok(groups)
1094}
1095
1096pub fn fit_union_structure(
1105 data: ArrayView2<'_, f64>,
1106 structure: UnionStructure,
1107 config: GaussianMixtureConfig,
1108) -> Result<UnionStructureFit, String> {
1109 let comps = structure.components();
1110 let m = comps.len();
1111 let groups = union_responsibility_split(data, m, config)?;
1112 let mut fits = Vec::with_capacity(m);
1113 let mut total_nle = 0.0_f64;
1114 let mut total_parameters = 0usize;
1115 for (kind, rows) in comps.iter().zip(groups.iter()) {
1116 let group = gather_union_rows(data, rows);
1117 let (nle, p) = fit_union_component(group.view(), *kind, config)?;
1118 if !nle.is_finite() {
1119 return Err(format!(
1120 "union {} component {:?} produced non-finite evidence",
1121 structure.as_str(),
1122 kind
1123 ));
1124 }
1125 total_nle += nle;
1126 total_parameters += p;
1127 fits.push(UnionComponentFit {
1128 kind: *kind,
1129 row_count: rows.len(),
1130 num_parameters: p,
1131 negative_log_evidence: nle,
1132 });
1133 }
1134 Ok(UnionStructureFit {
1135 structure,
1136 components: fits,
1137 negative_log_evidence: total_nle,
1138 total_parameters,
1139 })
1140}
1141
1142pub fn fit_union_ladder(
1147 data: ArrayView2<'_, f64>,
1148 config: GaussianMixtureConfig,
1149) -> Result<Vec<UnionStructureFit>, String> {
1150 let mut fits = Vec::new();
1151 let mut errors = Vec::new();
1152 for &structure in UNION_STRUCTURE_LADDER {
1153 match fit_union_structure(data, structure, config) {
1154 Ok(fit) => fits.push(fit),
1155 Err(e) => errors.push(format!("{}: {e}", structure.as_str())),
1156 }
1157 }
1158 if fits.is_empty() {
1159 return Err(format!(
1160 "union ladder produced no fittable composites{}",
1161 if errors.is_empty() {
1162 String::new()
1163 } else {
1164 format!(" ({})", errors.join("; "))
1165 }
1166 ));
1167 }
1168 let ranked = rank_priority_candidates(
1169 fits.into_iter()
1170 .enumerate()
1171 .map(|(idx, row)| {
1172 let score = row.negative_log_evidence;
1173 let tie = row.total_parameters; PriorityCandidate::new(row, idx, score, tie)
1175 })
1176 .collect(),
1177 )
1178 .into_iter()
1179 .map(|row| row.item)
1180 .collect::<Vec<_>>();
1181 Ok(ranked)
1182}
1183
1184fn gather_union_rows(data: ArrayView2<'_, f64>, idx: &[usize]) -> Array2<f64> {
1185 let d = data.ncols();
1186 let mut out = Array2::<f64>::zeros((idx.len(), d));
1187 for (r, &i) in idx.iter().enumerate() {
1188 for c in 0..d {
1189 out[[r, c]] = data[[i, c]];
1190 }
1191 }
1192 out
1193}
1194
1195fn fit_union_component(
1200 group: ArrayView2<'_, f64>,
1201 kind: UnionComponentKind,
1202 config: GaussianMixtureConfig,
1203) -> Result<(f64, usize), String> {
1204 match kind {
1205 UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1206 if group.nrows() < group.ncols() + 1 {
1210 return Err(format!(
1211 "union gaussian component needs >= {} rows, got {}",
1212 group.ncols() + 1,
1213 group.nrows()
1214 ));
1215 }
1216 let fit = fit_gaussian_mixture(group, 1, config)?;
1217 let nle = fit.laplace_negative_log_evidence(group)?;
1218 Ok((nle, fit.num_free_parameters()))
1219 }
1220 UnionComponentKind::Circle => fit_circle_component_evidence(group, config),
1221 }
1222}
1223
1224fn fit_circle_component_evidence(
1232 group: ArrayView2<'_, f64>,
1233 config: GaussianMixtureConfig,
1234) -> Result<(f64, usize), String> {
1235 let d = group.ncols();
1236 if d != 2 {
1237 return Err(format!(
1238 "union circle component requires 2-D data, got {d} columns"
1239 ));
1240 }
1241 let n = group.nrows();
1242 let p = 4usize; if n < p + 1 {
1244 return Err(format!(
1245 "union circle component needs >= {} rows, got {n}",
1246 p + 1
1247 ));
1248 }
1249 let mut cx = 0.0_f64;
1254 let mut cy = 0.0_f64;
1255 for i in 0..n {
1256 cx += group[[i, 0]];
1257 cy += group[[i, 1]];
1258 }
1259 cx /= n as f64;
1260 cy /= n as f64;
1261 let mut radii = vec![0.0_f64; n];
1262 let mut radius = 0.0_f64;
1263 for i in 0..n {
1264 let dx = group[[i, 0]] - cx;
1265 let dy = group[[i, 1]] - cy;
1266 let r = (dx * dx + dy * dy).sqrt();
1267 radii[i] = r;
1268 radius += r;
1269 }
1270 radius /= n as f64;
1271 let mut var_r = 0.0_f64;
1272 for &r in &radii {
1273 let e = r - radius;
1274 var_r += e * e;
1275 }
1276 var_r = (var_r / n as f64).max(config.covariance_floor);
1277 let inv_var = 1.0 / var_r;
1278 let mut loglik = 0.0_f64;
1281 let log_2pi = (2.0 * std::f64::consts::PI).ln();
1282 for &r in &radii {
1283 let e = r - radius;
1284 let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e * inv_var;
1285 let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1286 loglik += radial + angular;
1287 }
1288 let mut info = Array2::<f64>::zeros((p, p));
1295 let mut score = [0.0_f64; 4];
1296 for i in 0..n {
1297 let dx = group[[i, 0]] - cx;
1298 let dy = group[[i, 1]] - cy;
1299 let r = radii[i].max(f64::MIN_POSITIVE);
1300 let e = radii[i] - radius;
1301 let ee = e * inv_var;
1302 score[0] = ee * (-dx / r);
1303 score[1] = ee * (-dy / r);
1304 score[2] = ee;
1305 score[3] = -0.5 + 0.5 * e * e * inv_var;
1306 for a in 0..p {
1307 let sa = score[a];
1308 if sa == 0.0 {
1309 continue;
1310 }
1311 for b in 0..p {
1312 info[[a, b]] += sa * score[b];
1313 }
1314 }
1315 }
1316 for a in 0..p {
1319 for b in (a + 1)..p {
1320 let avg = 0.5 * (info[[a, b]] + info[[b, a]]);
1321 info[[a, b]] = avg;
1322 info[[b, a]] = avg;
1323 }
1324 info[[a, a]] += 1.0;
1325 }
1326 let apply_info = |x: &[f64]| -> Vec<f64> {
1327 let mut out = vec![0.0_f64; p];
1328 for r in 0..p {
1329 let mut acc = 0.0_f64;
1330 for c in 0..p {
1331 acc += info[[r, c]] * x[c];
1332 }
1333 out[r] = acc;
1334 }
1335 out
1336 };
1337 let hvp = EvidenceHvpLogDet {
1338 dim: p,
1339 apply: &apply_info,
1340 };
1341 let v = laplace_evidence(EvidenceLogDetSource::Hvp(hvp), 0.0, -loglik, p as f64, 0.0);
1342 if !v.is_finite() {
1343 return Err("union circle component Laplace evidence is not finite".to_string());
1344 }
1345 Ok((v, p))
1346}
1347
1348#[derive(Debug, Clone)]
1354enum UnionComponentDensity {
1355 Gaussian {
1356 log_weight: f64,
1357 eval: GaussianComponentEval,
1358 },
1359 Circle {
1360 log_weight: f64,
1361 center: [f64; 2],
1362 radius: f64,
1363 var_r: f64,
1364 },
1365}
1366
1367impl UnionComponentDensity {
1368 fn weighted_log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
1370 match self {
1371 UnionComponentDensity::Gaussian { log_weight, eval } => {
1372 log_weight + eval.log_density(y)
1373 }
1374 UnionComponentDensity::Circle {
1375 log_weight,
1376 center,
1377 radius,
1378 var_r,
1379 } => {
1380 let dx = y[0] - center[0];
1381 let dy = y[1] - center[1];
1382 let r = (dx * dx + dy * dy).sqrt();
1383 let log_2pi = (2.0 * std::f64::consts::PI).ln();
1384 let e = r - radius;
1385 let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e / var_r;
1386 let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1387 log_weight + radial + angular
1388 }
1389 }
1390 }
1391}
1392
1393fn fit_union_component_densities(
1397 train: ArrayView2<'_, f64>,
1398 structure: UnionStructure,
1399 config: GaussianMixtureConfig,
1400) -> Result<Vec<UnionComponentDensity>, String> {
1401 let comps = structure.components();
1402 let m = comps.len();
1403 let groups = union_responsibility_split(train, m, config)?;
1404 let n_train = train.nrows().max(1) as f64;
1405 let mut out = Vec::with_capacity(m);
1406 for (kind, rows) in comps.iter().zip(groups.iter()) {
1407 if rows.is_empty() {
1408 return Err(format!(
1409 "union {} held-out density: empty component group",
1410 structure.as_str()
1411 ));
1412 }
1413 let log_weight = (rows.len() as f64 / n_train).max(f64::MIN_POSITIVE).ln();
1414 let group = gather_union_rows(train, rows);
1415 match kind {
1416 UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1417 if group.nrows() < group.ncols() + 1 {
1418 return Err(format!(
1419 "union gaussian component density needs >= {} rows, got {}",
1420 group.ncols() + 1,
1421 group.nrows()
1422 ));
1423 }
1424 let fit = fit_gaussian_mixture(group.view(), 1, config)?;
1425 let eval = GaussianComponentEval::factor(fit.means.row(0), &fit.covariances[0])?;
1426 out.push(UnionComponentDensity::Gaussian { log_weight, eval });
1427 }
1428 UnionComponentKind::Circle => {
1429 let d = group.ncols();
1430 if d != 2 {
1431 return Err(format!(
1432 "union circle component density requires 2-D data, got {d} columns"
1433 ));
1434 }
1435 let n = group.nrows();
1436 if n < 5 {
1437 return Err(format!(
1438 "union circle component density needs >= 5 rows, got {n}"
1439 ));
1440 }
1441 let mut cx = 0.0_f64;
1442 let mut cy = 0.0_f64;
1443 for i in 0..n {
1444 cx += group[[i, 0]];
1445 cy += group[[i, 1]];
1446 }
1447 cx /= n as f64;
1448 cy /= n as f64;
1449 let mut radius = 0.0_f64;
1450 let mut radii = vec![0.0_f64; n];
1451 for i in 0..n {
1452 let dx = group[[i, 0]] - cx;
1453 let dy = group[[i, 1]] - cy;
1454 let r = (dx * dx + dy * dy).sqrt();
1455 radii[i] = r;
1456 radius += r;
1457 }
1458 radius /= n as f64;
1459 let mut var_r = 0.0_f64;
1460 for &r in &radii {
1461 let e = r - radius;
1462 var_r += e * e;
1463 }
1464 var_r = (var_r / n as f64).max(config.covariance_floor);
1465 out.push(UnionComponentDensity::Circle {
1466 log_weight,
1467 center: [cx, cy],
1468 radius,
1469 var_r,
1470 });
1471 }
1472 }
1473 }
1474 Ok(out)
1475}
1476
1477pub fn union_per_point_log_density(
1482 train: ArrayView2<'_, f64>,
1483 eval: ArrayView2<'_, f64>,
1484 structure: UnionStructure,
1485 config: GaussianMixtureConfig,
1486) -> Result<Array1<f64>, String> {
1487 if train.ncols() != eval.ncols() {
1488 return Err(format!(
1489 "union held-out density: train has {} columns, eval has {}",
1490 train.ncols(),
1491 eval.ncols()
1492 ));
1493 }
1494 let densities = fit_union_component_densities(train, structure, config)?;
1495 let mut out = Array1::<f64>::zeros(eval.nrows());
1496 let mut terms = vec![f64::NEG_INFINITY; densities.len()];
1497 for i in 0..eval.nrows() {
1498 let row = eval.row(i);
1499 let mut max_term = f64::NEG_INFINITY;
1500 for (c, dens) in densities.iter().enumerate() {
1501 let lt = dens.weighted_log_density(row);
1502 terms[c] = lt;
1503 if lt > max_term {
1504 max_term = lt;
1505 }
1506 }
1507 out[i] = log_sum_exp(&terms, max_term);
1508 }
1509 Ok(out)
1510}
1511
1512#[derive(Clone, Debug)]
1514pub struct RemlCandidate {
1515 pub index: usize,
1516 pub name: String,
1517 pub score: f64,
1520 pub edf: Option<f64>,
1521 pub log_lik: Option<f64>,
1525 pub family: Option<String>,
1531}
1532
1533impl RemlCandidate {
1534 pub fn ranking_score(&self) -> f64 {
1553 match (self.log_lik, self.edf) {
1554 (Some(log_lik), Some(edf)) if log_lik.is_finite() && edf.is_finite() => {
1555 -2.0 * log_lik + 2.0 * edf
1556 }
1557 _ => self.score,
1558 }
1559 }
1560}
1561
1562#[derive(Clone, Debug)]
1563pub struct RemlComparison {
1564 pub ranking: Vec<RankedRow>,
1565 pub winner: String,
1566 pub evidence_summary: String,
1567 pub score_table: Vec<ScoreRow>,
1568}
1569
1570#[derive(Clone, Debug)]
1571pub struct RankedRow {
1572 pub name: String,
1573 pub score: f64,
1574 pub delta: f64,
1581 pub bayes_factor: f64,
1584 pub edf: Option<f64>,
1585}
1586
1587#[derive(Clone, Debug)]
1588pub struct ScoreRow {
1589 pub name: String,
1590 pub reml_score: f64,
1591 pub delta_reml: f64,
1592 pub bayes_factor_best_over_model: f64,
1593 pub effective_dof: Option<f64>,
1594}
1595
1596#[inline]
1598pub fn log_bayes_factor(reml_score_a: f64, reml_score_b: f64) -> f64 {
1599 reml_score_b - reml_score_a
1600}
1601
1602pub fn compare_reml_fits(mut candidates: Vec<RemlCandidate>) -> Result<RemlComparison, String> {
1606 if candidates.is_empty() {
1607 return Err("compare_models requires at least one fit".to_string());
1608 }
1609 {
1617 let mut seen_family: Option<&str> = None;
1618 for cand in &candidates {
1619 if let Some(fam) = cand.family.as_deref() {
1620 match seen_family {
1621 None => seen_family = Some(fam),
1622 Some(prev) if prev != fam => {
1623 return Err(format!(
1624 "compare_models: cannot compare fits of different response families ('{prev}' vs '{fam}'); their REML/LAML evidence scores are on incomparable base measures. Compare models fit to the same response under the same family."
1625 ));
1626 }
1627 Some(_) => {}
1628 }
1629 }
1630 }
1631 }
1632 candidates = rank_priority_candidates(
1633 candidates
1634 .into_iter()
1635 .enumerate()
1636 .map(|(idx, row)| {
1637 let ranking = row.ranking_score();
1640 PriorityCandidate::new(row, idx, ranking, 0)
1641 })
1642 .collect(),
1643 )
1644 .into_iter()
1645 .map(|row| row.item)
1646 .collect();
1647
1648 let winner = candidates[0].name.clone();
1649 let best_ranking_score = candidates[0].ranking_score();
1659 let best_raw_score = candidates
1664 .iter()
1665 .map(|c| c.score)
1666 .fold(f64::INFINITY, f64::min);
1667 let mut ranking = Vec::with_capacity(candidates.len());
1668 let mut score_table = Vec::with_capacity(candidates.len());
1669 for row in &candidates {
1670 let delta = log_bayes_factor(best_ranking_score, row.ranking_score());
1671 let bayes_factor = delta.exp();
1672 let delta_reml = log_bayes_factor(best_raw_score, row.score);
1673 ranking.push(RankedRow {
1674 name: row.name.clone(),
1675 score: row.score,
1676 delta,
1677 bayes_factor,
1678 edf: row.edf,
1679 });
1680 score_table.push(ScoreRow {
1681 name: row.name.clone(),
1682 reml_score: row.score,
1683 delta_reml,
1684 bayes_factor_best_over_model: delta_reml.exp(),
1685 effective_dof: row.edf,
1686 });
1687 }
1688 let evidence_summary = if let Some(runner_up) = candidates.get(1) {
1693 let margin = runner_up.ranking_score() - candidates[0].ranking_score();
1694 format!(
1695 "{} wins by Bayes factor {} over {}",
1696 winner,
1697 format_bayes_factor(margin),
1698 runner_up.name
1699 )
1700 } else {
1701 format!("{winner} (single fit; no comparison)")
1702 };
1703 Ok(RemlComparison {
1704 ranking,
1705 winner,
1706 evidence_summary,
1707 score_table,
1708 })
1709}
1710
1711pub fn format_bayes_factor(log_bf: f64) -> String {
1712 if !log_bf.is_finite() {
1713 return "inf".to_string();
1714 }
1715 if log_bf.abs() >= std::f64::consts::LN_10 * 3.0 {
1716 return format!("1e{:+.1}", log_bf / std::f64::consts::LN_10);
1717 }
1718 format_three_significant(log_bf.exp())
1719}
1720
1721pub fn format_three_significant(value: f64) -> String {
1722 if value == 0.0 {
1723 return "0".to_string();
1724 }
1725 if !value.is_finite() {
1726 return format!("{value}");
1727 }
1728 let exponent = value.abs().log10().floor() as i32;
1729 if exponent >= 3 {
1730 return format!("{value:.2e}");
1731 }
1732 let decimals = (2 - exponent).max(0) as usize;
1733 let scale = 10f64.powi(decimals as i32);
1734 let rounded = (value * scale).abs().round() / scale * value.signum();
1735 format!("{rounded:.decimals$}")
1736}
1737
1738impl Default for TopologySelectOptions {
1739 fn default() -> Self {
1740 Self {
1741 tie_tolerance: 1e-3,
1742 score_scale: TopologyScoreScale::PerObservation,
1743 }
1744 }
1745}
1746
1747pub fn laplace_evidence(
1792 logdet_source: EvidenceLogDetSource<'_>,
1793 penalty_log_det: f64,
1794 residual_objective: f64,
1795 effective_dim: f64,
1796 penalty_rank: f64,
1797) -> f64 {
1798 if !(effective_dim.is_finite() && penalty_rank.is_finite()) {
1799 return f64::NAN;
1800 }
1801 let log_det_h = match evidence_hessian_log_det(logdet_source) {
1802 Ok(v) => v,
1803 Err(_) => return f64::NAN,
1804 };
1805 let null_dim = effective_dim - penalty_rank;
1806 if !null_dim.is_finite() || null_dim < -1e-9 {
1807 return f64::NAN;
1808 }
1809 residual_objective + 0.5 * log_det_h
1810 - 0.5 * penalty_log_det
1811 - 0.5 * null_dim.max(0.0) * (2.0 * std::f64::consts::PI).ln()
1812}
1813
1814pub fn evidence_hessian_log_det(source: EvidenceLogDetSource<'_>) -> Result<f64, String> {
1816 match source {
1817 EvidenceLogDetSource::FactoredArrow {
1818 cache,
1819 fallback_hvp,
1820 } => match arrow_log_det_from_cache(cache) {
1821 Some(v) => Ok(v),
1822 None => match fallback_hvp {
1823 Some(hvp) => hessian_log_det_from_hvp(hvp),
1824 None => {
1825 Err("evidence Hessian logdet requires exact factors or HVP fallback".into())
1826 }
1827 },
1828 },
1829 EvidenceLogDetSource::Hvp(hvp) => hessian_log_det_from_hvp(hvp),
1830 }
1831}
1832
1833pub fn hessian_log_det_from_hvp(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
1840 if hvp.dim == 0 {
1841 return Ok(0.0);
1842 }
1843 if hvp.dim <= ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD {
1844 let mut dense = Array2::<f64>::zeros((hvp.dim, hvp.dim));
1845 let mut basis = vec![0.0_f64; hvp.dim];
1846 for j in 0..hvp.dim {
1847 basis[j] = 1.0;
1848 let col = (hvp.apply)(&basis);
1849 basis[j] = 0.0;
1850 if col.len() != hvp.dim || col.iter().any(|v| !v.is_finite()) {
1851 return Err(format!(
1852 "evidence HVP logdet expected finite column of length {}, got {}",
1853 hvp.dim,
1854 col.len()
1855 ));
1856 }
1857 for i in 0..hvp.dim {
1858 dense[[i, j]] = col[i];
1859 }
1860 }
1861 validate_dense_hvp_symmetry(&dense)?;
1862 for i in 0..hvp.dim {
1863 for j in (i + 1)..hvp.dim {
1864 let avg = 0.5 * (dense[[i, j]] + dense[[j, i]]);
1865 dense[[i, j]] = avg;
1866 dense[[j, i]] = avg;
1867 }
1868 }
1869 dense_spd_log_det(&dense)
1870 } else {
1871 stochastic_hvp_log_det(hvp)
1872 }
1873}
1874
1875fn dense_spd_log_det(matrix: &Array2<f64>) -> Result<f64, String> {
1876 if matrix.nrows() != matrix.ncols() {
1877 return Err(format!(
1878 "evidence dense logdet requires square matrix, got {}x{}",
1879 matrix.nrows(),
1880 matrix.ncols()
1881 ));
1882 }
1883 if gam_gpu::cuda_selected() {
1884 return crate::gpu::reml_gpu::evidence_derivatives_gpu(
1885 crate::gpu::reml_gpu::RemlGpuInput {
1886 penalized_hessian: matrix.view(),
1887 derivative_hessians: Vec::new(),
1888 },
1889 )
1890 .map(|evidence| evidence.logdet_hessian);
1891 }
1892 let (evals, _) = matrix
1893 .eigh(Side::Lower)
1894 .map_err(|e| format!("evidence dense logdet eigendecomposition failed: {e}"))?;
1895 let mut logdet = 0.0_f64;
1896 for (idx, &ev) in evals.iter().enumerate() {
1897 if !ev.is_finite() || ev <= 0.0 {
1898 return Err(format!(
1899 "evidence dense logdet expected SPD Hessian, eigenvalue {idx} is {ev:.3e}"
1900 ));
1901 }
1902 logdet += ev.ln();
1903 }
1904 Ok(logdet)
1905}
1906
1907fn validate_dense_hvp_symmetry(matrix: &Array2<f64>) -> Result<(), String> {
1908 let n = matrix.nrows();
1909 let mut norm_sq = 0.0_f64;
1910 for &value in matrix.iter() {
1911 norm_sq += value * value;
1912 }
1913
1914 let mut skew_sq = 0.0_f64;
1915 for i in 0..n {
1916 for j in (i + 1)..n {
1917 let skew = matrix[[i, j]] - matrix[[j, i]];
1918 skew_sq += 2.0 * skew * skew;
1919 }
1920 }
1921
1922 let rel_skew = skew_sq.sqrt() / norm_sq.sqrt().max(1.0);
1923 if !rel_skew.is_finite() || rel_skew > EVIDENCE_HVP_SYMMETRY_REL_TOL {
1924 return Err(format!(
1925 "evidence HVP logdet requires symmetric operator, relative skew norm is {rel_skew:.3e}"
1926 ));
1927 }
1928 Ok(())
1929}
1930
1931fn validate_hvp_randomized_symmetry(hvp: EvidenceHvpLogDet<'_>) -> Result<(), String> {
1932 let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
1933 for probe in 0..EVIDENCE_HVP_SYMMETRY_PROBES.max(1) {
1934 let mut x = vec![0.0_f64; hvp.dim];
1935 let mut y = vec![0.0_f64; hvp.dim];
1936 rademacher_unit_probe_into_slice(&mut x, (2 * probe) as u64, inv_norm);
1937 rademacher_unit_probe_into_slice(&mut y, (2 * probe + 1) as u64, inv_norm);
1938
1939 let hx = (hvp.apply)(&x);
1940 let hy = (hvp.apply)(&y);
1941 if hx.len() != hvp.dim || hx.iter().any(|v| !v.is_finite()) {
1942 return Err(format!(
1943 "evidence HVP symmetry check expected finite vector of length {}, got {}",
1944 hvp.dim,
1945 hx.len()
1946 ));
1947 }
1948 if hy.len() != hvp.dim || hy.iter().any(|v| !v.is_finite()) {
1949 return Err(format!(
1950 "evidence HVP symmetry check expected finite vector of length {}, got {}",
1951 hvp.dim,
1952 hy.len()
1953 ));
1954 }
1955
1956 let lhs = dot_slice(&x, &hy);
1957 let rhs = dot_slice(&hx, &y);
1958 let scale = (norm2_slice(&hx) * norm2_slice(&y))
1959 .max(norm2_slice(&hy) * norm2_slice(&x))
1960 .max(lhs.abs())
1961 .max(rhs.abs())
1962 .max(1.0);
1963 let rel = (lhs - rhs).abs() / scale;
1964 if !rel.is_finite() || rel > EVIDENCE_HVP_SYMMETRY_REL_TOL {
1965 return Err(format!(
1966 "evidence HVP logdet requires symmetric operator, randomized symmetry probe {probe} has relative bilinear mismatch {rel:.3e}"
1967 ));
1968 }
1969 }
1970 Ok(())
1971}
1972
1973fn stochastic_hvp_log_det(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
1974 validate_hvp_randomized_symmetry(hvp)?;
1975 let probes = EVIDENCE_LOGDET_SLQ_PROBES.max(1);
1976 let steps = EVIDENCE_LOGDET_LANCZOS_STEPS.min(hvp.dim).max(1);
1977 let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
1978 let mut estimate = 0.0_f64;
1979 for probe in 0..probes {
1980 let mut q0 = vec![0.0_f64; hvp.dim];
1981 rademacher_unit_probe_into_slice(&mut q0, probe as u64, inv_norm);
1982 let quad = lanczos_log_quadrature_hvp(hvp, q0, steps)?;
1983 estimate += hvp.dim as f64 * quad;
1984 }
1985 Ok(estimate / probes as f64)
1986}
1987
1988fn lanczos_log_quadrature_hvp(
1989 hvp: EvidenceHvpLogDet<'_>,
1990 q: Vec<f64>,
1991 max_steps: usize,
1992) -> Result<f64, String> {
1993 let n = hvp.dim;
1994 let eigen = symmetric_lanczos_eigenpairs(
1995 n,
1996 &q,
1997 SymmetricLanczosOptions {
1998 max_steps,
1999 residual_tol: 1e-12,
2000 local_reorthogonalize: false,
2001 full_reorthogonalize: false,
2002 },
2003 |q, out| {
2004 let applied = (hvp.apply)(q);
2005 if applied.len() != n || applied.iter().any(|v| !v.is_finite()) {
2006 return Err(format!(
2007 "evidence HVP SLQ expected finite vector of length {n}, got {}",
2008 applied.len()
2009 ));
2010 }
2011 out.copy_from_slice(&applied);
2012 Ok(())
2013 },
2014 )
2015 .map_err(|e| format!("evidence HVP SLQ Lanczos failed: {e}"))?;
2016 symmetric_lanczos_log_quadrature(&eigen, "evidence HVP SLQ expected SPD Hessian")
2017}
2018
2019#[inline]
2020fn dot_slice(a: &[f64], b: &[f64]) -> f64 {
2021 assert_eq!(a.len(), b.len());
2022 let mut s = 0.0_f64;
2023 for i in 0..a.len() {
2024 s += a[i] * b[i];
2025 }
2026 s
2027}
2028
2029#[inline]
2030fn norm2_slice(a: &[f64]) -> f64 {
2031 dot_slice(a, a).sqrt()
2032}
2033
2034fn rademacher_unit_probe_into_slice(z: &mut [f64], probe: u64, scale: f64) {
2035 let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
2036 let mut bits = 0_u64;
2037 let mut remaining_bits = 0_u32;
2038 for value in z.iter_mut() {
2039 if remaining_bits == 0 {
2040 bits = splitmix64(&mut state);
2041 remaining_bits = 64;
2042 }
2043 *value = if bits & 1 == 0 { scale } else { -scale };
2044 bits >>= 1;
2045 remaining_bits -= 1;
2046 }
2047}
2048
2049#[inline]
2050const fn splitmix64(state: &mut u64) -> u64 {
2051 gam_linalg::utils::splitmix64(state)
2052}
2053
2054pub fn arrow_log_det_from_cache(cache: &ArrowFactorCache) -> Option<f64> {
2063 if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2064 return None;
2068 }
2069 if let Some(log_det) = cache.joint_hessian_log_det {
2070 return log_det.is_finite().then_some(log_det);
2071 }
2072 let schur = match cache.schur_factor.as_ref() {
2079 Some(schur) => Some(schur),
2080 None if cache.k == 0 => None,
2081 None => return None,
2082 };
2083
2084 let mut acc = 0.0_f64;
2085 for l in cache.undamped_factors_iter() {
2087 acc += 2.0 * log_det_from_chol_lower(l);
2088 }
2089 if let Some(schur) = schur {
2091 acc += 2.0 * log_det_from_chol_lower(schur.view());
2092 }
2093 let woodbury_correction = cache.cross_row_woodbury_log_det();
2098 if !woodbury_correction.is_finite() {
2099 return None;
2102 }
2103 acc += woodbury_correction;
2104 Some(acc)
2105}
2106
2107fn log_det_from_chol_lower(l: ArrayView2<'_, f64>) -> f64 {
2109 let n = l.nrows();
2110 let mut acc = 0.0_f64;
2111 for i in 0..n {
2112 let d = l[[i, i]];
2113 if d > 0.0 {
2114 acc += d.ln();
2115 } else {
2116 panic!(
2122 "log_det_from_chol_lower: non-positive Cholesky diagonal {d} at index {i}; \
2123 caller passed a corrupted or non-SPD factor"
2124 );
2125 }
2126 }
2127 acc
2128}
2129
2130pub fn ift_du_dbeta(cache: &ArrowFactorCache) -> Array2<f64> {
2140 let n = cache.undamped_factor_count();
2141 let total_len = cache.delta_t_len();
2142 let k = cache.k;
2143 if !cache.htbeta_available() {
2144 return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2145 }
2146 let mut out = Array2::<f64>::zeros((total_len, k));
2147 let mut beta_basis = Array1::<f64>::zeros(k);
2148 let mut rhs = Array1::<f64>::zeros(cache.d);
2150 for i in 0..n {
2151 let di = cache.row_dims[i];
2152 let row_base = cache.row_offsets[i];
2153 let factor = cache.undamped_factor(i);
2154 for col in 0..k {
2156 beta_basis.fill(0.0);
2157 beta_basis[col] = 1.0;
2158 let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2159 if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2162 return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2165 }
2166 let y = cholesky_solve_vector(factor, &rhs_i);
2167 for c in 0..di {
2168 out[[row_base + c, col]] = -y[c];
2169 }
2170 }
2171 }
2172 out
2173}
2174
2175pub fn coupling_components(hessian: ArrayView2<'_, f64>) -> Vec<usize> {
2196 let p = hessian.nrows();
2197 if p == 0 || hessian.ncols() != p {
2198 return Vec::new();
2199 }
2200 let mut parent: Vec<usize> = (0..p).collect();
2202 let mut size: Vec<usize> = vec![1; p];
2203
2204 fn find(parent: &mut [usize], mut x: usize) -> usize {
2205 while parent[x] != x {
2206 parent[x] = parent[parent[x]];
2207 x = parent[x];
2208 }
2209 x
2210 }
2211
2212 for i in 0..p {
2213 for j in (i + 1)..p {
2214 if hessian[[i, j]] != 0.0 || hessian[[j, i]] != 0.0 {
2217 let (ri, rj) = (find(&mut parent, i), find(&mut parent, j));
2218 if ri != rj {
2219 let (small, large) = if size[ri] < size[rj] {
2220 (ri, rj)
2221 } else {
2222 (rj, ri)
2223 };
2224 parent[small] = large;
2225 size[large] += size[small];
2226 }
2227 }
2228 }
2229 }
2230
2231 let mut label_of_root: Vec<Option<usize>> = vec![None; p];
2234 let mut next_label = 0usize;
2235 let mut labels = vec![0usize; p];
2236 for idx in 0..p {
2237 let root = find(&mut parent, idx);
2238 let label = match label_of_root[root] {
2239 Some(l) => l,
2240 None => {
2241 let l = next_label;
2242 label_of_root[root] = Some(l);
2243 next_label += 1;
2244 l
2245 }
2246 };
2247 labels[idx] = label;
2248 }
2249 labels
2250}
2251
2252pub fn cone_of_influence(labels: &[usize], support: &[usize]) -> Vec<usize> {
2263 if support.is_empty() {
2264 return Vec::new();
2265 }
2266 let mut in_cone_labels: Vec<usize> = support
2267 .iter()
2268 .filter_map(|&idx| labels.get(idx).copied())
2269 .collect();
2270 in_cone_labels.sort_unstable();
2271 in_cone_labels.dedup();
2272 if in_cone_labels.is_empty() {
2273 return Vec::new();
2274 }
2275 (0..labels.len())
2276 .filter(|idx| in_cone_labels.binary_search(&labels[*idx]).is_ok())
2277 .collect()
2278}
2279
2280pub fn ift_dbeta_drho(
2292 cache: &ArrowFactorCache,
2293 dg_red_drho: ArrayView2<'_, f64>,
2294) -> Option<Array2<f64>> {
2295 if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2296 return None;
2297 }
2298 let schur = cache.schur_factor.as_ref()?;
2299 if dg_red_drho.nrows() != cache.k || schur.nrows() != cache.k {
2300 return None;
2301 }
2302 crate::sensitivity::FitSensitivity::from_lower_triangular(schur)
2303 .mode_response(dg_red_drho)
2304}
2305
2306
2307#[derive(Clone)]
2325pub struct EvidenceIftGradientTerms<'a> {
2326 pub dbeta_drho: ArrayView2<'a, f64>,
2327 pub du_drho: ArrayView2<'a, f64>,
2328 pub value_beta: ArrayView1<'a, f64>,
2329 pub value_u: ArrayView1<'a, f64>,
2330 pub logdet_h_beta: ArrayView1<'a, f64>,
2331 pub logdet_h_u: ArrayView1<'a, f64>,
2332}
2333
2334pub fn evidence_ift_gradient_correction(terms: EvidenceIftGradientTerms<'_>) -> Array1<f64> {
2337 let k = terms.dbeta_drho.nrows();
2338 let nd = terms.du_drho.nrows();
2339 let r = terms.dbeta_drho.ncols();
2340 if terms.du_drho.ncols() != r
2341 || terms.value_beta.len() != k
2342 || terms.logdet_h_beta.len() != k
2343 || terms.value_u.len() != nd
2344 || terms.logdet_h_u.len() != nd
2345 {
2346 return Array1::<f64>::from_elem(r, f64::NAN);
2347 }
2348
2349 let mut out = Array1::<f64>::zeros(r);
2350 for a in 0..r {
2351 let mut acc = 0.0_f64;
2352 for j in 0..k {
2353 let mode = terms.dbeta_drho[[j, a]];
2354 acc += terms.value_beta[j] * mode;
2355 acc += 0.5 * terms.logdet_h_beta[j] * mode;
2356 }
2357 for j in 0..nd {
2358 let mode = terms.du_drho[[j, a]];
2359 acc += terms.value_u[j] * mode;
2360 acc += 0.5 * terms.logdet_h_u[j] * mode;
2361 }
2362 out[a] = acc;
2363 }
2364 out
2365}
2366
2367pub fn evidence_grad_rho(
2397 cache: &ArrowFactorCache,
2398 value_rho: ArrayView1<'_, f64>,
2399 huu_drho: &[Vec<Array2<f64>>],
2400 htbeta_drho: &[Vec<Array2<f64>>],
2401 hbb_drho: &[Array2<f64>],
2402 pen_logdet_drho: ArrayView1<'_, f64>,
2403 ift_terms: EvidenceIftGradientTerms<'_>,
2404) -> Array1<f64> {
2405 let r = value_rho.len();
2406 let n = cache.undamped_factor_count();
2407 let k = cache.k;
2408 let mut out = Array1::<f64>::zeros(r);
2409 if !cache.htbeta_available()
2410 || pen_logdet_drho.len() != r
2411 || huu_drho.len() != n
2412 || htbeta_drho.len() != n
2413 || hbb_drho.len() != r
2414 || huu_drho.iter().any(|row| row.len() != r)
2415 || htbeta_drho.iter().any(|row| row.len() != r)
2416 || hbb_drho.iter().any(|m| m.nrows() != k || m.ncols() != k)
2417 || huu_drho.iter().enumerate().any(|(i, row)| {
2418 let di = cache.row_dims[i];
2419 row.iter().any(|m| m.nrows() != di || m.ncols() != di)
2420 })
2421 || htbeta_drho.iter().enumerate().any(|(i, row)| {
2422 let di = cache.row_dims[i];
2423 row.iter().any(|m| m.nrows() != di || m.ncols() != k)
2424 })
2425 {
2426 out.fill(f64::NAN);
2427 return out;
2428 }
2429 let ift_correction = evidence_ift_gradient_correction(ift_terms);
2430 if ift_correction.len() != r || ift_correction.iter().any(|v| v.is_nan()) {
2431 out.fill(f64::NAN);
2432 return out;
2433 }
2434
2435 let schur = match cache.schur_factor.as_ref() {
2436 Some(s) => s,
2437 None => {
2438 for a in 0..r {
2439 out[a] = f64::NAN;
2440 }
2441 return out;
2442 }
2443 };
2444
2445 let mut y_blocks: Vec<Array2<f64>> = Vec::with_capacity(n);
2448 let mut beta_basis = Array1::<f64>::zeros(k);
2449 let mut rhs = Array1::<f64>::zeros(cache.d);
2451 for i in 0..n {
2452 let di = cache.row_dims[i];
2453 let factor = cache.undamped_factor(i);
2454 let mut yi = Array2::<f64>::zeros((di, k));
2455 for col in 0..k {
2456 beta_basis.fill(0.0);
2457 beta_basis[col] = 1.0;
2458 let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2459 if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2461 out.fill(f64::NAN);
2464 return out;
2465 }
2466 let v = cholesky_solve_vector(factor, &rhs_i);
2467 for c in 0..di {
2468 yi[[c, col]] = v[c];
2469 }
2470 }
2471 y_blocks.push(yi);
2472 }
2473
2474 let mut trace_rhs = Array1::<f64>::zeros(cache.d);
2477 let mut da_tmp = Array2::<f64>::zeros((cache.d, k));
2478 let mut col_scratch = Array1::<f64>::zeros(k);
2479 for a in 0..r {
2480 let mut grad = value_rho[a];
2482
2483 let mut row_trace_acc = 0.0_f64;
2490 for i in 0..n {
2491 let di = cache.row_dims[i];
2492 let m_i = &huu_drho[i][a];
2493 assert_eq!(m_i.shape(), &[di, di]);
2494 for col in 0..di {
2495 let mut tr_rhs_i = trace_rhs.slice_mut(ndarray::s![..di]).to_owned();
2496 for r0 in 0..di {
2497 tr_rhs_i[r0] = m_i[[r0, col]];
2498 }
2499 let v = cholesky_solve_vector(cache.undamped_factor(i), &tr_rhs_i);
2500 row_trace_acc += v[col];
2501 }
2502 }
2503
2504 let mut da = hbb_drho[a].clone();
2513 assert_eq!(da.shape(), &[k, k]);
2514 for i in 0..n {
2515 let di = cache.row_dims[i];
2516 let dhtb = &htbeta_drho[i][a]; let yi = &y_blocks[i]; for r0 in 0..k {
2520 for c0 in 0..k {
2521 let mut acc = 0.0;
2522 for cc in 0..di {
2523 acc += dhtb[[cc, r0]] * yi[[cc, c0]];
2524 }
2525 da[[r0, c0]] -= acc;
2526 }
2527 }
2528 for r0 in 0..k {
2530 for c0 in 0..k {
2531 let mut acc = 0.0;
2532 for cc in 0..di {
2533 acc += yi[[cc, r0]] * dhtb[[cc, c0]];
2534 }
2535 da[[r0, c0]] -= acc;
2536 }
2537 }
2538 let dhuu = &huu_drho[i][a];
2540 let mut da_tmp_i = da_tmp.slice_mut(ndarray::s![..di, ..]).to_owned();
2542 for r0 in 0..di {
2543 for c0 in 0..k {
2544 let mut acc = 0.0;
2545 for cc in 0..di {
2546 acc += dhuu[[r0, cc]] * yi[[cc, c0]];
2547 }
2548 da_tmp_i[[r0, c0]] = acc;
2549 }
2550 }
2551 for r0 in 0..k {
2553 for c0 in 0..k {
2554 let mut acc = 0.0;
2555 for cc in 0..di {
2556 acc += yi[[cc, r0]] * da_tmp_i[[cc, c0]];
2557 }
2558 da[[r0, c0]] += acc;
2559 }
2560 }
2561 }
2562
2563 let mut schur_trace_acc = 0.0_f64;
2565 for j in 0..k {
2566 for r0 in 0..k {
2567 col_scratch[r0] = da[[r0, j]];
2568 }
2569 let v = cholesky_solve_vector(schur, &col_scratch);
2570 schur_trace_acc += v[j];
2571 }
2572
2573 grad += 0.5 * (row_trace_acc + schur_trace_acc);
2574 grad += ift_correction[a];
2575
2576 grad -= 0.5 * pen_logdet_drho[a];
2578
2579 out[a] = grad;
2580 }
2581 out
2582}
2583
2584pub fn select_topology(
2610 candidates: &[TopologyCandidate],
2611 options: TopologySelectOptions,
2612) -> SelectedTopology {
2613 let mut valid: Vec<TopologyCandidate> = candidates
2615 .iter()
2616 .filter(|c| {
2617 c.converged
2618 && c.exclusion_reason.is_none()
2619 && c.negative_log_evidence.is_finite()
2620 && topology_selection_score(c, options.score_scale).is_finite()
2621 })
2622 .cloned()
2623 .collect();
2624 let mut excluded: Vec<TopologyCandidate> = candidates
2625 .iter()
2626 .filter(|c| {
2627 !(c.converged && c.exclusion_reason.is_none() && c.negative_log_evidence.is_finite())
2628 || !topology_selection_score(c, options.score_scale).is_finite()
2629 })
2630 .cloned()
2631 .collect();
2632
2633 assert!(
2634 !valid.is_empty(),
2635 "select_topology: no finite valid candidates; proposal §6.11 forbids silent fallback"
2636 );
2637
2638 valid = rank_priority_candidates(
2643 valid
2644 .into_iter()
2645 .enumerate()
2646 .map(|(idx, row)| {
2647 let score = topology_selection_score(&row, options.score_scale);
2648 let tie_break = usize::from(row.kind.complexity_rank());
2649 PriorityCandidate::new(row, idx, score, tie_break)
2650 })
2651 .collect(),
2652 )
2653 .into_iter()
2654 .map(|row| row.item)
2655 .collect();
2656
2657 let tie = if valid.len() >= 2 {
2659 let top = topology_selection_score(&valid[0], options.score_scale);
2660 let next = topology_selection_score(&valid[1], options.score_scale);
2661 (next - top).abs() <= options.tie_tolerance
2662 } else {
2663 false
2664 };
2665
2666 if tie {
2668 let top_score = topology_selection_score(&valid[0], options.score_scale);
2669 let tied_end = valid
2671 .iter()
2672 .position(|c| {
2673 (topology_selection_score(c, options.score_scale) - top_score).abs()
2674 > options.tie_tolerance
2675 })
2676 .unwrap_or(valid.len());
2677 valid[..tied_end].sort_by_key(|c| c.kind.complexity_rank());
2679 }
2680
2681 let winner = valid[0].kind;
2682 valid.append(&mut excluded);
2683 SelectedTopology {
2684 winner,
2685 ranking: valid,
2686 tie,
2687 }
2688}
2689
2690fn topology_selection_score(candidate: &TopologyCandidate, scale: TopologyScoreScale) -> f64 {
2691 match scale {
2692 TopologyScoreScale::PerObservation => {
2693 if candidate.n_obs == 0 {
2694 f64::NAN
2695 } else {
2696 candidate.negative_log_evidence / candidate.n_obs as f64
2697 }
2698 }
2699 TopologyScoreScale::PerEffectiveDim => {
2700 if !(candidate.effective_dim.is_finite() && candidate.effective_dim > 0.0) {
2701 f64::NAN
2702 } else {
2703 candidate.negative_log_evidence / candidate.effective_dim
2704 }
2705 }
2706 }
2707}
2708
2709pub fn cache_matches_system(cache: &ArrowFactorCache, sys: &ArrowSchurSystem) -> bool {
2718 cache.d == sys.d
2719 && cache.k == sys.k
2720 && cache.n_rows() == sys.rows.len()
2721 && cache.undamped_factor_count() == sys.rows.len()
2722 && cache.manifold_mode_fingerprint == sys.manifold_mode_fingerprint
2723 && cache.row_hessian_fingerprint == sys.current_row_hessian_fingerprint()
2724}
2725
2726#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2779pub enum HybridAtomParam {
2780 Curved { latent_dim: usize },
2782 Linear,
2784}
2785
2786impl HybridAtomParam {
2787 pub const fn as_str(self) -> &'static str {
2789 match self {
2790 HybridAtomParam::Curved { .. } => "curved",
2791 HybridAtomParam::Linear => "linear",
2792 }
2793 }
2794
2795 pub const fn is_linear(self) -> bool {
2797 matches!(self, HybridAtomParam::Linear)
2798 }
2799}
2800
2801#[derive(Debug, Clone, Copy)]
2809pub struct HybridAtomCandidate {
2810 pub param: HybridAtomParam,
2811 pub negative_log_evidence: f64,
2813 pub num_parameters: usize,
2815 pub fitted_turning: Option<f64>,
2820}
2821
2822impl HybridAtomCandidate {
2823 pub fn linear(negative_log_evidence: f64, num_parameters: usize) -> Self {
2825 Self {
2826 param: HybridAtomParam::Linear,
2827 negative_log_evidence,
2828 num_parameters,
2829 fitted_turning: Some(0.0),
2830 }
2831 }
2832
2833 pub fn curved(
2835 latent_dim: usize,
2836 negative_log_evidence: f64,
2837 num_parameters: usize,
2838 fitted_turning: Option<f64>,
2839 ) -> Self {
2840 Self {
2841 param: HybridAtomParam::Curved { latent_dim },
2842 negative_log_evidence,
2843 num_parameters,
2844 fitted_turning,
2845 }
2846 }
2847}
2848
2849#[derive(Debug, Clone, Copy)]
2853pub struct HybridAtomChoice {
2854 pub param: HybridAtomParam,
2855 pub negative_log_evidence: f64,
2857 pub num_parameters: usize,
2859 pub curved_turning: Option<f64>,
2862 pub curved_evidence_margin: f64,
2867}
2868
2869pub const HYBRID_LINEAR_TURNING_FLOOR: f64 = 1e-9;
2877
2878pub fn select_hybrid_atom(candidates: &[HybridAtomCandidate]) -> Option<HybridAtomChoice> {
2906 if candidates.is_empty() {
2907 return None;
2908 }
2909 let linear = candidates.iter().find(|c| c.param.is_linear());
2910 let curved = candidates.iter().find(|c| !c.param.is_linear());
2911 let curved_turning = curved.and_then(|c| c.fitted_turning);
2912 let curved_evidence_margin = match (linear, curved) {
2913 (Some(l), Some(c)) => l.negative_log_evidence - c.negative_log_evidence,
2914 _ => 0.0,
2915 };
2916
2917 if let (Some(l), Some(turning)) = (linear, curved_turning)
2920 && turning <= HYBRID_LINEAR_TURNING_FLOOR
2921 {
2922 return Some(HybridAtomChoice {
2923 param: l.param,
2924 negative_log_evidence: l.negative_log_evidence,
2925 num_parameters: l.num_parameters,
2926 curved_turning,
2927 curved_evidence_margin,
2928 });
2929 }
2930
2931 let mut best = candidates[0];
2933 for cand in &candidates[1..] {
2934 let better_evidence = cand.negative_log_evidence < best.negative_log_evidence;
2935 let tied = cand.negative_log_evidence == best.negative_log_evidence;
2936 let cheaper_on_tie = tied && cand.num_parameters < best.num_parameters;
2937 if better_evidence || cheaper_on_tie {
2938 best = *cand;
2939 }
2940 }
2941 Some(HybridAtomChoice {
2942 param: best.param,
2943 negative_log_evidence: best.negative_log_evidence,
2944 num_parameters: best.num_parameters,
2945 curved_turning,
2946 curved_evidence_margin,
2947 })
2948}
2949
2950#[derive(Debug, Clone)]
2954pub struct HybridSplitSelection {
2955 pub atoms: Vec<HybridAtomChoice>,
2957 pub total_negative_log_evidence: f64,
2966 pub total_parameters: usize,
2969 pub curved_atom_count: usize,
2971}
2972
2973impl HybridSplitSelection {
2974 pub fn linear_atom_count(&self) -> usize {
2976 self.atoms.len() - self.curved_atom_count
2977 }
2978
2979 pub fn is_pure_linear(&self) -> bool {
2982 self.curved_atom_count == 0 && !self.atoms.is_empty()
2983 }
2984
2985 pub fn is_pure_curved(&self) -> bool {
2988 self.curved_atom_count == self.atoms.len() && !self.atoms.is_empty()
2989 }
2990}
2991
2992pub fn select_hybrid_split(
3006 slots: &[Vec<HybridAtomCandidate>],
3007) -> Result<HybridSplitSelection, String> {
3008 let mut atoms = Vec::with_capacity(slots.len());
3009 let mut total_nle = 0.0_f64;
3010 let mut total_parameters = 0usize;
3011 let mut curved_atom_count = 0usize;
3012 for (i, slot) in slots.iter().enumerate() {
3013 let choice = select_hybrid_atom(slot)
3014 .ok_or_else(|| format!("hybrid split slot {i} has no candidate parameterizations"))?;
3015 if !choice.negative_log_evidence.is_finite() {
3016 return Err(format!(
3017 "hybrid split slot {i} selected a non-finite evidence ({})",
3018 choice.negative_log_evidence
3019 ));
3020 }
3021 if !choice.param.is_linear() {
3022 curved_atom_count += 1;
3023 }
3024 total_nle += choice.negative_log_evidence;
3025 total_parameters += choice.num_parameters;
3026 atoms.push(choice);
3027 }
3028 Ok(HybridSplitSelection {
3029 atoms,
3030 total_negative_log_evidence: total_nle,
3031 total_parameters,
3032 curved_atom_count,
3033 })
3034}
3035
3036#[cfg(test)]
3046mod tests {
3047 use super::*;
3048 use crate::arrow_schur::ArrowFactorSlab;
3049
3050 fn dense_inverse(h: &Array2<f64>) -> Array2<f64> {
3052 let p = h.nrows();
3053 let mut aug = Array2::<f64>::zeros((p, 2 * p));
3054 for i in 0..p {
3055 for j in 0..p {
3056 aug[[i, j]] = h[[i, j]];
3057 }
3058 aug[[i, p + i]] = 1.0;
3059 }
3060 for col in 0..p {
3061 let mut pivot = col;
3062 for row in (col + 1)..p {
3063 if aug[[row, col]].abs() > aug[[pivot, col]].abs() {
3064 pivot = row;
3065 }
3066 }
3067 if pivot != col {
3068 for j in 0..(2 * p) {
3069 aug.swap([col, j], [pivot, j]);
3070 }
3071 }
3072 let d = aug[[col, col]];
3073 for j in 0..(2 * p) {
3074 aug[[col, j]] /= d;
3075 }
3076 for row in 0..p {
3077 if row == col {
3078 continue;
3079 }
3080 let f = aug[[row, col]];
3081 if f != 0.0 {
3082 for j in 0..(2 * p) {
3083 aug[[row, j]] -= f * aug[[col, j]];
3084 }
3085 }
3086 }
3087 }
3088 let mut inv = Array2::<f64>::zeros((p, p));
3089 for i in 0..p {
3090 for j in 0..p {
3091 inv[[i, j]] = aug[[i, p + j]];
3092 }
3093 }
3094 inv
3095 }
3096
3097 #[test]
3098 fn coupling_components_block_diagonal_is_all_singletons_by_block() {
3099 let mut h = Array2::<f64>::eye(4);
3101 h[[0, 1]] = 0.3;
3102 h[[1, 0]] = 0.3;
3103 h[[2, 3]] = 0.7;
3104 h[[3, 2]] = 0.7;
3105 let labels = coupling_components(h.view());
3106 assert_eq!(labels[0], labels[1]);
3107 assert_eq!(labels[2], labels[3]);
3108 assert_ne!(labels[0], labels[2]);
3109 let mut uniq = labels.clone();
3111 uniq.sort_unstable();
3112 uniq.dedup();
3113 assert_eq!(uniq.len(), 2);
3114 }
3115
3116 #[test]
3117 fn coupling_components_fully_coupled_is_one_component() {
3118 let mut h = Array2::<f64>::eye(3);
3119 for i in 0..3 {
3120 for j in 0..3 {
3121 if i != j {
3122 h[[i, j]] = 0.1;
3123 }
3124 }
3125 }
3126 let labels = coupling_components(h.view());
3127 assert!(labels.iter().all(|&l| l == labels[0]));
3128 }
3129
3130 #[test]
3131 fn coupling_components_transitive_chain_merges() {
3132 let mut h = Array2::<f64>::eye(3);
3134 h[[0, 1]] = 0.5;
3135 h[[1, 0]] = 0.5;
3136 h[[1, 2]] = 0.5;
3137 h[[2, 1]] = 0.5;
3138 let labels = coupling_components(h.view());
3139 assert_eq!(labels[0], labels[1]);
3140 assert_eq!(labels[1], labels[2]);
3141 }
3142
3143 #[test]
3144 fn compare_reml_fits_delta_and_bayes_factor_never_contradict_winner_gh1465() {
3145 let cand = |name: &str, score: f64, edf: f64| RemlCandidate {
3157 index: 0,
3158 name: name.to_string(),
3159 score,
3160 edf: Some(edf),
3161 log_lik: Some(0.0),
3162 family: Some("gaussian".to_string()),
3163 };
3164 let candidates = vec![
3167 cand("m1", 53.748, 50.0),
3168 cand("m2", 41.605, 51.0),
3169 cand("m3", 120.011, 65.0),
3170 ];
3171 let cmp = compare_reml_fits(candidates).expect("comparison");
3172
3173 assert_eq!(cmp.winner, "m1", "AIC winner");
3174 for row in &cmp.ranking {
3176 assert!(
3177 row.delta >= 0.0,
3178 "ranking delta for {} must be >= 0, got {}",
3179 row.name,
3180 row.delta
3181 );
3182 assert!(
3183 row.bayes_factor >= 1.0 - 1e-12,
3184 "ranking bayes_factor for {} must be >= 1, got {}",
3185 row.name,
3186 row.bayes_factor
3187 );
3188 }
3189 let winner_row = cmp.ranking.iter().find(|r| r.name == "m1").unwrap();
3190 assert!(winner_row.delta.abs() < 1e-12, "winner delta == 0");
3191 assert!(
3192 (winner_row.bayes_factor - 1.0).abs() < 1e-9,
3193 "winner bayes_factor == 1"
3194 );
3195
3196 for row in &cmp.score_table {
3199 assert!(
3200 row.delta_reml >= 0.0,
3201 "score-table delta_reml for {} must be >= 0, got {}",
3202 row.name,
3203 row.delta_reml
3204 );
3205 assert!(
3206 row.bayes_factor_best_over_model >= 1.0 - 1e-12,
3207 "score-table bayes_factor for {} must be >= 1, got {}",
3208 row.name,
3209 row.bayes_factor_best_over_model
3210 );
3211 }
3212 let m2 = cmp.score_table.iter().find(|r| r.name == "m2").unwrap();
3214 assert!(
3215 m2.delta_reml.abs() < 1e-12,
3216 "the minimum-raw-REML row has delta_reml 0"
3217 );
3218 }
3219
3220 #[test]
3221 fn cone_of_influence_empty_support_is_empty() {
3222 let labels = vec![0usize, 0, 1, 1];
3223 assert!(cone_of_influence(&labels, &[]).is_empty());
3224 }
3225
3226 #[test]
3227 fn cone_of_influence_returns_full_component() {
3228 let labels = vec![0usize, 0, 1, 1];
3229 assert_eq!(cone_of_influence(&labels, &[0]), vec![0, 1]);
3231 assert_eq!(cone_of_influence(&labels, &[1, 2]), vec![0, 1, 2, 3]);
3233 }
3234
3235 #[test]
3236 fn coned_matches_full_solve_on_fully_coupled_hessian() {
3237 let h = Array2::from_shape_vec((3, 3), vec![4.0, 1.0, 0.5, 1.0, 3.0, 0.8, 0.5, 0.8, 2.5])
3240 .unwrap();
3241 let inv = dense_inverse(&h);
3242 let mut dg = Array2::<f64>::zeros((3, 2));
3244 dg[[0, 0]] = 1.3;
3245 dg[[2, 1]] = -0.7;
3246 let supports = vec![0..1usize, 2..3usize];
3247
3248 let eye: Array2<f64> = Array2::eye(3);
3249 let op = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv);
3250 let full = op.mode_response(dg.view()).unwrap();
3251 let coned = op
3252 .mode_response_coned(h.view(), dg.view(), &supports)
3253 .unwrap();
3254 for i in 0..3 {
3255 for a in 0..2 {
3256 assert!(
3257 (full[[i, a]] - coned[[i, a]]).abs() < 1e-12,
3258 "fully-coupled mismatch at ({i},{a}): {} vs {}",
3259 full[[i, a]],
3260 coned[[i, a]]
3261 );
3262 }
3263 }
3264 }
3265
3266 #[test]
3267 fn coned_confines_to_component_on_decoupled_hessian() {
3268 let mut h = Array2::<f64>::zeros((4, 4));
3272 h[[0, 0]] = 4.0;
3274 h[[1, 1]] = 3.0;
3275 h[[0, 1]] = 1.0;
3276 h[[1, 0]] = 1.0;
3277 h[[2, 2]] = 2.0;
3279 h[[3, 3]] = 5.0;
3280 h[[2, 3]] = 0.6;
3281 h[[3, 2]] = 0.6;
3282 let inv = dense_inverse(&h);
3283
3284 let mut dg = Array2::<f64>::zeros((4, 1));
3285 dg[[0, 0]] = 0.9;
3286 dg[[1, 0]] = -0.4;
3287 let support_range = 0..2usize;
3288 let supports = std::slice::from_ref(&support_range);
3289
3290 let eye: Array2<f64> = Array2::eye(4);
3291 let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv)
3292 .mode_response_coned(h.view(), dg.view(), supports)
3293 .unwrap();
3294 let q = dg.column(0).to_owned();
3297 let exact = inv.dot(&q).mapv(|v| -v);
3298 for i in 0..4 {
3299 assert!(
3300 (coned[[i, 0]] - exact[[i]]).abs() < 1e-12,
3301 "decoupled mismatch at {i}: {} vs {}",
3302 coned[[i, 0]],
3303 exact[[i]]
3304 );
3305 }
3306 assert_eq!(coned[[2, 0]], 0.0);
3308 assert_eq!(coned[[3, 0]], 0.0);
3309 }
3310
3311 #[test]
3312 fn coned_skips_inactive_column_with_empty_support() {
3313 let h = Array2::<f64>::eye(2);
3314 let dg = Array2::<f64>::zeros((2, 1));
3315 let empty_support = 0..0usize;
3317 let supports = std::slice::from_ref(&empty_support);
3318 let eye: Array2<f64> = Array2::eye(2);
3323 let nan_inv = Array2::<f64>::from_elem((2, 2), f64::NAN);
3324 let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &nan_inv)
3325 .mode_response_coned(h.view(), dg.view(), supports)
3326 .unwrap();
3327 assert_eq!(coned[[0, 0]], 0.0);
3328 assert_eq!(coned[[1, 0]], 0.0);
3329 }
3330
3331 fn make_minimal_cache() -> ArrowFactorCache {
3332 let l_huu = Array2::from_shape_vec((1, 1), vec![std::f64::consts::SQRT_2]).unwrap();
3335 let l_schur = Array2::from_shape_vec((1, 1), vec![(1.875_f64).sqrt()]).unwrap();
3336 let htbeta = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
3337 ArrowFactorCache {
3338 htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3339 htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3340 schur_factor: Some(l_schur),
3341 joint_hessian_log_det: None,
3342 solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3343 ridge_t: 0.0,
3344 ridge_beta: 0.0,
3345 htbeta: crate::arrow_schur::ArrowHtbetaCache::Dense {
3346 blocks: std::sync::Arc::from(vec![htbeta]),
3347 estimated_bytes: std::mem::size_of::<f64>(),
3348 },
3349 d: 1,
3350 row_dims: std::sync::Arc::from(vec![1usize]),
3351 row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3352 k: 1,
3353 manifold_mode_fingerprint: 0,
3354 row_hessian_fingerprint: 0,
3355 pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3356 gauge_deflated_directions: 0,
3357 deflated_row_directions: std::sync::Arc::from(Vec::new()),
3358 deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3359 cross_row_woodbury: None,
3360 }
3361 }
3362
3363 #[test]
3364 fn laplace_evidence_returns_finite_for_minimal_cache() {
3365 let cache = make_minimal_cache();
3366 let v = laplace_evidence(
3369 EvidenceLogDetSource::FactoredArrow {
3370 cache: &cache,
3371 fallback_hvp: None,
3372 },
3373 0.0,
3374 0.0,
3375 2.0,
3376 1.0,
3377 );
3378 assert!(v.is_finite());
3379 let expected =
3380 0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3381 assert!((v - expected).abs() < 1e-12);
3382 }
3383
3384 fn k0_direct_cache_no_schur(latent_diag: f64) -> ArrowFactorCache {
3393 let l_huu = Array2::from_shape_vec((1, 1), vec![latent_diag.sqrt()]).unwrap();
3394 ArrowFactorCache {
3395 htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3396 htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3397 schur_factor: None,
3398 joint_hessian_log_det: None,
3399 solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3400 ridge_t: 0.0,
3401 ridge_beta: 0.0,
3402 htbeta: crate::arrow_schur::ArrowHtbetaCache::Disabled { estimated_bytes: 0 },
3403 d: 1,
3404 row_dims: std::sync::Arc::from(vec![1usize]),
3405 row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3406 k: 0,
3407 manifold_mode_fingerprint: 0,
3408 row_hessian_fingerprint: 0,
3409 pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3410 gauge_deflated_directions: 0,
3411 deflated_row_directions: std::sync::Arc::from(Vec::new()),
3412 deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3413 cross_row_woodbury: None,
3414 }
3415 }
3416
3417 #[test]
3418 fn arrow_log_det_some_for_k0_direct_cache_without_schur() {
3419 let cache = k0_direct_cache_no_schur(3.0);
3420 let log_det = arrow_log_det_from_cache(&cache)
3421 .expect("k==0 Direct cache must yield Some(per-row sum), not None (#1132)");
3422 assert!(
3424 (log_det - 3.0_f64.ln()).abs() < 1e-12,
3425 "log_det = {log_det}"
3426 );
3427 let cached = cache
3429 .compute_undamped_arrow_log_det()
3430 .expect("compute_undamped_arrow_log_det must be Some for k==0");
3431 assert!((cached - 3.0_f64.ln()).abs() < 1e-12, "cached = {cached}");
3432 }
3433
3434 #[test]
3435 fn arrow_log_det_none_for_kpos_cache_without_schur() {
3436 let mut cache = k0_direct_cache_no_schur(3.0);
3439 cache.k = 1;
3440 cache.solver_mode = crate::arrow_schur::ArrowSolverMode::InexactPCG;
3441 assert!(arrow_log_det_from_cache(&cache).is_none());
3442 assert!(cache.compute_undamped_arrow_log_det().is_none());
3443 }
3444
3445 #[test]
3446 fn laplace_evidence_nan_when_ridge_is_nonzero() {
3447 let mut cache = make_minimal_cache();
3448 cache.ridge_t = 1e-3;
3449 assert!(
3450 laplace_evidence(
3451 EvidenceLogDetSource::FactoredArrow {
3452 cache: &cache,
3453 fallback_hvp: None,
3454 },
3455 0.0,
3456 0.0,
3457 2.0,
3458 1.0,
3459 )
3460 .is_nan()
3461 );
3462 }
3463
3464 #[test]
3465 fn laplace_evidence_uses_hvp_fallback_without_schur_factor() {
3466 let mut cache = make_minimal_cache();
3467 cache.schur_factor = None;
3468 let hvp = |x: &[f64]| -> Vec<f64> { vec![2.0 * x[0], 1.875 * x[1]] };
3469 let v = laplace_evidence(
3470 EvidenceLogDetSource::FactoredArrow {
3471 cache: &cache,
3472 fallback_hvp: Some(EvidenceHvpLogDet {
3473 dim: 2,
3474 apply: &hvp,
3475 }),
3476 },
3477 0.0,
3478 0.0,
3479 2.0,
3480 1.0,
3481 );
3482 let expected =
3483 0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3484 assert!((v - expected).abs() < 1e-12);
3485 }
3486
3487 #[test]
3488 fn ift_du_dbeta_has_expected_shape() {
3489 let cache = make_minimal_cache();
3490 let du_db = ift_du_dbeta(&cache);
3491 assert_eq!(du_db.shape(), &[1, 1]);
3492 assert!((du_db[[0, 0]] - (-0.25)).abs() < 1e-12);
3494 }
3495
3496 #[test]
3497 fn ift_dbeta_drho_returns_some_for_direct_cache() {
3498 let cache = make_minimal_cache();
3499 let q = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
3500 let out = ift_dbeta_drho(&cache, q.view()).unwrap();
3501 assert_eq!(out.shape(), &[1, 1]);
3502 assert!((out[[0, 0]] + 1.0 / 1.875).abs() < 1e-12);
3504 }
3505
3506 #[test]
3507 fn topology_select_picks_lowest_negative_log_evidence() {
3508 let candidates = vec![
3509 TopologyCandidate {
3510 kind: TopologyKind::Flat,
3511 negative_log_evidence: 10.0,
3512 effective_dim: 4.0,
3513 n_obs: 100,
3514 converged: true,
3515 exclusion_reason: None,
3516 },
3517 TopologyCandidate {
3518 kind: TopologyKind::Sphere,
3519 negative_log_evidence: 8.0,
3520 effective_dim: 5.0,
3521 n_obs: 100,
3522 converged: true,
3523 exclusion_reason: None,
3524 },
3525 TopologyCandidate {
3526 kind: TopologyKind::Torus,
3527 negative_log_evidence: f64::NAN,
3528 effective_dim: 6.0,
3529 n_obs: 100,
3530 converged: false,
3531 exclusion_reason: Some("torus periods missing".to_string()),
3532 },
3533 ];
3534 let sel = select_topology(&candidates, TopologySelectOptions::default());
3535 assert_eq!(sel.winner, TopologyKind::Sphere);
3536 assert!(!sel.tie);
3537 }
3538
3539 #[test]
3540 fn topology_select_tie_breaks_to_simpler() {
3541 let candidates = vec![
3542 TopologyCandidate {
3543 kind: TopologyKind::Sphere,
3544 negative_log_evidence: 5.0,
3545 effective_dim: 5.0,
3546 n_obs: 100,
3547 converged: true,
3548 exclusion_reason: None,
3549 },
3550 TopologyCandidate {
3551 kind: TopologyKind::Flat,
3552 negative_log_evidence: 5.0 + 1e-6,
3553 effective_dim: 4.0,
3554 n_obs: 100,
3555 converged: true,
3556 exclusion_reason: None,
3557 },
3558 ];
3559 let sel = select_topology(&candidates, TopologySelectOptions::default());
3560 assert_eq!(sel.winner, TopologyKind::Flat);
3561 assert!(sel.tie);
3562 }
3563
3564 fn gaussian_logpdf(y: f64, mean: f64, sd: f64) -> f64 {
3565 let z = (y - mean) / sd;
3566 -0.5 * (2.0 * std::f64::consts::PI).ln() - sd.ln() - 0.5 * z * z
3567 }
3568
3569 #[test]
3570 fn stacking_single_candidate_gets_full_weight() {
3571 let log_density = Array2::from_shape_vec((3, 1), vec![-1.0, -2.0, -0.5]).unwrap();
3572 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3573 assert!((out.weights[0] - 1.0).abs() < 1e-12);
3574 assert_eq!(out.weights.len(), 1);
3575 }
3576
3577 #[test]
3578 fn stacking_dominant_candidate_attracts_nearly_all_weight() {
3579 let mut log_density = Array2::<f64>::zeros((50, 2));
3580 for i in 0..50 {
3581 log_density[[i, 0]] = -0.1;
3582 log_density[[i, 1]] = -5.0;
3583 }
3584 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3585 assert!(out.weights[0] > 0.99, "w0 = {}", out.weights[0]);
3586 assert!(out.weights[1] < 0.01, "w1 = {}", out.weights[1]);
3587 }
3588
3589 #[test]
3590 fn stacking_complementary_candidates_share_weight() {
3591 let n = 40;
3594 let mut log_density = Array2::<f64>::zeros((n, 2));
3595 for i in 0..n {
3596 if i < n / 2 {
3597 log_density[[i, 0]] = gaussian_logpdf(0.0, 0.0, 0.5);
3598 log_density[[i, 1]] = gaussian_logpdf(0.0, 1.5, 0.5);
3599 } else {
3600 log_density[[i, 0]] = gaussian_logpdf(0.0, 1.5, 0.5);
3601 log_density[[i, 1]] = gaussian_logpdf(0.0, 0.0, 0.5);
3602 }
3603 }
3604 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3605 assert!(
3606 out.weights[0] > 0.2 && out.weights[0] < 0.8,
3607 "w0 = {}",
3608 out.weights[0]
3609 );
3610 assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3611 }
3612
3613 #[test]
3614 fn stacking_weights_stay_on_the_simplex() {
3615 let log_density = Array2::from_shape_vec(
3616 (3, 3),
3617 vec![-1.0, -2.0, -3.0, -2.5, -1.0, -2.0, -3.0, -2.0, -1.0],
3618 )
3619 .unwrap();
3620 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3621 assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3622 assert!(out.weights.iter().all(|&w| w >= -1e-12));
3623 }
3624
3625 #[test]
3626 fn stacking_mean_log_score_is_monotone_under_more_iterations() {
3627 let log_density =
3630 Array2::from_shape_vec((4, 2), vec![-0.2, -3.0, -3.0, -0.2, -0.5, -1.5, -1.5, -0.5])
3631 .unwrap();
3632 let mut prev = f64::NEG_INFINITY;
3633 for max_iter in [1usize, 2, 4, 8, 32] {
3634 let out = solve_stacking_weights(
3635 log_density.view(),
3636 StackingConfig {
3637 max_iter,
3638 weight_tol: 0.0,
3639 },
3640 )
3641 .unwrap();
3642 assert!(
3643 out.mean_log_score >= prev - 1e-12,
3644 "log-score decreased at max_iter={max_iter}: {prev} -> {}",
3645 out.mean_log_score
3646 );
3647 prev = out.mean_log_score;
3648 }
3649 }
3650
3651 #[test]
3652 fn stacking_dead_candidate_column_is_rejected_and_zero_weighted() {
3653 let log_density = Array2::from_shape_vec(
3654 (3, 2),
3655 vec![
3656 -1.0,
3657 f64::NEG_INFINITY,
3658 -2.0,
3659 f64::NAN,
3660 -0.5,
3661 f64::NEG_INFINITY,
3662 ],
3663 )
3664 .unwrap();
3665 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3666 assert_eq!(out.weights[1], 0.0);
3667 assert!((out.weights[0] - 1.0).abs() < 1e-12);
3668 }
3669
3670 #[test]
3671 fn stacking_rows_with_no_finite_density_are_dropped() {
3672 let log_density = Array2::from_shape_vec(
3673 (3, 2),
3674 vec![-1.0, -2.0, f64::NAN, f64::NEG_INFINITY, -2.0, -1.0],
3675 )
3676 .unwrap();
3677 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3678 assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3679 assert!(out.mean_log_score.is_finite());
3680 }
3681
3682 #[test]
3683 fn stacking_all_dead_table_errors() {
3684 let log_density = Array2::from_elem((2, 2), f64::NEG_INFINITY);
3685 assert!(solve_stacking_weights(log_density.view(), StackingConfig::default()).is_err());
3686 }
3687
3688 #[test]
3689 fn stacked_mean_is_weighted_combination() {
3690 let weights = Array1::from_vec(vec![0.25, 0.75]);
3691 let means = vec![
3692 Array1::from_vec(vec![1.0, 2.0, 3.0]),
3693 Array1::from_vec(vec![5.0, 6.0, 7.0]),
3694 ];
3695 let out = stacked_predictive_mean(&weights, &means).unwrap();
3696 assert!((out[0] - (0.25 * 1.0 + 0.75 * 5.0)).abs() < 1e-12);
3697 assert!((out[2] - (0.25 * 3.0 + 0.75 * 7.0)).abs() < 1e-12);
3698 }
3699
3700 #[test]
3701 fn stacked_mean_rejects_shape_mismatch() {
3702 let weights = Array1::from_vec(vec![0.5, 0.5]);
3703 let means = vec![
3704 Array1::from_vec(vec![1.0, 2.0]),
3705 Array1::from_vec(vec![3.0]),
3706 ];
3707 assert!(stacked_predictive_mean(&weights, &means).is_err());
3708 }
3709
3710 fn hybrid_slot(
3725 linear_nle: f64,
3726 p_linear: usize,
3727 latent_dim: usize,
3728 p_curved: usize,
3729 theta: f64,
3730 curved_loglik_gain: f64,
3731 ) -> Vec<HybridAtomCandidate> {
3732 let param_price =
3733 0.5 * (p_curved as f64 - p_linear as f64) * (2.0 * std::f64::consts::PI).ln();
3734 let curved_nle = linear_nle - curved_loglik_gain + param_price;
3735 vec![
3736 HybridAtomCandidate::linear(linear_nle, p_linear),
3737 HybridAtomCandidate::curved(latent_dim, curved_nle, p_curved, Some(theta)),
3738 ]
3739 }
3740
3741 #[test]
3742 fn hybrid_dominance_floor_selects_linear_when_turning_is_zero() {
3743 let slot = hybrid_slot(100.0, 2, 1, 5, 0.0, 0.0);
3748 let choice = select_hybrid_atom(&slot).unwrap();
3749 assert!(choice.param.is_linear());
3750 assert_eq!(choice.param, HybridAtomParam::Linear);
3751 assert!(choice.curved_turning.unwrap() <= HYBRID_LINEAR_TURNING_FLOOR);
3753 }
3754
3755 #[test]
3756 fn hybrid_selects_curved_when_turning_pays_for_itself() {
3757 let slot = hybrid_slot(100.0, 2, 1, 5, 2.0 * std::f64::consts::PI, 30.0);
3761 let choice = select_hybrid_atom(&slot).unwrap();
3762 assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
3763 assert!(choice.curved_evidence_margin > 0.0);
3765 }
3766
3767 #[test]
3768 fn hybrid_keeps_linear_when_curvature_doesnt_pay_its_price() {
3769 let slot = hybrid_slot(100.0, 2, 1, 5, 0.05, 0.1);
3773 let choice = select_hybrid_atom(&slot).unwrap();
3774 assert!(choice.param.is_linear());
3775 assert!(choice.curved_evidence_margin <= 0.0);
3776 }
3777
3778 #[test]
3779 fn hybrid_tie_breaks_to_the_cheaper_linear_atom() {
3780 let theta = 0.5; let nle = 42.0;
3785 let slot = vec![
3786 HybridAtomCandidate::linear(nle, 2),
3787 HybridAtomCandidate::curved(1, nle, 5, Some(theta)),
3788 ];
3789 let choice = select_hybrid_atom(&slot).unwrap();
3790 assert!(choice.param.is_linear());
3791 assert_eq!(choice.num_parameters, 2);
3792 }
3793
3794 #[test]
3795 fn hybrid_split_reduces_to_pure_linear_when_all_features_are_straight() {
3796 let slots: Vec<Vec<HybridAtomCandidate>> = (0..6)
3800 .map(|i| hybrid_slot(50.0 + i as f64, 2, 1, 5, 0.0, 0.0))
3801 .collect();
3802 let split = select_hybrid_split(&slots).unwrap();
3803 assert!(split.is_pure_linear());
3804 assert_eq!(split.curved_atom_count, 0);
3805 assert_eq!(split.linear_atom_count(), 6);
3806 let pure_linear: f64 = (0..6).map(|i| 50.0 + i as f64).sum();
3808 assert!((split.total_negative_log_evidence - pure_linear).abs() < 1e-12);
3809 }
3810
3811 #[test]
3812 fn hybrid_split_reduces_to_pure_curved_when_every_feature_curves() {
3813 let slots: Vec<Vec<HybridAtomCandidate>> = (0..5)
3816 .map(|i| hybrid_slot(80.0 + i as f64, 2, 1, 5, 2.0 * std::f64::consts::PI, 40.0))
3817 .collect();
3818 let split = select_hybrid_split(&slots).unwrap();
3819 assert!(split.is_pure_curved());
3820 assert_eq!(split.curved_atom_count, 5);
3821 assert_eq!(split.linear_atom_count(), 0);
3822 }
3823
3824 #[test]
3825 fn hybrid_split_on_mixed_dictionary_picks_curved_for_circles_linear_for_directions() {
3826 let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
3836 let mut pure_linear_baseline = 0.0_f64;
3837 for i in 0..3 {
3840 let linear_nle = 120.0 + 3.0 * i as f64;
3841 pure_linear_baseline += linear_nle;
3842 slots.push(hybrid_slot(
3843 linear_nle,
3844 2,
3845 1,
3846 5,
3847 2.0 * std::f64::consts::PI,
3848 35.0,
3849 ));
3850 }
3851 for i in 0..4 {
3854 let linear_nle = 90.0 + 2.0 * i as f64;
3855 pure_linear_baseline += linear_nle;
3856 slots.push(hybrid_slot(linear_nle, 2, 1, 5, 0.0, 0.0));
3857 }
3858
3859 let split = select_hybrid_split(&slots).unwrap();
3860
3861 for (idx, choice) in split.atoms.iter().enumerate() {
3864 if idx < 3 {
3865 assert_eq!(
3866 choice.param,
3867 HybridAtomParam::Curved { latent_dim: 1 },
3868 "circle slot {idx} should select curved"
3869 );
3870 } else {
3871 assert!(
3872 choice.param.is_linear(),
3873 "direction slot {idx} should select linear"
3874 );
3875 }
3876 }
3877 assert_eq!(split.curved_atom_count, 3);
3878 assert_eq!(split.linear_atom_count(), 4);
3879
3880 assert!(
3886 split.total_negative_log_evidence <= pure_linear_baseline + 1e-9,
3887 "hybrid NLE {} must be <= summed linear-candidate NLE {}",
3888 split.total_negative_log_evidence,
3889 pure_linear_baseline
3890 );
3891 assert!(split.total_negative_log_evidence < pure_linear_baseline);
3893 }
3894
3895 #[test]
3896 fn hybrid_split_rejects_empty_slot() {
3897 let slots = vec![hybrid_slot(10.0, 2, 1, 5, 0.0, 0.0), Vec::new()];
3898 assert!(select_hybrid_split(&slots).is_err());
3899 }
3900
3901 fn cand(name: &str, score: f64, edf: f64, log_lik: f64) -> RemlCandidate {
3909 RemlCandidate {
3910 index: 0,
3911 name: name.to_string(),
3912 score,
3913 edf: Some(edf),
3914 log_lik: Some(log_lik),
3915 family: None,
3916 }
3917 }
3918
3919 #[test]
3920 fn ranking_score_is_conditional_aic_when_loglik_and_edf_present() {
3921 let c = cand("m", 999.0, 6.748, -32.0866);
3923 let expected = -2.0 * -32.0866 + 2.0 * 6.748;
3924 assert!((c.ranking_score() - expected).abs() < 1e-9);
3925 }
3926
3927 #[test]
3928 fn ranking_score_falls_back_to_evidence_without_loglik() {
3929 let c = RemlCandidate {
3930 index: 0,
3931 name: "m".to_string(),
3932 score: 151.28,
3933 edf: Some(6.0),
3934 log_lik: None,
3935 family: None,
3936 };
3937 assert_eq!(c.ranking_score(), 151.28);
3938 }
3939
3940 #[test]
3941 fn compare_models_rejects_pure_noise_smooth_despite_lower_evidence() {
3942 let small = cand("small", 180.526, 6.748, -32.0866);
3949 let big = cand("big", 177.404, 14.250, -32.1212);
3950
3951 assert!(big.score < small.score);
3953
3954 let cmp = compare_reml_fits(vec![small, big]).expect("compare");
3955 assert_eq!(
3956 cmp.winner, "small",
3957 "compare_models must Occam-penalise the pure-noise smooth and pick the smaller model"
3958 );
3959 let small_row = cmp
3962 .score_table
3963 .iter()
3964 .find(|r| r.name == "small")
3965 .expect("small row");
3966 let big_row = cmp
3967 .score_table
3968 .iter()
3969 .find(|r| r.name == "big")
3970 .expect("big row");
3971 assert!((small_row.reml_score - 180.526).abs() < 1e-9);
3972 assert!((big_row.reml_score - 177.404).abs() < 1e-9);
3973 }
3974
3975 #[test]
3976 fn compare_models_keeps_power_for_a_relevant_smooth() {
3977 let small = cand("small", 1025.067, 6.75, -368.985);
3983 let big = cand("big", 199.509, 14.25, -33.165);
3984 let cmp = compare_reml_fits(vec![small, big]).expect("compare");
3985 assert_eq!(
3986 cmp.winner, "big",
3987 "compare_models must retain power: the relevant smooth's model must win"
3988 );
3989 }
3990}