use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::FaerEigh;
use crate::linalg::lanczos::{
SymmetricLanczosOptions, symmetric_lanczos_eigenpairs, symmetric_lanczos_log_quadrature,
};
use crate::linalg::triangular::cholesky_solve_vector;
use crate::solver::arrow_schur::{ArrowFactorCache, ArrowSchurSystem};
use crate::solver::priority_selection::{PriorityCandidate, rank_priority_candidates};
pub const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
const EVIDENCE_LOGDET_SLQ_PROBES: usize = 16;
const EVIDENCE_LOGDET_LANCZOS_STEPS: usize = 32;
const EVIDENCE_HVP_SYMMETRY_REL_TOL: f64 = 1e-8;
const EVIDENCE_HVP_SYMMETRY_PROBES: usize = 4;
#[derive(Clone, Copy)]
pub struct EvidenceHvpLogDet<'a> {
pub dim: usize,
pub apply: &'a dyn Fn(&[f64]) -> Vec<f64>,
}
#[derive(Clone, Copy)]
pub enum EvidenceLogDetSource<'a> {
FactoredArrow {
cache: &'a ArrowFactorCache,
fallback_hvp: Option<EvidenceHvpLogDet<'a>>,
},
Hvp(EvidenceHvpLogDet<'a>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TopologyKind {
Periodic,
Flat,
Sphere,
Torus,
}
impl TopologyKind {
pub fn complexity_rank(self) -> u8 {
match self {
TopologyKind::Flat => 0,
TopologyKind::Periodic => 1,
TopologyKind::Sphere => 2,
TopologyKind::Torus => 3,
}
}
}
#[derive(Debug, Clone)]
pub struct TopologyCandidate {
pub kind: TopologyKind,
pub negative_log_evidence: f64,
pub effective_dim: f64,
pub n_obs: usize,
pub converged: bool,
pub exclusion_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SelectedTopology {
pub winner: TopologyKind,
pub ranking: Vec<TopologyCandidate>,
pub tie: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct TopologySelectOptions {
pub tie_tolerance: f64,
pub score_scale: TopologyScoreScale,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TopologyScoreScale {
PerObservation,
PerEffectiveDim,
}
#[derive(Debug, Clone, Copy)]
pub struct StackingConfig {
pub max_iter: usize,
pub weight_tol: f64,
}
impl Default for StackingConfig {
fn default() -> Self {
Self {
max_iter: 1000,
weight_tol: 1e-10,
}
}
}
#[derive(Debug, Clone)]
pub struct StackingWeights {
pub weights: Array1<f64>,
pub mean_log_score: f64,
pub iterations: usize,
}
pub fn solve_stacking_weights(
log_density: ArrayView2<'_, f64>,
config: StackingConfig,
) -> Result<StackingWeights, String> {
let n_obs = log_density.nrows();
let n_cand = log_density.ncols();
if n_cand == 0 {
return Err("stacking requires at least one candidate column".to_string());
}
if n_obs == 0 {
return Err("stacking requires at least one held-out observation row".to_string());
}
let kept_cols: Vec<usize> = (0..n_cand)
.filter(|&k| (0..n_obs).any(|i| log_density[[i, k]].is_finite()))
.collect();
if kept_cols.is_empty() {
return Err("stacking found no candidate with any finite held-out density".to_string());
}
let rows: Vec<usize> = (0..n_obs)
.filter(|&i| kept_cols.iter().any(|&k| log_density[[i, k]].is_finite()))
.collect();
if rows.is_empty() {
return Err("stacking found no held-out row with a finite density".to_string());
}
let kept = kept_cols.len();
let mut weights = Array1::<f64>::from_elem(kept, 1.0 / kept as f64);
let mut next = Array1::<f64>::zeros(kept);
let mut iterations = 0usize;
for _ in 0..config.max_iter {
iterations += 1;
next.fill(0.0);
let mut active_rows = 0usize;
for &row in &rows {
let mut row_max = f64::NEG_INFINITY;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
row_max = row_max.max(weights[local_col].ln() + log_p);
}
}
if !row_max.is_finite() {
continue;
}
let mut denom = 0.0_f64;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
denom += (weights[local_col].ln() + log_p - row_max).exp();
}
}
if denom <= 0.0 {
continue;
}
active_rows += 1;
let log_mix = row_max + denom.ln();
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
next[local_col] += (weights[local_col].ln() + log_p - log_mix).exp();
}
}
}
if active_rows == 0 {
break;
}
next.mapv_inplace(|value| value / active_rows as f64);
let total = next.sum();
if total > 0.0 {
next.mapv_inplace(|value| value / total);
}
let delta = next
.iter()
.zip(weights.iter())
.fold(0.0_f64, |acc, (a, b)| acc.max((a - b).abs()));
weights.assign(&next);
if delta <= config.weight_tol {
break;
}
}
let mean_log_score = stacking_mean_log_score(log_density, &rows, &kept_cols, weights.view());
let mut full = Array1::<f64>::zeros(n_cand);
for (local_col, &source_col) in kept_cols.iter().enumerate() {
full[source_col] = weights[local_col];
}
Ok(StackingWeights {
weights: full,
mean_log_score,
iterations,
})
}
fn stacking_mean_log_score(
log_density: ArrayView2<'_, f64>,
rows: &[usize],
kept_cols: &[usize],
weights: ArrayView1<'_, f64>,
) -> f64 {
let mut score_sum = 0.0_f64;
let mut counted = 0usize;
for &row in rows {
let mut row_max = f64::NEG_INFINITY;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
row_max = row_max.max(weights[local_col].ln() + log_p);
}
}
if !row_max.is_finite() {
continue;
}
let mut denom = 0.0_f64;
for (local_col, &source_col) in kept_cols.iter().enumerate() {
let log_p = log_density[[row, source_col]];
if log_p.is_finite() && weights[local_col] > 0.0 {
denom += (weights[local_col].ln() + log_p - row_max).exp();
}
}
if denom > 0.0 {
score_sum += row_max + denom.ln();
counted += 1;
}
}
if counted == 0 {
f64::NEG_INFINITY
} else {
score_sum / counted as f64
}
}
pub fn stacked_predictive_mean(
weights: &Array1<f64>,
candidate_means: &[Array1<f64>],
) -> Result<Array1<f64>, String> {
if candidate_means.len() != weights.len() {
return Err(format!(
"stacked_predictive_mean: {} weights but {} candidate mean vectors",
weights.len(),
candidate_means.len()
));
}
let Some(first) = candidate_means.first() else {
return Err("stacked_predictive_mean requires at least one candidate".to_string());
};
let n_rows = first.len();
if candidate_means.iter().any(|means| means.len() != n_rows) {
return Err(
"stacked_predictive_mean: candidate mean vectors disagree on row count".to_string(),
);
}
let mut out = Array1::<f64>::zeros(n_rows);
for (weight, means) in weights.iter().zip(candidate_means) {
if *weight != 0.0 {
out.scaled_add(*weight, means);
}
}
Ok(out)
}
#[derive(Clone, Debug)]
pub struct RemlCandidate {
pub index: usize,
pub name: String,
pub score: f64,
pub edf: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct RemlComparison {
pub ranking: Vec<RankedRow>,
pub winner: String,
pub evidence_summary: String,
pub score_table: Vec<ScoreRow>,
}
#[derive(Clone, Debug)]
pub struct RankedRow {
pub name: String,
pub score: f64,
pub delta: f64,
pub bayes_factor: f64,
pub edf: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct ScoreRow {
pub name: String,
pub reml_score: f64,
pub delta_reml: f64,
pub bayes_factor_best_over_model: f64,
pub effective_dof: Option<f64>,
}
#[inline]
pub fn log_bayes_factor(reml_score_a: f64, reml_score_b: f64) -> f64 {
reml_score_b - reml_score_a
}
pub fn compare_reml_fits(mut candidates: Vec<RemlCandidate>) -> Result<RemlComparison, String> {
if candidates.is_empty() {
return Err("compare_models requires at least one fit".to_string());
}
candidates = rank_priority_candidates(
candidates
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = row.score;
PriorityCandidate::new(row, idx, score, 0)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
let best_score = candidates[0].score;
let winner = candidates[0].name.clone();
let mut ranking = Vec::with_capacity(candidates.len());
let mut score_table = Vec::with_capacity(candidates.len());
for row in &candidates {
let delta = log_bayes_factor(best_score, row.score);
let bayes_factor = delta.exp();
ranking.push(RankedRow {
name: row.name.clone(),
score: row.score,
delta,
bayes_factor,
edf: row.edf,
});
score_table.push(ScoreRow {
name: row.name.clone(),
reml_score: row.score,
delta_reml: delta,
bayes_factor_best_over_model: bayes_factor,
effective_dof: row.edf,
});
}
let evidence_summary = if let Some(runner_up) = candidates.get(1) {
format!(
"{} wins by Bayes factor {} over {}",
winner,
format_bayes_factor(log_bayes_factor(best_score, runner_up.score)),
runner_up.name
)
} else {
format!("{winner} (single fit; no comparison)")
};
Ok(RemlComparison {
ranking,
winner,
evidence_summary,
score_table,
})
}
pub fn format_bayes_factor(log_bf: f64) -> String {
if !log_bf.is_finite() {
return "inf".to_string();
}
if log_bf.abs() >= std::f64::consts::LN_10 * 3.0 {
return format!("1e{:+.1}", log_bf / std::f64::consts::LN_10);
}
format_three_significant(log_bf.exp())
}
pub fn format_three_significant(value: f64) -> String {
if value == 0.0 {
return "0".to_string();
}
if !value.is_finite() {
return format!("{value}");
}
let exponent = value.abs().log10().floor() as i32;
if exponent >= 3 {
return format!("{value:.2e}");
}
let decimals = (2 - exponent).max(0) as usize;
let scale = 10f64.powi(decimals as i32);
let rounded = (value * scale).abs().round() / scale * value.signum();
format!("{rounded:.decimals$}")
}
impl Default for TopologySelectOptions {
fn default() -> Self {
Self {
tie_tolerance: 1e-3,
score_scale: TopologyScoreScale::PerObservation,
}
}
}
pub fn laplace_evidence(
logdet_source: EvidenceLogDetSource<'_>,
penalty_log_det: f64,
residual_objective: f64,
effective_dim: f64,
penalty_rank: f64,
) -> f64 {
if !(effective_dim.is_finite() && penalty_rank.is_finite()) {
return f64::NAN;
}
let log_det_h = match evidence_hessian_log_det(logdet_source) {
Ok(v) => v,
Err(_) => return f64::NAN,
};
let null_dim = effective_dim - penalty_rank;
if !null_dim.is_finite() || null_dim < -1e-9 {
return f64::NAN;
}
residual_objective + 0.5 * log_det_h
- 0.5 * penalty_log_det
- 0.5 * null_dim.max(0.0) * (2.0 * std::f64::consts::PI).ln()
}
pub fn evidence_hessian_log_det(source: EvidenceLogDetSource<'_>) -> Result<f64, String> {
match source {
EvidenceLogDetSource::FactoredArrow {
cache,
fallback_hvp,
} => match arrow_log_det_from_cache(cache) {
Some(v) => Ok(v),
None => match fallback_hvp {
Some(hvp) => hessian_log_det_from_hvp(hvp),
None => {
Err("evidence Hessian logdet requires exact factors or HVP fallback".into())
}
},
},
EvidenceLogDetSource::Hvp(hvp) => hessian_log_det_from_hvp(hvp),
}
}
pub fn hessian_log_det_from_hvp(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
if hvp.dim == 0 {
return Ok(0.0);
}
if hvp.dim <= ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD {
let mut dense = Array2::<f64>::zeros((hvp.dim, hvp.dim));
let mut basis = vec![0.0_f64; hvp.dim];
for j in 0..hvp.dim {
basis[j] = 1.0;
let col = (hvp.apply)(&basis);
basis[j] = 0.0;
if col.len() != hvp.dim || col.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP logdet expected finite column of length {}, got {}",
hvp.dim,
col.len()
));
}
for i in 0..hvp.dim {
dense[[i, j]] = col[i];
}
}
validate_dense_hvp_symmetry(&dense)?;
for i in 0..hvp.dim {
for j in (i + 1)..hvp.dim {
let avg = 0.5 * (dense[[i, j]] + dense[[j, i]]);
dense[[i, j]] = avg;
dense[[j, i]] = avg;
}
}
dense_spd_log_det(&dense)
} else {
stochastic_hvp_log_det(hvp)
}
}
fn dense_spd_log_det(matrix: &Array2<f64>) -> Result<f64, String> {
if matrix.nrows() != matrix.ncols() {
return Err(format!(
"evidence dense logdet requires square matrix, got {}x{}",
matrix.nrows(),
matrix.ncols()
));
}
if crate::gpu::cuda_selected() {
return crate::solver::gpu::reml_gpu::evidence_derivatives_gpu(
crate::solver::gpu::reml_gpu::RemlGpuInput {
penalized_hessian: matrix.view(),
derivative_hessians: Vec::new(),
},
)
.map(|evidence| evidence.logdet_hessian);
}
let (evals, _) = matrix
.eigh(Side::Lower)
.map_err(|e| format!("evidence dense logdet eigendecomposition failed: {e}"))?;
let mut logdet = 0.0_f64;
for (idx, &ev) in evals.iter().enumerate() {
if !ev.is_finite() || ev <= 0.0 {
return Err(format!(
"evidence dense logdet expected SPD Hessian, eigenvalue {idx} is {ev:.3e}"
));
}
logdet += ev.ln();
}
Ok(logdet)
}
fn validate_dense_hvp_symmetry(matrix: &Array2<f64>) -> Result<(), String> {
let n = matrix.nrows();
let mut norm_sq = 0.0_f64;
for &value in matrix.iter() {
norm_sq += value * value;
}
let mut skew_sq = 0.0_f64;
for i in 0..n {
for j in (i + 1)..n {
let skew = matrix[[i, j]] - matrix[[j, i]];
skew_sq += 2.0 * skew * skew;
}
}
let rel_skew = skew_sq.sqrt() / norm_sq.sqrt().max(1.0);
if !rel_skew.is_finite() || rel_skew > EVIDENCE_HVP_SYMMETRY_REL_TOL {
return Err(format!(
"evidence HVP logdet requires symmetric operator, relative skew norm is {rel_skew:.3e}"
));
}
Ok(())
}
fn validate_hvp_randomized_symmetry(hvp: EvidenceHvpLogDet<'_>) -> Result<(), String> {
let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
for probe in 0..EVIDENCE_HVP_SYMMETRY_PROBES.max(1) {
let mut x = vec![0.0_f64; hvp.dim];
let mut y = vec![0.0_f64; hvp.dim];
rademacher_unit_probe_into_slice(&mut x, (2 * probe) as u64, inv_norm);
rademacher_unit_probe_into_slice(&mut y, (2 * probe + 1) as u64, inv_norm);
let hx = (hvp.apply)(&x);
let hy = (hvp.apply)(&y);
if hx.len() != hvp.dim || hx.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP symmetry check expected finite vector of length {}, got {}",
hvp.dim,
hx.len()
));
}
if hy.len() != hvp.dim || hy.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP symmetry check expected finite vector of length {}, got {}",
hvp.dim,
hy.len()
));
}
let lhs = dot_slice(&x, &hy);
let rhs = dot_slice(&hx, &y);
let scale = (norm2_slice(&hx) * norm2_slice(&y))
.max(norm2_slice(&hy) * norm2_slice(&x))
.max(lhs.abs())
.max(rhs.abs())
.max(1.0);
let rel = (lhs - rhs).abs() / scale;
if !rel.is_finite() || rel > EVIDENCE_HVP_SYMMETRY_REL_TOL {
return Err(format!(
"evidence HVP logdet requires symmetric operator, randomized symmetry probe {probe} has relative bilinear mismatch {rel:.3e}"
));
}
}
Ok(())
}
fn stochastic_hvp_log_det(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
validate_hvp_randomized_symmetry(hvp)?;
let probes = EVIDENCE_LOGDET_SLQ_PROBES.max(1);
let steps = EVIDENCE_LOGDET_LANCZOS_STEPS.min(hvp.dim).max(1);
let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
let mut estimate = 0.0_f64;
for probe in 0..probes {
let mut q0 = vec![0.0_f64; hvp.dim];
rademacher_unit_probe_into_slice(&mut q0, probe as u64, inv_norm);
let quad = lanczos_log_quadrature_hvp(hvp, q0, steps)?;
estimate += hvp.dim as f64 * quad;
}
Ok(estimate / probes as f64)
}
fn lanczos_log_quadrature_hvp(
hvp: EvidenceHvpLogDet<'_>,
q: Vec<f64>,
max_steps: usize,
) -> Result<f64, String> {
let n = hvp.dim;
let eigen = symmetric_lanczos_eigenpairs(
n,
&q,
SymmetricLanczosOptions {
max_steps,
residual_tol: 1e-12,
local_reorthogonalize: false,
full_reorthogonalize: false,
},
|q, out| {
let applied = (hvp.apply)(q);
if applied.len() != n || applied.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP SLQ expected finite vector of length {n}, got {}",
applied.len()
));
}
out.copy_from_slice(&applied);
Ok(())
},
)
.map_err(|e| format!("evidence HVP SLQ Lanczos failed: {e}"))?;
symmetric_lanczos_log_quadrature(&eigen, "evidence HVP SLQ expected SPD Hessian")
}
#[inline]
fn dot_slice(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len());
let mut s = 0.0_f64;
for i in 0..a.len() {
s += a[i] * b[i];
}
s
}
#[inline]
fn norm2_slice(a: &[f64]) -> f64 {
dot_slice(a, a).sqrt()
}
fn rademacher_unit_probe_into_slice(z: &mut [f64], probe: u64, scale: f64) {
let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
let mut bits = 0_u64;
let mut remaining_bits = 0_u32;
for value in z.iter_mut() {
if remaining_bits == 0 {
bits = splitmix64(&mut state);
remaining_bits = 64;
}
*value = if bits & 1 == 0 { scale } else { -scale };
bits >>= 1;
remaining_bits -= 1;
}
}
#[inline]
const fn splitmix64(state: &mut u64) -> u64 {
crate::linalg::utils::splitmix64(state)
}
pub fn arrow_log_det_from_cache(cache: &ArrowFactorCache) -> Option<f64> {
if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
return None;
}
let schur = cache.schur_factor.as_ref()?;
let mut acc = 0.0_f64;
for l in cache.undamped_factors_iter() {
acc += 2.0 * log_det_from_chol_lower(l);
}
acc += 2.0 * log_det_from_chol_lower(schur);
Some(acc)
}
fn log_det_from_chol_lower(l: &Array2<f64>) -> f64 {
let n = l.nrows();
let mut acc = 0.0_f64;
for i in 0..n {
let d = l[[i, i]];
if d > 0.0 {
acc += d.ln();
} else {
return f64::NAN;
}
}
acc
}
pub fn ift_du_dbeta(cache: &ArrowFactorCache) -> Array2<f64> {
let n = cache.undamped_factor_count();
let total_len = cache.delta_t_len();
let k = cache.k;
if !cache.htbeta_available() {
return Array2::<f64>::from_elem((total_len, k), f64::NAN);
}
let mut out = Array2::<f64>::zeros((total_len, k));
let mut beta_basis = Array1::<f64>::zeros(k);
let mut rhs = Array1::<f64>::zeros(cache.d);
for i in 0..n {
let di = cache.row_dims[i];
let row_base = cache.row_offsets[i];
let factor = cache.undamped_factor(i);
for col in 0..k {
beta_basis.fill(0.0);
beta_basis[col] = 1.0;
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
return Array2::<f64>::from_elem((total_len, k), f64::NAN);
}
let y = cholesky_solve_vector(factor, &rhs_i);
for c in 0..di {
out[[row_base + c, col]] = -y[c];
}
}
}
out
}
pub(crate) fn ift_dbeta_drho_from_solver(
beta_dim: usize,
dg_drho: ArrayView2<'_, f64>,
mut solve_beta_hessian: impl FnMut(&Array1<f64>) -> Array1<f64>,
) -> Option<Array2<f64>> {
let r = dg_drho.ncols();
if dg_drho.nrows() != beta_dim {
return None;
}
let mut out = Array2::<f64>::zeros((beta_dim, r));
let mut rhs = Array1::<f64>::zeros(beta_dim);
for a in 0..r {
for row in 0..beta_dim {
rhs[row] = dg_drho[[row, a]];
}
let solved = solve_beta_hessian(&rhs);
if solved.len() != beta_dim || solved.iter().any(|value| !value.is_finite()) {
return None;
}
for row in 0..beta_dim {
out[[row, a]] = -solved[row];
}
}
Some(out)
}
pub fn ift_dbeta_drho(
cache: &ArrowFactorCache,
dg_red_drho: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
return None;
}
let schur = cache.schur_factor.as_ref()?;
ift_dbeta_drho_from_solver(cache.k, dg_red_drho, |rhs| {
cholesky_solve_vector(schur, rhs)
})
}
pub fn ift_du_drho(
cache: &ArrowFactorCache,
gu_rho: ArrayView2<'_, f64>,
dbeta_drho: ArrayView2<'_, f64>,
) -> Array2<f64> {
let n = cache.undamped_factor_count();
let total_len = cache.delta_t_len();
let k = cache.k;
let r = dbeta_drho.ncols();
if !cache.htbeta_available()
|| gu_rho.nrows() != total_len
|| gu_rho.ncols() != r
|| dbeta_drho.nrows() != k
{
return Array2::<f64>::from_elem((total_len, r), f64::NAN);
}
let mut out = Array2::<f64>::zeros((total_len, r));
let mut rhs = Array1::<f64>::zeros(cache.d);
let mut htbeta_delta = Array1::<f64>::zeros(cache.d);
for a in 0..r {
for i in 0..n {
let di = cache.row_dims[i];
let row_base = cache.row_offsets[i];
let mut htbeta_i = htbeta_delta.slice_mut(ndarray::s![..di]).to_owned();
if !cache.apply_htbeta_row(i, dbeta_drho.column(a), &mut htbeta_i) {
return Array2::<f64>::from_elem((total_len, r), f64::NAN);
}
{
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]);
for c in 0..di {
rhs_i[c] = gu_rho[[row_base + c, a]] + htbeta_i[c];
}
}
let rhs_slice = rhs.slice(ndarray::s![..di]).to_owned();
let v = cholesky_solve_vector(cache.undamped_factor(i), &rhs_slice);
for c in 0..di {
out[[row_base + c, a]] = -v[c];
}
}
}
out
}
#[derive(Clone)]
pub struct EvidenceIftGradientTerms<'a> {
pub dbeta_drho: ArrayView2<'a, f64>,
pub du_drho: ArrayView2<'a, f64>,
pub value_beta: ArrayView1<'a, f64>,
pub value_u: ArrayView1<'a, f64>,
pub logdet_h_beta: ArrayView1<'a, f64>,
pub logdet_h_u: ArrayView1<'a, f64>,
}
pub fn evidence_ift_gradient_correction(terms: EvidenceIftGradientTerms<'_>) -> Array1<f64> {
let k = terms.dbeta_drho.nrows();
let nd = terms.du_drho.nrows();
let r = terms.dbeta_drho.ncols();
if terms.du_drho.ncols() != r
|| terms.value_beta.len() != k
|| terms.logdet_h_beta.len() != k
|| terms.value_u.len() != nd
|| terms.logdet_h_u.len() != nd
{
return Array1::<f64>::from_elem(r, f64::NAN);
}
let mut out = Array1::<f64>::zeros(r);
for a in 0..r {
let mut acc = 0.0_f64;
for j in 0..k {
let mode = terms.dbeta_drho[[j, a]];
acc += terms.value_beta[j] * mode;
acc += 0.5 * terms.logdet_h_beta[j] * mode;
}
for j in 0..nd {
let mode = terms.du_drho[[j, a]];
acc += terms.value_u[j] * mode;
acc += 0.5 * terms.logdet_h_u[j] * mode;
}
out[a] = acc;
}
out
}
pub fn evidence_grad_rho(
cache: &ArrowFactorCache,
value_rho: ArrayView1<'_, f64>,
huu_drho: &[Vec<Array2<f64>>],
htbeta_drho: &[Vec<Array2<f64>>],
hbb_drho: &[Array2<f64>],
pen_logdet_drho: ArrayView1<'_, f64>,
ift_terms: EvidenceIftGradientTerms<'_>,
) -> Array1<f64> {
let r = value_rho.len();
let n = cache.undamped_factor_count();
let k = cache.k;
let mut out = Array1::<f64>::zeros(r);
if !cache.htbeta_available()
|| pen_logdet_drho.len() != r
|| huu_drho.len() != n
|| htbeta_drho.len() != n
|| hbb_drho.len() != r
|| huu_drho.iter().any(|row| row.len() != r)
|| htbeta_drho.iter().any(|row| row.len() != r)
|| hbb_drho.iter().any(|m| m.nrows() != k || m.ncols() != k)
|| huu_drho.iter().enumerate().any(|(i, row)| {
let di = cache.row_dims[i];
row.iter().any(|m| m.nrows() != di || m.ncols() != di)
})
|| htbeta_drho.iter().enumerate().any(|(i, row)| {
let di = cache.row_dims[i];
row.iter().any(|m| m.nrows() != di || m.ncols() != k)
})
{
out.fill(f64::NAN);
return out;
}
let ift_correction = evidence_ift_gradient_correction(ift_terms);
if ift_correction.len() != r || ift_correction.iter().any(|v| v.is_nan()) {
out.fill(f64::NAN);
return out;
}
let schur = match cache.schur_factor.as_ref() {
Some(s) => s,
None => {
for a in 0..r {
out[a] = f64::NAN;
}
return out;
}
};
let mut y_blocks: Vec<Array2<f64>> = Vec::with_capacity(n);
let mut beta_basis = Array1::<f64>::zeros(k);
let mut rhs = Array1::<f64>::zeros(cache.d);
for i in 0..n {
let di = cache.row_dims[i];
let factor = cache.undamped_factor(i);
let mut yi = Array2::<f64>::zeros((di, k));
for col in 0..k {
beta_basis.fill(0.0);
beta_basis[col] = 1.0;
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]).to_owned();
if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs_i) {
out.fill(f64::NAN);
return out;
}
let v = cholesky_solve_vector(factor, &rhs_i);
for c in 0..di {
yi[[c, col]] = v[c];
}
}
y_blocks.push(yi);
}
let mut trace_rhs = Array1::<f64>::zeros(cache.d);
let mut da_tmp = Array2::<f64>::zeros((cache.d, k));
let mut col_scratch = Array1::<f64>::zeros(k);
for a in 0..r {
let mut grad = value_rho[a];
let mut row_trace_acc = 0.0_f64;
for i in 0..n {
let di = cache.row_dims[i];
let m_i = &huu_drho[i][a];
assert_eq!(m_i.shape(), &[di, di]);
for col in 0..di {
let mut tr_rhs_i = trace_rhs.slice_mut(ndarray::s![..di]).to_owned();
for r0 in 0..di {
tr_rhs_i[r0] = m_i[[r0, col]];
}
let v = cholesky_solve_vector(cache.undamped_factor(i), &tr_rhs_i);
row_trace_acc += v[col];
}
}
let mut da = hbb_drho[a].clone();
assert_eq!(da.shape(), &[k, k]);
for i in 0..n {
let di = cache.row_dims[i];
let dhtb = &htbeta_drho[i][a]; let yi = &y_blocks[i]; for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += dhtb[[cc, r0]] * yi[[cc, c0]];
}
da[[r0, c0]] -= acc;
}
}
for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += yi[[cc, r0]] * dhtb[[cc, c0]];
}
da[[r0, c0]] -= acc;
}
}
let dhuu = &huu_drho[i][a];
let mut da_tmp_i = da_tmp.slice_mut(ndarray::s![..di, ..]).to_owned();
for r0 in 0..di {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += dhuu[[r0, cc]] * yi[[cc, c0]];
}
da_tmp_i[[r0, c0]] = acc;
}
}
for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..di {
acc += yi[[cc, r0]] * da_tmp_i[[cc, c0]];
}
da[[r0, c0]] += acc;
}
}
}
let mut schur_trace_acc = 0.0_f64;
for j in 0..k {
for r0 in 0..k {
col_scratch[r0] = da[[r0, j]];
}
let v = cholesky_solve_vector(schur, &col_scratch);
schur_trace_acc += v[j];
}
grad += 0.5 * (row_trace_acc + schur_trace_acc);
grad += ift_correction[a];
grad -= 0.5 * pen_logdet_drho[a];
out[a] = grad;
}
out
}
pub fn select_topology(
candidates: &[TopologyCandidate],
options: TopologySelectOptions,
) -> SelectedTopology {
let mut valid: Vec<TopologyCandidate> = candidates
.iter()
.filter(|c| {
c.converged
&& c.exclusion_reason.is_none()
&& c.negative_log_evidence.is_finite()
&& topology_selection_score(c, options.score_scale).is_finite()
})
.cloned()
.collect();
let mut excluded: Vec<TopologyCandidate> = candidates
.iter()
.filter(|c| {
!(c.converged && c.exclusion_reason.is_none() && c.negative_log_evidence.is_finite())
|| !topology_selection_score(c, options.score_scale).is_finite()
})
.cloned()
.collect();
assert!(
!valid.is_empty(),
"select_topology: no finite valid candidates; proposal §6.11 forbids silent fallback"
);
valid = rank_priority_candidates(
valid
.into_iter()
.enumerate()
.map(|(idx, row)| {
let score = topology_selection_score(&row, options.score_scale);
let tie_break = usize::from(row.kind.complexity_rank());
PriorityCandidate::new(row, idx, score, tie_break)
})
.collect(),
)
.into_iter()
.map(|row| row.item)
.collect();
let tie = if valid.len() >= 2 {
let top = topology_selection_score(&valid[0], options.score_scale);
let next = topology_selection_score(&valid[1], options.score_scale);
(next - top).abs() <= options.tie_tolerance
} else {
false
};
if tie {
let top_score = topology_selection_score(&valid[0], options.score_scale);
let tied_end = valid
.iter()
.position(|c| {
(topology_selection_score(c, options.score_scale) - top_score).abs()
> options.tie_tolerance
})
.unwrap_or(valid.len());
valid[..tied_end].sort_by_key(|c| c.kind.complexity_rank());
}
let winner = valid[0].kind;
valid.append(&mut excluded);
SelectedTopology {
winner,
ranking: valid,
tie,
}
}
fn topology_selection_score(candidate: &TopologyCandidate, scale: TopologyScoreScale) -> f64 {
match scale {
TopologyScoreScale::PerObservation => {
if candidate.n_obs == 0 {
f64::NAN
} else {
candidate.negative_log_evidence / candidate.n_obs as f64
}
}
TopologyScoreScale::PerEffectiveDim => {
if !(candidate.effective_dim.is_finite() && candidate.effective_dim > 0.0) {
f64::NAN
} else {
candidate.negative_log_evidence / candidate.effective_dim
}
}
}
}
pub fn cache_supports_exact_evidence(cache: &ArrowFactorCache) -> bool {
cache.ridge_t == 0.0
&& cache.ridge_beta == 0.0
&& cache.schur_factor.is_some()
&& cache.htbeta_available()
}
pub fn cache_matches_system(cache: &ArrowFactorCache, sys: &ArrowSchurSystem) -> bool {
cache.d == sys.d
&& cache.k == sys.k
&& cache.n_rows() == sys.rows.len()
&& cache.undamped_factor_count() == sys.rows.len()
&& cache.manifold_mode_fingerprint == sys.manifold_mode_fingerprint
&& cache.row_hessian_fingerprint == sys.current_row_hessian_fingerprint()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_minimal_cache() -> ArrowFactorCache {
let l_huu = Array2::from_shape_vec((1, 1), vec![std::f64::consts::SQRT_2]).unwrap();
let l_schur = Array2::from_shape_vec((1, 1), vec![(1.875_f64).sqrt()]).unwrap();
let htbeta = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
ArrowFactorCache {
htt_factors: std::sync::Arc::from(vec![l_huu]),
htt_factors_undamped: crate::solver::arrow_schur::ArrowUndampedFactors::SameAsDamped,
schur_factor: Some(l_schur),
solver_mode: crate::solver::arrow_schur::ArrowSolverMode::Direct,
ridge_t: 0.0,
ridge_beta: 0.0,
htbeta: crate::solver::arrow_schur::ArrowHtbetaCache::Dense {
blocks: std::sync::Arc::from(vec![htbeta]),
estimated_bytes: std::mem::size_of::<f64>(),
},
d: 1,
row_dims: std::sync::Arc::from(vec![1usize]),
row_offsets: std::sync::Arc::from(vec![0usize, 1usize]),
k: 1,
manifold_mode_fingerprint: 0,
row_hessian_fingerprint: 0,
pcg_diagnostics: crate::solver::arrow_schur::PcgDiagnostics::default(),
}
}
#[test]
fn laplace_evidence_returns_finite_for_minimal_cache() {
let cache = make_minimal_cache();
let v = laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: None,
},
0.0,
0.0,
2.0,
1.0,
);
assert!(v.is_finite());
let expected =
0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
assert!((v - expected).abs() < 1e-12);
}
#[test]
fn laplace_evidence_nan_when_ridge_is_nonzero() {
let mut cache = make_minimal_cache();
cache.ridge_t = 1e-3;
assert!(
laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: None,
},
0.0,
0.0,
2.0,
1.0,
)
.is_nan()
);
}
#[test]
fn laplace_evidence_uses_hvp_fallback_without_schur_factor() {
let mut cache = make_minimal_cache();
cache.schur_factor = None;
let hvp = |x: &[f64]| -> Vec<f64> { vec![2.0 * x[0], 1.875 * x[1]] };
let v = laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: Some(EvidenceHvpLogDet {
dim: 2,
apply: &hvp,
}),
},
0.0,
0.0,
2.0,
1.0,
);
let expected =
0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
assert!((v - expected).abs() < 1e-12);
}
#[test]
fn ift_du_dbeta_has_expected_shape() {
let cache = make_minimal_cache();
let du_db = ift_du_dbeta(&cache);
assert_eq!(du_db.shape(), &[1, 1]);
assert!((du_db[[0, 0]] - (-0.25)).abs() < 1e-12);
}
#[test]
fn ift_dbeta_drho_returns_some_for_direct_cache() {
let cache = make_minimal_cache();
let q = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
let out = ift_dbeta_drho(&cache, q.view()).unwrap();
assert_eq!(out.shape(), &[1, 1]);
assert!((out[[0, 0]] + 1.0 / 1.875).abs() < 1e-12);
}
#[test]
fn topology_select_picks_lowest_negative_log_evidence() {
let candidates = vec![
TopologyCandidate {
kind: TopologyKind::Flat,
negative_log_evidence: 10.0,
effective_dim: 4.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Sphere,
negative_log_evidence: 8.0,
effective_dim: 5.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Torus,
negative_log_evidence: f64::NAN,
effective_dim: 6.0,
n_obs: 100,
converged: false,
exclusion_reason: Some("torus periods missing".to_string()),
},
];
let sel = select_topology(&candidates, TopologySelectOptions::default());
assert_eq!(sel.winner, TopologyKind::Sphere);
assert!(!sel.tie);
}
#[test]
fn topology_select_tie_breaks_to_simpler() {
let candidates = vec![
TopologyCandidate {
kind: TopologyKind::Sphere,
negative_log_evidence: 5.0,
effective_dim: 5.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Flat,
negative_log_evidence: 5.0 + 1e-6,
effective_dim: 4.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
];
let sel = select_topology(&candidates, TopologySelectOptions::default());
assert_eq!(sel.winner, TopologyKind::Flat);
assert!(sel.tie);
}
fn gaussian_logpdf(y: f64, mean: f64, sd: f64) -> f64 {
let z = (y - mean) / sd;
-0.5 * (2.0 * std::f64::consts::PI).ln() - sd.ln() - 0.5 * z * z
}
#[test]
fn stacking_single_candidate_gets_full_weight() {
let log_density = Array2::from_shape_vec((3, 1), vec![-1.0, -2.0, -0.5]).unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!((out.weights[0] - 1.0).abs() < 1e-12);
assert_eq!(out.weights.len(), 1);
}
#[test]
fn stacking_dominant_candidate_attracts_nearly_all_weight() {
let mut log_density = Array2::<f64>::zeros((50, 2));
for i in 0..50 {
log_density[[i, 0]] = -0.1;
log_density[[i, 1]] = -5.0;
}
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!(out.weights[0] > 0.99, "w0 = {}", out.weights[0]);
assert!(out.weights[1] < 0.01, "w1 = {}", out.weights[1]);
}
#[test]
fn stacking_complementary_candidates_share_weight() {
let n = 40;
let mut log_density = Array2::<f64>::zeros((n, 2));
for i in 0..n {
if i < n / 2 {
log_density[[i, 0]] = gaussian_logpdf(0.0, 0.0, 0.5);
log_density[[i, 1]] = gaussian_logpdf(0.0, 1.5, 0.5);
} else {
log_density[[i, 0]] = gaussian_logpdf(0.0, 1.5, 0.5);
log_density[[i, 1]] = gaussian_logpdf(0.0, 0.0, 0.5);
}
}
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!(
out.weights[0] > 0.2 && out.weights[0] < 0.8,
"w0 = {}",
out.weights[0]
);
assert!((out.weights.sum() - 1.0).abs() < 1e-9);
}
#[test]
fn stacking_weights_stay_on_the_simplex() {
let log_density = Array2::from_shape_vec(
(3, 3),
vec![-1.0, -2.0, -3.0, -2.5, -1.0, -2.0, -3.0, -2.0, -1.0],
)
.unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!((out.weights.sum() - 1.0).abs() < 1e-9);
assert!(out.weights.iter().all(|&w| w >= -1e-12));
}
#[test]
fn stacking_mean_log_score_is_monotone_under_more_iterations() {
let log_density =
Array2::from_shape_vec((4, 2), vec![-0.2, -3.0, -3.0, -0.2, -0.5, -1.5, -1.5, -0.5])
.unwrap();
let mut prev = f64::NEG_INFINITY;
for max_iter in [1usize, 2, 4, 8, 32] {
let out = solve_stacking_weights(
log_density.view(),
StackingConfig {
max_iter,
weight_tol: 0.0,
},
)
.unwrap();
assert!(
out.mean_log_score >= prev - 1e-12,
"log-score decreased at max_iter={max_iter}: {prev} -> {}",
out.mean_log_score
);
prev = out.mean_log_score;
}
}
#[test]
fn stacking_dead_candidate_column_is_rejected_and_zero_weighted() {
let log_density = Array2::from_shape_vec(
(3, 2),
vec![
-1.0,
f64::NEG_INFINITY,
-2.0,
f64::NAN,
-0.5,
f64::NEG_INFINITY,
],
)
.unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert_eq!(out.weights[1], 0.0);
assert!((out.weights[0] - 1.0).abs() < 1e-12);
}
#[test]
fn stacking_rows_with_no_finite_density_are_dropped() {
let log_density = Array2::from_shape_vec(
(3, 2),
vec![-1.0, -2.0, f64::NAN, f64::NEG_INFINITY, -2.0, -1.0],
)
.unwrap();
let out = solve_stacking_weights(log_density.view(), StackingConfig::default()).unwrap();
assert!((out.weights.sum() - 1.0).abs() < 1e-9);
assert!(out.mean_log_score.is_finite());
}
#[test]
fn stacking_all_dead_table_errors() {
let log_density = Array2::from_elem((2, 2), f64::NEG_INFINITY);
assert!(solve_stacking_weights(log_density.view(), StackingConfig::default()).is_err());
}
#[test]
fn stacked_mean_is_weighted_combination() {
let weights = Array1::from_vec(vec![0.25, 0.75]);
let means = vec![
Array1::from_vec(vec![1.0, 2.0, 3.0]),
Array1::from_vec(vec![5.0, 6.0, 7.0]),
];
let out = stacked_predictive_mean(&weights, &means).unwrap();
assert!((out[0] - (0.25 * 1.0 + 0.75 * 5.0)).abs() < 1e-12);
assert!((out[2] - (0.25 * 3.0 + 0.75 * 7.0)).abs() < 1e-12);
}
#[test]
fn stacked_mean_rejects_shape_mismatch() {
let weights = Array1::from_vec(vec![0.5, 0.5]);
let means = vec![
Array1::from_vec(vec![1.0, 2.0]),
Array1::from_vec(vec![3.0]),
];
assert!(stacked_predictive_mean(&weights, &means).is_err());
}
}