use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::FaerEigh;
use crate::solver::arrow_schur::{ArrowFactorCache, ArrowSchurSystem};
pub const ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD: usize = 1024;
const EVIDENCE_LOGDET_SLQ_PROBES: usize = 16;
const EVIDENCE_LOGDET_LANCZOS_STEPS: usize = 32;
const EVIDENCE_HVP_SYMMETRY_REL_TOL: f64 = 1e-8;
const EVIDENCE_HVP_SYMMETRY_PROBES: usize = 4;
#[derive(Clone, Copy)]
pub struct EvidenceHvpLogDet<'a> {
pub dim: usize,
pub apply: &'a dyn Fn(&[f64]) -> Vec<f64>,
}
#[derive(Clone, Copy)]
pub enum EvidenceLogDetSource<'a> {
FactoredArrow {
cache: &'a ArrowFactorCache,
fallback_hvp: Option<EvidenceHvpLogDet<'a>>,
},
Hvp(EvidenceHvpLogDet<'a>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TopologyKind {
Periodic,
Flat,
Sphere,
Torus,
}
impl TopologyKind {
pub fn complexity_rank(self) -> u8 {
match self {
TopologyKind::Flat => 0,
TopologyKind::Periodic => 1,
TopologyKind::Sphere => 2,
TopologyKind::Torus => 3,
}
}
}
#[derive(Debug, Clone)]
pub struct TopologyCandidate {
pub kind: TopologyKind,
pub negative_log_evidence: f64,
pub effective_dim: f64,
pub n_obs: usize,
pub converged: bool,
pub exclusion_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SelectedTopology {
pub winner: TopologyKind,
pub ranking: Vec<TopologyCandidate>,
pub tie: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct TopologySelectOptions {
pub tie_tolerance: f64,
pub score_scale: TopologyScoreScale,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TopologyScoreScale {
PerObservation,
PerEffectiveDim,
}
impl Default for TopologySelectOptions {
fn default() -> Self {
Self {
tie_tolerance: 1e-3,
score_scale: TopologyScoreScale::PerObservation,
}
}
}
pub fn laplace_evidence(
logdet_source: EvidenceLogDetSource<'_>,
penalty_log_det: f64,
residual_objective: f64,
effective_dim: f64,
penalty_rank: f64,
) -> f64 {
if !(effective_dim.is_finite() && penalty_rank.is_finite()) {
return f64::NAN;
}
let log_det_h = match evidence_hessian_log_det(logdet_source) {
Ok(v) => v,
Err(_) => return f64::NAN,
};
let null_dim = effective_dim - penalty_rank;
if !null_dim.is_finite() || null_dim < -1e-9 {
return f64::NAN;
}
residual_objective + 0.5 * log_det_h
- 0.5 * penalty_log_det
- 0.5 * null_dim.max(0.0) * (2.0 * std::f64::consts::PI).ln()
}
pub fn evidence_hessian_log_det(source: EvidenceLogDetSource<'_>) -> Result<f64, String> {
match source {
EvidenceLogDetSource::FactoredArrow {
cache,
fallback_hvp,
} => match arrow_log_det_from_cache(cache) {
Some(v) => Ok(v),
None => match fallback_hvp {
Some(hvp) => hessian_log_det_from_hvp(hvp),
None => {
Err("evidence Hessian logdet requires exact factors or HVP fallback".into())
}
},
},
EvidenceLogDetSource::Hvp(hvp) => hessian_log_det_from_hvp(hvp),
}
}
pub fn hessian_log_det_from_hvp(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
if hvp.dim == 0 {
return Ok(0.0);
}
if hvp.dim <= ANALYTIC_LOGDET_DENSE_DIM_THRESHOLD {
let mut dense = Array2::<f64>::zeros((hvp.dim, hvp.dim));
let mut basis = vec![0.0_f64; hvp.dim];
for j in 0..hvp.dim {
basis[j] = 1.0;
let col = (hvp.apply)(&basis);
basis[j] = 0.0;
if col.len() != hvp.dim || col.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP logdet expected finite column of length {}, got {}",
hvp.dim,
col.len()
));
}
for i in 0..hvp.dim {
dense[[i, j]] = col[i];
}
}
validate_dense_hvp_symmetry(&dense)?;
for i in 0..hvp.dim {
for j in (i + 1)..hvp.dim {
let avg = 0.5 * (dense[[i, j]] + dense[[j, i]]);
dense[[i, j]] = avg;
dense[[j, i]] = avg;
}
}
dense_spd_log_det(&dense)
} else {
stochastic_hvp_log_det(hvp)
}
}
fn dense_spd_log_det(matrix: &Array2<f64>) -> Result<f64, String> {
if matrix.nrows() != matrix.ncols() {
return Err(format!(
"evidence dense logdet requires square matrix, got {}x{}",
matrix.nrows(),
matrix.ncols()
));
}
if crate::solver::gpu::cuda_selected() {
return crate::solver::gpu::reml_gpu::evidence_derivatives_gpu(
crate::solver::gpu::reml_gpu::RemlGpuInput {
penalized_hessian: matrix.view(),
derivative_hessians: Vec::new(),
},
)
.map(|evidence| evidence.logdet_hessian);
}
let (evals, _) = matrix
.eigh(Side::Lower)
.map_err(|e| format!("evidence dense logdet eigendecomposition failed: {e}"))?;
let mut logdet = 0.0_f64;
for (idx, &ev) in evals.iter().enumerate() {
if !ev.is_finite() || ev <= 0.0 {
return Err(format!(
"evidence dense logdet expected SPD Hessian, eigenvalue {idx} is {ev:.3e}"
));
}
logdet += ev.ln();
}
Ok(logdet)
}
fn validate_dense_hvp_symmetry(matrix: &Array2<f64>) -> Result<(), String> {
let n = matrix.nrows();
let mut norm_sq = 0.0_f64;
for &value in matrix.iter() {
norm_sq += value * value;
}
let mut skew_sq = 0.0_f64;
for i in 0..n {
for j in (i + 1)..n {
let skew = matrix[[i, j]] - matrix[[j, i]];
skew_sq += 2.0 * skew * skew;
}
}
let rel_skew = skew_sq.sqrt() / norm_sq.sqrt().max(1.0);
if !rel_skew.is_finite() || rel_skew > EVIDENCE_HVP_SYMMETRY_REL_TOL {
return Err(format!(
"evidence HVP logdet requires symmetric operator, relative skew norm is {rel_skew:.3e}"
));
}
Ok(())
}
fn validate_hvp_randomized_symmetry(hvp: EvidenceHvpLogDet<'_>) -> Result<(), String> {
let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
for probe in 0..EVIDENCE_HVP_SYMMETRY_PROBES.max(1) {
let mut x = vec![0.0_f64; hvp.dim];
let mut y = vec![0.0_f64; hvp.dim];
rademacher_unit_probe_into_slice(&mut x, (2 * probe) as u64, inv_norm);
rademacher_unit_probe_into_slice(&mut y, (2 * probe + 1) as u64, inv_norm);
let hx = (hvp.apply)(&x);
let hy = (hvp.apply)(&y);
if hx.len() != hvp.dim || hx.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP symmetry check expected finite vector of length {}, got {}",
hvp.dim,
hx.len()
));
}
if hy.len() != hvp.dim || hy.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP symmetry check expected finite vector of length {}, got {}",
hvp.dim,
hy.len()
));
}
let lhs = dot_slice(&x, &hy);
let rhs = dot_slice(&hx, &y);
let scale = (norm2_slice(&hx) * norm2_slice(&y))
.max(norm2_slice(&hy) * norm2_slice(&x))
.max(lhs.abs())
.max(rhs.abs())
.max(1.0);
let rel = (lhs - rhs).abs() / scale;
if !rel.is_finite() || rel > EVIDENCE_HVP_SYMMETRY_REL_TOL {
return Err(format!(
"evidence HVP logdet requires symmetric operator, randomized symmetry probe {probe} has relative bilinear mismatch {rel:.3e}"
));
}
}
Ok(())
}
fn stochastic_hvp_log_det(hvp: EvidenceHvpLogDet<'_>) -> Result<f64, String> {
validate_hvp_randomized_symmetry(hvp)?;
let probes = EVIDENCE_LOGDET_SLQ_PROBES.max(1);
let steps = EVIDENCE_LOGDET_LANCZOS_STEPS.min(hvp.dim).max(1);
let inv_norm = 1.0 / (hvp.dim as f64).sqrt();
let mut estimate = 0.0_f64;
for probe in 0..probes {
let mut q0 = vec![0.0_f64; hvp.dim];
rademacher_unit_probe_into_slice(&mut q0, probe as u64, inv_norm);
let quad = lanczos_log_quadrature_hvp(hvp, q0, steps)?;
estimate += hvp.dim as f64 * quad;
}
Ok(estimate / probes as f64)
}
fn lanczos_log_quadrature_hvp(
hvp: EvidenceHvpLogDet<'_>,
mut q: Vec<f64>,
max_steps: usize,
) -> Result<f64, String> {
let n = hvp.dim;
let mut q_prev = vec![0.0_f64; n];
let mut alphas = Vec::<f64>::with_capacity(max_steps);
let mut betas = Vec::<f64>::with_capacity(max_steps.saturating_sub(1));
let mut beta_prev = 0.0_f64;
let tol = 1e-12_f64;
for step in 0..max_steps {
let applied = (hvp.apply)(&q);
if applied.len() != n || applied.iter().any(|v| !v.is_finite()) {
return Err(format!(
"evidence HVP SLQ expected finite vector of length {n}, got {}",
applied.len()
));
}
let mut w = applied;
if step > 0 {
for i in 0..n {
w[i] -= beta_prev * q_prev[i];
}
}
let alpha = dot_slice(&q, &w);
if !alpha.is_finite() {
return Err("evidence HVP SLQ produced non-finite alpha".to_string());
}
for i in 0..n {
w[i] -= alpha * q[i];
}
let beta = norm2_slice(&w);
alphas.push(alpha);
if step + 1 == max_steps || beta <= tol {
break;
}
if !beta.is_finite() {
return Err("evidence HVP SLQ produced non-finite beta".to_string());
}
betas.push(beta);
q_prev = q;
q = w;
for v in q.iter_mut() {
*v /= beta;
}
beta_prev = beta;
}
let k = alphas.len();
let mut tri = Array2::<f64>::zeros((k, k));
for i in 0..k {
tri[[i, i]] = alphas[i];
if i + 1 < k {
tri[[i, i + 1]] = betas[i];
tri[[i + 1, i]] = betas[i];
}
}
let (evals, evecs) = tri
.eigh(Side::Lower)
.map_err(|e| format!("evidence HVP SLQ eigendecomposition failed: {e}"))?;
let mut quad = 0.0_f64;
for j in 0..k {
let theta = evals[j];
if !theta.is_finite() || theta <= 0.0 {
return Err(format!(
"evidence HVP SLQ expected SPD Hessian, Lanczos Ritz value {j} is {theta:.3e}"
));
}
let weight = evecs[[0, j]] * evecs[[0, j]];
quad += weight * theta.ln();
}
Ok(quad)
}
#[inline]
fn dot_slice(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len());
let mut s = 0.0_f64;
for i in 0..a.len() {
s += a[i] * b[i];
}
s
}
#[inline]
fn norm2_slice(a: &[f64]) -> f64 {
dot_slice(a, a).sqrt()
}
fn rademacher_unit_probe_into_slice(z: &mut [f64], probe: u64, scale: f64) {
let mut state = 0x6A09E667F3BCC909_u64 ^ probe.wrapping_mul(0xD1B54A32D192ED03);
let mut bits = 0_u64;
let mut remaining_bits = 0_u32;
for value in z.iter_mut() {
if remaining_bits == 0 {
bits = splitmix64(&mut state);
remaining_bits = 64;
}
*value = if bits & 1 == 0 { scale } else { -scale };
bits >>= 1;
remaining_bits -= 1;
}
}
#[inline]
const fn splitmix64(state: &mut u64) -> u64 {
crate::linalg::utils::splitmix64(state)
}
pub fn arrow_log_det_from_cache(cache: &ArrowFactorCache) -> Option<f64> {
if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
return None;
}
let schur = cache.schur_factor.as_ref()?;
let mut acc = 0.0_f64;
for l in cache.undamped_factors_iter() {
acc += 2.0 * log_det_from_chol_lower(l);
}
acc += 2.0 * log_det_from_chol_lower(schur);
Some(acc)
}
fn log_det_from_chol_lower(l: &Array2<f64>) -> f64 {
let n = l.nrows();
let mut acc = 0.0_f64;
for i in 0..n {
let d = l[[i, i]];
if d > 0.0 {
acc += d.ln();
} else {
return f64::NAN;
}
}
acc
}
pub fn ift_du_dbeta(cache: &ArrowFactorCache) -> Array2<f64> {
let n = cache.undamped_factor_count();
let d = cache.d;
let k = cache.k;
if !cache.htbeta_available() {
return Array2::<f64>::from_elem((n * d, k), f64::NAN);
}
let mut out = Array2::<f64>::zeros((n * d, k));
let mut beta_basis = Array1::<f64>::zeros(k);
let mut rhs = Array1::<f64>::zeros(d);
for i in 0..n {
let factor = cache.undamped_factor(i);
for col in 0..k {
beta_basis.fill(0.0);
beta_basis[col] = 1.0;
if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs) {
return Array2::<f64>::from_elem((n * d, k), f64::NAN);
}
let y = chol_lower_solve_vector(factor, &rhs);
for c in 0..d {
out[[i * d + c, col]] = -y[c];
}
}
}
out
}
pub fn ift_dbeta_drho(
cache: &ArrowFactorCache,
dg_red_drho: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
if cache.ridge_t != 0.0 || cache.ridge_beta != 0.0 {
return None;
}
let schur = cache.schur_factor.as_ref()?;
let k = cache.k;
let r = dg_red_drho.ncols();
if dg_red_drho.nrows() != k {
return None;
}
let mut out = Array2::<f64>::zeros((k, r));
let mut rhs = Array1::<f64>::zeros(k);
for a in 0..r {
for row in 0..k {
rhs[row] = dg_red_drho[[row, a]];
}
let x = chol_lower_solve_vector(schur, &rhs);
for row in 0..k {
out[[row, a]] = -x[row];
}
}
Some(out)
}
pub fn ift_du_drho(
cache: &ArrowFactorCache,
gu_rho: ArrayView2<'_, f64>,
dbeta_drho: ArrayView2<'_, f64>,
) -> Array2<f64> {
let n = cache.undamped_factor_count();
let d = cache.d;
let k = cache.k;
let r = dbeta_drho.ncols();
if !cache.htbeta_available()
|| gu_rho.nrows() != n * d
|| gu_rho.ncols() != r
|| dbeta_drho.nrows() != k
{
return Array2::<f64>::from_elem((n * d, r), f64::NAN);
}
let mut out = Array2::<f64>::zeros((n * d, r));
let mut rhs = Array1::<f64>::zeros(d);
let mut htbeta_delta = Array1::<f64>::zeros(d);
for a in 0..r {
for i in 0..n {
if !cache.apply_htbeta_row(i, dbeta_drho.column(a), &mut htbeta_delta) {
return Array2::<f64>::from_elem((n * d, r), f64::NAN);
}
for c in 0..d {
rhs[c] = gu_rho[[i * d + c, a]] + htbeta_delta[c];
}
let v = chol_lower_solve_vector(cache.undamped_factor(i), &rhs);
for c in 0..d {
out[[i * d + c, a]] = -v[c];
}
}
}
out
}
#[derive(Clone)]
pub struct EvidenceIftGradientTerms<'a> {
pub dbeta_drho: ArrayView2<'a, f64>,
pub du_drho: ArrayView2<'a, f64>,
pub value_beta: ArrayView1<'a, f64>,
pub value_u: ArrayView1<'a, f64>,
pub logdet_h_beta: ArrayView1<'a, f64>,
pub logdet_h_u: ArrayView1<'a, f64>,
}
pub fn evidence_ift_gradient_correction(terms: EvidenceIftGradientTerms<'_>) -> Array1<f64> {
let k = terms.dbeta_drho.nrows();
let nd = terms.du_drho.nrows();
let r = terms.dbeta_drho.ncols();
if terms.du_drho.ncols() != r
|| terms.value_beta.len() != k
|| terms.logdet_h_beta.len() != k
|| terms.value_u.len() != nd
|| terms.logdet_h_u.len() != nd
{
return Array1::<f64>::from_elem(r, f64::NAN);
}
let mut out = Array1::<f64>::zeros(r);
for a in 0..r {
let mut acc = 0.0_f64;
for j in 0..k {
let mode = terms.dbeta_drho[[j, a]];
acc += terms.value_beta[j] * mode;
acc += 0.5 * terms.logdet_h_beta[j] * mode;
}
for j in 0..nd {
let mode = terms.du_drho[[j, a]];
acc += terms.value_u[j] * mode;
acc += 0.5 * terms.logdet_h_u[j] * mode;
}
out[a] = acc;
}
out
}
pub fn evidence_grad_rho(
cache: &ArrowFactorCache,
value_rho: ArrayView1<'_, f64>,
huu_drho: &[Vec<Array2<f64>>],
htbeta_drho: &[Vec<Array2<f64>>],
hbb_drho: &[Array2<f64>],
pen_logdet_drho: ArrayView1<'_, f64>,
ift_terms: EvidenceIftGradientTerms<'_>,
) -> Array1<f64> {
let r = value_rho.len();
let n = cache.undamped_factor_count();
let d = cache.d;
let k = cache.k;
let mut out = Array1::<f64>::zeros(r);
if !cache.htbeta_available()
|| pen_logdet_drho.len() != r
|| huu_drho.len() != n
|| htbeta_drho.len() != n
|| hbb_drho.len() != r
|| huu_drho.iter().any(|row| row.len() != r)
|| htbeta_drho.iter().any(|row| row.len() != r)
|| hbb_drho.iter().any(|m| m.nrows() != k || m.ncols() != k)
|| huu_drho
.iter()
.any(|row| row.iter().any(|m| m.nrows() != d || m.ncols() != d))
|| htbeta_drho
.iter()
.any(|row| row.iter().any(|m| m.nrows() != d || m.ncols() != k))
{
out.fill(f64::NAN);
return out;
}
let ift_correction = evidence_ift_gradient_correction(ift_terms);
if ift_correction.len() != r || ift_correction.iter().any(|v| v.is_nan()) {
out.fill(f64::NAN);
return out;
}
let schur = match cache.schur_factor.as_ref() {
Some(s) => s,
None => {
for a in 0..r {
out[a] = f64::NAN;
}
return out;
}
};
let mut y_blocks: Vec<Array2<f64>> = Vec::with_capacity(n);
let mut beta_basis = Array1::<f64>::zeros(k);
let mut rhs = Array1::<f64>::zeros(d);
for i in 0..n {
let factor = cache.undamped_factor(i);
let mut yi = Array2::<f64>::zeros((d, k));
for col in 0..k {
beta_basis.fill(0.0);
beta_basis[col] = 1.0;
if !cache.apply_htbeta_row(i, beta_basis.view(), &mut rhs) {
out.fill(f64::NAN);
return out;
}
let v = chol_lower_solve_vector(factor, &rhs);
for c in 0..d {
yi[[c, col]] = v[c];
}
}
y_blocks.push(yi);
}
let mut trace_rhs = Array1::<f64>::zeros(d);
let mut da_tmp = Array2::<f64>::zeros((d, k));
let mut col_scratch = Array1::<f64>::zeros(k);
for a in 0..r {
let mut grad = value_rho[a];
let mut row_trace_acc = 0.0_f64;
for i in 0..n {
let m_i = &huu_drho[i][a];
assert_eq!(m_i.shape(), &[d, d]);
for col in 0..d {
for r0 in 0..d {
trace_rhs[r0] = m_i[[r0, col]];
}
let v = chol_lower_solve_vector(cache.undamped_factor(i), &trace_rhs);
row_trace_acc += v[col];
}
}
let mut da = hbb_drho[a].clone();
assert_eq!(da.shape(), &[k, k]);
for i in 0..n {
let dhtb = &htbeta_drho[i][a]; let yi = &y_blocks[i]; for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..d {
acc += dhtb[[cc, r0]] * yi[[cc, c0]];
}
da[[r0, c0]] -= acc;
}
}
for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..d {
acc += yi[[cc, r0]] * dhtb[[cc, c0]];
}
da[[r0, c0]] -= acc;
}
}
let dhuu = &huu_drho[i][a];
for r0 in 0..d {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..d {
acc += dhuu[[r0, cc]] * yi[[cc, c0]];
}
da_tmp[[r0, c0]] = acc;
}
}
for r0 in 0..k {
for c0 in 0..k {
let mut acc = 0.0;
for cc in 0..d {
acc += yi[[cc, r0]] * da_tmp[[cc, c0]];
}
da[[r0, c0]] += acc;
}
}
}
let mut schur_trace_acc = 0.0_f64;
for j in 0..k {
for r0 in 0..k {
col_scratch[r0] = da[[r0, j]];
}
let v = chol_lower_solve_vector(schur, &col_scratch);
schur_trace_acc += v[j];
}
grad += 0.5 * (row_trace_acc + schur_trace_acc);
grad += ift_correction[a];
grad -= 0.5 * pen_logdet_drho[a];
out[a] = grad;
}
out
}
pub fn select_topology(
candidates: &[TopologyCandidate],
options: TopologySelectOptions,
) -> SelectedTopology {
let mut valid: Vec<TopologyCandidate> = candidates
.iter()
.filter(|c| {
c.converged
&& c.exclusion_reason.is_none()
&& c.negative_log_evidence.is_finite()
&& topology_selection_score(c, options.score_scale).is_finite()
})
.cloned()
.collect();
let mut excluded: Vec<TopologyCandidate> = candidates
.iter()
.filter(|c| {
!(c.converged && c.exclusion_reason.is_none() && c.negative_log_evidence.is_finite())
|| !topology_selection_score(c, options.score_scale).is_finite()
})
.cloned()
.collect();
assert!(
!valid.is_empty(),
"select_topology: no finite valid candidates; proposal §6.11 forbids silent fallback"
);
valid.sort_by(|a, b| {
topology_selection_score(a, options.score_scale)
.partial_cmp(&topology_selection_score(b, options.score_scale))
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.kind.complexity_rank().cmp(&b.kind.complexity_rank()))
});
let tie = if valid.len() >= 2 {
let top = topology_selection_score(&valid[0], options.score_scale);
let next = topology_selection_score(&valid[1], options.score_scale);
(next - top).abs() <= options.tie_tolerance
} else {
false
};
if tie {
let top_score = topology_selection_score(&valid[0], options.score_scale);
let tied_end = valid
.iter()
.position(|c| {
(topology_selection_score(c, options.score_scale) - top_score).abs()
> options.tie_tolerance
})
.unwrap_or(valid.len());
valid[..tied_end].sort_by_key(|c| c.kind.complexity_rank());
}
let winner = valid[0].kind;
valid.append(&mut excluded);
SelectedTopology {
winner,
ranking: valid,
tie,
}
}
fn topology_selection_score(candidate: &TopologyCandidate, scale: TopologyScoreScale) -> f64 {
match scale {
TopologyScoreScale::PerObservation => {
if candidate.n_obs == 0 {
f64::NAN
} else {
candidate.negative_log_evidence / candidate.n_obs as f64
}
}
TopologyScoreScale::PerEffectiveDim => {
if !(candidate.effective_dim.is_finite() && candidate.effective_dim > 0.0) {
f64::NAN
} else {
candidate.negative_log_evidence / candidate.effective_dim
}
}
}
}
pub fn cache_supports_exact_evidence(cache: &ArrowFactorCache) -> bool {
cache.ridge_t == 0.0
&& cache.ridge_beta == 0.0
&& cache.schur_factor.is_some()
&& cache.htbeta_available()
}
pub fn cache_matches_system(cache: &ArrowFactorCache, sys: &ArrowSchurSystem) -> bool {
cache.d == sys.d
&& cache.k == sys.k
&& cache.n_rows() == sys.rows.len()
&& cache.undamped_factor_count() == sys.rows.len()
&& cache.manifold_mode_fingerprint == sys.manifold_mode_fingerprint
&& cache.row_hessian_fingerprint == sys.current_row_hessian_fingerprint()
}
fn chol_lower_solve_vector(l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for kk in 0..i {
sum -= l[[i, kk]] * y[kk];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for kk in (i + 1)..n {
sum -= l[[kk, i]] * x[kk];
}
x[i] = sum / l[[i, i]];
}
x
}
#[cfg(test)]
mod tests {
use super::*;
fn make_minimal_cache() -> ArrowFactorCache {
let l_huu = Array2::from_shape_vec((1, 1), vec![std::f64::consts::SQRT_2]).unwrap();
let l_schur = Array2::from_shape_vec((1, 1), vec![(1.875_f64).sqrt()]).unwrap();
let htbeta = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
ArrowFactorCache {
htt_factors: std::sync::Arc::from(vec![l_huu]),
htt_factors_undamped: crate::solver::arrow_schur::ArrowUndampedFactors::SameAsDamped,
schur_factor: Some(l_schur),
solver_mode: crate::solver::arrow_schur::ArrowSolverMode::Direct,
ridge_t: 0.0,
ridge_beta: 0.0,
htbeta: crate::solver::arrow_schur::ArrowHtbetaCache::Dense {
blocks: std::sync::Arc::from(vec![htbeta]),
estimated_bytes: std::mem::size_of::<f64>(),
},
d: 1,
k: 1,
manifold_mode_fingerprint: 0,
row_hessian_fingerprint: 0,
}
}
#[test]
fn laplace_evidence_returns_finite_for_minimal_cache() {
let cache = make_minimal_cache();
let v = laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: None,
},
0.0,
0.0,
2.0,
1.0,
);
assert!(v.is_finite());
let expected =
0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
assert!((v - expected).abs() < 1e-12);
}
#[test]
fn laplace_evidence_nan_when_ridge_is_nonzero() {
let mut cache = make_minimal_cache();
cache.ridge_t = 1e-3;
assert!(
laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: None,
},
0.0,
0.0,
2.0,
1.0,
)
.is_nan()
);
}
#[test]
fn laplace_evidence_uses_hvp_fallback_without_schur_factor() {
let mut cache = make_minimal_cache();
cache.schur_factor = None;
let hvp = |x: &[f64]| -> Vec<f64> { vec![2.0 * x[0], 1.875 * x[1]] };
let v = laplace_evidence(
EvidenceLogDetSource::FactoredArrow {
cache: &cache,
fallback_hvp: Some(EvidenceHvpLogDet {
dim: 2,
apply: &hvp,
}),
},
0.0,
0.0,
2.0,
1.0,
);
let expected =
0.5 * (2.0_f64.ln() + 1.875_f64.ln()) - 0.5 * (2.0 * std::f64::consts::PI).ln();
assert!((v - expected).abs() < 1e-12);
}
#[test]
fn ift_du_dbeta_has_expected_shape() {
let cache = make_minimal_cache();
let du_db = ift_du_dbeta(&cache);
assert_eq!(du_db.shape(), &[1, 1]);
assert!((du_db[[0, 0]] - (-0.25)).abs() < 1e-12);
}
#[test]
fn ift_dbeta_drho_returns_some_for_direct_cache() {
let cache = make_minimal_cache();
let q = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
let out = ift_dbeta_drho(&cache, q.view()).unwrap();
assert_eq!(out.shape(), &[1, 1]);
assert!((out[[0, 0]] + 1.0 / 1.875).abs() < 1e-12);
}
#[test]
fn topology_select_picks_lowest_negative_log_evidence() {
let candidates = vec![
TopologyCandidate {
kind: TopologyKind::Flat,
negative_log_evidence: 10.0,
effective_dim: 4.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Sphere,
negative_log_evidence: 8.0,
effective_dim: 5.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Torus,
negative_log_evidence: f64::NAN,
effective_dim: 6.0,
n_obs: 100,
converged: false,
exclusion_reason: Some("torus periods missing".to_string()),
},
];
let sel = select_topology(&candidates, TopologySelectOptions::default());
assert_eq!(sel.winner, TopologyKind::Sphere);
assert!(!sel.tie);
}
#[test]
fn topology_select_tie_breaks_to_simpler() {
let candidates = vec![
TopologyCandidate {
kind: TopologyKind::Sphere,
negative_log_evidence: 5.0,
effective_dim: 5.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
TopologyCandidate {
kind: TopologyKind::Flat,
negative_log_evidence: 5.0 + 1e-6,
effective_dim: 4.0,
n_obs: 100,
converged: true,
exclusion_reason: None,
},
];
let sel = select_topology(&candidates, TopologySelectOptions::default());
assert_eq!(sel.winner, TopologyKind::Flat);
assert!(sel.tie);
}
}