use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::FaerEigh;
use crate::linalg::lanczos::{
SymmetricLanczosOptions, symmetric_lanczos_eigenpairs, symmetric_lanczos_log_quadrature,
};
use crate::linalg::triangular::cholesky_solve_vector;
use crate::solver::arrow_schur::{ArrowFactorCache, ArrowSchurSystem};
use crate::solver::priority_selection::{PriorityCandidate, rank_priority_candidates};
pub const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
const EVIDENCE_LOGDET_SLQ_PROBES: usize = 16;
const EVIDENCE_LOGDET_LANCZOS_STEPS: usize = 32;
const EVIDENCE_HVP_SYMMETRY_REL_TOL: f64 = 1e-8;
const EVIDENCE_HVP_SYMMETRY_PROBES: usize = 4;
#[derive(Clone, Copy)]
pub struct EvidenceHvpLogDet<'a> {
pub dim: usize,
pub apply: &'a dyn Fn(&[f64]) -> Vec<f64>,
}
#[derive(Clone, Copy)]
pub enum EvidenceLogDetSource<'a> {
FactoredArrow {
cache: &'a ArrowFactorCache,
fallback_hvp: Option<EvidenceHvpLogDet<'a>>,
},
Hvp(EvidenceHvpLogDet<'a>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TopologyKind {
Periodic,
Flat,
Sphere,
Torus,
}
impl TopologyKind {
pub fn complexity_rank(self) -> u8 {
match self {
TopologyKind::Flat => 0,
TopologyKind::Periodic => 1,
TopologyKind::Sphere => 2,
TopologyKind::Torus => 3,
}
}
}
#[derive(Debug, Clone)]
pub struct TopologyCandidate {
pub kind: TopologyKind,
pub negative_log_evidence: f64,
pub effective_dim: f64,
pub n_obs: usize,
pub converged: bool,
pub exclusion_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SelectedTopology {
pub winner: TopologyKind,
pub ranking: Vec<TopologyCandidate>,
pub tie: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct TopologySelectOptions {
pub tie_tolerance: f64,
pub score_scale: TopologyScoreScale,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TopologyScoreScale {
PerObservation,
PerEffectiveDim,
}
#[derive(Debug, Clone, Copy)]
pub struct StackingConfig {
pub max_iter: usize,
pub weight_tol: f64,
}
impl Default for StackingConfig {
fn default() -> Self {
Self {
max_iter: 1000,
weight_tol: 1e-10,
}
}
}
#[derive(Debug, Clone)]
pub struct StackingWeights {
pub weights: Array1<f64>,
pub mean_log_score: f64,
pub iterations: usize,
}
pub fn solve_stacking_weights(
log_density: ArrayView2<'_, f64>,
config: StackingConfig,
) -> Result<StackingWeights, String> {
let n_obs = log_density.nrows();
let n_cand = log_density.ncols();
if n_cand == 0 {
return Err("stacking requires at least one candidate column".to_string());
}
if n_obs == 0 {
return Err("stacking requires at least one held-out observation row".to_string());
}
let kept_cols: Vec<usize> = (0..n_cand)
.filter(|&k| (0..n_obs).any(|i| log_density[[i, k]].is_finite()))
.collect();
if kept_cols.is_empty() {
return Err("stacking found no candidate with any finite held-out density".to_string());
}
let rows: Vec<usize> = (0..n_obs)
.filter(|&i| kept_cols.iter().any(|&k| log_density[[i, k]].is_finite()))
.collect();
if rows.is_empty() {
return Err("stacking found no held-out row with a finite density".to_string());
}
let kept = kept_cols.len();
let mut weights = Array1::<f64>::from_elem(kept, 1.0 / kept as f64);
let mut next = Array1::<f64>::zeros(kept);
let mut iterations = 0usize;
for _ in 0..config.max_iter {
iterations += 1;
next.fill(0.0);
let mut active_rows = 0usize;
for &row in &rows {
let mut row_max = f64::NEG_INFINITY;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
row_max = row_max.max(weights[local_col].ln() + log_p);
}
}
if !row_max.is_finite() {
continue;
}
let mut denom = 0.0_f64;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
denom += (weights[local_col].ln() + log_p - row_max).exp();
}
}
if denom <= 0.0 {
continue;
}
active_rows += 1;
let log_mix = row_max + denom.ln();
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
next[local_col] += (weights[local_col].ln() + log_p - log_mix).exp();
}
}
}
if active_rows == 0 {
break;
}
next.mapv_inplace(|value| value / active_rows as f64);
let total = next.sum();
if total > 0.0 {
next.mapv_inplace(|value| value / total);
}
let delta = next
.iter()
.zip(weights.iter())
.fold(0.0_f64, |acc, (a, b)| acc.max((a - b).abs()));
weights.assign(&next);
if delta <= config.weight_tol {
break;
}
}
let mean_log_score = stacking_mean_log_score(log_density, &rows, &kept_cols, weights.view());
let mut full = Array1::<f64>::zeros(n_cand);
for (local_col, &source_col) in kept_cols.iter().enumerate() {
full[source_col] = weights[local_col];
}
Ok(StackingWeights {
weights: full,
mean_log_score,
iterations,
})
}
fn stacking_mean_log_score(
log_density: ArrayView2<'_, f64>,
rows: &[usize],
kept_cols: &[usize],
weights: ArrayView1<'_, f64>,
) -> f64 {
let mut score_sum = 0.0_f64;
let mut counted = 0usize;
for &row in rows {
let mut row_max = f64::NEG_INFINITY;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
row_max = row_max.max(weights[local_col].ln() + log_p);
}
}
if !row_max.is_finite() {
continue;
}
let mut denom = 0.0_f64;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
denom += (weights[local_col].ln() + log_p - row_max).exp();
}
}
if denom > 0.0 {
score_sum += row_max + denom.ln();
counted += 1;
}
}
if counted == 0 {
f64::NEG_INFINITY
} else {
score_sum / counted as f64
}
}
pub fn stacked_predictive_mean(
weights: &Array1<f64>,
candidate_means: &[Array1<f64>],
) -> Result<Array1<f64>, String> {
if candidate_means.len() != weights.len() {
return Err(format!(
"stacked_predictive_mean: {} weights but {} candidate mean vectors",
weights.len(),
candidate_means.len()
));
}
let Some(first) = candidate_means.first() else {
return Err("stacked_predictive_mean requires at least one candidate".to_string());
};
let n_rows = first.len();
if candidate_means.iter().any(|means| means.len() != n_rows) {
return Err(
"stacked_predictive_mean: candidate mean vectors disagree on row count".to_string(),
);
}
let mut out = Array1::<f64>::zeros(n_rows);
for (weight, means) in weights.iter().zip(candidate_means) {
if *weight != 0.0 {
out.scaled_add(*weight, means);
}
}
Ok(out)
}
#[derive(Debug, Clone, Copy)]
pub struct GaussianMixtureConfig {
pub max_iter: usize,
pub loglik_tol: f64,
pub covariance_floor: f64,
pub kmeans_max_iter: usize,
}
impl Default for GaussianMixtureConfig {
fn default() -> Self {
Self {
max_iter: 200,
loglik_tol: 1e-7,
covariance_floor: 1e-6,
kmeans_max_iter: 25,
}
}
}
#[derive(Debug, Clone)]
pub struct GaussianMixtureFit {
pub weights: Array1<f64>,
pub means: Array2<f64>,
pub covariances: Vec<Array2<f64>>,
pub k: usize,
pub d: usize,
pub n_obs: usize,
pub loglik: f64,
pub iterations: usize,
}
impl GaussianMixtureFit {
pub fn num_free_parameters(&self) -> usize {
let cov_per = self.d * (self.d + 1) / 2;
(self.k - 1) + self.k * self.d + self.k * cov_per
}
pub fn per_point_log_density(&self, data: ArrayView2<'_, f64>) -> Result<Array1<f64>, String> {
if data.ncols() != self.d {
return Err(format!(
"mixture log-density expects {} columns, got {}",
self.d,
data.ncols()
));
}
let n = data.nrows();
let mut comp = vec![GaussianComponentEval::new(self.d); self.k];
for j in 0..self.k {
comp[j] = GaussianComponentEval::factor(self.means.row(j), &self.covariances[j])?;
}
let mut out = Array1::<f64>::zeros(n);
let log_w: Vec<f64> = self
.weights
.iter()
.map(|w| w.max(f64::MIN_POSITIVE).ln())
.collect();
for i in 0..n {
let row = data.row(i);
let mut log_terms = vec![f64::NEG_INFINITY; self.k];
let mut max_term = f64::NEG_INFINITY;
for j in 0..self.k {
let lt = log_w[j] + comp[j].log_density(row);
log_terms[j] = lt;
if lt > max_term {
max_term = lt;
}
}
out[i] = log_sum_exp(&log_terms, max_term);
}
Ok(out)
}
pub fn laplace_negative_log_evidence(&self, data: ArrayView2<'_, f64>) -> Result<f64, String> {
let p = self.num_free_parameters();
let information = self.empirical_fisher_information(data)?;
if information.nrows() != p {
return Err(format!(
"mixture empirical-Fisher information has dim {} but expected free-parameter count {p}",
information.nrows()
));
}
let apply_info = |x: &[f64]| -> Vec<f64> {
let mut out = vec![0.0_f64; p];
for r in 0..p {
let mut acc = 0.0_f64;
for c in 0..p {
acc += information[[r, c]] * x[c];
}
out[r] = acc;
}
out
};
let hvp = EvidenceHvpLogDet {
dim: p,
apply: &apply_info,
};
let v = laplace_evidence(
EvidenceLogDetSource::Hvp(hvp),
0.0,
-self.loglik,
p as f64,
0.0,
);
if !v.is_finite() {
return Err("mixture Laplace evidence is not finite".to_string());
}
Ok(v)
}
fn empirical_fisher_information(
&self,
data: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
if data.ncols() != self.d {
return Err(format!(
"mixture information expects {} columns, got {}",
self.d,
data.ncols()
));
}
let n = data.nrows();
let p = self.num_free_parameters();
let cov_per = self.d * (self.d + 1) / 2;
let mut comp = Vec::with_capacity(self.k);
for j in 0..self.k {
comp.push(GaussianComponentEval::factor(
self.means.row(j),
&self.covariances[j],
)?);
}
let log_w: Vec<f64> = self
.weights
.iter()
.map(|w| w.max(f64::MIN_POSITIVE).ln())
.collect();
let mean_base = self.k - 1;
let cov_base = mean_base + self.k * self.d;
let mut info = Array2::<f64>::zeros((p, p));
let mut score = vec![0.0_f64; p];
for i in 0..n {
let row = data.row(i);
let mut log_terms = vec![0.0_f64; self.k];
let mut max_term = f64::NEG_INFINITY;
for j in 0..self.k {
let lt = log_w[j] + comp[j].log_density(row);
log_terms[j] = lt;
if lt > max_term {
max_term = lt;
}
}
let log_mix = log_sum_exp(&log_terms, max_term);
let resp: Vec<f64> = log_terms.iter().map(|lt| (lt - log_mix).exp()).collect();
for s in score.iter_mut() {
*s = 0.0;
}
for j in 1..self.k {
score[j - 1] = resp[j] - self.weights[j];
}
for j in 0..self.k {
let prec_v = comp[j].precision_times_residual(row); let mbo = mean_base + j * self.d;
for c in 0..self.d {
score[mbo + c] = resp[j] * prec_v[c];
}
let cbo = cov_base + j * cov_per;
let mut idx = 0usize;
for a in 0..self.d {
for b in 0..=a {
let outer = prec_v[a] * prec_v[b];
let prec_ab = comp[j].precision[[a, b]];
let mut g = 0.5 * (outer - prec_ab);
if a != b {
g *= 2.0;
}
score[cbo + idx] = resp[j] * g;
idx += 1;
}
}
}
for r in 0..p {
let sr = score[r];
if sr == 0.0 {
continue;
}
for c in 0..p {
info[[r, c]] += sr * score[c];
}
}
}
for r in 0..p {
for c in (r + 1)..p {
let avg = 0.5 * (info[[r, c]] + info[[c, r]]);
info[[r, c]] = avg;
info[[c, r]] = avg;
}
info[[r, r]] += 1.0;
}
Ok(info)
}
}
#[derive(Debug, Clone)]
struct GaussianComponentEval {
mean: Array1<f64>,
precision: Array2<f64>,
log_norm: f64,
d: usize,
}
impl GaussianComponentEval {
fn new(d: usize) -> Self {
Self {
mean: Array1::zeros(d),
precision: Array2::eye(d),
log_norm: 0.0,
d,
}
}
fn factor(mean: ArrayView1<'_, f64>, cov: &Array2<f64>) -> Result<Self, String> {
let d = mean.len();
if cov.nrows() != d || cov.ncols() != d {
return Err(format!(
"mixture component covariance must be {d}x{d}, got {}x{}",
cov.nrows(),
cov.ncols()
));
}
let (evals, evecs) = cov
.eigh(Side::Lower)
.map_err(|e| format!("mixture component covariance eigendecomposition failed: {e}"))?;
let mut log_det = 0.0_f64;
let mut inv_evals = Array1::<f64>::zeros(d);
for (idx, &ev) in evals.iter().enumerate() {
if !ev.is_finite() || ev <= 0.0 {
return Err(format!(
"mixture component covariance is not SPD: eigenvalue {idx} is {ev:.3e}"
));
}
log_det += ev.ln();
inv_evals[idx] = 1.0 / ev;
}
let mut precision = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in 0..d {
let mut acc = 0.0_f64;
for m in 0..d {
acc += evecs[[a, m]] * inv_evals[m] * evecs[[b, m]];
}
precision[[a, b]] = acc;
}
}
let log_norm = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det);
Ok(Self {
mean: mean.to_owned(),
precision,
log_norm,
d,
})
}
#[inline]
fn log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
let pv = self.precision_times_residual(y);
let mut quad = 0.0_f64;
for c in 0..self.d {
quad += (y[c] - self.mean[c]) * pv[c];
}
self.log_norm - 0.5 * quad
}
#[inline]
fn precision_times_residual(&self, y: ArrayView1<'_, f64>) -> Vec<f64> {
let mut out = vec![0.0_f64; self.d];
for a in 0..self.d {
let mut acc = 0.0_f64;
for b in 0..self.d {
acc += self.precision[[a, b]] * (y[b] - self.mean[b]);
}
out[a] = acc;
}
out
}
}
#[inline]
fn log_sum_exp(terms: &[f64], max_term: f64) -> f64 {
if !max_term.is_finite() {
return f64::NEG_INFINITY;
}
let mut acc = 0.0_f64;
for &t in terms {
acc += (t - max_term).exp();
}
max_term + acc.ln()
}
pub fn fit_gaussian_mixture(
data: ArrayView2<'_, f64>,
k: usize,
config: GaussianMixtureConfig,
) -> Result<GaussianMixtureFit, String> {
let n = data.nrows();
let d = data.ncols();
if k == 0 {
return Err("gaussian mixture requires k >= 1".to_string());
}
if d == 0 {
return Err("gaussian mixture requires at least one column".to_string());
}
if k > n {
return Err(format!(
"gaussian mixture requested {k} components but data has {n} rows"
));
}
let centers = crate::basis::select_centers_by_strategy(
data,
&crate::basis::CenterStrategy::KMeans {
num_centers: k,
max_iter: config.kmeans_max_iter,
},
)
.map_err(|e| format!("gaussian mixture k-means seeding failed: {e}"))?;
if centers.nrows() != k || centers.ncols() != d {
return Err(format!(
"gaussian mixture seeding returned {}x{} centers, expected {k}x{d}",
centers.nrows(),
centers.ncols()
));
}
let mut means = centers;
let global_cov = data_covariance(data, config.covariance_floor);
let mut covariances = vec![global_cov; k];
let mut weights = Array1::<f64>::from_elem(k, 1.0 / k as f64);
let mut resp = Array2::<f64>::zeros((n, k));
let mut prev_mean_ll = f64::NEG_INFINITY;
let mut total_loglik = f64::NEG_INFINITY;
let mut iterations = 0usize;
for iter in 0..config.max_iter.max(1) {
iterations = iter + 1;
let mut comp = Vec::with_capacity(k);
for j in 0..k {
comp.push(GaussianComponentEval::factor(
means.row(j),
&covariances[j],
)?);
}
let log_w: Vec<f64> = weights
.iter()
.map(|w| w.max(f64::MIN_POSITIVE).ln())
.collect();
total_loglik = 0.0;
for i in 0..n {
let yrow = data.row(i);
let mut log_terms = vec![0.0_f64; k];
let mut max_term = f64::NEG_INFINITY;
for j in 0..k {
let lt = log_w[j] + comp[j].log_density(yrow);
log_terms[j] = lt;
if lt > max_term {
max_term = lt;
}
}
let log_mix = log_sum_exp(&log_terms, max_term);
total_loglik += log_mix;
for j in 0..k {
resp[[i, j]] = (log_terms[j] - log_mix).exp();
}
}
let mean_ll = total_loglik / n as f64;
if iter > 0 {
let denom = prev_mean_ll.abs().max(1.0);
if (mean_ll - prev_mean_ll).abs() / denom <= config.loglik_tol {
break;
}
}
prev_mean_ll = mean_ll;
let mut nk = vec![0.0_f64; k];
for j in 0..k {
let mut sum = 0.0_f64;
for i in 0..n {
sum += resp[[i, j]];
}
nk[j] = sum.max(f64::MIN_POSITIVE);
}
for j in 0..k {
weights[j] = nk[j] / n as f64;
let mut mu = Array1::<f64>::zeros(d);
for i in 0..n {
let r = resp[[i, j]];
if r == 0.0 {
continue;
}
for c in 0..d {
mu[c] += r * data[[i, c]];
}
}
mu.mapv_inplace(|v| v / nk[j]);
for c in 0..d {
means[[j, c]] = mu[c];
}
let mut cov = Array2::<f64>::zeros((d, d));
for i in 0..n {
let r = resp[[i, j]];
if r == 0.0 {
continue;
}
for a in 0..d {
let da = data[[i, a]] - mu[a];
for b in 0..d {
cov[[a, b]] += r * da * (data[[i, b]] - mu[b]);
}
}
}
cov.mapv_inplace(|v| v / nk[j]);
for a in 0..d {
cov[[a, a]] += config.covariance_floor;
}
covariances[j] = cov;
}
}
Ok(GaussianMixtureFit {
weights,
means,
covariances,
k,
d,
n_obs: n,
loglik: total_loglik,
iterations,
})
}
fn data_covariance(data: ArrayView2<'_, f64>, floor: f64) -> Array2<f64> {
let n = data.nrows();
let d = data.ncols();
let mut mean = Array1::<f64>::zeros(d);
for i in 0..n {
for c in 0..d {
mean[c] += data[[i, c]];
}
}
mean.mapv_inplace(|v| v / n.max(1) as f64);
let mut cov = Array2::<f64>::zeros((d, d));
for i in 0..n {
for a in 0..d {
let da = data[[i, a]] - mean[a];
for b in 0..d {
cov[[a, b]] += da * (data[[i, b]] - mean[b]);
}
}
}
let inv = 1.0 / (n.max(1) as f64);
cov.mapv_inplace(|v| v * inv);
for a in 0..d {
cov[[a, a]] += floor;
}
cov
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnionStructure {
CircleCircle,
CirclePointCluster,
LineCluster,
}
pub const UNION_STRUCTURE_LADDER: &[UnionStructure] = &[
UnionStructure::CircleCircle,
UnionStructure::CirclePointCluster,
UnionStructure::LineCluster,
];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnionComponentKind {
Circle,
Line,
PointCluster,
}
impl UnionStructure {
pub const fn as_str(self) -> &'static str {
match self {
UnionStructure::CircleCircle => "union_circle+circle",
UnionStructure::CirclePointCluster => "union_circle+cluster",
UnionStructure::LineCluster => "union_line+cluster",
}
}
pub const fn components(self) -> &'static [UnionComponentKind] {
match self {
UnionStructure::CircleCircle => {
&[UnionComponentKind::Circle, UnionComponentKind::Circle]
}
UnionStructure::CirclePointCluster => {
&[UnionComponentKind::Circle, UnionComponentKind::PointCluster]
}
UnionStructure::LineCluster => {
&[UnionComponentKind::Line, UnionComponentKind::PointCluster]
}
}
}
pub const fn num_components(self) -> usize {
self.components().len()
}
}
#[derive(Debug, Clone)]
pub struct UnionComponentFit {
pub kind: UnionComponentKind,
pub row_count: usize,
pub num_parameters: usize,
pub negative_log_evidence: f64,
}
#[derive(Debug, Clone)]
pub struct UnionStructureFit {
pub structure: UnionStructure,
pub components: Vec<UnionComponentFit>,
pub negative_log_evidence: f64,
pub total_parameters: usize,
}
pub fn union_responsibility_split(
data: ArrayView2<'_, f64>,
m: usize,
config: GaussianMixtureConfig,
) -> Result<Vec<Vec<usize>>, String> {
let n = data.nrows();
if m == 0 {
return Err("union split requires at least one component".to_string());
}
if m > n {
return Err(format!(
"union split requested {m} groups but data has {n} rows"
));
}
if m == 1 {
return Ok(vec![(0..n).collect()]);
}
let fit = fit_gaussian_mixture(data, m, config)?;
let mut groups: Vec<Vec<usize>> = vec![Vec::new(); m];
let mut comp = Vec::with_capacity(m);
for j in 0..m {
comp.push(GaussianComponentEval::factor(
fit.means.row(j),
&fit.covariances[j],
)?);
}
let log_w: Vec<f64> = fit
.weights
.iter()
.map(|w| w.max(f64::MIN_POSITIVE).ln())
.collect();
for i in 0..n {
let row = data.row(i);
let mut best_j = 0usize;
let mut best_lt = f64::NEG_INFINITY;
for j in 0..m {
let lt = log_w[j] + comp[j].log_density(row);
if lt > best_lt {
best_lt = lt;
best_j = j;
}
}
groups[best_j].push(i);
}
Ok(groups)
}
pub fn fit_union_structure(
data: ArrayView2<'_, f64>,
structure: UnionStructure,
config: GaussianMixtureConfig,
) -> Result<UnionStructureFit, String> {
let comps = structure.components();
let m = comps.len();
let groups = union_responsibility_split(data, m, config)?;
let mut fits = Vec::with_capacity(m);
let mut total_nle = 0.0_f64;
let mut total_parameters = 0usize;
for (kind, rows) in comps.iter().zip(groups.iter()) {
let group = gather_union_rows(data, rows);
let (nle, p) = fit_union_component(group.view(), *kind, config)?;
if !nle.is_finite() {
return Err(format!(
"union {} component {:?} produced non-finite evidence",
structure.as_str(),
kind
));
}
total_nle += nle;
total_parameters += p;
fits.push(UnionComponentFit {
kind: *kind,
row_count: rows.len(),
num_parameters: p,
negative_log_evidence: nle,
});
}
Ok(UnionStructureFit {
structure,
components: fits,
negative_log_evidence: total_nle,
total_parameters,
})
}
pub fn fit_union_ladder(
data: ArrayView2<'_, f64>,
config: GaussianMixtureConfig,
) -> Result<Vec<UnionStructureFit>, String> {
let mut fits = Vec::new();
let mut errors = Vec::new();
for &structure in UNION_STRUCTURE_LADDER {
match fit_union_structure(data, structure, config) {
Ok(fit) => fits.push(fit),
Err(e) => errors.push(format!("{}: {e}", structure.as_str())),
}
}
if fits.is_empty() {
return Err(format!(
"union ladder produced no fittable composites{}",
if errors.is_empty() {
String::new()
} else {
format!(" ({})", errors.join("; "))
}
));
}
let ranked = rank_priority_candidates(
fits.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.negative_log_evidence;
let tie = row.total_parameters; PriorityCandidate::new(row, idx, score, tie)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect::<Vec<_>>();
Ok(ranked)
}
fn gather_union_rows(data: ArrayView2<'_, f64>, idx: &[usize]) -> Array2<f64> {
let d = data.ncols();
let mut out = Array2::<f64>::zeros((idx.len(), d));
for (r, &i) in idx.iter().enumerate() {
for c in 0..d {
out[[r, c]] = data[[i, c]];
}
}
out
}
fn fit_union_component(
group: ArrayView2<'_, f64>,
kind: UnionComponentKind,
config: GaussianMixtureConfig,
) -> Result<(f64, usize), String> {
match kind {
UnionComponentKind::Line | UnionComponentKind::PointCluster => {
if group.nrows() < group.ncols() + 1 {
return Err(format!(
"union gaussian component needs >= {} rows, got {}",
group.ncols() + 1,
group.nrows()
));
}
let fit = fit_gaussian_mixture(group, 1, config)?;
let nle = fit.laplace_negative_log_evidence(group)?;
Ok((nle, fit.num_free_parameters()))
}
UnionComponentKind::Circle => fit_circle_component_evidence(group, config),
}
}
fn fit_circle_component_evidence(
group: ArrayView2<'_, f64>,
config: GaussianMixtureConfig,
) -> Result<(f64, usize), String> {
let d = group.ncols();
if d != 2 {
return Err(format!(
"union circle component requires 2-D data, got {d} columns"
));
}
let n = group.nrows();
let p = 4usize; if n < p + 1 {
return Err(format!(
"union circle component needs >= {} rows, got {n}",
p + 1
));
}
let mut cx = 0.0_f64;
let mut cy = 0.0_f64;
for i in 0..n {
cx += group[[i, 0]];
cy += group[[i, 1]];
}
cx /= n as f64;
cy /= n as f64;
let mut radii = vec![0.0_f64; n];
let mut radius = 0.0_f64;
for i in 0..n {
let dx = group[[i, 0]] - cx;
let dy = group[[i, 1]] - cy;
let r = (dx * dx + dy * dy).sqrt();
radii[i] = r;
radius += r;
}
radius /= n as f64;
let mut var_r = 0.0_f64;
for &r in &radii {
let e = r - radius;
var_r += e * e;
}
var_r = (var_r / n as f64).max(config.covariance_floor);
let inv_var = 1.0 / var_r;
let mut loglik = 0.0_f64;
let log_2pi = (2.0 * std::f64::consts::PI).ln();
for &r in &radii {
let e = r - radius;
let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e * inv_var;
let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
loglik += radial + angular;
}
let mut info = Array2::<f64>::zeros((p, p));
let mut score = [0.0_f64; 4];
for i in 0..n {
let dx = group[[i, 0]] - cx;
let dy = group[[i, 1]] - cy;
let r = radii[i].max(f64::MIN_POSITIVE);
let e = radii[i] - radius;
let ee = e * inv_var;
score[0] = ee * (-dx / r);
score[1] = ee * (-dy / r);
score[2] = ee;
score[3] = -0.5 + 0.5 * e * e * inv_var;
for a in 0..p {
let sa = score[a];
if sa == 0.0 {
continue;
}
for b in 0..p {
info[[a, b]] += sa * score[b];
}
}
}
for a in 0..p {
for b in (a + 1)..p {
let avg = 0.5 * (info[[a, b]] + info[[b, a]]);
info[[a, b]] = avg;
info[[b, a]] = avg;
}
info[[a, a]] += 1.0;
}
let apply_info = |x: &[f64]| -> Vec<f64> {
let mut out = vec![0.0_f64; p];
for r in 0..p {
let mut acc = 0.0_f64;
for c in 0..p {
acc += info[[r, c]] * x[c];
}
out[r] = acc;
}
out
};
let hvp = EvidenceHvpLogDet {
dim: p,
apply: &apply_info,
};
let v = laplace_evidence(EvidenceLogDetSource::Hvp(hvp), 0.0, -loglik, p as f64, 0.0);
if !v.is_finite() {
return Err("union circle component Laplace evidence is not finite".to_string());
}
Ok((v, p))
}
#[derive(Debug, Clone)]
enum UnionComponentDensity {
Gaussian {
log_weight: f64,
eval: GaussianComponentEval,
},
Circle {
log_weight: f64,
center: [f64; 2],
radius: f64,
var_r: f64,
},
}
impl UnionComponentDensity {
fn weighted_log_density(&self, y: ArrayView1<'_, f64>) -> f64 {
match self {
UnionComponentDensity::Gaussian { log_weight, eval } => {
log_weight + eval.log_density(y)
}
UnionComponentDensity::Circle {
log_weight,
center,
radius,
var_r,
} => {
let dx = y[0] - center[0];
let dy = y[1] - center[1];
let r = (dx * dx + dy * dy).sqrt();
let log_2pi = (2.0 * std::f64::consts::PI).ln();
let e = r - radius;
let radial = -0.5 * (log_2pi + var_r.ln()) - 0.5 * e * e / var_r;
let angular = -(log_2pi + r.max(f64::MIN_POSITIVE).ln());
log_weight + radial + angular
}
}
}
}
fn fit_union_component_densities(
train: ArrayView2<'_, f64>,
structure: UnionStructure,
config: GaussianMixtureConfig,
) -> Result<Vec<UnionComponentDensity>, String> {
let comps = structure.components();
let m = comps.len();
let groups = union_responsibility_split(train, m, config)?;
let n_train = train.nrows().max(1) as f64;
let mut out = Vec::with_capacity(m);
for (kind, rows) in comps.iter().zip(groups.iter()) {
if rows.is_empty() {
return Err(format!(
"union {} held-out density: empty component group",
structure.as_str()
));
}
let log_weight = (rows.len() as f64 / n_train).max(f64::MIN_POSITIVE).ln();
let group = gather_union_rows(train, rows);
match kind {
UnionComponentKind::Line | UnionComponentKind::PointCluster => {
if group.nrows() < group.ncols() + 1 {
return Err(format!(
"union gaussian component density needs >= {} rows, got {}",
group.ncols() + 1,
group.nrows()
));
}
let fit = fit_gaussian_mixture(group.view(), 1, config)?;
let eval = GaussianComponentEval::factor(fit.means.row(0), &fit.covariances[0])?;
out.push(UnionComponentDensity::Gaussian { log_weight, eval });
}
UnionComponentKind::Circle => {
let d = group.ncols();
if d != 2 {
return Err(format!(
"union circle component density requires 2-D data, got {d} columns"
));
}
let n = group.nrows();
if n < 5 {
return Err(format!(
"union circle component density needs >= 5 rows, got {n}"
));
}
let mut cx = 0.0_f64;
let mut cy = 0.0_f64;
for i in 0..n {
cx += group[[i, 0]];
cy += group[[i, 1]];
}
cx /= n as f64;
cy /= n as f64;
let mut radius = 0.0_f64;
let mut radii = vec![0.0_f64; n];
for i in 0..n {
let dx = group[[i, 0]] - cx;
let dy = group[[i, 1]] - cy;
let r = (dx * dx + dy * dy).sqrt();
radii[i] = r;
radius += r;
}
radius /= n as f64;
let mut var_r = 0.0_f64;
for &r in &radii {
let e = r - radius;
var_r += e * e;
}
var_r = (var_r / n as f64).max(config.covariance_floor);
out.push(UnionComponentDensity::Circle {
log_weight,
center: [cx, cy],
radius,
var_r,
});
}
}
}
Ok(out)
}
pub fn union_per_point_log_density(
train: ArrayView2<'_, f64>,
eval: ArrayView2<'_, f64>,
structure: UnionStructure,
config: GaussianMixtureConfig,
) -> Result<Array1<f64>, String> {
if train.ncols() != eval.ncols() {
return Err(format!(
"union held-out density: train has {} columns, eval has {}",
train.ncols(),
eval.ncols()
));
}
let densities = fit_union_component_densities(train, structure, config)?;
let mut out = Array1::<f64>::zeros(eval.nrows());
let mut terms = vec![f64::NEG_INFINITY; densities.len()];
for i in 0..eval.nrows() {
let row = eval.row(i);
let mut max_term = f64::NEG_INFINITY;
for (c, dens) in densities.iter().enumerate() {
let lt = dens.weighted_log_density(row);
terms[c] = lt;
if lt > max_term {
max_term = lt;
}
}
out[i] = log_sum_exp(&terms, max_term);
}
Ok(out)
}
#[derive(Clone, Debug)]
pub struct RemlCandidate {
pub index: usize,
pub name: String,
pub score: f64,
pub edf: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct RemlComparison {
pub ranking: Vec<RankedRow>,
pub winner: String,
pub evidence_summary: String,
pub score_table: Vec<ScoreRow>,
}
#[derive(Clone, Debug)]
pub struct RankedRow {
pub name: String,
pub score: f64,
pub delta: f64,
pub bayes_factor: f64,
pub edf: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct ScoreRow {
pub name: String,
pub reml_score: f64,
pub delta_reml: f64,
pub bayes_factor_best_over_model: f64,
pub effective_dof: Option<f64>,
}
#[inline]
pub fn log_bayes_factor(reml_score_a: f64, reml_score_b: f64) -> f64 {
reml_score_b - reml_score_a
}
pub fn compare_reml_fits(mut candidates: Vec<RemlCandidate>) -> Result<RemlComparison, String> {
if candidates.is_empty() {
return Err("compare_models requires at least one fit".to_string());
}
candidates = rank_priority_candidates(
candidates
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.score;
PriorityCandidate::new(row, idx, score, 0)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
let best_score = candidates[0].score;
let winner = candidates[0].name.clone();
let mut ranking = Vec::with_capacity(candidates.len());
let mut score_table = Vec::with_capacity(candidates.len());
for row in &candidates {
let delta = log_bayes_factor(best_score, row.score);
let bayes_factor = delta.exp();
ranking.push(RankedRow {
name: row.name.clone(),
score: row.score,
delta,
bayes_factor,
edf: row.edf,
});
score_table.push(ScoreRow {
name: row.name.clone(),
reml_score: row.score,
delta_reml: delta,
bayes_factor_best_over_model: bayes_factor,
effective_dof: row.edf,
});
}
let evidence_summary = if let Some(runner_up) = candidates.get(1) {
format!(
"{} wins by Bayes factor {} over {}",
winner,
format_bayes_factor(log_bayes_factor(best_score, runner_up.score)),
runner_up.name
)
} else {
format!("{winner} (single fit; no comparison)")
};
Ok(RemlComparison {
ranking,
winner,
evidence_summary,
score_table,
})
}
pub fn format_bayes_factor(log_bf: f64) -> String {
if !log_bf.is_finite() {
return "inf".to_string();
}
if log_bf.abs() >= std::f64::consts::LN_10 * 3.0 {
return format!("1e{:+.1}", log_bf / std::f64::consts::LN_10);
}
format_three_significant(log_bf.exp())
}
pub fn format_three_significant(value: f64) -> String {
if value == 0.0 {
return "0".to_string();
}
if !value.is_finite() {
return format!("{value}");
}
let exponent = value.abs().log10().floor() as i32;
if exponent >= 3 {
return format!("{value:.2e}");
}
let decimals = (2 - exponent).max(0) as usize;
let scale = 10f64.powi(decimals as i32);
let rounded = (value * scale).abs().round() / scale * value.signum();
format!("{rounded:.decimals$}")
}
impl Default for TopologySelectOptions {
fn default() -> Self {
Self {
tie_tolerance: 1e-3,
score_scale: TopologyScoreScale::PerObservation,
}
}
}
pub fn laplace_evidence(
logdet_source: EvidenceLogDetSource<'_>,
penalty_log_det: f64,
residual_objective: f64,
effective_dim: f64,
penalty_rank: f64,
) -> f64 {
if !(effective_dim.is_finite() && penalty_rank.is_finite()) {
return f64::NAN;
}
let log_det_h = match evidence_hessian_log_det(logdet_source) {
Ok(v) => v,
Err(_) => return f64::NAN,
};
let null_dim = effective_dim - penalty_rank;
if !null_dim.is_finite() || null_dim < -1e-9 {
return f64::NAN;
}
residual_objective + 0.5 * log_det_h
- 0.5 * penalty_log_det
- 0.5 * null_dim.max(0.0) * (2.0 * std::f64::consts::PI).ln()
}
pub fn evidence_hessian_log_det(source: EvidenceLogDetSource<'_>) -> Result<f64, String> {
match source {
EvidenceLogDetSource::FactoredArrow {
cache,
fallback_hvp,
} => match arrow_log_det_from_cache(cache) {
Some(v) => Ok(v),
None => match fallback_hvp {
Some(hvp) => hessian_log_det_from_hvp(hvp),
None => {
Err("evidence Hessian logdet requires exact factors or HVP fallback".into())
}
},
},
EvidenceLogDetSource::Hvp(hvp) => hessian_log_det_from_hvp(hvp),
}
}
pub fn hessian_log_det_from_hvp(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
if hvp.dim == 0 {
return Ok(0.0);
}
if hvp.dim <= ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD {
let mut dense = Array2::<f64>::zeros((hvp.dim, hvp.dim));
let mut basis = vec![0.0_f64; hvp.dim];
for j in 0..hvp.dim {
basis[j] = 1.0;
let col = (hvp.apply)(&basis);
basis[j] = 0.0;
if col.len() != hvp.dim || col.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP logdet expected finite column of length {}, got {}",
hvp.dim,
col.len()
));
}
for i in 0..hvp.dim {
dense[[i, j]] = col[i];
}
}
validate_dense_hvp_symmetry(&dense)?;
for i in 0..hvp.dim {
for j in (i + 1)..hvp.dim {
let avg = 0.5 * (dense[[i, j]] + dense[[j, i]]);
dense[[i, j]] = avg;
dense[[j, i]] = avg;
}
}
dense_spd_log_det(&dense)
} else {
stochastic_hvp_log_det(hvp)
}
}
fn dense_spd_log_det(matrix: &Array2<f64>) -> Result<f64, String> {
if matrix.nrows() != matrix.ncols() {
return Err(format!(
"evidence dense logdet requires square matrix, got {}x{}",
matrix.nrows(),
matrix.ncols()
));
}
if crate::gpu::cuda_selected() {
return crate::solver::gpu::reml_gpu::evidence_derivatives_gpu(
crate::solver::gpu::reml_gpu::RemlGpuInput {
penalized_hessian: matrix.view(),
derivative_hessians: Vec::new(),
},
)
.map(|evidence| evidence.logdet_hessian);
}
let (evals, _) = matrix
.eigh(Side::Lower)
.map_err(|e| format!("evidence dense logdet eigendecomposition failed: {e}"))?;
let mut logdet = 0.0_f64;
for (idx, &ev) in evals.iter().enumerate() {
if !ev.is_finite() || ev <= 0.0 {
return Err(format!(
"evidence dense logdet expected SPD Hessian, eigenvalue {idx} is {ev:.3e}"
));
}
logdet += ev.ln();
}
Ok(logdet)
}
fn validate_dense_hvp_symmetry(matrix: &Array2<f64>) -> Result<(), String> {
let n = matrix.nrows();
let mut norm_sq = 0.0_f64;
for &value in matrix.iter() {
norm_sq += value * value;
}
let mut skew_sq = 0.0_f64;
for i in 0..n {
for j in (i + 1)..n {
let skew = matrix[[i, j]] - matrix[[j, i]];
skew_sq += 2.0 * skew * skew;
}
}
let rel_skew = skew_sq.sqrt() / norm_sq.sqrt().max(1.0);
if !rel_skew.is_finite() || rel_skew > EVIDENCE_HVP_SYMMETRY_REL_TOL {
return Err(format!(
"evidence HVP logdet requires symmetric operator, relative skew norm is {rel_skew:.3e}"
));
}
Ok(())
}
fn validate_hvp_randomized_symmetry(hvp: EvidenceHvpLogDet<'_>) -> Result<(), String> {
let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
for probe in 0..EVIDENCE_HVP_SYMMETRY_PROBES.max(1) {
let mut x = vec![0.0_f64; hvp.dim];
let mut y = vec![0.0_f64; hvp.dim];
rademacher_unit_probe_into_slice(&mut x, (2 * probe) as u64, inv_norm);
rademacher_unit_probe_into_slice(&mut y, (2 * probe + 1) as u64, inv_norm);
let hx = (hvp.apply)(&x);
let hy = (hvp.apply)(&y);
if hx.len() != hvp.dim || hx.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP symmetry check expected finite vector of length {}, got {}",
hvp.dim,
hx.len()
));
}
if hy.len() != hvp.dim || hy.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP symmetry check expected finite vector of length {}, got {}",
hvp.dim,
hy.len()
));
}
let lhs = dot_slice(&x, &hy);
let rhs = dot_slice(&hx, &y);
let scale = (norm2_slice(&hx) * norm2_slice(&y))
.max(norm2_slice(&hy) * norm2_slice(&x))
.max(lhs.abs())
.max(rhs.abs())
.max(1.0);
let rel = (lhs - rhs).abs() / scale;
if !rel.is_finite() || rel > EVIDENCE_HVP_SYMMETRY_REL_TOL {
return Err(format!(
"evidence HVP logdet requires symmetric operator, randomized symmetry probe {probe} has relative bilinear mismatch {rel:.3e}"
));
}
}
Ok(())
}
fn stochastic_hvp_log_det(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
validate_hvp_randomized_symmetry(hvp)?;
let probes = EVIDENCE_LOGDET_SLQ_PROBES.max(1);
let steps = EVIDENCE_LOGDET_LANCZOS_STEPS.min(hvp.dim).max(1);
let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
let mut estimate = 0.0_f64;
for probe in 0..probes {
let mut q0 = vec![0.0_f64; hvp.dim];
rademacher_unit_probe_into_slice(&mut q0, probe as u64, inv_norm);
let quad = lanczos_log_quadrature_hvp(hvp, q0, steps)?;
estimate += hvp.dim as f64 * quad;
}
Ok(estimate / probes as f64)
}
fn lanczos_log_quadrature_hvp(
hvp: EvidenceHvpLogDet<'_>,
q: Vec<f64>,
max_steps: usize,
) -> Result<f64, String> {
let n = hvp.dim;
let eigen = symmetric_lanczos_eigenpairs(
n,
&q,
SymmetricLanczosOptions {
max_steps,
residual_tol: 1e-12,
local_reorthogonalize: false,
full_reorthogonalize: false,
},
|q, out| {
let applied = (hvp.apply)(q);
if applied.len() != n || applied.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP SLQ expected finite vector of length {n}, got {}",
applied.len()
));
}
out.copy_from_slice(&applied);
Ok(())
},
)
.map_err(|e| format!("evidence HVP SLQ Lanczos failed: {e}"))?;
symmetric_lanczos_log_quadrature(&eigen, "evidence HVP SLQ expected SPD Hessian")
}
#[inline]
fn dot_slice(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len());
let mut s = 0.0_f64;
for i in 0..a.len() {
s += a[i] * b[i];
}
s
}
#[inline]
fn norm2_slice(a: &[f64]) -> f64 {
dot_slice(a, a).sqrt()
}
fn rademacher_unit_probe_into_slice(z: &mut [f64], probe: u64, scale: f64) {
let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
let mut bits = 0_u64;
let mut remaining_bits = 0_u32;
for value in z.iter_mut() {
if remaining_bits == 0 {
bits = splitmix64(&mut state);
remaining_bits = 64;
}
*value = if bits & 1 == 0 { scale } else { -scale };
bits >>= 1;
remaining_bits -= 1;
}
}
#[inline]
const fn splitmix64(state: &mut u64) -> u64 {
crate::linalg::utils::splitmix64(state)
}
pub fn arrow_log_det_from_cache(cache: &ArrowFactorCache) -> Option<f64> {
if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
return None;
}
let schur = cache.schur_factor.as_ref()?;
let mut acc = 0.0_f64;
for l in cache.undamped_factors_iter() {
acc += 2.0 * log_det_from_chol_lower(l);
}
acc += 2.0 * log_det_from_chol_lower(schur.view());
Some(acc)
}
fn log_det_from_chol_lower(l: ArrayView2<'_, f64>) -> f64 {
let n = l.nrows();
let mut acc = 0.0_f64;
for i in 0..n {
let d = l[[i, i]];
if d > 0.0 {
acc += d.ln();
} else {
return f64::NAN;
}
}
acc
}
pub fn ift_du_dbeta(cache: &ArrowFactorCache) -> Array2<f64> {
let n = cache.undamped_factor_count();
let total_len = cache.delta_t_len();
let k = cache.k;
if !cache.htbeta_available() {
return Array2::<f64>::from_elem((total_len, k), f64::NAN);
}
let mut out = Array2::<f64>::zeros((total_len, k));
let mut beta_basis = Array1::<f64>::zeros(k);
let mut rhs = Array1::<f64>::zeros(cache.d);
for i in 0..n {
let di = cache.row_dims[i];
let row_base = cache.row_offsets[i];
let factor = cache.undamped_factor(i);
for col in 0..k {
beta_basis.fill(0.0);
beta_basis[col] = 1.0;
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
return Array2::<f64>::from_elem((total_len, k), f64::NAN);
}
let y = cholesky_solve_vector(factor, &rhs_i);
for c in 0..di {
out[[row_base + c, col]] = -y[c];
}
}
}
out
}
pub fn coupling_components(hessian: ArrayView2<'_, f64>) -> Vec<usize> {
let p = hessian.nrows();
if p == 0 || hessian.ncols() != p {
return Vec::new();
}
let mut parent: Vec<usize> = (0..p).collect();
let mut size: Vec<usize> = vec![1; p];
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]];
x = parent[x];
}
x
}
for i in 0..p {
for j in (i + 1)..p {
if hessian[[i, j]] != 0.0 || hessian[[j, i]] != 0.0 {
let (ri, rj) = (find(&mut parent, i), find(&mut parent, j));
if ri != rj {
let (small, large) = if size[ri] < size[rj] {
(ri, rj)
} else {
(rj, ri)
};
parent[small] = large;
size[large] += size[small];
}
}
}
}
let mut label_of_root: Vec<Option<usize>> = vec![None; p];
let mut next_label = 0usize;
let mut labels = vec![0usize; p];
for idx in 0..p {
let root = find(&mut parent, idx);
let label = match label_of_root[root] {
Some(l) => l,
None => {
let l = next_label;
label_of_root[root] = Some(l);
next_label += 1;
l
}
};
labels[idx] = label;
}
labels
}
pub fn cone_of_influence(labels: &[usize], support: &[usize]) -> Vec<usize> {
if support.is_empty() {
return Vec::new();
}
let mut in_cone_labels: Vec<usize> = support
.iter()
.filter_map(|&idx| labels.get(idx).copied())
.collect();
in_cone_labels.sort_unstable();
in_cone_labels.dedup();
if in_cone_labels.is_empty() {
return Vec::new();
}
(0..labels.len())
.filter(|idx| in_cone_labels.binary_search(&labels[*idx]).is_ok())
.collect()
}
pub fn ift_dbeta_drho(
cache: &ArrowFactorCache,
dg_red_drho: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
return None;
}
let schur = cache.schur_factor.as_ref()?;
if dg_red_drho.nrows() != cache.k || schur.nrows() != cache.k {
return None;
}
crate::solver::sensitivity::FitSensitivity::from_lower_triangular(schur)
.mode_response(dg_red_drho)
}
pub fn ift_du_drho(
cache: &ArrowFactorCache,
gu_rho: ArrayView2<'_, f64>,
dbeta_drho: ArrayView2<'_, f64>,
) -> Array2<f64> {
let n = cache.undamped_factor_count();
let total_len = cache.delta_t_len();
let k = cache.k;
let r = dbeta_drho.ncols();
if !cache.htbeta_available()
|| gu_rho.nrows() != total_len
|| gu_rho.ncols() != r
|| dbeta_drho.nrows() != k
{
return Array2::<f64>::from_elem((total_len, r), f64::NAN);
}
let mut out = Array2::<f64>::zeros((total_len, r));
let mut rhs = Array1::<f64>::zeros(cache.d);
let mut htbeta_delta = Array1::<f64>::zeros(cache.d);
for a in 0..r {
for i in 0..n {
let di = cache.row_dims[i];
let row_base = cache.row_offsets[i];
let mut htbeta_i = htbeta_delta.slice_mut(ndarray::s![..di]).to_owned();
if !cache.apply_htbeta_row(i, dbeta_drho.column(a), &mut htbeta_i) {
return Array2::<f64>::from_elem((total_len, r), f64::NAN);
}
{
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]);
for c in 0..di {
rhs_i[c] = gu_rho[[row_base + c, a]] + htbeta_i[c];
}
}
let rhs_slice = rhs.slice(ndarray::s![..di]).to_owned();
let v = cholesky_solve_vector(cache.undamped_factor(i), &rhs_slice);
for c in 0..di {
out[[row_base + c, a]] = -v[c];
}
}
}
out
}
#[derive(Clone)]
pub struct EvidenceIftGradientTerms<'a> {
pub dbeta_drho: ArrayView2<'a, f64>,
pub du_drho: ArrayView2<'a, f64>,
pub value_beta: ArrayView1<'a, f64>,
pub value_u: ArrayView1<'a, f64>,
pub logdet_h_beta: ArrayView1<'a, f64>,
pub logdet_h_u: ArrayView1<'a, f64>,
}
pub fn evidence_ift_gradient_correction(terms: EvidenceIftGradientTerms<'_>) -> Array1<f64> {
let k = terms.dbeta_drho.nrows();
let nd = terms.du_drho.nrows();
let r = terms.dbeta_drho.ncols();
if terms.du_drho.ncols() != r
|| terms.value_beta.len() != k
|| terms.logdet_h_beta.len() != k
|| terms.value_u.len() != nd
|| terms.logdet_h_u.len() != nd
{
return Array1::<f64>::from_elem(r, f64::NAN);
}
let mut out = Array1::<f64>::zeros(r);
for a in 0..r {
let mut acc = 0.0_f64;
for j in 0..k {
let mode = terms.dbeta_drho[[j, a]];
acc += terms.value_beta[j] * mode;
acc += 0.5 * terms.logdet_h_beta[j] * mode;
}
for j in 0..nd {
let mode = terms.du_drho[[j, a]];
acc += terms.value_u[j] * mode;
acc += 0.5 * terms.logdet_h_u[j] * mode;
}
out[a] = acc;
}
out
}
pub fn evidence_grad_rho(
cache: &ArrowFactorCache,
value_rho: ArrayView1<'_, f64>,
huu_drho: &[Vec<Array2<f64>>],
htbeta_drho: &[Vec<Array2<f64>>],
hbb_drho: &[Array2<f64>],
pen_logdet_drho: ArrayView1<'_, f64>,
ift_terms: EvidenceIftGradientTerms<'_>,
) -> Array1<f64> {
let r = value_rho.len();
let n = cache.undamped_factor_count();
let k = cache.k;
let mut out = Array1::<f64>::zeros(r);
if !cache.htbeta_available()
|| pen_logdet_drho.len() != r
|| huu_drho.len() != n
|| htbeta_drho.len() != n
|| hbb_drho.len() != r
|| huu_drho.iter().any(|row| row.len() != r)
|| htbeta_drho.iter().any(|row| row.len() != r)
|| hbb_drho.iter().any(|m| m.nrows() != k || m.ncols() != k)
|| huu_drho.iter().enumerate().any(|(i, row)| {
let di = cache.row_dims[i];
row.iter().any(|m| m.nrows() != di || m.ncols() != di)
})
|| htbeta_drho.iter().enumerate().any(|(i, row)| {
let di = cache.row_dims[i];
row.iter().any(|m| m.nrows() != di || m.ncols() != k)
})
{
out.fill(f64::NAN);
return out;
}
let ift_correction = evidence_ift_gradient_correction(ift_terms);
if ift_correction.len() != r || ift_correction.iter().any(|v| v.is_nan()) {
out.fill(f64::NAN);
return out;
}
let schur = match cache.schur_factor.as_ref() {
Some(s) => s,
None => {
for a in 0..r {
out[a] = f64::NAN;
}
return out;
}
};
let mut y_blocks: Vec<Array2<f64>> = Vec::with_capacity(n);
let mut beta_basis = Array1::<f64>::zeros(k);
let mut rhs = Array1::<f64>::zeros(cache.d);
for i in 0..n {
let di = cache.row_dims[i];
let factor = cache.undamped_factor(i);
let mut yi = Array2::<f64>::zeros((di, k));
for col in 0..k {
beta_basis.fill(0.0);
beta_basis[col] = 1.0;
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
out.fill(f64::NAN);
return out;
}
let v = cholesky_solve_vector(factor, &rhs_i);
for c in 0..di {
yi[[c, col]] = v[c];
}
}
y_blocks.push(yi);
}
let mut trace_rhs = Array1::<f64>::zeros(cache.d);
let mut da_tmp = Array2::<f64>::zeros((cache.d, k));
let mut col_scratch = Array1::<f64>::zeros(k);
for a in 0..r {
let mut grad = value_rho[a];
let mut row_trace_acc = 0.0_f64;
for i in 0..n {
let di = cache.row_dims[i];
let m_i = &huu_drho[i][a];
assert_eq!(m_i.shape(), &[di, di]);
for col in 0..di {
let mut tr_rhs_i = trace_rhs.slice_mut(ndarray::s![..di]).to_owned();
for r0 in 0..di {
tr_rhs_i[r0] = m_i[[r0, col]];
}
let v = cholesky_solve_vector(cache.undamped_factor(i), &tr_rhs_i);
row_trace_acc += v[col];
}
}
let mut da = hbb_drho[a].clone();
assert_eq!(da.shape(), &[k, k]);
for i in 0..n {
let di = cache.row_dims[i];
let dhtb = &htbeta_drho[i][a]; let yi = &y_blocks[i]; for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += dhtb[[cc, r0]] * yi[[cc, c0]];
}
da[[r0, c0]] -= acc;
}
}
for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += yi[[cc, r0]] * dhtb[[cc, c0]];
}
da[[r0, c0]] -= acc;
}
}
let dhuu = &huu_drho[i][a];
let mut da_tmp_i = da_tmp.slice_mut(ndarray::s![..di, ..]).to_owned();
for r0 in 0..di {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += dhuu[[r0, cc]] * yi[[cc, c0]];
}
da_tmp_i[[r0, c0]] = acc;
}
}
for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += yi[[cc, r0]] * da_tmp_i[[cc, c0]];
}
da[[r0, c0]] += acc;
}
}
}
let mut schur_trace_acc = 0.0_f64;
for j in 0..k {
for r0 in 0..k {
col_scratch[r0] = da[[r0, j]];
}
let v = cholesky_solve_vector(schur, &col_scratch);
schur_trace_acc += v[j];
}
grad += 0.5 * (row_trace_acc + schur_trace_acc);
grad += ift_correction[a];
grad -= 0.5 * pen_logdet_drho[a];
out[a] = grad;
}
out
}
pub fn select_topology(
candidates: &[TopologyCandidate],
options: TopologySelectOptions,
) -> SelectedTopology {
let mut valid: Vec<TopologyCandidate> = candidates
.iter()
.filter(|c| {
c.converged
&& c.exclusion_reason.is_none()
&& c.negative_log_evidence.is_finite()
&& topology_selection_score(c, options.score_scale).is_finite()
})
.cloned()
.collect();
let mut excluded: Vec<TopologyCandidate> = candidates
.iter()
.filter(|c| {
!(c.converged && c.exclusion_reason.is_none() && c.negative_log_evidence.is_finite())
|| !topology_selection_score(c, options.score_scale).is_finite()
})
.cloned()
.collect();
assert!(
!valid.is_empty(),
"select_topology: no finite valid candidates; proposal §6.11 forbids silent fallback"
);
valid = rank_priority_candidates(
valid
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = topology_selection_score(&row, options.score_scale);
let tie_break = usize::from(row.kind.complexity_rank());
PriorityCandidate::new(row, idx, score, tie_break)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
let tie = if valid.len() >= 2 {
let top = topology_selection_score(&valid[0], options.score_scale);
let next = topology_selection_score(&valid[1], options.score_scale);
(next - top).abs() <= options.tie_tolerance
} else {
false
};
if tie {
let top_score = topology_selection_score(&valid[0], options.score_scale);
let tied_end = valid
.iter()
.position(|c| {
(topology_selection_score(c, options.score_scale) - top_score).abs()
> options.tie_tolerance
})
.unwrap_or(valid.len());
valid[..tied_end].sort_by_key(|c| c.kind.complexity_rank());
}
let winner = valid[0].kind;
valid.append(&mut excluded);
SelectedTopology {
winner,
ranking: valid,
tie,
}
}
fn topology_selection_score(candidate: &TopologyCandidate, scale: TopologyScoreScale) -> f64 {
match scale {
TopologyScoreScale::PerObservation => {
if candidate.n_obs == 0 {
f64::NAN
} else {
candidate.negative_log_evidence / candidate.n_obs as f64
}
}
TopologyScoreScale::PerEffectiveDim => {
if !(candidate.effective_dim.is_finite() && candidate.effective_dim > 0.0) {
f64::NAN
} else {
candidate.negative_log_evidence / candidate.effective_dim
}
}
}
}
pub fn cache_supports_exact_evidence(cache: &ArrowFactorCache) -> bool {
cache.ridge_t == 0.0
&& cache.ridge_beta == 0.0
&& cache.schur_factor.is_some()
&& cache.htbeta_available()
}
pub fn cache_matches_system(cache: &ArrowFactorCache, sys: &ArrowSchurSystem) -> bool {
cache.d == sys.d
&& cache.k == sys.k
&& cache.n_rows() == sys.rows.len()
&& cache.undamped_factor_count() == sys.rows.len()
&& cache.manifold_mode_fingerprint == sys.manifold_mode_fingerprint
&& cache.row_hessian_fingerprint == sys.current_row_hessian_fingerprint()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::arrow_schur::ArrowFactorSlab;
fn dense_inverse(h: &Array2<f64>) -> Array2<f64> {
let p = h.nrows();
let mut aug = Array2::<f64>::zeros((p, 2 * p));
for i in 0..p {
for j in 0..p {
aug[[i, j]] = h[[i, j]];
}
aug[[i, p + i]] = 1.0;
}
for col in 0..p {
let mut pivot = col;
for row in (col + 1)..p {
if aug[[row, col]].abs() > aug[[pivot, col]].abs() {
pivot = row;
}
}
if pivot != col {
for j in 0..(2 * p) {
aug.swap([col, j], [pivot, j]);
}
}
let d = aug[[col, col]];
for j in 0..(2 * p) {
aug[[col, j]] /= d;
}
for row in 0..p {
if row == col {
continue;
}
let f = aug[[row, col]];
if f != 0.0 {
for j in 0..(2 * p) {
aug[[row, j]] -= f * aug[[col, j]];
}
}
}
}
let mut inv = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
inv[[i, j]] = aug[[i, p + j]];
}
}
inv
}
#[test]
fn coupling_components_block_diagonal_is_all_singletons_by_block() {
let mut h = Array2::<f64>::eye(4);
h[[0, 1]] = 0.3;
h[[1, 0]] = 0.3;
h[[2, 3]] = 0.7;
h[[3, 2]] = 0.7;
let labels = coupling_components(h.view());
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
let mut uniq = labels.clone();
uniq.sort_unstable();
uniq.dedup();
assert_eq!(uniq.len(), 2);
}
#[test]
fn coupling_components_fully_coupled_is_one_component() {
let mut h = Array2::<f64>::eye(3);
for i in 0..3 {
for j in 0..3 {
if i != j {
h[[i, j]] = 0.1;
}
}
}
let labels = coupling_components(h.view());
assert!(labels.iter().all(|&l| l == labels[0]));
}
#[test]
fn coupling_components_transitive_chain_merges() {
let mut h = Array2::<f64>::eye(3);
h[[0, 1]] = 0.5;
h[[1, 0]] = 0.5;
h[[1, 2]] = 0.5;
h[[2, 1]] = 0.5;
let labels = coupling_components(h.view());
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[1], labels[2]);
}
#[test]
fn cone_of_influence_empty_support_is_empty() {
let labels = vec![0usize, 0, 1, 1];
assert!(cone_of_influence(&labels, &[]).is_empty());
}
#[test]
fn cone_of_influence_returns_full_component() {
let labels = vec![0usize, 0, 1, 1];
assert_eq!(cone_of_influence(&labels, &[0]), vec![0, 1]);
assert_eq!(cone_of_influence(&labels, &[1, 2]), vec![0, 1, 2, 3]);
}
#[test]
fn coned_matches_full_solve_on_fully_coupled_hessian() {
let h = Array2::from_shape_vec((3, 3), vec![4.0, 1.0, 0.5, 1.0, 3.0, 0.8, 0.5, 0.8, 2.5])
.unwrap();
let inv = dense_inverse(&h);
let mut dg = Array2::<f64>::zeros((3, 2));
dg[[0, 0]] = 1.3;
dg[[2, 1]] = -0.7;
let supports = vec![0..1usize, 2..3usize];
let eye: Array2<f64> = Array2::eye(3);
let op = crate::solver::sensitivity::FitSensitivity::from_projected(&eye, &inv);
let full = op.mode_response(dg.view()).unwrap();
let coned = op
.mode_response_coned(h.view(), dg.view(), &supports)
.unwrap();
for i in 0..3 {
for a in 0..2 {
assert!(
(full[[i, a]] - coned[[i, a]]).abs() < 1e-12,
"fully-coupled mismatch at ({i},{a}): {} vs {}",
full[[i, a]],
coned[[i, a]]
);
}
}
}
#[test]
fn coned_confines_to_component_on_decoupled_hessian() {
let mut h = Array2::<f64>::zeros((4, 4));
h[[0, 0]] = 4.0;
h[[1, 1]] = 3.0;
h[[0, 1]] = 1.0;
h[[1, 0]] = 1.0;
h[[2, 2]] = 2.0;
h[[3, 3]] = 5.0;
h[[2, 3]] = 0.6;
h[[3, 2]] = 0.6;
let inv = dense_inverse(&h);
let mut dg = Array2::<f64>::zeros((4, 1));
dg[[0, 0]] = 0.9;
dg[[1, 0]] = -0.4;
let support_range = 0..2usize;
let supports = std::slice::from_ref(&support_range);
let eye: Array2<f64> = Array2::eye(4);
let coned = crate::solver::sensitivity::FitSensitivity::from_projected(&eye, &inv)
.mode_response_coned(h.view(), dg.view(), supports)
.unwrap();
let q = dg.column(0).to_owned();
let exact = inv.dot(&q).mapv(|v| -v);
for i in 0..4 {
assert!(
(coned[[i, 0]] - exact[[i]]).abs() < 1e-12,
"decoupled mismatch at {i}: {} vs {}",
coned[[i, 0]],
exact[[i]]
);
}
assert_eq!(coned[[2, 0]], 0.0);
assert_eq!(coned[[3, 0]], 0.0);
}
#[test]
fn coned_skips_inactive_column_with_empty_support() {
let h = Array2::<f64>::eye(2);
let dg = Array2::<f64>::zeros((2, 1));
let empty_support = 0..0usize;
let supports = std::slice::from_ref(&empty_support);
let eye: Array2<f64> = Array2::eye(2);
let nan_inv = Array2::<f64>::from_elem((2, 2), f64::NAN);
let coned = crate::solver::sensitivity::FitSensitivity::from_projected(&eye, &nan_inv)
.mode_response_coned(h.view(), dg.view(), supports)
.unwrap();
assert_eq!(coned[[0, 0]], 0.0);
assert_eq!(coned[[1, 0]], 0.0);
}
fn make_minimal_cache() -> ArrowFactorCache {
let l_huu = Array2::from_shape_vec((1, 1), vec![std::f64::consts::SQRT_2]).unwrap();
let l_schur = Array2::from_shape_vec((1, 1), vec![(1.875_f64).sqrt()]).unwrap();
let htbeta = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
ArrowFactorCache {
htt_factors: ArrowFactorSlab::from_blocks(vec![l_huu]),
htt_factors_undamped: crate::solver::arrow_schur::ArrowUndampedFactors::SameAsDamped,
schur_factor: Some(l_schur),
solver_mode: crate::solver::arrow_schur::ArrowSolverMode::Direct,
ridge_t: 0.0,
ridge_beta: 0.0,
htbeta: crate::solver::arrow_schur::ArrowHtbetaCache::Dense {
blocks: std::sync::Arc::from(vec![htbeta]),
estimated_bytes: std::mem::size_of::<f64>(),
},
d: 1,
row_dims: std::sync::Arc::from(vec![1usize]),
row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
k: 1,
manifold_mode_fingerprint: 0,
row_hessian_fingerprint: 0,
pcg_diagnostics: crate::solver::arrow_schur::PcgDiagnostics::default(),
gauge_deflated_directions: 0,
}
}
#[test]
fn laplace_evidence_returns_finite_for_minimal_cache() {
let cache = make_minimal_cache();
let v = laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: None,
},
0.0,
0.0,
2.0,
1.0,
);
assert!(v.is_finite());
let expected =
0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
assert!((v - expected).abs() < 1e-12);
}
#[test]
fn laplace_evidence_nan_when_ridge_is_nonzero() {
let mut cache = make_minimal_cache();
cache.ridge_t = 1e-3;
assert!(
laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: None,
},
0.0,
0.0,
2.0,
1.0,
)
.is_nan()
);
}
#[test]
fn laplace_evidence_uses_hvp_fallback_without_schur_factor() {
let mut cache = make_minimal_cache();
cache.schur_factor = None;
let hvp = |x: &[f64]| -> Vec<f64> { vec![2.0 * x[0], 1.875 * x[1]] };
let v = laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: Some(EvidenceHvpLogDet {
dim: 2,
apply: &hvp,
}),
},
0.0,
0.0,
2.0,
1.0,
);
let expected =
0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
assert!((v - expected).abs() < 1e-12);
}
#[test]
fn ift_du_dbeta_has_expected_shape() {
let cache = make_minimal_cache();
let du_db = ift_du_dbeta(&cache);
assert_eq!(du_db.shape(), &[1, 1]);
assert!((du_db[[0, 0]] - (-0.25)).abs() < 1e-12);
}
#[test]
fn ift_dbeta_drho_returns_some_for_direct_cache() {
let cache = make_minimal_cache();
let q = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
let out = ift_dbeta_drho(&cache, q.view()).unwrap();
assert_eq!(out.shape(), &[1, 1]);
assert!((out[[0, 0]] + 1.0 / 1.875).abs() < 1e-12);
}
#[test]
fn topology_select_picks_lowest_negative_log_evidence() {
let candidates = vec![
TopologyCandidate {
kind: TopologyKind::Flat,
negative_log_evidence: 10.0,
effective_dim: 4.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Sphere,
negative_log_evidence: 8.0,
effective_dim: 5.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Torus,
negative_log_evidence: f64::NAN,
effective_dim: 6.0,
n_obs: 100,
converged: false,
exclusion_reason: Some("torus periods missing".to_string()),
},
];
let sel = select_topology(&candidates, TopologySelectOptions::default());
assert_eq!(sel.winner, TopologyKind::Sphere);
assert!(!sel.tie);
}
#[test]
fn topology_select_tie_breaks_to_simpler() {
let candidates = vec![
TopologyCandidate {
kind: TopologyKind::Sphere,
negative_log_evidence: 5.0,
effective_dim: 5.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Flat,
negative_log_evidence: 5.0 + 1e-6,
effective_dim: 4.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
];
let sel = select_topology(&candidates, TopologySelectOptions::default());
assert_eq!(sel.winner, TopologyKind::Flat);
assert!(sel.tie);
}
fn gaussian_logpdf(y: f64, mean: f64, sd: f64) -> f64 {
let z = (y - mean) / sd;
-0.5 * (2.0 * std::f64::consts::PI).ln() - sd.ln() - 0.5 * z * z
}
#[test]
fn stacking_single_candidate_gets_full_weight() {
let log_density = Array2::from_shape_vec((3, 1), vec![-1.0, -2.0, -0.5]).unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!((out.weights[0] - 1.0).abs() < 1e-12);
assert_eq!(out.weights.len(), 1);
}
#[test]
fn stacking_dominant_candidate_attracts_nearly_all_weight() {
let mut log_density = Array2::<f64>::zeros((50, 2));
for i in 0..50 {
log_density[[i, 0]] = -0.1;
log_density[[i, 1]] = -5.0;
}
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!(out.weights[0] > 0.99, "w0 = {}", out.weights[0]);
assert!(out.weights[1] < 0.01, "w1 = {}", out.weights[1]);
}
#[test]
fn stacking_complementary_candidates_share_weight() {
let n = 40;
let mut log_density = Array2::<f64>::zeros((n, 2));
for i in 0..n {
if i < n / 2 {
log_density[[i, 0]] = gaussian_logpdf(0.0, 0.0, 0.5);
log_density[[i, 1]] = gaussian_logpdf(0.0, 1.5, 0.5);
} else {
log_density[[i, 0]] = gaussian_logpdf(0.0, 1.5, 0.5);
log_density[[i, 1]] = gaussian_logpdf(0.0, 0.0, 0.5);
}
}
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!(
out.weights[0] > 0.2 && out.weights[0] < 0.8,
"w0 = {}",
out.weights[0]
);
assert!((out.weights.sum() - 1.0).abs() < 1e-9);
}
#[test]
fn stacking_weights_stay_on_the_simplex() {
let log_density = Array2::from_shape_vec(
(3, 3),
vec![-1.0, -2.0, -3.0, -2.5, -1.0, -2.0, -3.0, -2.0, -1.0],
)
.unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!((out.weights.sum() - 1.0).abs() < 1e-9);
assert!(out.weights.iter().all(|&w| w >= -1e-12));
}
#[test]
fn stacking_mean_log_score_is_monotone_under_more_iterations() {
let log_density =
Array2::from_shape_vec((4, 2), vec![-0.2, -3.0, -3.0, -0.2, -0.5, -1.5, -1.5, -0.5])
.unwrap();
let mut prev = f64::NEG_INFINITY;
for max_iter in [1usize, 2, 4, 8, 32] {
let out = solve_stacking_weights(
log_density.view(),
StackingConfig {
max_iter,
weight_tol: 0.0,
},
)
.unwrap();
assert!(
out.mean_log_score >= prev - 1e-12,
"log-score decreased at max_iter={max_iter}: {prev} -> {}",
out.mean_log_score
);
prev = out.mean_log_score;
}
}
#[test]
fn stacking_dead_candidate_column_is_rejected_and_zero_weighted() {
let log_density = Array2::from_shape_vec(
(3, 2),
vec![
-1.0,
f64::NEG_INFINITY,
-2.0,
f64::NAN,
-0.5,
f64::NEG_INFINITY,
],
)
.unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert_eq!(out.weights[1], 0.0);
assert!((out.weights[0] - 1.0).abs() < 1e-12);
}
#[test]
fn stacking_rows_with_no_finite_density_are_dropped() {
let log_density = Array2::from_shape_vec(
(3, 2),
vec![-1.0, -2.0, f64::NAN, f64::NEG_INFINITY, -2.0, -1.0],
)
.unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!((out.weights.sum() - 1.0).abs() < 1e-9);
assert!(out.mean_log_score.is_finite());
}
#[test]
fn stacking_all_dead_table_errors() {
let log_density = Array2::from_elem((2, 2), f64::NEG_INFINITY);
assert!(solve_stacking_weights(log_density.view(), StackingConfig::default()).is_err());
}
#[test]
fn stacked_mean_is_weighted_combination() {
let weights = Array1::from_vec(vec![0.25, 0.75]);
let means = vec![
Array1::from_vec(vec![1.0, 2.0, 3.0]),
Array1::from_vec(vec![5.0, 6.0, 7.0]),
];
let out = stacked_predictive_mean(&weights, &means).unwrap();
assert!((out[0] - (0.25 * 1.0 + 0.75 * 5.0)).abs() < 1e-12);
assert!((out[2] - (0.25 * 3.0 + 0.75 * 7.0)).abs() < 1e-12);
}
#[test]
fn stacked_mean_rejects_shape_mismatch() {
let weights = Array1::from_vec(vec![0.5, 0.5]);
let means = vec![
Array1::from_vec(vec![1.0, 2.0]),
Array1::from_vec(vec![3.0]),
];
assert!(stacked_predictive_mean(&weights, &means).is_err());
}
}