1use faer::Side;
45use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
46
47use gam_linalg::faer_ndarray::FaerEigh;
48use gam_linalg::lanczos::{
49 SymmetricLanczosOptions, symmetric_lanczos_eigenpairs, symmetric_lanczos_log_quadrature,
50};
51use gam_linalg::triangular::cholesky_solve_vector;
52use crate::arrow_schur::{ArrowFactorCache, ArrowSchurSystem};
53use crate::priority_selection::{PriorityCandidate, rank_priority_candidates};
54
55pub const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
56const EVIDENCE_LOGDET_SLQ_PROBES: usize = 16;
57const EVIDENCE_LOGDET_LANCZOS_STEPS: usize = 32;
58const EVIDENCE_HVP_SYMMETRY_REL_TOL: f64 = 1e-8;
59const EVIDENCE_HVP_SYMMETRY_PROBES: usize = 4;
60
61#[derive(Clone, Copy)]
65pub struct EvidenceHvpLogDet<'a> {
66 pub dim: usize,
67 pub apply: &'a dyn Fn(&[f64]) -> Vec<f64>,
68}
69
70#[derive(Clone, Copy)]
72pub enum EvidenceLogDetSource<'a> {
73 FactoredArrow {
76 cache: &'a ArrowFactorCache,
77 fallback_hvp: Option<EvidenceHvpLogDet<'a>>,
78 },
79 Hvp(EvidenceHvpLogDet<'a>),
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum TopologyKind {
97 Periodic,
99 Flat,
101 Sphere,
103 Torus,
105}
106
107impl TopologyKind {
108 pub fn complexity_rank(self) -> u8 {
111 match self {
112 TopologyKind::Flat => 0,
113 TopologyKind::Periodic => 1,
114 TopologyKind::Sphere => 2,
115 TopologyKind::Torus => 3,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
123pub struct TopologyCandidate {
124 pub kind: TopologyKind,
125 pub negative_log_evidence: f64,
128 pub effective_dim: f64,
131 pub n_obs: usize,
134 pub converged: bool,
138 pub exclusion_reason: Option<String>,
141}
142
143#[derive(Debug, Clone)]
145pub struct SelectedTopology {
146 pub winner: TopologyKind,
147 pub ranking: Vec<TopologyCandidate>,
150 pub tie: bool,
154}
155
156#[derive(Debug, Clone, Copy)]
158pub struct TopologySelectOptions {
159 pub tie_tolerance: f64,
163 pub score_scale: TopologyScoreScale,
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum TopologyScoreScale {
172 PerObservation,
174 PerEffectiveDim,
176}
177
178#[derive(Debug, Clone, Copy)]
180pub struct StackingConfig {
181 pub max_iter: usize,
182 pub weight_tol: f64,
183}
184
185impl Default for StackingConfig {
186 fn default() -> Self {
187 Self {
188 max_iter: 1000,
189 weight_tol: 1e-10,
190 }
191 }
192}
193
194#[derive(Debug, Clone)]
197pub struct StackingWeights {
198 pub weights: Array1<f64>,
199 pub mean_log_score: f64,
200 pub iterations: usize,
201}
202
203pub fn solve_stacking_weights(
210 log_density: ArrayView2<'_, f64>,
211 config: StackingConfig,
212) -> Result<StackingWeights, String> {
213 let n_obs = log_density.nrows();
214 let n_cand = log_density.ncols();
215 if n_cand == 0 {
216 return Err("stacking requires at least one candidate column".to_string());
217 }
218 if n_obs == 0 {
219 return Err("stacking requires at least one held-out observation row".to_string());
220 }
221
222 let kept_cols: Vec<usize> = (0..n_cand)
223 .filter(|&k| (0..n_obs).any(|i| log_density[[i, k]].is_finite()))
224 .collect();
225 if kept_cols.is_empty() {
226 return Err("stacking found no candidate with any finite held-out density".to_string());
227 }
228 let rows: Vec<usize> = (0..n_obs)
229 .filter(|&i| kept_cols.iter().any(|&k| log_density[[i, k]].is_finite()))
230 .collect();
231 if rows.is_empty() {
232 return Err("stacking found no held-out row with a finite density".to_string());
233 }
234
235 let kept = kept_cols.len();
236 let mut weights = Array1::<f64>::from_elem(kept, 1.0 / kept as f64);
237 let mut next = Array1::<f64>::zeros(kept);
238 let mut iterations = 0usize;
239 for _ in 0..config.max_iter {
240 iterations += 1;
241 next.fill(0.0);
242 let mut active_rows = 0usize;
243 for &row in &rows {
244 let mut row_max = f64::NEG_INFINITY;
245 for (local_col, &source_col) in kept_cols.iter().enumerate() {
246 let log_p = log_density[[row, source_col]];
247 if log_p.is_finite() && weights[local_col] > 0.0 {
248 row_max = row_max.max(weights[local_col].ln() + log_p);
249 }
250 }
251 if !row_max.is_finite() {
252 continue;
253 }
254 let mut denom = 0.0_f64;
255 for (local_col, &source_col) in kept_cols.iter().enumerate() {
256 let log_p = log_density[[row, source_col]];
257 if log_p.is_finite() && weights[local_col] > 0.0 {
258 denom += (weights[local_col].ln() + log_p - row_max).exp();
259 }
260 }
261 if denom <= 0.0 {
262 continue;
263 }
264 active_rows += 1;
265 let log_mix = row_max + denom.ln();
266 for (local_col, &source_col) in kept_cols.iter().enumerate() {
267 let log_p = log_density[[row, source_col]];
268 if log_p.is_finite() && weights[local_col] > 0.0 {
269 next[local_col] += (weights[local_col].ln() + log_p - log_mix).exp();
270 }
271 }
272 }
273 if active_rows == 0 {
274 break;
275 }
276 next.mapv_inplace(|value| value / active_rows as f64);
277 let total = next.sum();
278 if total > 0.0 {
279 next.mapv_inplace(|value| value / total);
280 }
281 let delta = next
282 .iter()
283 .zip(weights.iter())
284 .fold(0.0_f64, |acc, (a, b)| acc.max((a - b).abs()));
285 weights.assign(&next);
286 if delta <= config.weight_tol {
287 break;
288 }
289 }
290
291 let mean_log_score = stacking_mean_log_score(log_density, &rows, &kept_cols, weights.view());
292 let mut full = Array1::<f64>::zeros(n_cand);
293 for (local_col, &source_col) in kept_cols.iter().enumerate() {
294 full[source_col] = weights[local_col];
295 }
296 Ok(StackingWeights {
297 weights: full,
298 mean_log_score,
299 iterations,
300 })
301}
302
303fn stacking_mean_log_score(
304 log_density: ArrayView2<'_, f64>,
305 rows: &[usize],
306 kept_cols: &[usize],
307 weights: ArrayView1<'_, f64>,
308) -> f64 {
309 let mut score_sum = 0.0_f64;
310 let mut counted = 0usize;
311 for &row in rows {
312 let mut row_max = f64::NEG_INFINITY;
313 for (local_col, &source_col) in kept_cols.iter().enumerate() {
314 let log_p = log_density[[row, source_col]];
315 if log_p.is_finite() && weights[local_col] > 0.0 {
316 row_max = row_max.max(weights[local_col].ln() + log_p);
317 }
318 }
319 if !row_max.is_finite() {
320 continue;
321 }
322 let mut denom = 0.0_f64;
323 for (local_col, &source_col) in kept_cols.iter().enumerate() {
324 let log_p = log_density[[row, source_col]];
325 if log_p.is_finite() && weights[local_col] > 0.0 {
326 denom += (weights[local_col].ln() + log_p - row_max).exp();
327 }
328 }
329 if denom > 0.0 {
330 score_sum += row_max + denom.ln();
331 counted += 1;
332 }
333 }
334 if counted == 0 {
335 f64::NEG_INFINITY
336 } else {
337 score_sum / counted as f64
338 }
339}
340
341pub fn stacked_predictive_mean(
343 weights: &Array1<f64>,
344 candidate_means: &[Array1<f64>],
345) -> Result<Array1<f64>, String> {
346 if candidate_means.len() != weights.len() {
347 return Err(format!(
348 "stacked_predictive_mean: {} weights but {} candidate mean vectors",
349 weights.len(),
350 candidate_means.len()
351 ));
352 }
353 let Some(first) = candidate_means.first() else {
354 return Err("stacked_predictive_mean requires at least one candidate".to_string());
355 };
356 let n_rows = first.len();
357 if candidate_means.iter().any(|means| means.len() != n_rows) {
358 return Err(
359 "stacked_predictive_mean: candidate mean vectors disagree on row count".to_string(),
360 );
361 }
362 let mut out = Array1::<f64>::zeros(n_rows);
363 for (weight, means) in weights.iter().zip(candidate_means) {
364 if *weight != 0.0 {
365 out.scaled_add(*weight, means);
366 }
367 }
368 Ok(out)
369}
370
371#[derive(Debug, Clone, Copy)]
393pub struct GaussianMixtureConfig {
394 pub max_iter: usize,
396 pub loglik_tol: f64,
398 pub covariance_floor: f64,
401 pub kmeans_max_iter: usize,
403}
404
405impl Default for GaussianMixtureConfig {
406 fn default() -> Self {
407 Self {
408 max_iter: 200,
409 loglik_tol: 1e-7,
410 covariance_floor: 1e-6,
411 kmeans_max_iter: 25,
412 }
413 }
414}
415
416#[derive(Debug, Clone)]
418pub struct GaussianMixtureFit {
419 pub weights: Array1<f64>,
421 pub means: Array2<f64>,
423 pub covariances: Vec<Array2<f64>>,
425 pub k: usize,
427 pub d: usize,
429 pub n_obs: usize,
431 pub loglik: f64,
433 pub iterations: usize,
435}
436
437impl GaussianMixtureFit {
438 pub fn num_free_parameters(&self) -> usize {
443 let cov_per = self.d * (self.d + 1) / 2;
444 (self.k - 1) + self.k * self.d + self.k * cov_per
445 }
446
447 pub fn per_point_log_density(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
451 if data.ncols() != self.d {
452 return Err(format!(
453 "mixture log-density expects {} columns, got {}",
454 self.d,
455 data.ncols()
456 ));
457 }
458 let n = data.nrows();
459 let mut comp = vec![GaussianComponentEval::new(self.d); self.k];
460 for j in 0..self.k {
461 comp[j] = GaussianComponentEval::factor(self.means.row(j), &self.covariances[j])?;
462 }
463 let mut out = Array1::<f64>::zeros(n);
464 let log_w: Vec<f64> = self
465 .weights
466 .iter()
467 .map(|w| w.max(f64::MIN_POSITIVE).ln())
468 .collect();
469 for i in 0..n {
470 let row = data.row(i);
471 let mut log_terms = vec![f64::NEG_INFINITY; self.k];
472 let mut max_term = f64::NEG_INFINITY;
473 for j in 0..self.k {
474 let lt = log_w[j] + comp[j].log_density(row);
475 log_terms[j] = lt;
476 if lt > max_term {
477 max_term = lt;
478 }
479 }
480 out[i] = log_sum_exp(&log_terms, max_term);
481 }
482 Ok(out)
483 }
484
485 pub fn laplace_negative_log_evidence(&self, data: ArrayView2<'_, f64>) -> Result<f64, String> {
491 let p = self.num_free_parameters();
492 let information = self.empirical_fisher_information(data)?;
493 if information.nrows() != p {
494 return Err(format!(
495 "mixture empirical-Fisher information has dim {} but expected free-parameter count {p}",
496 information.nrows()
497 ));
498 }
499 let apply_info = |x: &[f64]| -> Vec<f64> {
500 let mut out = vec![0.0_f64; p];
501 for r in 0..p {
502 let mut acc = 0.0_f64;
503 for c in 0..p {
504 acc += information[[r, c]] * x[c];
505 }
506 out[r] = acc;
507 }
508 out
509 };
510 let hvp = EvidenceHvpLogDet {
511 dim: p,
512 apply: &apply_info,
513 };
514 let v = laplace_evidence(
515 EvidenceLogDetSource::Hvp(hvp),
516 0.0,
517 -self.loglik,
518 p as f64,
519 0.0,
520 );
521 if !v.is_finite() {
522 return Err("mixture Laplace evidence is not finite".to_string());
523 }
524 Ok(v)
525 }
526
527 fn empirical_fisher_information(
536 &self,
537 data: ArrayView2<'_, f64>,
538 ) -> Result<Array2<f64>, String> {
539 if data.ncols() != self.d {
540 return Err(format!(
541 "mixture information expects {} columns, got {}",
542 self.d,
543 data.ncols()
544 ));
545 }
546 let n = data.nrows();
547 let p = self.num_free_parameters();
548 let cov_per = self.d * (self.d + 1) / 2;
549 let mut comp = Vec::with_capacity(self.k);
551 for j in 0..self.k {
552 comp.push(GaussianComponentEval::factor(
553 self.means.row(j),
554 &self.covariances[j],
555 )?);
556 }
557 let log_w: Vec<f64> = self
558 .weights
559 .iter()
560 .map(|w| w.max(f64::MIN_POSITIVE).ln())
561 .collect();
562
563 let mean_base = self.k - 1;
564 let cov_base = mean_base + self.k * self.d;
565
566 let mut info = Array2::<f64>::zeros((p, p));
567 let mut score = vec![0.0_f64; p];
568 for i in 0..n {
569 let row = data.row(i);
570 let mut log_terms = vec![0.0_f64; self.k];
572 let mut max_term = f64::NEG_INFINITY;
573 for j in 0..self.k {
574 let lt = log_w[j] + comp[j].log_density(row);
575 log_terms[j] = lt;
576 if lt > max_term {
577 max_term = lt;
578 }
579 }
580 let log_mix = log_sum_exp(&log_terms, max_term);
581 let resp: Vec<f64> = log_terms.iter().map(|lt| (lt - log_mix).exp()).collect();
582
583 for s in score.iter_mut() {
584 *s = 0.0;
585 }
586 for j in 1..self.k {
589 score[j - 1] = resp[j] - self.weights[j];
590 }
591 for j in 0..self.k {
596 let prec_v = comp[j].precision_times_residual(row); let mbo = mean_base + j * self.d;
598 for c in 0..self.d {
599 score[mbo + c] = resp[j] * prec_v[c];
600 }
601 let cbo = cov_base + j * cov_per;
602 let mut idx = 0usize;
603 for a in 0..self.d {
604 for b in 0..=a {
605 let outer = prec_v[a] * prec_v[b];
606 let prec_ab = comp[j].precision[[a, b]];
607 let mut g = 0.5 * (outer - prec_ab);
608 if a != b {
609 g *= 2.0;
612 }
613 score[cbo + idx] = resp[j] * g;
614 idx += 1;
615 }
616 }
617 }
618 for r in 0..p {
620 let sr = score[r];
621 if sr == 0.0 {
622 continue;
623 }
624 for c in 0..p {
625 info[[r, c]] += sr * score[c];
626 }
627 }
628 }
629 for r in 0..p {
637 for c in (r + 1)..p {
638 let avg = 0.5 * (info[[r, c]] + info[[c, r]]);
639 info[[r, c]] = avg;
640 info[[c, r]] = avg;
641 }
642 info[[r, r]] += 1.0;
643 }
644 Ok(info)
645 }
646}
647
648#[derive(Debug, Clone)]
651struct GaussianComponentEval {
652 mean: Array1<f64>,
653 precision: Array2<f64>,
654 log_norm: f64,
655 d: usize,
656}
657
658impl GaussianComponentEval {
659 fn new(d: usize) -> Self {
660 Self {
661 mean: Array1::zeros(d),
662 precision: Array2::eye(d),
663 log_norm: 0.0,
664 d,
665 }
666 }
667
668 fn factor(mean: ArrayView1<'_, f64>, cov: &Array2<f64>) -> Result<Self, String> {
669 let d = mean.len();
670 if cov.nrows() != d || cov.ncols() != d {
671 return Err(format!(
672 "mixture component covariance must be {d}x{d}, got {}x{}",
673 cov.nrows(),
674 cov.ncols()
675 ));
676 }
677 let (evals, evecs) = cov
678 .eigh(Side::Lower)
679 .map_err(|e| format!("mixture component covariance eigendecomposition failed: {e}"))?;
680 let mut log_det = 0.0_f64;
681 let mut inv_evals = Array1::<f64>::zeros(d);
682 for (idx, &ev) in evals.iter().enumerate() {
683 if !ev.is_finite() || ev <= 0.0 {
684 return Err(format!(
685 "mixture component covariance is not SPD: eigenvalue {idx} is {ev:.3e}"
686 ));
687 }
688 log_det += ev.ln();
689 inv_evals[idx] = 1.0 / ev;
690 }
691 let mut precision = Array2::<f64>::zeros((d, d));
693 for a in 0..d {
694 for b in 0..d {
695 let mut acc = 0.0_f64;
696 for m in 0..d {
697 acc += evecs[[a, m]] * inv_evals[m] * evecs[[b, m]];
698 }
699 precision[[a, b]] = acc;
700 }
701 }
702 let log_norm = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det);
703 Ok(Self {
704 mean: mean.to_owned(),
705 precision,
706 log_norm,
707 d,
708 })
709 }
710
711 #[inline]
712 fn log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
713 let pv = self.precision_times_residual(y);
714 let mut quad = 0.0_f64;
715 for c in 0..self.d {
716 quad += (y[c] - self.mean[c]) * pv[c];
717 }
718 self.log_norm - 0.5 * quad
719 }
720
721 #[inline]
723 fn precision_times_residual(&self, y: ArrayView1<'_, f64>) -> Vec<f64> {
724 let mut out = vec![0.0_f64; self.d];
725 for a in 0..self.d {
726 let mut acc = 0.0_f64;
727 for b in 0..self.d {
728 acc += self.precision[[a, b]] * (y[b] - self.mean[b]);
729 }
730 out[a] = acc;
731 }
732 out
733 }
734}
735
736#[inline]
737fn log_sum_exp(terms: &[f64], max_term: f64) -> f64 {
738 if !max_term.is_finite() {
739 return f64::NEG_INFINITY;
740 }
741 let mut acc = 0.0_f64;
742 for &t in terms {
743 acc += (t - max_term).exp();
744 }
745 max_term + acc.ln()
746}
747
748pub fn fit_gaussian_mixture(
757 data: ArrayView2<'_, f64>,
758 k: usize,
759 config: GaussianMixtureConfig,
760) -> Result<GaussianMixtureFit, String> {
761 let n = data.nrows();
762 let d = data.ncols();
763 if k == 0 {
764 return Err("gaussian mixture requires k >= 1".to_string());
765 }
766 if d == 0 {
767 return Err("gaussian mixture requires at least one column".to_string());
768 }
769 if k > n {
770 return Err(format!(
771 "gaussian mixture requested {k} components but data has {n} rows"
772 ));
773 }
774 let centers = gam_terms::basis::select_centers_by_strategy(
777 data,
778 &gam_terms::basis::CenterStrategy::KMeans {
779 num_centers: k,
780 max_iter: config.kmeans_max_iter,
781 },
782 )
783 .map_err(|e| format!("gaussian mixture k-means seeding failed: {e}"))?;
784 if centers.nrows() != k || centers.ncols() != d {
785 return Err(format!(
786 "gaussian mixture seeding returned {}x{} centers, expected {k}x{d}",
787 centers.nrows(),
788 centers.ncols()
789 ));
790 }
791
792 let mut means = centers;
793 let global_cov = data_covariance(data, config.covariance_floor);
795 let mut covariances = vec![global_cov; k];
796 let mut weights = Array1::<f64>::from_elem(k, 1.0 / k as f64);
797
798 let mut resp = Array2::<f64>::zeros((n, k));
799 let mut prev_mean_ll = f64::NEG_INFINITY;
800 let mut total_loglik = f64::NEG_INFINITY;
801 let mut iterations = 0usize;
802
803 for iter in 0..config.max_iter.max(1) {
804 iterations = iter + 1;
805 let mut comp = Vec::with_capacity(k);
807 for j in 0..k {
808 comp.push(GaussianComponentEval::factor(
809 means.row(j),
810 &covariances[j],
811 )?);
812 }
813 let log_w: Vec<f64> = weights
814 .iter()
815 .map(|w| w.max(f64::MIN_POSITIVE).ln())
816 .collect();
817 total_loglik = 0.0;
818 for i in 0..n {
819 let yrow = data.row(i);
820 let mut log_terms = vec![0.0_f64; k];
821 let mut max_term = f64::NEG_INFINITY;
822 for j in 0..k {
823 let lt = log_w[j] + comp[j].log_density(yrow);
824 log_terms[j] = lt;
825 if lt > max_term {
826 max_term = lt;
827 }
828 }
829 let log_mix = log_sum_exp(&log_terms, max_term);
830 total_loglik += log_mix;
831 for j in 0..k {
832 resp[[i, j]] = (log_terms[j] - log_mix).exp();
833 }
834 }
835 let mean_ll = total_loglik / n as f64;
836 if iter > 0 {
837 let denom = prev_mean_ll.abs().max(1.0);
838 if (mean_ll - prev_mean_ll).abs() / denom <= config.loglik_tol {
839 break;
840 }
841 }
842 prev_mean_ll = mean_ll;
843
844 let mut nk = vec![0.0_f64; k];
846 for j in 0..k {
847 let mut sum = 0.0_f64;
848 for i in 0..n {
849 sum += resp[[i, j]];
850 }
851 nk[j] = sum.max(f64::MIN_POSITIVE);
852 }
853 for j in 0..k {
854 weights[j] = nk[j] / n as f64;
855 let mut mu = Array1::<f64>::zeros(d);
857 for i in 0..n {
858 let r = resp[[i, j]];
859 if r == 0.0 {
860 continue;
861 }
862 for c in 0..d {
863 mu[c] += r * data[[i, c]];
864 }
865 }
866 mu.mapv_inplace(|v| v / nk[j]);
867 for c in 0..d {
868 means[[j, c]] = mu[c];
869 }
870 let mut cov = Array2::<f64>::zeros((d, d));
872 for i in 0..n {
873 let r = resp[[i, j]];
874 if r == 0.0 {
875 continue;
876 }
877 for a in 0..d {
878 let da = data[[i, a]] - mu[a];
879 for b in 0..d {
880 cov[[a, b]] += r * da * (data[[i, b]] - mu[b]);
881 }
882 }
883 }
884 cov.mapv_inplace(|v| v / nk[j]);
885 for a in 0..d {
886 cov[[a, a]] += config.covariance_floor;
887 }
888 covariances[j] = cov;
889 }
890 }
891
892 Ok(GaussianMixtureFit {
893 weights,
894 means,
895 covariances,
896 k,
897 d,
898 n_obs: n,
899 loglik: total_loglik,
900 iterations,
901 })
902}
903
904fn data_covariance(data: ArrayView2<'_, f64>, floor: f64) -> Array2<f64> {
906 let n = data.nrows();
907 let d = data.ncols();
908 let mut mean = Array1::<f64>::zeros(d);
909 for i in 0..n {
910 for c in 0..d {
911 mean[c] += data[[i, c]];
912 }
913 }
914 mean.mapv_inplace(|v| v / n.max(1) as f64);
915 let mut cov = Array2::<f64>::zeros((d, d));
916 for i in 0..n {
917 for a in 0..d {
918 let da = data[[i, a]] - mean[a];
919 for b in 0..d {
920 cov[[a, b]] += da * (data[[i, b]] - mean[b]);
921 }
922 }
923 }
924 let inv = 1.0 / (n.max(1) as f64);
925 cov.mapv_inplace(|v| v * inv);
926 for a in 0..d {
927 cov[[a, a]] += floor;
928 }
929 cov
930}
931
932#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
961pub enum UnionStructure {
962 CircleCircle,
964 CirclePointCluster,
966 LineCluster,
968}
969
970pub const UNION_STRUCTURE_LADDER: &[UnionStructure] = &[
972 UnionStructure::CircleCircle,
973 UnionStructure::CirclePointCluster,
974 UnionStructure::LineCluster,
975];
976
977#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
983pub enum UnionComponentKind {
984 Circle,
985 Line,
986 PointCluster,
987}
988
989impl UnionStructure {
990 pub const fn as_str(self) -> &'static str {
992 match self {
993 UnionStructure::CircleCircle => "union_circle+circle",
994 UnionStructure::CirclePointCluster => "union_circle+cluster",
995 UnionStructure::LineCluster => "union_line+cluster",
996 }
997 }
998
999 pub const fn components(self) -> &'static [UnionComponentKind] {
1001 match self {
1002 UnionStructure::CircleCircle => {
1003 &[UnionComponentKind::Circle, UnionComponentKind::Circle]
1004 }
1005 UnionStructure::CirclePointCluster => {
1006 &[UnionComponentKind::Circle, UnionComponentKind::PointCluster]
1007 }
1008 UnionStructure::LineCluster => {
1009 &[UnionComponentKind::Line, UnionComponentKind::PointCluster]
1010 }
1011 }
1012 }
1013
1014 pub const fn num_components(self) -> usize {
1016 self.components().len()
1017 }
1018}
1019
1020#[derive(Debug, Clone)]
1024pub struct UnionComponentFit {
1025 pub kind: UnionComponentKind,
1026 pub row_count: usize,
1027 pub num_parameters: usize,
1028 pub negative_log_evidence: f64,
1029}
1030
1031#[derive(Debug, Clone)]
1035pub struct UnionStructureFit {
1036 pub structure: UnionStructure,
1037 pub components: Vec<UnionComponentFit>,
1038 pub negative_log_evidence: f64,
1040 pub total_parameters: usize,
1042}
1043
1044pub fn union_responsibility_split(
1049 data: ArrayView2<'_, f64>,
1050 m: usize,
1051 config: GaussianMixtureConfig,
1052) -> Result<Vec<Vec<usize>>, String> {
1053 let n = data.nrows();
1054 if m == 0 {
1055 return Err("union split requires at least one component".to_string());
1056 }
1057 if m > n {
1058 return Err(format!(
1059 "union split requested {m} groups but data has {n} rows"
1060 ));
1061 }
1062 if m == 1 {
1063 return Ok(vec![(0..n).collect()]);
1064 }
1065 let fit = fit_gaussian_mixture(data, m, config)?;
1066 let mut groups: Vec<Vec<usize>> = vec![Vec::new(); m];
1067 let mut comp = Vec::with_capacity(m);
1069 for j in 0..m {
1070 comp.push(GaussianComponentEval::factor(
1071 fit.means.row(j),
1072 &fit.covariances[j],
1073 )?);
1074 }
1075 let log_w: Vec<f64> = fit
1076 .weights
1077 .iter()
1078 .map(|w| w.max(f64::MIN_POSITIVE).ln())
1079 .collect();
1080 for i in 0..n {
1081 let row = data.row(i);
1082 let mut best_j = 0usize;
1083 let mut best_lt = f64::NEG_INFINITY;
1084 for j in 0..m {
1085 let lt = log_w[j] + comp[j].log_density(row);
1086 if lt > best_lt {
1087 best_lt = lt;
1088 best_j = j;
1089 }
1090 }
1091 groups[best_j].push(i);
1092 }
1093 Ok(groups)
1094}
1095
1096pub fn fit_union_structure(
1105 data: ArrayView2<'_, f64>,
1106 structure: UnionStructure,
1107 config: GaussianMixtureConfig,
1108) -> Result<UnionStructureFit, String> {
1109 let comps = structure.components();
1110 let m = comps.len();
1111 let groups = union_responsibility_split(data, m, config)?;
1112 let mut fits = Vec::with_capacity(m);
1113 let mut total_nle = 0.0_f64;
1114 let mut total_parameters = 0usize;
1115 for (kind, rows) in comps.iter().zip(groups.iter()) {
1116 let group = gather_union_rows(data, rows);
1117 let (nle, p) = fit_union_component(group.view(), *kind, config)?;
1118 if !nle.is_finite() {
1119 return Err(format!(
1120 "union {} component {:?} produced non-finite evidence",
1121 structure.as_str(),
1122 kind
1123 ));
1124 }
1125 total_nle += nle;
1126 total_parameters += p;
1127 fits.push(UnionComponentFit {
1128 kind: *kind,
1129 row_count: rows.len(),
1130 num_parameters: p,
1131 negative_log_evidence: nle,
1132 });
1133 }
1134 Ok(UnionStructureFit {
1135 structure,
1136 components: fits,
1137 negative_log_evidence: total_nle,
1138 total_parameters,
1139 })
1140}
1141
1142pub fn fit_union_ladder(
1147 data: ArrayView2<'_, f64>,
1148 config: GaussianMixtureConfig,
1149) -> Result<Vec<UnionStructureFit>, String> {
1150 let mut fits = Vec::new();
1151 let mut errors = Vec::new();
1152 for &structure in UNION_STRUCTURE_LADDER {
1153 match fit_union_structure(data, structure, config) {
1154 Ok(fit) => fits.push(fit),
1155 Err(e) => errors.push(format!("{}: {e}", structure.as_str())),
1156 }
1157 }
1158 if fits.is_empty() {
1159 return Err(format!(
1160 "union ladder produced no fittable composites{}",
1161 if errors.is_empty() {
1162 String::new()
1163 } else {
1164 format!(" ({})", errors.join("; "))
1165 }
1166 ));
1167 }
1168 let ranked = rank_priority_candidates(
1169 fits.into_iter()
1170 .enumerate()
1171 .map(|(idx, row)| {
1172 let score = row.negative_log_evidence;
1173 let tie = row.total_parameters; PriorityCandidate::new(row, idx, score, tie)
1175 })
1176 .collect(),
1177 )
1178 .into_iter()
1179 .map(|row| row.item)
1180 .collect::<Vec<_>>();
1181 Ok(ranked)
1182}
1183
1184fn gather_union_rows(data: ArrayView2<'_, f64>, idx: &[usize]) -> Array2<f64> {
1185 let d = data.ncols();
1186 let mut out = Array2::<f64>::zeros((idx.len(), d));
1187 for (r, &i) in idx.iter().enumerate() {
1188 for c in 0..d {
1189 out[[r, c]] = data[[i, c]];
1190 }
1191 }
1192 out
1193}
1194
1195fn fit_union_component(
1200 group: ArrayView2<'_, f64>,
1201 kind: UnionComponentKind,
1202 config: GaussianMixtureConfig,
1203) -> Result<(f64, usize), String> {
1204 match kind {
1205 UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1206 if group.nrows() < group.ncols() + 1 {
1210 return Err(format!(
1211 "union gaussian component needs >= {} rows, got {}",
1212 group.ncols() + 1,
1213 group.nrows()
1214 ));
1215 }
1216 let fit = fit_gaussian_mixture(group, 1, config)?;
1217 let nle = fit.laplace_negative_log_evidence(group)?;
1218 Ok((nle, fit.num_free_parameters()))
1219 }
1220 UnionComponentKind::Circle => fit_circle_component_evidence(group, config),
1221 }
1222}
1223
1224fn fit_circle_component_evidence(
1232 group: ArrayView2<'_, f64>,
1233 config: GaussianMixtureConfig,
1234) -> Result<(f64, usize), String> {
1235 let d = group.ncols();
1236 if d != 2 {
1237 return Err(format!(
1238 "union circle component requires 2-D data, got {d} columns"
1239 ));
1240 }
1241 let n = group.nrows();
1242 let p = 4usize; if n < p + 1 {
1244 return Err(format!(
1245 "union circle component needs >= {} rows, got {n}",
1246 p + 1
1247 ));
1248 }
1249 let mut cx = 0.0_f64;
1254 let mut cy = 0.0_f64;
1255 for i in 0..n {
1256 cx += group[[i, 0]];
1257 cy += group[[i, 1]];
1258 }
1259 cx /= n as f64;
1260 cy /= n as f64;
1261 let mut radii = vec![0.0_f64; n];
1262 let mut radius = 0.0_f64;
1263 for i in 0..n {
1264 let dx = group[[i, 0]] - cx;
1265 let dy = group[[i, 1]] - cy;
1266 let r = (dx * dx + dy * dy).sqrt();
1267 radii[i] = r;
1268 radius += r;
1269 }
1270 radius /= n as f64;
1271 let mut var_r = 0.0_f64;
1272 for &r in &radii {
1273 let e = r - radius;
1274 var_r += e * e;
1275 }
1276 var_r = (var_r / n as f64).max(config.covariance_floor);
1277 let inv_var = 1.0 / var_r;
1278 let mut loglik = 0.0_f64;
1281 let log_2pi = (2.0 * std::f64::consts::PI).ln();
1282 for &r in &radii {
1283 let e = r - radius;
1284 let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e * inv_var;
1285 let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1286 loglik += radial + angular;
1287 }
1288 let mut info = Array2::<f64>::zeros((p, p));
1295 let mut score = [0.0_f64; 4];
1296 for i in 0..n {
1297 let dx = group[[i, 0]] - cx;
1298 let dy = group[[i, 1]] - cy;
1299 let r = radii[i].max(f64::MIN_POSITIVE);
1300 let e = radii[i] - radius;
1301 let ee = e * inv_var;
1302 score[0] = ee * (-dx / r);
1303 score[1] = ee * (-dy / r);
1304 score[2] = ee;
1305 score[3] = -0.5 + 0.5 * e * e * inv_var;
1306 for a in 0..p {
1307 let sa = score[a];
1308 if sa == 0.0 {
1309 continue;
1310 }
1311 for b in 0..p {
1312 info[[a, b]] += sa * score[b];
1313 }
1314 }
1315 }
1316 for a in 0..p {
1319 for b in (a + 1)..p {
1320 let avg = 0.5 * (info[[a, b]] + info[[b, a]]);
1321 info[[a, b]] = avg;
1322 info[[b, a]] = avg;
1323 }
1324 info[[a, a]] += 1.0;
1325 }
1326 let apply_info = |x: &[f64]| -> Vec<f64> {
1327 let mut out = vec![0.0_f64; p];
1328 for r in 0..p {
1329 let mut acc = 0.0_f64;
1330 for c in 0..p {
1331 acc += info[[r, c]] * x[c];
1332 }
1333 out[r] = acc;
1334 }
1335 out
1336 };
1337 let hvp = EvidenceHvpLogDet {
1338 dim: p,
1339 apply: &apply_info,
1340 };
1341 let v = laplace_evidence(EvidenceLogDetSource::Hvp(hvp), 0.0, -loglik, p as f64, 0.0);
1342 if !v.is_finite() {
1343 return Err("union circle component Laplace evidence is not finite".to_string());
1344 }
1345 Ok((v, p))
1346}
1347
1348#[derive(Debug, Clone)]
1354enum UnionComponentDensity {
1355 Gaussian {
1356 log_weight: f64,
1357 eval: GaussianComponentEval,
1358 },
1359 Circle {
1360 log_weight: f64,
1361 center: [f64; 2],
1362 radius: f64,
1363 var_r: f64,
1364 },
1365}
1366
1367impl UnionComponentDensity {
1368 fn weighted_log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
1370 match self {
1371 UnionComponentDensity::Gaussian { log_weight, eval } => {
1372 log_weight + eval.log_density(y)
1373 }
1374 UnionComponentDensity::Circle {
1375 log_weight,
1376 center,
1377 radius,
1378 var_r,
1379 } => {
1380 let dx = y[0] - center[0];
1381 let dy = y[1] - center[1];
1382 let r = (dx * dx + dy * dy).sqrt();
1383 let log_2pi = (2.0 * std::f64::consts::PI).ln();
1384 let e = r - radius;
1385 let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e / var_r;
1386 let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
1387 log_weight + radial + angular
1388 }
1389 }
1390 }
1391}
1392
1393fn fit_union_component_densities(
1397 train: ArrayView2<'_, f64>,
1398 structure: UnionStructure,
1399 config: GaussianMixtureConfig,
1400) -> Result<Vec<UnionComponentDensity>, String> {
1401 let comps = structure.components();
1402 let m = comps.len();
1403 let groups = union_responsibility_split(train, m, config)?;
1404 let n_train = train.nrows().max(1) as f64;
1405 let mut out = Vec::with_capacity(m);
1406 for (kind, rows) in comps.iter().zip(groups.iter()) {
1407 if rows.is_empty() {
1408 return Err(format!(
1409 "union {} held-out density: empty component group",
1410 structure.as_str()
1411 ));
1412 }
1413 let log_weight = (rows.len() as f64 / n_train).max(f64::MIN_POSITIVE).ln();
1414 let group = gather_union_rows(train, rows);
1415 match kind {
1416 UnionComponentKind::Line | UnionComponentKind::PointCluster => {
1417 if group.nrows() < group.ncols() + 1 {
1418 return Err(format!(
1419 "union gaussian component density needs >= {} rows, got {}",
1420 group.ncols() + 1,
1421 group.nrows()
1422 ));
1423 }
1424 let fit = fit_gaussian_mixture(group.view(), 1, config)?;
1425 let eval = GaussianComponentEval::factor(fit.means.row(0), &fit.covariances[0])?;
1426 out.push(UnionComponentDensity::Gaussian { log_weight, eval });
1427 }
1428 UnionComponentKind::Circle => {
1429 let d = group.ncols();
1430 if d != 2 {
1431 return Err(format!(
1432 "union circle component density requires 2-D data, got {d} columns"
1433 ));
1434 }
1435 let n = group.nrows();
1436 if n < 5 {
1437 return Err(format!(
1438 "union circle component density needs >= 5 rows, got {n}"
1439 ));
1440 }
1441 let mut cx = 0.0_f64;
1442 let mut cy = 0.0_f64;
1443 for i in 0..n {
1444 cx += group[[i, 0]];
1445 cy += group[[i, 1]];
1446 }
1447 cx /= n as f64;
1448 cy /= n as f64;
1449 let mut radius = 0.0_f64;
1450 let mut radii = vec![0.0_f64; n];
1451 for i in 0..n {
1452 let dx = group[[i, 0]] - cx;
1453 let dy = group[[i, 1]] - cy;
1454 let r = (dx * dx + dy * dy).sqrt();
1455 radii[i] = r;
1456 radius += r;
1457 }
1458 radius /= n as f64;
1459 let mut var_r = 0.0_f64;
1460 for &r in &radii {
1461 let e = r - radius;
1462 var_r += e * e;
1463 }
1464 var_r = (var_r / n as f64).max(config.covariance_floor);
1465 out.push(UnionComponentDensity::Circle {
1466 log_weight,
1467 center: [cx, cy],
1468 radius,
1469 var_r,
1470 });
1471 }
1472 }
1473 }
1474 Ok(out)
1475}
1476
1477pub fn union_per_point_log_density(
1482 train: ArrayView2<'_, f64>,
1483 eval: ArrayView2<'_, f64>,
1484 structure: UnionStructure,
1485 config: GaussianMixtureConfig,
1486) -> Result<Array1<f64>, String> {
1487 if train.ncols() != eval.ncols() {
1488 return Err(format!(
1489 "union held-out density: train has {} columns, eval has {}",
1490 train.ncols(),
1491 eval.ncols()
1492 ));
1493 }
1494 let densities = fit_union_component_densities(train, structure, config)?;
1495 let mut out = Array1::<f64>::zeros(eval.nrows());
1496 let mut terms = vec![f64::NEG_INFINITY; densities.len()];
1497 for i in 0..eval.nrows() {
1498 let row = eval.row(i);
1499 let mut max_term = f64::NEG_INFINITY;
1500 for (c, dens) in densities.iter().enumerate() {
1501 let lt = dens.weighted_log_density(row);
1502 terms[c] = lt;
1503 if lt > max_term {
1504 max_term = lt;
1505 }
1506 }
1507 out[i] = log_sum_exp(&terms, max_term);
1508 }
1509 Ok(out)
1510}
1511
1512#[derive(Clone, Debug)]
1514pub struct RemlCandidate {
1515 pub index: usize,
1516 pub name: String,
1517 pub score: f64,
1520 pub edf: Option<f64>,
1521 pub log_lik: Option<f64>,
1525 pub family: Option<String>,
1531 pub n_obs: Option<usize>,
1539}
1540
1541impl RemlCandidate {
1542 pub fn ranking_score(&self) -> f64 {
1561 match (self.log_lik, self.edf) {
1562 (Some(log_lik), Some(edf)) if log_lik.is_finite() && edf.is_finite() => {
1563 -2.0 * log_lik + 2.0 * edf
1564 }
1565 _ => self.score,
1566 }
1567 }
1568}
1569
1570#[derive(Clone, Debug)]
1571pub struct RemlComparison {
1572 pub ranking: Vec<RankedRow>,
1573 pub winner: String,
1574 pub evidence_summary: String,
1575 pub score_table: Vec<ScoreRow>,
1576}
1577
1578#[derive(Clone, Debug)]
1579pub struct RankedRow {
1580 pub name: String,
1581 pub score: f64,
1582 pub delta: f64,
1589 pub bayes_factor: f64,
1592 pub edf: Option<f64>,
1593}
1594
1595#[derive(Clone, Debug)]
1596pub struct ScoreRow {
1597 pub name: String,
1598 pub reml_score: f64,
1599 pub delta_reml: f64,
1600 pub bayes_factor_best_over_model: f64,
1601 pub effective_dof: Option<f64>,
1602}
1603
1604#[inline]
1606pub fn log_bayes_factor(reml_score_a: f64, reml_score_b: f64) -> f64 {
1607 reml_score_b - reml_score_a
1608}
1609
1610pub fn compare_reml_fits(mut candidates: Vec<RemlCandidate>) -> Result<RemlComparison, String> {
1614 if candidates.is_empty() {
1615 return Err("compare_models requires at least one fit".to_string());
1616 }
1617 {
1625 let mut seen_family: Option<&str> = None;
1626 for cand in &candidates {
1627 if let Some(fam) = cand.family.as_deref() {
1628 match seen_family {
1629 None => seen_family = Some(fam),
1630 Some(prev) if prev != fam => {
1631 return Err(format!(
1632 "compare_models: cannot compare fits of different response families ('{prev}' vs '{fam}'); their REML/LAML evidence scores are on incomparable base measures. Compare models fit to the same response under the same family."
1633 ));
1634 }
1635 Some(_) => {}
1636 }
1637 }
1638 }
1639 }
1640 {
1651 let mut seen_n: Option<usize> = None;
1652 for cand in &candidates {
1653 if let Some(n) = cand.n_obs {
1654 match seen_n {
1655 None => seen_n = Some(n),
1656 Some(prev) if prev != n => {
1657 return Err(format!(
1658 "compare_models: cannot compare fits made on a different number of \
1659 observations (n={prev} vs n={n}); AIC / REML-LAML evidence scales \
1660 with the sample size, so their score difference is not a Bayes \
1661 factor. Compare models fit to the same response on the same data."
1662 ));
1663 }
1664 Some(_) => {}
1665 }
1666 }
1667 }
1668 }
1669 candidates = rank_priority_candidates(
1670 candidates
1671 .into_iter()
1672 .enumerate()
1673 .map(|(idx, row)| {
1674 let ranking = row.ranking_score();
1677 PriorityCandidate::new(row, idx, ranking, 0)
1678 })
1679 .collect(),
1680 )
1681 .into_iter()
1682 .map(|row| row.item)
1683 .collect();
1684
1685 let winner = candidates[0].name.clone();
1686 let best_ranking_score = candidates[0].ranking_score();
1696 let best_raw_score = candidates
1701 .iter()
1702 .map(|c| c.score)
1703 .fold(f64::INFINITY, f64::min);
1704 let mut ranking = Vec::with_capacity(candidates.len());
1705 let mut score_table = Vec::with_capacity(candidates.len());
1706 for row in &candidates {
1707 let delta = log_bayes_factor(best_ranking_score, row.ranking_score());
1708 let bayes_factor = delta.exp();
1709 let delta_reml = log_bayes_factor(best_raw_score, row.score);
1710 ranking.push(RankedRow {
1711 name: row.name.clone(),
1712 score: row.score,
1713 delta,
1714 bayes_factor,
1715 edf: row.edf,
1716 });
1717 score_table.push(ScoreRow {
1718 name: row.name.clone(),
1719 reml_score: row.score,
1720 delta_reml,
1721 bayes_factor_best_over_model: delta_reml.exp(),
1722 effective_dof: row.edf,
1723 });
1724 }
1725 let evidence_summary = if let Some(runner_up) = candidates.get(1) {
1730 let margin = runner_up.ranking_score() - candidates[0].ranking_score();
1731 format!(
1732 "{} wins by Bayes factor {} over {}",
1733 winner,
1734 format_bayes_factor(margin),
1735 runner_up.name
1736 )
1737 } else {
1738 format!("{winner} (single fit; no comparison)")
1739 };
1740 Ok(RemlComparison {
1741 ranking,
1742 winner,
1743 evidence_summary,
1744 score_table,
1745 })
1746}
1747
1748pub fn format_bayes_factor(log_bf: f64) -> String {
1749 if !log_bf.is_finite() {
1750 return "inf".to_string();
1751 }
1752 if log_bf.abs() >= std::f64::consts::LN_10 * 3.0 {
1753 return format!("1e{:+.1}", log_bf / std::f64::consts::LN_10);
1754 }
1755 format_three_significant(log_bf.exp())
1756}
1757
1758pub fn format_three_significant(value: f64) -> String {
1759 if value == 0.0 {
1760 return "0".to_string();
1761 }
1762 if !value.is_finite() {
1763 return format!("{value}");
1764 }
1765 let exponent = value.abs().log10().floor() as i32;
1766 if exponent >= 3 {
1767 return format!("{value:.2e}");
1768 }
1769 let decimals = (2 - exponent).max(0) as usize;
1770 let scale = 10f64.powi(decimals as i32);
1771 let rounded = (value * scale).abs().round() / scale * value.signum();
1772 format!("{rounded:.decimals$}")
1773}
1774
1775impl Default for TopologySelectOptions {
1776 fn default() -> Self {
1777 Self {
1778 tie_tolerance: 1e-3,
1779 score_scale: TopologyScoreScale::PerObservation,
1780 }
1781 }
1782}
1783
1784pub fn laplace_evidence(
1829 logdet_source: EvidenceLogDetSource<'_>,
1830 penalty_log_det: f64,
1831 residual_objective: f64,
1832 effective_dim: f64,
1833 penalty_rank: f64,
1834) -> f64 {
1835 if !(effective_dim.is_finite() && penalty_rank.is_finite()) {
1836 return f64::NAN;
1837 }
1838 let log_det_h = match evidence_hessian_log_det(logdet_source) {
1839 Ok(v) => v,
1840 Err(_) => return f64::NAN,
1841 };
1842 let null_dim = effective_dim - penalty_rank;
1843 if !null_dim.is_finite() || null_dim < -1e-9 {
1844 return f64::NAN;
1845 }
1846 residual_objective + 0.5 * log_det_h
1847 - 0.5 * penalty_log_det
1848 - 0.5 * null_dim.max(0.0) * (2.0 * std::f64::consts::PI).ln()
1849}
1850
1851pub fn evidence_hessian_log_det(source: EvidenceLogDetSource<'_>) -> Result<f64, String> {
1853 match source {
1854 EvidenceLogDetSource::FactoredArrow {
1855 cache,
1856 fallback_hvp,
1857 } => match arrow_log_det_from_cache(cache) {
1858 Some(v) => Ok(v),
1859 None => match fallback_hvp {
1860 Some(hvp) => hessian_log_det_from_hvp(hvp),
1861 None => {
1862 Err("evidence Hessian logdet requires exact factors or HVP fallback".into())
1863 }
1864 },
1865 },
1866 EvidenceLogDetSource::Hvp(hvp) => hessian_log_det_from_hvp(hvp),
1867 }
1868}
1869
1870pub fn hessian_log_det_from_hvp(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
1877 if hvp.dim == 0 {
1878 return Ok(0.0);
1879 }
1880 if hvp.dim <= ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD {
1881 let mut dense = Array2::<f64>::zeros((hvp.dim, hvp.dim));
1882 let mut basis = vec![0.0_f64; hvp.dim];
1883 for j in 0..hvp.dim {
1884 basis[j] = 1.0;
1885 let col = (hvp.apply)(&basis);
1886 basis[j] = 0.0;
1887 if col.len() != hvp.dim || col.iter().any(|v| !v.is_finite()) {
1888 return Err(format!(
1889 "evidence HVP logdet expected finite column of length {}, got {}",
1890 hvp.dim,
1891 col.len()
1892 ));
1893 }
1894 for i in 0..hvp.dim {
1895 dense[[i, j]] = col[i];
1896 }
1897 }
1898 validate_dense_hvp_symmetry(&dense)?;
1899 for i in 0..hvp.dim {
1900 for j in (i + 1)..hvp.dim {
1901 let avg = 0.5 * (dense[[i, j]] + dense[[j, i]]);
1902 dense[[i, j]] = avg;
1903 dense[[j, i]] = avg;
1904 }
1905 }
1906 dense_spd_log_det(&dense)
1907 } else {
1908 stochastic_hvp_log_det(hvp)
1909 }
1910}
1911
1912fn dense_spd_log_det(matrix: &Array2<f64>) -> Result<f64, String> {
1913 if matrix.nrows() != matrix.ncols() {
1914 return Err(format!(
1915 "evidence dense logdet requires square matrix, got {}x{}",
1916 matrix.nrows(),
1917 matrix.ncols()
1918 ));
1919 }
1920 if gam_gpu::cuda_selected() {
1921 return crate::gpu::reml_gpu::evidence_derivatives_gpu(
1922 crate::gpu::reml_gpu::RemlGpuInput {
1923 penalized_hessian: matrix.view(),
1924 derivative_hessians: Vec::new(),
1925 },
1926 )
1927 .map(|evidence| evidence.logdet_hessian);
1928 }
1929 let (evals, _) = matrix
1930 .eigh(Side::Lower)
1931 .map_err(|e| format!("evidence dense logdet eigendecomposition failed: {e}"))?;
1932 let mut logdet = 0.0_f64;
1933 for (idx, &ev) in evals.iter().enumerate() {
1934 if !ev.is_finite() || ev <= 0.0 {
1935 return Err(format!(
1936 "evidence dense logdet expected SPD Hessian, eigenvalue {idx} is {ev:.3e}"
1937 ));
1938 }
1939 logdet += ev.ln();
1940 }
1941 Ok(logdet)
1942}
1943
1944fn validate_dense_hvp_symmetry(matrix: &Array2<f64>) -> Result<(), String> {
1945 let n = matrix.nrows();
1946 let mut norm_sq = 0.0_f64;
1947 for &value in matrix.iter() {
1948 norm_sq += value * value;
1949 }
1950
1951 let mut skew_sq = 0.0_f64;
1952 for i in 0..n {
1953 for j in (i + 1)..n {
1954 let skew = matrix[[i, j]] - matrix[[j, i]];
1955 skew_sq += 2.0 * skew * skew;
1956 }
1957 }
1958
1959 let rel_skew = skew_sq.sqrt() / norm_sq.sqrt().max(1.0);
1960 if !rel_skew.is_finite() || rel_skew > EVIDENCE_HVP_SYMMETRY_REL_TOL {
1961 return Err(format!(
1962 "evidence HVP logdet requires symmetric operator, relative skew norm is {rel_skew:.3e}"
1963 ));
1964 }
1965 Ok(())
1966}
1967
1968fn validate_hvp_randomized_symmetry(hvp: EvidenceHvpLogDet<'_>) -> Result<(), String> {
1969 let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
1970 for probe in 0..EVIDENCE_HVP_SYMMETRY_PROBES.max(1) {
1971 let mut x = vec![0.0_f64; hvp.dim];
1972 let mut y = vec![0.0_f64; hvp.dim];
1973 rademacher_unit_probe_into_slice(&mut x, (2 * probe) as u64, inv_norm);
1974 rademacher_unit_probe_into_slice(&mut y, (2 * probe + 1) as u64, inv_norm);
1975
1976 let hx = (hvp.apply)(&x);
1977 let hy = (hvp.apply)(&y);
1978 if hx.len() != hvp.dim || hx.iter().any(|v| !v.is_finite()) {
1979 return Err(format!(
1980 "evidence HVP symmetry check expected finite vector of length {}, got {}",
1981 hvp.dim,
1982 hx.len()
1983 ));
1984 }
1985 if hy.len() != hvp.dim || hy.iter().any(|v| !v.is_finite()) {
1986 return Err(format!(
1987 "evidence HVP symmetry check expected finite vector of length {}, got {}",
1988 hvp.dim,
1989 hy.len()
1990 ));
1991 }
1992
1993 let lhs = dot_slice(&x, &hy);
1994 let rhs = dot_slice(&hx, &y);
1995 let scale = (norm2_slice(&hx) * norm2_slice(&y))
1996 .max(norm2_slice(&hy) * norm2_slice(&x))
1997 .max(lhs.abs())
1998 .max(rhs.abs())
1999 .max(1.0);
2000 let rel = (lhs - rhs).abs() / scale;
2001 if !rel.is_finite() || rel > EVIDENCE_HVP_SYMMETRY_REL_TOL {
2002 return Err(format!(
2003 "evidence HVP logdet requires symmetric operator, randomized symmetry probe {probe} has relative bilinear mismatch {rel:.3e}"
2004 ));
2005 }
2006 }
2007 Ok(())
2008}
2009
2010fn stochastic_hvp_log_det(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
2011 validate_hvp_randomized_symmetry(hvp)?;
2012 let probes = EVIDENCE_LOGDET_SLQ_PROBES.max(1);
2013 let steps = EVIDENCE_LOGDET_LANCZOS_STEPS.min(hvp.dim).max(1);
2014 let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
2015 let mut estimate = 0.0_f64;
2016 for probe in 0..probes {
2017 let mut q0 = vec![0.0_f64; hvp.dim];
2018 rademacher_unit_probe_into_slice(&mut q0, probe as u64, inv_norm);
2019 let quad = lanczos_log_quadrature_hvp(hvp, q0, steps)?;
2020 estimate += hvp.dim as f64 * quad;
2021 }
2022 Ok(estimate / probes as f64)
2023}
2024
2025fn lanczos_log_quadrature_hvp(
2026 hvp: EvidenceHvpLogDet<'_>,
2027 q: Vec<f64>,
2028 max_steps: usize,
2029) -> Result<f64, String> {
2030 let n = hvp.dim;
2031 let eigen = symmetric_lanczos_eigenpairs(
2032 n,
2033 &q,
2034 SymmetricLanczosOptions {
2035 max_steps,
2036 residual_tol: 1e-12,
2037 local_reorthogonalize: false,
2038 full_reorthogonalize: false,
2039 },
2040 |q, out| {
2041 let applied = (hvp.apply)(q);
2042 if applied.len() != n || applied.iter().any(|v| !v.is_finite()) {
2043 return Err(format!(
2044 "evidence HVP SLQ expected finite vector of length {n}, got {}",
2045 applied.len()
2046 ));
2047 }
2048 out.copy_from_slice(&applied);
2049 Ok(())
2050 },
2051 )
2052 .map_err(|e| format!("evidence HVP SLQ Lanczos failed: {e}"))?;
2053 symmetric_lanczos_log_quadrature(&eigen, "evidence HVP SLQ expected SPD Hessian")
2054}
2055
2056#[inline]
2057fn dot_slice(a: &[f64], b: &[f64]) -> f64 {
2058 assert_eq!(a.len(), b.len());
2059 let mut s = 0.0_f64;
2060 for i in 0..a.len() {
2061 s += a[i] * b[i];
2062 }
2063 s
2064}
2065
2066#[inline]
2067fn norm2_slice(a: &[f64]) -> f64 {
2068 dot_slice(a, a).sqrt()
2069}
2070
2071fn rademacher_unit_probe_into_slice(z: &mut [f64], probe: u64, scale: f64) {
2072 let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
2073 let mut bits = 0_u64;
2074 let mut remaining_bits = 0_u32;
2075 for value in z.iter_mut() {
2076 if remaining_bits == 0 {
2077 bits = splitmix64(&mut state);
2078 remaining_bits = 64;
2079 }
2080 *value = if bits & 1 == 0 { scale } else { -scale };
2081 bits >>= 1;
2082 remaining_bits -= 1;
2083 }
2084}
2085
2086#[inline]
2087const fn splitmix64(state: &mut u64) -> u64 {
2088 gam_linalg::utils::splitmix64(state)
2089}
2090
2091pub fn arrow_log_det_from_cache(cache: &ArrowFactorCache) -> Option<f64> {
2100 if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2101 return None;
2105 }
2106 if let Some(log_det) = cache.joint_hessian_log_det {
2107 return log_det.is_finite().then_some(log_det);
2108 }
2109 let schur = match cache.schur_factor.as_ref() {
2116 Some(schur) => Some(schur),
2117 None if cache.k == 0 => None,
2118 None => return None,
2119 };
2120
2121 let mut acc = 0.0_f64;
2122 for l in cache.undamped_factors_iter() {
2124 acc += 2.0 * log_det_from_chol_lower(l);
2125 }
2126 if let Some(schur) = schur {
2128 acc += 2.0 * log_det_from_chol_lower(schur.view());
2129 }
2130 let woodbury_correction = cache.cross_row_woodbury_log_det();
2135 if !woodbury_correction.is_finite() {
2136 return None;
2139 }
2140 acc += woodbury_correction;
2141 Some(acc)
2142}
2143
2144fn log_det_from_chol_lower(l: ArrayView2<'_, f64>) -> f64 {
2146 let n = l.nrows();
2147 let mut acc = 0.0_f64;
2148 for i in 0..n {
2149 let d = l[[i, i]];
2150 if d > 0.0 {
2151 acc += d.ln();
2152 } else {
2153 panic!(
2159 "log_det_from_chol_lower: non-positive Cholesky diagonal {d} at index {i}; \
2160 caller passed a corrupted or non-SPD factor"
2161 );
2162 }
2163 }
2164 acc
2165}
2166
2167pub fn ift_du_dbeta(cache: &ArrowFactorCache) -> Array2<f64> {
2177 let n = cache.undamped_factor_count();
2178 let total_len = cache.delta_t_len();
2179 let k = cache.k;
2180 if !cache.htbeta_available() {
2181 return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2182 }
2183 let mut out = Array2::<f64>::zeros((total_len, k));
2184 let mut beta_basis = Array1::<f64>::zeros(k);
2185 let mut rhs = Array1::<f64>::zeros(cache.d);
2187 for i in 0..n {
2188 let di = cache.row_dims[i];
2189 let row_base = cache.row_offsets[i];
2190 let factor = cache.undamped_factor(i);
2191 for col in 0..k {
2193 beta_basis.fill(0.0);
2194 beta_basis[col] = 1.0;
2195 let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2196 if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2199 return Array2::<f64>::from_elem((total_len, k), f64::NAN);
2202 }
2203 let y = cholesky_solve_vector(factor, &rhs_i);
2204 for c in 0..di {
2205 out[[row_base + c, col]] = -y[c];
2206 }
2207 }
2208 }
2209 out
2210}
2211
2212pub fn coupling_components(hessian: ArrayView2<'_, f64>) -> Vec<usize> {
2233 let p = hessian.nrows();
2234 if p == 0 || hessian.ncols() != p {
2235 return Vec::new();
2236 }
2237 let mut parent: Vec<usize> = (0..p).collect();
2239 let mut size: Vec<usize> = vec![1; p];
2240
2241 fn find(parent: &mut [usize], mut x: usize) -> usize {
2242 while parent[x] != x {
2243 parent[x] = parent[parent[x]];
2244 x = parent[x];
2245 }
2246 x
2247 }
2248
2249 for i in 0..p {
2250 for j in (i + 1)..p {
2251 if hessian[[i, j]] != 0.0 || hessian[[j, i]] != 0.0 {
2254 let (ri, rj) = (find(&mut parent, i), find(&mut parent, j));
2255 if ri != rj {
2256 let (small, large) = if size[ri] < size[rj] {
2257 (ri, rj)
2258 } else {
2259 (rj, ri)
2260 };
2261 parent[small] = large;
2262 size[large] += size[small];
2263 }
2264 }
2265 }
2266 }
2267
2268 let mut label_of_root: Vec<Option<usize>> = vec![None; p];
2271 let mut next_label = 0usize;
2272 let mut labels = vec![0usize; p];
2273 for idx in 0..p {
2274 let root = find(&mut parent, idx);
2275 let label = match label_of_root[root] {
2276 Some(l) => l,
2277 None => {
2278 let l = next_label;
2279 label_of_root[root] = Some(l);
2280 next_label += 1;
2281 l
2282 }
2283 };
2284 labels[idx] = label;
2285 }
2286 labels
2287}
2288
2289pub fn cone_of_influence(labels: &[usize], support: &[usize]) -> Vec<usize> {
2300 if support.is_empty() {
2301 return Vec::new();
2302 }
2303 let mut in_cone_labels: Vec<usize> = support
2304 .iter()
2305 .filter_map(|&idx| labels.get(idx).copied())
2306 .collect();
2307 in_cone_labels.sort_unstable();
2308 in_cone_labels.dedup();
2309 if in_cone_labels.is_empty() {
2310 return Vec::new();
2311 }
2312 (0..labels.len())
2313 .filter(|idx| in_cone_labels.binary_search(&labels[*idx]).is_ok())
2314 .collect()
2315}
2316
2317pub fn ift_dbeta_drho(
2329 cache: &ArrowFactorCache,
2330 dg_red_drho: ArrayView2<'_, f64>,
2331) -> Option<Array2<f64>> {
2332 if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
2333 return None;
2334 }
2335 let schur = cache.schur_factor.as_ref()?;
2336 if dg_red_drho.nrows() != cache.k || schur.nrows() != cache.k {
2337 return None;
2338 }
2339 crate::sensitivity::FitSensitivity::from_lower_triangular(schur)
2340 .mode_response(dg_red_drho)
2341}
2342
2343
2344#[derive(Clone)]
2362pub struct EvidenceIftGradientTerms<'a> {
2363 pub dbeta_drho: ArrayView2<'a, f64>,
2364 pub du_drho: ArrayView2<'a, f64>,
2365 pub value_beta: ArrayView1<'a, f64>,
2366 pub value_u: ArrayView1<'a, f64>,
2367 pub logdet_h_beta: ArrayView1<'a, f64>,
2368 pub logdet_h_u: ArrayView1<'a, f64>,
2369}
2370
2371pub fn evidence_ift_gradient_correction(terms: EvidenceIftGradientTerms<'_>) -> Array1<f64> {
2374 let k = terms.dbeta_drho.nrows();
2375 let nd = terms.du_drho.nrows();
2376 let r = terms.dbeta_drho.ncols();
2377 if terms.du_drho.ncols() != r
2378 || terms.value_beta.len() != k
2379 || terms.logdet_h_beta.len() != k
2380 || terms.value_u.len() != nd
2381 || terms.logdet_h_u.len() != nd
2382 {
2383 return Array1::<f64>::from_elem(r, f64::NAN);
2384 }
2385
2386 let mut out = Array1::<f64>::zeros(r);
2387 for a in 0..r {
2388 let mut acc = 0.0_f64;
2389 for j in 0..k {
2390 let mode = terms.dbeta_drho[[j, a]];
2391 acc += terms.value_beta[j] * mode;
2392 acc += 0.5 * terms.logdet_h_beta[j] * mode;
2393 }
2394 for j in 0..nd {
2395 let mode = terms.du_drho[[j, a]];
2396 acc += terms.value_u[j] * mode;
2397 acc += 0.5 * terms.logdet_h_u[j] * mode;
2398 }
2399 out[a] = acc;
2400 }
2401 out
2402}
2403
2404pub fn evidence_grad_rho(
2434 cache: &ArrowFactorCache,
2435 value_rho: ArrayView1<'_, f64>,
2436 huu_drho: &[Vec<Array2<f64>>],
2437 htbeta_drho: &[Vec<Array2<f64>>],
2438 hbb_drho: &[Array2<f64>],
2439 pen_logdet_drho: ArrayView1<'_, f64>,
2440 ift_terms: EvidenceIftGradientTerms<'_>,
2441) -> Array1<f64> {
2442 let r = value_rho.len();
2443 let n = cache.undamped_factor_count();
2444 let k = cache.k;
2445 let mut out = Array1::<f64>::zeros(r);
2446 if !cache.htbeta_available()
2447 || pen_logdet_drho.len() != r
2448 || huu_drho.len() != n
2449 || htbeta_drho.len() != n
2450 || hbb_drho.len() != r
2451 || huu_drho.iter().any(|row| row.len() != r)
2452 || htbeta_drho.iter().any(|row| row.len() != r)
2453 || hbb_drho.iter().any(|m| m.nrows() != k || m.ncols() != k)
2454 || huu_drho.iter().enumerate().any(|(i, row)| {
2455 let di = cache.row_dims[i];
2456 row.iter().any(|m| m.nrows() != di || m.ncols() != di)
2457 })
2458 || htbeta_drho.iter().enumerate().any(|(i, row)| {
2459 let di = cache.row_dims[i];
2460 row.iter().any(|m| m.nrows() != di || m.ncols() != k)
2461 })
2462 {
2463 out.fill(f64::NAN);
2464 return out;
2465 }
2466 let ift_correction = evidence_ift_gradient_correction(ift_terms);
2467 if ift_correction.len() != r || ift_correction.iter().any(|v| v.is_nan()) {
2468 out.fill(f64::NAN);
2469 return out;
2470 }
2471
2472 let schur = match cache.schur_factor.as_ref() {
2473 Some(s) => s,
2474 None => {
2475 for a in 0..r {
2476 out[a] = f64::NAN;
2477 }
2478 return out;
2479 }
2480 };
2481
2482 let mut y_blocks: Vec<Array2<f64>> = Vec::with_capacity(n);
2485 let mut beta_basis = Array1::<f64>::zeros(k);
2486 let mut rhs = Array1::<f64>::zeros(cache.d);
2488 for i in 0..n {
2489 let di = cache.row_dims[i];
2490 let factor = cache.undamped_factor(i);
2491 let mut yi = Array2::<f64>::zeros((di, k));
2492 for col in 0..k {
2493 beta_basis.fill(0.0);
2494 beta_basis[col] = 1.0;
2495 let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
2496 if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
2498 out.fill(f64::NAN);
2501 return out;
2502 }
2503 let v = cholesky_solve_vector(factor, &rhs_i);
2504 for c in 0..di {
2505 yi[[c, col]] = v[c];
2506 }
2507 }
2508 y_blocks.push(yi);
2509 }
2510
2511 let mut trace_rhs = Array1::<f64>::zeros(cache.d);
2514 let mut da_tmp = Array2::<f64>::zeros((cache.d, k));
2515 let mut col_scratch = Array1::<f64>::zeros(k);
2516 for a in 0..r {
2517 let mut grad = value_rho[a];
2519
2520 let mut row_trace_acc = 0.0_f64;
2527 for i in 0..n {
2528 let di = cache.row_dims[i];
2529 let m_i = &huu_drho[i][a];
2530 assert_eq!(m_i.shape(), &[di, di]);
2531 for col in 0..di {
2532 let mut tr_rhs_i = trace_rhs.slice_mut(ndarray::s![..di]).to_owned();
2533 for r0 in 0..di {
2534 tr_rhs_i[r0] = m_i[[r0, col]];
2535 }
2536 let v = cholesky_solve_vector(cache.undamped_factor(i), &tr_rhs_i);
2537 row_trace_acc += v[col];
2538 }
2539 }
2540
2541 let mut da = hbb_drho[a].clone();
2550 assert_eq!(da.shape(), &[k, k]);
2551 for i in 0..n {
2552 let di = cache.row_dims[i];
2553 let dhtb = &htbeta_drho[i][a]; let yi = &y_blocks[i]; for r0 in 0..k {
2557 for c0 in 0..k {
2558 let mut acc = 0.0;
2559 for cc in 0..di {
2560 acc += dhtb[[cc, r0]] * yi[[cc, c0]];
2561 }
2562 da[[r0, c0]] -= acc;
2563 }
2564 }
2565 for r0 in 0..k {
2567 for c0 in 0..k {
2568 let mut acc = 0.0;
2569 for cc in 0..di {
2570 acc += yi[[cc, r0]] * dhtb[[cc, c0]];
2571 }
2572 da[[r0, c0]] -= acc;
2573 }
2574 }
2575 let dhuu = &huu_drho[i][a];
2577 let mut da_tmp_i = da_tmp.slice_mut(ndarray::s![..di, ..]).to_owned();
2579 for r0 in 0..di {
2580 for c0 in 0..k {
2581 let mut acc = 0.0;
2582 for cc in 0..di {
2583 acc += dhuu[[r0, cc]] * yi[[cc, c0]];
2584 }
2585 da_tmp_i[[r0, c0]] = acc;
2586 }
2587 }
2588 for r0 in 0..k {
2590 for c0 in 0..k {
2591 let mut acc = 0.0;
2592 for cc in 0..di {
2593 acc += yi[[cc, r0]] * da_tmp_i[[cc, c0]];
2594 }
2595 da[[r0, c0]] += acc;
2596 }
2597 }
2598 }
2599
2600 let mut schur_trace_acc = 0.0_f64;
2602 for j in 0..k {
2603 for r0 in 0..k {
2604 col_scratch[r0] = da[[r0, j]];
2605 }
2606 let v = cholesky_solve_vector(schur, &col_scratch);
2607 schur_trace_acc += v[j];
2608 }
2609
2610 grad += 0.5 * (row_trace_acc + schur_trace_acc);
2611 grad += ift_correction[a];
2612
2613 grad -= 0.5 * pen_logdet_drho[a];
2615
2616 out[a] = grad;
2617 }
2618 out
2619}
2620
2621pub fn select_topology(
2647 candidates: &[TopologyCandidate],
2648 options: TopologySelectOptions,
2649) -> SelectedTopology {
2650 let mut valid: Vec<TopologyCandidate> = candidates
2652 .iter()
2653 .filter(|c| {
2654 c.converged
2655 && c.exclusion_reason.is_none()
2656 && c.negative_log_evidence.is_finite()
2657 && topology_selection_score(c, options.score_scale).is_finite()
2658 })
2659 .cloned()
2660 .collect();
2661 let mut excluded: Vec<TopologyCandidate> = candidates
2662 .iter()
2663 .filter(|c| {
2664 !(c.converged && c.exclusion_reason.is_none() && c.negative_log_evidence.is_finite())
2665 || !topology_selection_score(c, options.score_scale).is_finite()
2666 })
2667 .cloned()
2668 .collect();
2669
2670 assert!(
2671 !valid.is_empty(),
2672 "select_topology: no finite valid candidates; proposal §6.11 forbids silent fallback"
2673 );
2674
2675 valid = rank_priority_candidates(
2680 valid
2681 .into_iter()
2682 .enumerate()
2683 .map(|(idx, row)| {
2684 let score = topology_selection_score(&row, options.score_scale);
2685 let tie_break = usize::from(row.kind.complexity_rank());
2686 PriorityCandidate::new(row, idx, score, tie_break)
2687 })
2688 .collect(),
2689 )
2690 .into_iter()
2691 .map(|row| row.item)
2692 .collect();
2693
2694 let tie = if valid.len() >= 2 {
2696 let top = topology_selection_score(&valid[0], options.score_scale);
2697 let next = topology_selection_score(&valid[1], options.score_scale);
2698 (next - top).abs() <= options.tie_tolerance
2699 } else {
2700 false
2701 };
2702
2703 if tie {
2705 let top_score = topology_selection_score(&valid[0], options.score_scale);
2706 let tied_end = valid
2708 .iter()
2709 .position(|c| {
2710 (topology_selection_score(c, options.score_scale) - top_score).abs()
2711 > options.tie_tolerance
2712 })
2713 .unwrap_or(valid.len());
2714 valid[..tied_end].sort_by_key(|c| c.kind.complexity_rank());
2716 }
2717
2718 let winner = valid[0].kind;
2719 valid.append(&mut excluded);
2720 SelectedTopology {
2721 winner,
2722 ranking: valid,
2723 tie,
2724 }
2725}
2726
2727fn topology_selection_score(candidate: &TopologyCandidate, scale: TopologyScoreScale) -> f64 {
2728 match scale {
2729 TopologyScoreScale::PerObservation => {
2730 if candidate.n_obs == 0 {
2731 f64::NAN
2732 } else {
2733 candidate.negative_log_evidence / candidate.n_obs as f64
2734 }
2735 }
2736 TopologyScoreScale::PerEffectiveDim => {
2737 if !(candidate.effective_dim.is_finite() && candidate.effective_dim > 0.0) {
2738 f64::NAN
2739 } else {
2740 candidate.negative_log_evidence / candidate.effective_dim
2741 }
2742 }
2743 }
2744}
2745
2746pub fn cache_matches_system(cache: &ArrowFactorCache, sys: &ArrowSchurSystem) -> bool {
2755 cache.d == sys.d
2756 && cache.k == sys.k
2757 && cache.n_rows() == sys.rows.len()
2758 && cache.undamped_factor_count() == sys.rows.len()
2759 && cache.manifold_mode_fingerprint == sys.manifold_mode_fingerprint
2760 && cache.row_hessian_fingerprint == sys.current_row_hessian_fingerprint()
2761}
2762
2763#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2816pub enum HybridAtomParam {
2817 Curved { latent_dim: usize },
2819 Linear,
2821}
2822
2823impl HybridAtomParam {
2824 pub const fn as_str(self) -> &'static str {
2826 match self {
2827 HybridAtomParam::Curved { .. } => "curved",
2828 HybridAtomParam::Linear => "linear",
2829 }
2830 }
2831
2832 pub const fn is_linear(self) -> bool {
2834 matches!(self, HybridAtomParam::Linear)
2835 }
2836}
2837
2838#[derive(Debug, Clone, Copy)]
2846pub struct HybridAtomCandidate {
2847 pub param: HybridAtomParam,
2848 pub negative_log_evidence: f64,
2850 pub num_parameters: usize,
2852 pub fitted_turning: Option<f64>,
2857}
2858
2859impl HybridAtomCandidate {
2860 pub fn linear(negative_log_evidence: f64, num_parameters: usize) -> Self {
2862 Self {
2863 param: HybridAtomParam::Linear,
2864 negative_log_evidence,
2865 num_parameters,
2866 fitted_turning: Some(0.0),
2867 }
2868 }
2869
2870 pub fn curved(
2872 latent_dim: usize,
2873 negative_log_evidence: f64,
2874 num_parameters: usize,
2875 fitted_turning: Option<f64>,
2876 ) -> Self {
2877 Self {
2878 param: HybridAtomParam::Curved { latent_dim },
2879 negative_log_evidence,
2880 num_parameters,
2881 fitted_turning,
2882 }
2883 }
2884}
2885
2886#[derive(Debug, Clone, Copy)]
2890pub struct HybridAtomChoice {
2891 pub param: HybridAtomParam,
2892 pub negative_log_evidence: f64,
2894 pub num_parameters: usize,
2896 pub curved_turning: Option<f64>,
2899 pub curved_evidence_margin: f64,
2904}
2905
2906pub const HYBRID_LINEAR_TURNING_FLOOR: f64 = 1e-9;
2914
2915pub fn select_hybrid_atom(candidates: &[HybridAtomCandidate]) -> Option<HybridAtomChoice> {
2943 if candidates.is_empty() {
2944 return None;
2945 }
2946 let linear = candidates.iter().find(|c| c.param.is_linear());
2947 let curved = candidates.iter().find(|c| !c.param.is_linear());
2948 let curved_turning = curved.and_then(|c| c.fitted_turning);
2949 let curved_evidence_margin = match (linear, curved) {
2950 (Some(l), Some(c)) => l.negative_log_evidence - c.negative_log_evidence,
2951 _ => 0.0,
2952 };
2953
2954 if let (Some(l), Some(turning)) = (linear, curved_turning)
2957 && turning <= HYBRID_LINEAR_TURNING_FLOOR
2958 {
2959 return Some(HybridAtomChoice {
2960 param: l.param,
2961 negative_log_evidence: l.negative_log_evidence,
2962 num_parameters: l.num_parameters,
2963 curved_turning,
2964 curved_evidence_margin,
2965 });
2966 }
2967
2968 let mut best = candidates[0];
2970 for cand in &candidates[1..] {
2971 let better_evidence = cand.negative_log_evidence < best.negative_log_evidence;
2972 let tied = cand.negative_log_evidence == best.negative_log_evidence;
2973 let cheaper_on_tie = tied && cand.num_parameters < best.num_parameters;
2974 if better_evidence || cheaper_on_tie {
2975 best = *cand;
2976 }
2977 }
2978 Some(HybridAtomChoice {
2979 param: best.param,
2980 negative_log_evidence: best.negative_log_evidence,
2981 num_parameters: best.num_parameters,
2982 curved_turning,
2983 curved_evidence_margin,
2984 })
2985}
2986
2987#[derive(Debug, Clone)]
2991pub struct HybridSplitSelection {
2992 pub atoms: Vec<HybridAtomChoice>,
2994 pub total_negative_log_evidence: f64,
3003 pub total_parameters: usize,
3006 pub curved_atom_count: usize,
3008}
3009
3010impl HybridSplitSelection {
3011 pub fn linear_atom_count(&self) -> usize {
3013 self.atoms.len() - self.curved_atom_count
3014 }
3015
3016 pub fn is_pure_linear(&self) -> bool {
3019 self.curved_atom_count == 0 && !self.atoms.is_empty()
3020 }
3021
3022 pub fn is_pure_curved(&self) -> bool {
3025 self.curved_atom_count == self.atoms.len() && !self.atoms.is_empty()
3026 }
3027}
3028
3029pub fn select_hybrid_split(
3043 slots: &[Vec<HybridAtomCandidate>],
3044) -> Result<HybridSplitSelection, String> {
3045 let mut atoms = Vec::with_capacity(slots.len());
3046 let mut total_nle = 0.0_f64;
3047 let mut total_parameters = 0usize;
3048 let mut curved_atom_count = 0usize;
3049 for (i, slot) in slots.iter().enumerate() {
3050 let choice = select_hybrid_atom(slot)
3051 .ok_or_else(|| format!("hybrid split slot {i} has no candidate parameterizations"))?;
3052 if !choice.negative_log_evidence.is_finite() {
3053 return Err(format!(
3054 "hybrid split slot {i} selected a non-finite evidence ({})",
3055 choice.negative_log_evidence
3056 ));
3057 }
3058 if !choice.param.is_linear() {
3059 curved_atom_count += 1;
3060 }
3061 total_nle += choice.negative_log_evidence;
3062 total_parameters += choice.num_parameters;
3063 atoms.push(choice);
3064 }
3065 Ok(HybridSplitSelection {
3066 atoms,
3067 total_negative_log_evidence: total_nle,
3068 total_parameters,
3069 curved_atom_count,
3070 })
3071}
3072
3073#[cfg(test)]
3083mod tests {
3084 use super::*;
3085 use crate::arrow_schur::ArrowFactorSlab;
3086
3087 fn dense_inverse(h: &Array2<f64>) -> Array2<f64> {
3089 let p = h.nrows();
3090 let mut aug = Array2::<f64>::zeros((p, 2 * p));
3091 for i in 0..p {
3092 for j in 0..p {
3093 aug[[i, j]] = h[[i, j]];
3094 }
3095 aug[[i, p + i]] = 1.0;
3096 }
3097 for col in 0..p {
3098 let mut pivot = col;
3099 for row in (col + 1)..p {
3100 if aug[[row, col]].abs() > aug[[pivot, col]].abs() {
3101 pivot = row;
3102 }
3103 }
3104 if pivot != col {
3105 for j in 0..(2 * p) {
3106 aug.swap([col, j], [pivot, j]);
3107 }
3108 }
3109 let d = aug[[col, col]];
3110 for j in 0..(2 * p) {
3111 aug[[col, j]] /= d;
3112 }
3113 for row in 0..p {
3114 if row == col {
3115 continue;
3116 }
3117 let f = aug[[row, col]];
3118 if f != 0.0 {
3119 for j in 0..(2 * p) {
3120 aug[[row, j]] -= f * aug[[col, j]];
3121 }
3122 }
3123 }
3124 }
3125 let mut inv = Array2::<f64>::zeros((p, p));
3126 for i in 0..p {
3127 for j in 0..p {
3128 inv[[i, j]] = aug[[i, p + j]];
3129 }
3130 }
3131 inv
3132 }
3133
3134 #[test]
3135 fn coupling_components_block_diagonal_is_all_singletons_by_block() {
3136 let mut h = Array2::<f64>::eye(4);
3138 h[[0, 1]] = 0.3;
3139 h[[1, 0]] = 0.3;
3140 h[[2, 3]] = 0.7;
3141 h[[3, 2]] = 0.7;
3142 let labels = coupling_components(h.view());
3143 assert_eq!(labels[0], labels[1]);
3144 assert_eq!(labels[2], labels[3]);
3145 assert_ne!(labels[0], labels[2]);
3146 let mut uniq = labels.clone();
3148 uniq.sort_unstable();
3149 uniq.dedup();
3150 assert_eq!(uniq.len(), 2);
3151 }
3152
3153 #[test]
3154 fn coupling_components_fully_coupled_is_one_component() {
3155 let mut h = Array2::<f64>::eye(3);
3156 for i in 0..3 {
3157 for j in 0..3 {
3158 if i != j {
3159 h[[i, j]] = 0.1;
3160 }
3161 }
3162 }
3163 let labels = coupling_components(h.view());
3164 assert!(labels.iter().all(|&l| l == labels[0]));
3165 }
3166
3167 #[test]
3168 fn coupling_components_transitive_chain_merges() {
3169 let mut h = Array2::<f64>::eye(3);
3171 h[[0, 1]] = 0.5;
3172 h[[1, 0]] = 0.5;
3173 h[[1, 2]] = 0.5;
3174 h[[2, 1]] = 0.5;
3175 let labels = coupling_components(h.view());
3176 assert_eq!(labels[0], labels[1]);
3177 assert_eq!(labels[1], labels[2]);
3178 }
3179
3180 #[test]
3181 fn compare_reml_fits_delta_and_bayes_factor_never_contradict_winner_gh1465() {
3182 let cand = |name: &str, score: f64, edf: f64| RemlCandidate {
3194 index: 0,
3195 name: name.to_string(),
3196 score,
3197 edf: Some(edf),
3198 log_lik: Some(0.0),
3199 family: Some("gaussian".to_string()),
3200 n_obs: Some(100),
3201 };
3202 let candidates = vec![
3205 cand("m1", 53.748, 50.0),
3206 cand("m2", 41.605, 51.0),
3207 cand("m3", 120.011, 65.0),
3208 ];
3209 let cmp = compare_reml_fits(candidates).expect("comparison");
3210
3211 assert_eq!(cmp.winner, "m1", "AIC winner");
3212 for row in &cmp.ranking {
3214 assert!(
3215 row.delta >= 0.0,
3216 "ranking delta for {} must be >= 0, got {}",
3217 row.name,
3218 row.delta
3219 );
3220 assert!(
3221 row.bayes_factor >= 1.0 - 1e-12,
3222 "ranking bayes_factor for {} must be >= 1, got {}",
3223 row.name,
3224 row.bayes_factor
3225 );
3226 }
3227 let winner_row = cmp.ranking.iter().find(|r| r.name == "m1").unwrap();
3228 assert!(winner_row.delta.abs() < 1e-12, "winner delta == 0");
3229 assert!(
3230 (winner_row.bayes_factor - 1.0).abs() < 1e-9,
3231 "winner bayes_factor == 1"
3232 );
3233
3234 for row in &cmp.score_table {
3237 assert!(
3238 row.delta_reml >= 0.0,
3239 "score-table delta_reml for {} must be >= 0, got {}",
3240 row.name,
3241 row.delta_reml
3242 );
3243 assert!(
3244 row.bayes_factor_best_over_model >= 1.0 - 1e-12,
3245 "score-table bayes_factor for {} must be >= 1, got {}",
3246 row.name,
3247 row.bayes_factor_best_over_model
3248 );
3249 }
3250 let m2 = cmp.score_table.iter().find(|r| r.name == "m2").unwrap();
3252 assert!(
3253 m2.delta_reml.abs() < 1e-12,
3254 "the minimum-raw-REML row has delta_reml 0"
3255 );
3256 }
3257
3258 #[test]
3259 fn cone_of_influence_empty_support_is_empty() {
3260 let labels = vec![0usize, 0, 1, 1];
3261 assert!(cone_of_influence(&labels, &[]).is_empty());
3262 }
3263
3264 #[test]
3265 fn cone_of_influence_returns_full_component() {
3266 let labels = vec![0usize, 0, 1, 1];
3267 assert_eq!(cone_of_influence(&labels, &[0]), vec![0, 1]);
3269 assert_eq!(cone_of_influence(&labels, &[1, 2]), vec![0, 1, 2, 3]);
3271 }
3272
3273 #[test]
3274 fn coned_matches_full_solve_on_fully_coupled_hessian() {
3275 let h = Array2::from_shape_vec((3, 3), vec![4.0, 1.0, 0.5, 1.0, 3.0, 0.8, 0.5, 0.8, 2.5])
3278 .unwrap();
3279 let inv = dense_inverse(&h);
3280 let mut dg = Array2::<f64>::zeros((3, 2));
3282 dg[[0, 0]] = 1.3;
3283 dg[[2, 1]] = -0.7;
3284 let supports = vec![0..1usize, 2..3usize];
3285
3286 let eye: Array2<f64> = Array2::eye(3);
3287 let op = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv);
3288 let full = op.mode_response(dg.view()).unwrap();
3289 let coned = op
3290 .mode_response_coned(h.view(), dg.view(), &supports)
3291 .unwrap();
3292 for i in 0..3 {
3293 for a in 0..2 {
3294 assert!(
3295 (full[[i, a]] - coned[[i, a]]).abs() < 1e-12,
3296 "fully-coupled mismatch at ({i},{a}): {} vs {}",
3297 full[[i, a]],
3298 coned[[i, a]]
3299 );
3300 }
3301 }
3302 }
3303
3304 #[test]
3305 fn coned_confines_to_component_on_decoupled_hessian() {
3306 let mut h = Array2::<f64>::zeros((4, 4));
3310 h[[0, 0]] = 4.0;
3312 h[[1, 1]] = 3.0;
3313 h[[0, 1]] = 1.0;
3314 h[[1, 0]] = 1.0;
3315 h[[2, 2]] = 2.0;
3317 h[[3, 3]] = 5.0;
3318 h[[2, 3]] = 0.6;
3319 h[[3, 2]] = 0.6;
3320 let inv = dense_inverse(&h);
3321
3322 let mut dg = Array2::<f64>::zeros((4, 1));
3323 dg[[0, 0]] = 0.9;
3324 dg[[1, 0]] = -0.4;
3325 let support_range = 0..2usize;
3326 let supports = std::slice::from_ref(&support_range);
3327
3328 let eye: Array2<f64> = Array2::eye(4);
3329 let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &inv)
3330 .mode_response_coned(h.view(), dg.view(), supports)
3331 .unwrap();
3332 let q = dg.column(0).to_owned();
3335 let exact = inv.dot(&q).mapv(|v| -v);
3336 for i in 0..4 {
3337 assert!(
3338 (coned[[i, 0]] - exact[[i]]).abs() < 1e-12,
3339 "decoupled mismatch at {i}: {} vs {}",
3340 coned[[i, 0]],
3341 exact[[i]]
3342 );
3343 }
3344 assert_eq!(coned[[2, 0]], 0.0);
3346 assert_eq!(coned[[3, 0]], 0.0);
3347 }
3348
3349 #[test]
3350 fn coned_skips_inactive_column_with_empty_support() {
3351 let h = Array2::<f64>::eye(2);
3352 let dg = Array2::<f64>::zeros((2, 1));
3353 let empty_support = 0..0usize;
3355 let supports = std::slice::from_ref(&empty_support);
3356 let eye: Array2<f64> = Array2::eye(2);
3361 let nan_inv = Array2::<f64>::from_elem((2, 2), f64::NAN);
3362 let coned = crate::sensitivity::FitSensitivity::from_projected(&eye, &nan_inv)
3363 .mode_response_coned(h.view(), dg.view(), supports)
3364 .unwrap();
3365 assert_eq!(coned[[0, 0]], 0.0);
3366 assert_eq!(coned[[1, 0]], 0.0);
3367 }
3368
3369 fn make_minimal_cache() -> ArrowFactorCache {
3370 let l_huu = Array2::from_shape_vec((1, 1), vec![std::f64::consts::SQRT_2]).unwrap();
3373 let l_schur = Array2::from_shape_vec((1, 1), vec![(1.875_f64).sqrt()]).unwrap();
3374 let htbeta = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
3375 ArrowFactorCache {
3376 htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3377 htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3378 schur_factor: Some(l_schur),
3379 joint_hessian_log_det: None,
3380 solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3381 ridge_t: 0.0,
3382 ridge_beta: 0.0,
3383 htbeta: crate::arrow_schur::ArrowHtbetaCache::Dense {
3384 blocks: std::sync::Arc::from(vec![htbeta]),
3385 estimated_bytes: std::mem::size_of::<f64>(),
3386 },
3387 d: 1,
3388 row_dims: std::sync::Arc::from(vec![1usize]),
3389 row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3390 k: 1,
3391 manifold_mode_fingerprint: 0,
3392 row_hessian_fingerprint: 0,
3393 pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3394 gauge_deflated_directions: 0,
3395 deflated_row_directions: std::sync::Arc::from(Vec::new()),
3396 deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3397 cross_row_woodbury: None,
3398 }
3399 }
3400
3401 #[test]
3402 fn laplace_evidence_returns_finite_for_minimal_cache() {
3403 let cache = make_minimal_cache();
3404 let v = laplace_evidence(
3407 EvidenceLogDetSource::FactoredArrow {
3408 cache: &cache,
3409 fallback_hvp: None,
3410 },
3411 0.0,
3412 0.0,
3413 2.0,
3414 1.0,
3415 );
3416 assert!(v.is_finite());
3417 let expected =
3418 0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3419 assert!((v - expected).abs() < 1e-12);
3420 }
3421
3422 fn k0_direct_cache_no_schur(latent_diag: f64) -> ArrowFactorCache {
3431 let l_huu = Array2::from_shape_vec((1, 1), vec![latent_diag.sqrt()]).unwrap();
3432 ArrowFactorCache {
3433 htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
3434 htt_factors_undamped: crate::arrow_schur::ArrowUndampedFactors::SameAsDamped,
3435 schur_factor: None,
3436 joint_hessian_log_det: None,
3437 solver_mode: crate::arrow_schur::ArrowSolverMode::Direct,
3438 ridge_t: 0.0,
3439 ridge_beta: 0.0,
3440 htbeta: crate::arrow_schur::ArrowHtbetaCache::Disabled { estimated_bytes: 0 },
3441 d: 1,
3442 row_dims: std::sync::Arc::from(vec![1usize]),
3443 row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
3444 k: 0,
3445 manifold_mode_fingerprint: 0,
3446 row_hessian_fingerprint: 0,
3447 pcg_diagnostics: crate::arrow_schur::PcgDiagnostics::default(),
3448 gauge_deflated_directions: 0,
3449 deflated_row_directions: std::sync::Arc::from(Vec::new()),
3450 deflation_row_spectra: std::sync::Arc::from(Vec::new()),
3451 cross_row_woodbury: None,
3452 }
3453 }
3454
3455 #[test]
3456 fn arrow_log_det_some_for_k0_direct_cache_without_schur() {
3457 let cache = k0_direct_cache_no_schur(3.0);
3458 let log_det = arrow_log_det_from_cache(&cache)
3459 .expect("k==0 Direct cache must yield Some(per-row sum), not None (#1132)");
3460 assert!(
3462 (log_det - 3.0_f64.ln()).abs() < 1e-12,
3463 "log_det = {log_det}"
3464 );
3465 let cached = cache
3467 .compute_undamped_arrow_log_det()
3468 .expect("compute_undamped_arrow_log_det must be Some for k==0");
3469 assert!((cached - 3.0_f64.ln()).abs() < 1e-12, "cached = {cached}");
3470 }
3471
3472 #[test]
3473 fn arrow_log_det_none_for_kpos_cache_without_schur() {
3474 let mut cache = k0_direct_cache_no_schur(3.0);
3477 cache.k = 1;
3478 cache.solver_mode = crate::arrow_schur::ArrowSolverMode::InexactPCG;
3479 assert!(arrow_log_det_from_cache(&cache).is_none());
3480 assert!(cache.compute_undamped_arrow_log_det().is_none());
3481 }
3482
3483 #[test]
3484 fn laplace_evidence_nan_when_ridge_is_nonzero() {
3485 let mut cache = make_minimal_cache();
3486 cache.ridge_t = 1e-3;
3487 assert!(
3488 laplace_evidence(
3489 EvidenceLogDetSource::FactoredArrow {
3490 cache: &cache,
3491 fallback_hvp: None,
3492 },
3493 0.0,
3494 0.0,
3495 2.0,
3496 1.0,
3497 )
3498 .is_nan()
3499 );
3500 }
3501
3502 #[test]
3503 fn laplace_evidence_uses_hvp_fallback_without_schur_factor() {
3504 let mut cache = make_minimal_cache();
3505 cache.schur_factor = None;
3506 let hvp = |x: &[f64]| -> Vec<f64> { vec![2.0 * x[0], 1.875 * x[1]] };
3507 let v = laplace_evidence(
3508 EvidenceLogDetSource::FactoredArrow {
3509 cache: &cache,
3510 fallback_hvp: Some(EvidenceHvpLogDet {
3511 dim: 2,
3512 apply: &hvp,
3513 }),
3514 },
3515 0.0,
3516 0.0,
3517 2.0,
3518 1.0,
3519 );
3520 let expected =
3521 0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
3522 assert!((v - expected).abs() < 1e-12);
3523 }
3524
3525 #[test]
3526 fn ift_du_dbeta_has_expected_shape() {
3527 let cache = make_minimal_cache();
3528 let du_db = ift_du_dbeta(&cache);
3529 assert_eq!(du_db.shape(), &[1, 1]);
3530 assert!((du_db[[0, 0]] - (-0.25)).abs() < 1e-12);
3532 }
3533
3534 #[test]
3535 fn ift_dbeta_drho_returns_some_for_direct_cache() {
3536 let cache = make_minimal_cache();
3537 let q = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
3538 let out = ift_dbeta_drho(&cache, q.view()).unwrap();
3539 assert_eq!(out.shape(), &[1, 1]);
3540 assert!((out[[0, 0]] + 1.0 / 1.875).abs() < 1e-12);
3542 }
3543
3544 #[test]
3545 fn topology_select_picks_lowest_negative_log_evidence() {
3546 let candidates = vec![
3547 TopologyCandidate {
3548 kind: TopologyKind::Flat,
3549 negative_log_evidence: 10.0,
3550 effective_dim: 4.0,
3551 n_obs: 100,
3552 converged: true,
3553 exclusion_reason: None,
3554 },
3555 TopologyCandidate {
3556 kind: TopologyKind::Sphere,
3557 negative_log_evidence: 8.0,
3558 effective_dim: 5.0,
3559 n_obs: 100,
3560 converged: true,
3561 exclusion_reason: None,
3562 },
3563 TopologyCandidate {
3564 kind: TopologyKind::Torus,
3565 negative_log_evidence: f64::NAN,
3566 effective_dim: 6.0,
3567 n_obs: 100,
3568 converged: false,
3569 exclusion_reason: Some("torus periods missing".to_string()),
3570 },
3571 ];
3572 let sel = select_topology(&candidates, TopologySelectOptions::default());
3573 assert_eq!(sel.winner, TopologyKind::Sphere);
3574 assert!(!sel.tie);
3575 }
3576
3577 #[test]
3578 fn topology_select_tie_breaks_to_simpler() {
3579 let candidates = vec![
3580 TopologyCandidate {
3581 kind: TopologyKind::Sphere,
3582 negative_log_evidence: 5.0,
3583 effective_dim: 5.0,
3584 n_obs: 100,
3585 converged: true,
3586 exclusion_reason: None,
3587 },
3588 TopologyCandidate {
3589 kind: TopologyKind::Flat,
3590 negative_log_evidence: 5.0 + 1e-6,
3591 effective_dim: 4.0,
3592 n_obs: 100,
3593 converged: true,
3594 exclusion_reason: None,
3595 },
3596 ];
3597 let sel = select_topology(&candidates, TopologySelectOptions::default());
3598 assert_eq!(sel.winner, TopologyKind::Flat);
3599 assert!(sel.tie);
3600 }
3601
3602 fn gaussian_logpdf(y: f64, mean: f64, sd: f64) -> f64 {
3603 let z = (y - mean) / sd;
3604 -0.5 * (2.0 * std::f64::consts::PI).ln() - sd.ln() - 0.5 * z * z
3605 }
3606
3607 #[test]
3608 fn stacking_single_candidate_gets_full_weight() {
3609 let log_density = Array2::from_shape_vec((3, 1), vec![-1.0, -2.0, -0.5]).unwrap();
3610 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3611 assert!((out.weights[0] - 1.0).abs() < 1e-12);
3612 assert_eq!(out.weights.len(), 1);
3613 }
3614
3615 #[test]
3616 fn stacking_dominant_candidate_attracts_nearly_all_weight() {
3617 let mut log_density = Array2::<f64>::zeros((50, 2));
3618 for i in 0..50 {
3619 log_density[[i, 0]] = -0.1;
3620 log_density[[i, 1]] = -5.0;
3621 }
3622 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3623 assert!(out.weights[0] > 0.99, "w0 = {}", out.weights[0]);
3624 assert!(out.weights[1] < 0.01, "w1 = {}", out.weights[1]);
3625 }
3626
3627 #[test]
3628 fn stacking_complementary_candidates_share_weight() {
3629 let n = 40;
3632 let mut log_density = Array2::<f64>::zeros((n, 2));
3633 for i in 0..n {
3634 if i < n / 2 {
3635 log_density[[i, 0]] = gaussian_logpdf(0.0, 0.0, 0.5);
3636 log_density[[i, 1]] = gaussian_logpdf(0.0, 1.5, 0.5);
3637 } else {
3638 log_density[[i, 0]] = gaussian_logpdf(0.0, 1.5, 0.5);
3639 log_density[[i, 1]] = gaussian_logpdf(0.0, 0.0, 0.5);
3640 }
3641 }
3642 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3643 assert!(
3644 out.weights[0] > 0.2 && out.weights[0] < 0.8,
3645 "w0 = {}",
3646 out.weights[0]
3647 );
3648 assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3649 }
3650
3651 #[test]
3652 fn stacking_weights_stay_on_the_simplex() {
3653 let log_density = Array2::from_shape_vec(
3654 (3, 3),
3655 vec![-1.0, -2.0, -3.0, -2.5, -1.0, -2.0, -3.0, -2.0, -1.0],
3656 )
3657 .unwrap();
3658 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3659 assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3660 assert!(out.weights.iter().all(|&w| w >= -1e-12));
3661 }
3662
3663 #[test]
3664 fn stacking_mean_log_score_is_monotone_under_more_iterations() {
3665 let log_density =
3668 Array2::from_shape_vec((4, 2), vec![-0.2, -3.0, -3.0, -0.2, -0.5, -1.5, -1.5, -0.5])
3669 .unwrap();
3670 let mut prev = f64::NEG_INFINITY;
3671 for max_iter in [1usize, 2, 4, 8, 32] {
3672 let out = solve_stacking_weights(
3673 log_density.view(),
3674 StackingConfig {
3675 max_iter,
3676 weight_tol: 0.0,
3677 },
3678 )
3679 .unwrap();
3680 assert!(
3681 out.mean_log_score >= prev - 1e-12,
3682 "log-score decreased at max_iter={max_iter}: {prev} -> {}",
3683 out.mean_log_score
3684 );
3685 prev = out.mean_log_score;
3686 }
3687 }
3688
3689 #[test]
3690 fn stacking_dead_candidate_column_is_rejected_and_zero_weighted() {
3691 let log_density = Array2::from_shape_vec(
3692 (3, 2),
3693 vec![
3694 -1.0,
3695 f64::NEG_INFINITY,
3696 -2.0,
3697 f64::NAN,
3698 -0.5,
3699 f64::NEG_INFINITY,
3700 ],
3701 )
3702 .unwrap();
3703 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3704 assert_eq!(out.weights[1], 0.0);
3705 assert!((out.weights[0] - 1.0).abs() < 1e-12);
3706 }
3707
3708 #[test]
3709 fn stacking_rows_with_no_finite_density_are_dropped() {
3710 let log_density = Array2::from_shape_vec(
3711 (3, 2),
3712 vec![-1.0, -2.0, f64::NAN, f64::NEG_INFINITY, -2.0, -1.0],
3713 )
3714 .unwrap();
3715 let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
3716 assert!((out.weights.sum() - 1.0).abs() < 1e-9);
3717 assert!(out.mean_log_score.is_finite());
3718 }
3719
3720 #[test]
3721 fn stacking_all_dead_table_errors() {
3722 let log_density = Array2::from_elem((2, 2), f64::NEG_INFINITY);
3723 assert!(solve_stacking_weights(log_density.view(), StackingConfig::default()).is_err());
3724 }
3725
3726 #[test]
3727 fn stacked_mean_is_weighted_combination() {
3728 let weights = Array1::from_vec(vec![0.25, 0.75]);
3729 let means = vec![
3730 Array1::from_vec(vec![1.0, 2.0, 3.0]),
3731 Array1::from_vec(vec![5.0, 6.0, 7.0]),
3732 ];
3733 let out = stacked_predictive_mean(&weights, &means).unwrap();
3734 assert!((out[0] - (0.25 * 1.0 + 0.75 * 5.0)).abs() < 1e-12);
3735 assert!((out[2] - (0.25 * 3.0 + 0.75 * 7.0)).abs() < 1e-12);
3736 }
3737
3738 #[test]
3739 fn stacked_mean_rejects_shape_mismatch() {
3740 let weights = Array1::from_vec(vec![0.5, 0.5]);
3741 let means = vec![
3742 Array1::from_vec(vec![1.0, 2.0]),
3743 Array1::from_vec(vec![3.0]),
3744 ];
3745 assert!(stacked_predictive_mean(&weights, &means).is_err());
3746 }
3747
3748 fn hybrid_slot(
3763 linear_nle: f64,
3764 p_linear: usize,
3765 latent_dim: usize,
3766 p_curved: usize,
3767 theta: f64,
3768 curved_loglik_gain: f64,
3769 ) -> Vec<HybridAtomCandidate> {
3770 let param_price =
3771 0.5 * (p_curved as f64 - p_linear as f64) * (2.0 * std::f64::consts::PI).ln();
3772 let curved_nle = linear_nle - curved_loglik_gain + param_price;
3773 vec![
3774 HybridAtomCandidate::linear(linear_nle, p_linear),
3775 HybridAtomCandidate::curved(latent_dim, curved_nle, p_curved, Some(theta)),
3776 ]
3777 }
3778
3779 #[test]
3780 fn hybrid_dominance_floor_selects_linear_when_turning_is_zero() {
3781 let slot = hybrid_slot(100.0, 2, 1, 5, 0.0, 0.0);
3786 let choice = select_hybrid_atom(&slot).unwrap();
3787 assert!(choice.param.is_linear());
3788 assert_eq!(choice.param, HybridAtomParam::Linear);
3789 assert!(choice.curved_turning.unwrap() <= HYBRID_LINEAR_TURNING_FLOOR);
3791 }
3792
3793 #[test]
3794 fn hybrid_selects_curved_when_turning_pays_for_itself() {
3795 let slot = hybrid_slot(100.0, 2, 1, 5, 2.0 * std::f64::consts::PI, 30.0);
3799 let choice = select_hybrid_atom(&slot).unwrap();
3800 assert_eq!(choice.param, HybridAtomParam::Curved { latent_dim: 1 });
3801 assert!(choice.curved_evidence_margin > 0.0);
3803 }
3804
3805 #[test]
3806 fn hybrid_keeps_linear_when_curvature_doesnt_pay_its_price() {
3807 let slot = hybrid_slot(100.0, 2, 1, 5, 0.05, 0.1);
3811 let choice = select_hybrid_atom(&slot).unwrap();
3812 assert!(choice.param.is_linear());
3813 assert!(choice.curved_evidence_margin <= 0.0);
3814 }
3815
3816 #[test]
3817 fn hybrid_tie_breaks_to_the_cheaper_linear_atom() {
3818 let theta = 0.5; let nle = 42.0;
3823 let slot = vec![
3824 HybridAtomCandidate::linear(nle, 2),
3825 HybridAtomCandidate::curved(1, nle, 5, Some(theta)),
3826 ];
3827 let choice = select_hybrid_atom(&slot).unwrap();
3828 assert!(choice.param.is_linear());
3829 assert_eq!(choice.num_parameters, 2);
3830 }
3831
3832 #[test]
3833 fn hybrid_split_reduces_to_pure_linear_when_all_features_are_straight() {
3834 let slots: Vec<Vec<HybridAtomCandidate>> = (0..6)
3838 .map(|i| hybrid_slot(50.0 + i as f64, 2, 1, 5, 0.0, 0.0))
3839 .collect();
3840 let split = select_hybrid_split(&slots).unwrap();
3841 assert!(split.is_pure_linear());
3842 assert_eq!(split.curved_atom_count, 0);
3843 assert_eq!(split.linear_atom_count(), 6);
3844 let pure_linear: f64 = (0..6).map(|i| 50.0 + i as f64).sum();
3846 assert!((split.total_negative_log_evidence - pure_linear).abs() < 1e-12);
3847 }
3848
3849 #[test]
3850 fn hybrid_split_reduces_to_pure_curved_when_every_feature_curves() {
3851 let slots: Vec<Vec<HybridAtomCandidate>> = (0..5)
3854 .map(|i| hybrid_slot(80.0 + i as f64, 2, 1, 5, 2.0 * std::f64::consts::PI, 40.0))
3855 .collect();
3856 let split = select_hybrid_split(&slots).unwrap();
3857 assert!(split.is_pure_curved());
3858 assert_eq!(split.curved_atom_count, 5);
3859 assert_eq!(split.linear_atom_count(), 0);
3860 }
3861
3862 #[test]
3863 fn hybrid_split_on_mixed_dictionary_picks_curved_for_circles_linear_for_directions() {
3864 let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
3874 let mut pure_linear_baseline = 0.0_f64;
3875 for i in 0..3 {
3878 let linear_nle = 120.0 + 3.0 * i as f64;
3879 pure_linear_baseline += linear_nle;
3880 slots.push(hybrid_slot(
3881 linear_nle,
3882 2,
3883 1,
3884 5,
3885 2.0 * std::f64::consts::PI,
3886 35.0,
3887 ));
3888 }
3889 for i in 0..4 {
3892 let linear_nle = 90.0 + 2.0 * i as f64;
3893 pure_linear_baseline += linear_nle;
3894 slots.push(hybrid_slot(linear_nle, 2, 1, 5, 0.0, 0.0));
3895 }
3896
3897 let split = select_hybrid_split(&slots).unwrap();
3898
3899 for (idx, choice) in split.atoms.iter().enumerate() {
3902 if idx < 3 {
3903 assert_eq!(
3904 choice.param,
3905 HybridAtomParam::Curved { latent_dim: 1 },
3906 "circle slot {idx} should select curved"
3907 );
3908 } else {
3909 assert!(
3910 choice.param.is_linear(),
3911 "direction slot {idx} should select linear"
3912 );
3913 }
3914 }
3915 assert_eq!(split.curved_atom_count, 3);
3916 assert_eq!(split.linear_atom_count(), 4);
3917
3918 assert!(
3924 split.total_negative_log_evidence <= pure_linear_baseline + 1e-9,
3925 "hybrid NLE {} must be <= summed linear-candidate NLE {}",
3926 split.total_negative_log_evidence,
3927 pure_linear_baseline
3928 );
3929 assert!(split.total_negative_log_evidence < pure_linear_baseline);
3931 }
3932
3933 #[test]
3934 fn hybrid_split_rejects_empty_slot() {
3935 let slots = vec![hybrid_slot(10.0, 2, 1, 5, 0.0, 0.0), Vec::new()];
3936 assert!(select_hybrid_split(&slots).is_err());
3937 }
3938
3939 fn cand(name: &str, score: f64, edf: f64, log_lik: f64) -> RemlCandidate {
3947 RemlCandidate {
3948 index: 0,
3949 name: name.to_string(),
3950 score,
3951 edf: Some(edf),
3952 log_lik: Some(log_lik),
3953 family: None,
3954 n_obs: None,
3955 }
3956 }
3957
3958 #[test]
3959 fn ranking_score_is_conditional_aic_when_loglik_and_edf_present() {
3960 let c = cand("m", 999.0, 6.748, -32.0866);
3962 let expected = -2.0 * -32.0866 + 2.0 * 6.748;
3963 assert!((c.ranking_score() - expected).abs() < 1e-9);
3964 }
3965
3966 #[test]
3967 fn ranking_score_falls_back_to_evidence_without_loglik() {
3968 let c = RemlCandidate {
3969 index: 0,
3970 name: "m".to_string(),
3971 score: 151.28,
3972 edf: Some(6.0),
3973 log_lik: None,
3974 family: None,
3975 n_obs: None,
3976 };
3977 assert_eq!(c.ranking_score(), 151.28);
3978 }
3979
3980 #[test]
3981 fn compare_models_rejects_pure_noise_smooth_despite_lower_evidence() {
3982 let small = cand("small", 180.526, 6.748, -32.0866);
3989 let big = cand("big", 177.404, 14.250, -32.1212);
3990
3991 assert!(big.score < small.score);
3993
3994 let cmp = compare_reml_fits(vec![small, big]).expect("compare");
3995 assert_eq!(
3996 cmp.winner, "small",
3997 "compare_models must Occam-penalise the pure-noise smooth and pick the smaller model"
3998 );
3999 let small_row = cmp
4002 .score_table
4003 .iter()
4004 .find(|r| r.name == "small")
4005 .expect("small row");
4006 let big_row = cmp
4007 .score_table
4008 .iter()
4009 .find(|r| r.name == "big")
4010 .expect("big row");
4011 assert!((small_row.reml_score - 180.526).abs() < 1e-9);
4012 assert!((big_row.reml_score - 177.404).abs() < 1e-9);
4013 }
4014
4015 #[test]
4016 fn compare_models_keeps_power_for_a_relevant_smooth() {
4017 let small = cand("small", 1025.067, 6.75, -368.985);
4023 let big = cand("big", 199.509, 14.25, -33.165);
4024 let cmp = compare_reml_fits(vec![small, big]).expect("compare");
4025 assert_eq!(
4026 cmp.winner, "big",
4027 "compare_models must retain power: the relevant smooth's model must win"
4028 );
4029 }
4030
4031 #[test]
4032 fn compare_models_rejects_mismatched_observation_counts() {
4033 let with_n = |name: &str, n: usize| RemlCandidate {
4037 index: 0,
4038 name: name.to_string(),
4039 score: 100.0,
4040 edf: Some(5.0),
4041 log_lik: Some(-40.0),
4042 family: Some("gaussian".to_string()),
4043 n_obs: Some(n),
4044 };
4045 let err = compare_reml_fits(vec![with_n("big", 500), with_n("small", 100)])
4046 .expect_err("cross-n comparison must be rejected");
4047 assert!(
4048 err.contains("number of observations") && err.contains("500") && err.contains("100"),
4049 "n-guard error should name the incomparable counts, got: {err}"
4050 );
4051
4052 compare_reml_fits(vec![with_n("a", 250), with_n("b", 250)])
4054 .expect("same-n comparison must succeed");
4055
4056 let without_n = RemlCandidate {
4059 index: 0,
4060 name: "legacy".to_string(),
4061 score: 90.0,
4062 edf: Some(4.0),
4063 log_lik: Some(-35.0),
4064 family: Some("gaussian".to_string()),
4065 n_obs: None,
4066 };
4067 compare_reml_fits(vec![with_n("counted", 500), without_n])
4068 .expect("an unconstrained (None) count must not trip the guard");
4069 }
4070}