use faer::Side;
use ndarray::Array2;
use crate::linalg::faer_ndarray::FaerCholesky;
use crate::linalg::triangular::forward_substitution_lower_matrix;
#[derive(Debug, Clone)]
pub struct LogdetEnclosure {
pub block_diag_logdet: f64,
pub lower: f64,
pub upper: f64,
pub rho: f64,
pub p2: f64,
pub p3: Option<f64>,
}
impl LogdetEnclosure {
pub fn gap(&self) -> f64 {
self.upper - self.lower
}
pub fn gap_resolves_margin(gap: f64, decision_margin: f64) -> bool {
decision_margin.is_finite()
&& decision_margin > 0.0
&& gap.is_finite()
&& gap < decision_margin
}
pub fn decide_within_margin(&self, decision_margin: f64) -> MarginVerdict {
let gap = self.gap();
if decision_margin.is_finite() && decision_margin > 0.0 && gap < decision_margin {
MarginVerdict::Decided {
value: 0.5 * (self.lower + self.upper),
gap,
decision_margin,
}
} else {
MarginVerdict::InsufficientMargin {
gap,
decision_margin,
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MarginVerdict {
Decided {
value: f64,
gap: f64,
decision_margin: f64,
},
InsufficientMargin {
gap: f64,
decision_margin: f64,
},
}
impl MarginVerdict {
pub fn decided_value(&self) -> Option<f64> {
match self {
MarginVerdict::Decided { value, .. } => Some(*value),
MarginVerdict::InsufficientMargin { .. } => None,
}
}
pub fn is_decided(&self) -> bool {
matches!(self, MarginVerdict::Decided { .. })
}
}
pub fn refine_logdet_enclosure_to_margin(
diag: &[Array2<f64>],
off: &[(usize, usize, Array2<f64>)],
decision_margin: f64,
max_absorptions: usize,
) -> Result<(LogdetEnclosure, MarginVerdict), String> {
let order2 = block_preconditioned_logdet_enclosure(diag, off, false)?;
if order2.decide_within_margin(decision_margin).is_decided() {
let verdict = order2.decide_within_margin(decision_margin);
return Ok((order2, verdict));
}
let order3 = block_preconditioned_logdet_enclosure(diag, off, true)?;
let verdict3 = order3.decide_within_margin(decision_margin);
if verdict3.is_decided() {
return Ok((order3, verdict3));
}
if max_absorptions == 0 || off.is_empty() {
return Ok((order3, verdict3));
}
let (merged_diag, merged_off) = absorb_strongest_pair(diag, off)?;
refine_logdet_enclosure_to_margin(
&merged_diag,
&merged_off,
decision_margin,
max_absorptions - 1,
)
}
fn absorb_strongest_pair(
diag: &[Array2<f64>],
off: &[(usize, usize, Array2<f64>)],
) -> Result<(Vec<Array2<f64>>, Vec<(usize, usize, Array2<f64>)>), String> {
let k = diag.len();
let mut lowers: Vec<Array2<f64>> = Vec::with_capacity(k);
for (i, s_ii) in diag.iter().enumerate() {
let factor = s_ii
.cholesky(Side::Lower)
.map_err(|e| format!("absorb_strongest_pair: block {i} is not SPD: {e:?}"))?;
lowers.push(factor.lower_triangular());
}
let mut best: Option<(usize, f64)> = None; for (idx, (i, j, s_ij)) in off.iter().enumerate() {
let e_ij = whitened_off_block(&lowers[*i], &lowers[*j], s_ij);
let f = frobenius_sq(&e_ij).sqrt();
match best {
None => best = Some((idx, f)),
Some((_, bf)) if f > bf => best = Some((idx, f)),
_ => {}
}
}
let (best_idx, _) =
best.ok_or_else(|| "absorb_strongest_pair: no off-diagonal pair to absorb".to_string())?;
let (a, b, _) = &off[best_idx];
let (a, b) = (*a.min(b), *a.max(b));
let (ma, mb) = (diag[a].nrows(), diag[b].nrows());
let merged_dim = ma + mb;
let mut joint = Array2::<f64>::zeros((merged_dim, merged_dim));
joint.slice_mut(ndarray::s![..ma, ..ma]).assign(&diag[a]);
joint.slice_mut(ndarray::s![ma.., ma..]).assign(&diag[b]);
if let Some((_, _, s_ab)) = off.iter().find(|(i, j, _)| (*i, *j) == (a, b)) {
joint.slice_mut(ndarray::s![..ma, ma..]).assign(s_ab);
let s_ba = s_ab.t().to_owned();
joint.slice_mut(ndarray::s![ma.., ..ma]).assign(&s_ba);
}
let mut old_to_new = vec![0usize; k];
let mut new_diag: Vec<Array2<f64>> = Vec::with_capacity(k - 1);
new_diag.push(joint);
let mut next = 1usize;
for g in 0..k {
if g == a || g == b {
old_to_new[g] = 0;
} else {
old_to_new[g] = next;
new_diag.push(diag[g].clone());
next += 1;
}
}
let mut new_off: Vec<(usize, usize, Array2<f64>)> = Vec::new();
use std::collections::BTreeMap;
let mut joint_couplings: BTreeMap<usize, Array2<f64>> = BTreeMap::new();
for (i, j, s_ij) in off {
let (i, j) = (*i, *j);
if (i, j) == (a, b) {
continue; }
let touches_a = i == a || j == a;
let touches_b = i == b || j == b;
if touches_a || touches_b {
let g = if i == a || i == b { j } else { i };
let mg = diag[g].nrows();
let entry = joint_couplings
.entry(old_to_new[g])
.or_insert_with(|| Array2::<f64>::zeros((merged_dim, mg)));
let s_half = if i == g {
s_ij.t().to_owned()
} else {
s_ij.clone()
};
let row_off = if i == a || j == a { 0 } else { ma };
entry
.slice_mut(ndarray::s![row_off..row_off + s_half.nrows(), ..])
.assign(&s_half);
} else {
let (ni, nj) = (old_to_new[i], old_to_new[j]);
let (ni, nj, block) = if ni < nj {
(ni, nj, s_ij.clone())
} else {
(nj, ni, s_ij.t().to_owned())
};
new_off.push((ni, nj, block));
}
}
for (g_new, block) in joint_couplings {
new_off.push((0, g_new, block));
}
Ok((new_diag, new_off))
}
fn whitened_off_block(l_i: &Array2<f64>, l_j: &Array2<f64>, s_ij: &Array2<f64>) -> Array2<f64> {
let x = forward_substitution_lower_matrix(l_i, s_ij);
let xt = x.t().to_owned();
forward_substitution_lower_matrix(l_j, &xt).t().to_owned()
}
fn frobenius_sq(a: &Array2<f64>) -> f64 {
a.iter().map(|v| v * v).sum()
}
pub fn block_preconditioned_logdet_enclosure(
diag: &[Array2<f64>],
off: &[(usize, usize, Array2<f64>)],
use_third_moment: bool,
) -> Result<LogdetEnclosure, String> {
let k = diag.len();
if k == 0 {
return Err("block_preconditioned_logdet_enclosure: no diagonal blocks".to_string());
}
let mut lowers: Vec<Array2<f64>> = Vec::with_capacity(k);
let mut block_diag_logdet = 0.0_f64;
for (i, s_ii) in diag.iter().enumerate() {
if s_ii.nrows() != s_ii.ncols() {
return Err(format!(
"block_preconditioned_logdet_enclosure: diagonal block {i} is not square"
));
}
let factor = s_ii.cholesky(Side::Lower).map_err(|e| {
format!("block_preconditioned_logdet_enclosure: block {i} is not SPD: {e:?}")
})?;
let l = factor.lower_triangular();
for d in 0..l.nrows() {
block_diag_logdet += 2.0 * l[[d, d]].ln();
}
lowers.push(l);
}
let mut whitened: Vec<(usize, usize, Array2<f64>)> = Vec::with_capacity(off.len());
let mut p2 = 0.0_f64;
let mut row_sums = vec![0.0_f64; k];
for (i, j, s_ij) in off {
let (i, j) = (*i, *j);
if i >= j || j >= k {
return Err(format!(
"block_preconditioned_logdet_enclosure: off-block ({i},{j}) must satisfy i<j<K={k}"
));
}
if s_ij.nrows() != lowers[i].nrows() || s_ij.ncols() != lowers[j].nrows() {
return Err(format!(
"block_preconditioned_logdet_enclosure: off-block ({i},{j}) shape mismatch"
));
}
let e_ij = whitened_off_block(&lowers[i], &lowers[j], s_ij);
let f2 = frobenius_sq(&e_ij);
p2 += 2.0 * f2;
let f = f2.sqrt();
row_sums[i] += f;
row_sums[j] += f;
whitened.push((i, j, e_ij));
}
let gershgorin = row_sums.iter().fold(0.0_f64, |a, &b| a.max(b));
let rho = p2.sqrt().min(gershgorin);
if !(rho < 1.0) {
return Err(format!(
"block_preconditioned_logdet_enclosure: spectral-radius certificate failed \
(ρ = {rho:.6} ≥ 1); refine the block partition (absorb the strongest \
off-diagonal pair into the preconditioner) and retry"
));
}
let p3 = if use_third_moment {
let get = |a: usize, b: usize| -> Option<Array2<f64>> {
for (i, j, e) in &whitened {
if *i == a && *j == b {
return Some(e.clone());
}
if *i == b && *j == a {
return Some(e.t().to_owned());
}
}
None
};
let mut acc = 0.0_f64;
for a in 0..k {
for b in 0..k {
if b == a {
continue;
}
let Some(e_ab) = get(a, b) else { continue };
for c in 0..k {
if c == a || c == b {
continue;
}
let (Some(e_bc), Some(e_ca)) = (get(b, c), get(c, a)) else {
continue;
};
let prod = e_ab.dot(&e_bc);
for r in 0..prod.nrows() {
for s in 0..prod.ncols() {
acc += prod[[r, s]] * e_ca[[s, r]];
}
}
}
}
}
Some(acc)
} else {
None
};
let (corr_lower, corr_upper) = match p3 {
Some(p3) => {
let upper = -p2 / 2.0 + p3 / 3.0;
let lower = upper - rho * rho * p2 / (4.0 * (1.0 - rho));
(lower, upper)
}
None => {
let upper = -p2 / 2.0 + rho * p2 / 3.0;
let lower = -p2 / 2.0 - rho * p2 / (3.0 * (1.0 - rho));
(lower, upper)
}
};
Ok(LogdetEnclosure {
block_diag_logdet,
lower: block_diag_logdet + corr_lower,
upper: block_diag_logdet + corr_upper,
rho,
p2,
p3,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn fixture(
k: usize,
m: usize,
coupling: f64,
) -> (
Vec<Array2<f64>>,
Vec<(usize, usize, Array2<f64>)>,
Array2<f64>,
) {
let dim = k * m;
let mut dense = Array2::<f64>::zeros((dim, dim));
let mut diag = Vec::new();
let mut off = Vec::new();
for i in 0..k {
let mut d = Array2::<f64>::zeros((m, m));
for r in 0..m {
for c in 0..m {
let v = if r == c {
3.0 + 0.4 * (i as f64) + 0.2 * (r as f64)
} else {
0.3 * ((r + 2 * c + i) as f64 * 0.7).sin()
};
d[[r, c]] = v;
}
}
let mut sym = Array2::<f64>::zeros((m, m));
for r in 0..m {
for c in 0..m {
sym[[r, c]] = 0.5 * (d[[r, c]] + d[[c, r]]);
}
sym[[r, r]] += 1.0;
}
for r in 0..m {
for c in 0..m {
dense[[i * m + r, i * m + c]] = sym[[r, c]];
}
}
diag.push(sym);
}
for i in 0..k {
for j in (i + 1)..k {
let mut b = Array2::<f64>::zeros((m, m));
for r in 0..m {
for c in 0..m {
b[[r, c]] =
coupling * ((r as f64) - (c as f64) + (i + j) as f64 * 0.31).cos();
}
}
for r in 0..m {
for c in 0..m {
dense[[i * m + r, j * m + c]] = b[[r, c]];
dense[[j * m + c, i * m + r]] = b[[r, c]];
}
}
off.push((i, j, b));
}
}
(diag, off, dense)
}
fn dense_logdet(s: &Array2<f64>) -> f64 {
let l = s
.cholesky(Side::Lower)
.expect("oracle fixture must be SPD")
.lower_triangular();
(0..l.nrows()).map(|d| 2.0 * l[[d, d]].ln()).sum()
}
#[test]
fn enclosure_contains_dense_truth_and_order3_tightens() {
let (diag, off, dense) = fixture(4, 3, 0.08);
let truth = dense_logdet(&dense);
let e2 =
block_preconditioned_logdet_enclosure(&diag, &off, false).expect("order-2 enclosure");
let e3 =
block_preconditioned_logdet_enclosure(&diag, &off, true).expect("order-3 enclosure");
assert!(
e2.lower <= truth && truth <= e2.upper,
"order-2 enclosure [{}, {}] must contain dense log|S| = {}",
e2.lower,
e2.upper,
truth
);
assert!(
e3.lower <= truth && truth <= e3.upper,
"order-3 enclosure [{}, {}] must contain dense log|S| = {}",
e3.lower,
e3.upper,
truth
);
assert!(
e3.gap() <= e2.gap() + 1e-12,
"order-3 gap {} must not exceed order-2 gap {}",
e3.gap(),
e2.gap()
);
assert!(e3.gap() < 0.5 * e2.gap() + 1e-9 || e2.gap() < 1e-9);
}
#[test]
fn zero_coupling_is_exact() {
let (diag, _off, dense) = fixture(3, 2, 0.0);
let truth = dense_logdet(&dense);
let e = block_preconditioned_logdet_enclosure(&diag, &[], true).expect("enclosure");
assert!((e.lower - truth).abs() < 1e-10 && (e.upper - truth).abs() < 1e-10);
assert!(e.gap() < 1e-12);
}
#[test]
fn failed_radius_certificate_refuses() {
let (diag, off, _dense) = fixture(3, 2, 5.0);
let err = block_preconditioned_logdet_enclosure(&diag, &off, false)
.expect_err("ρ ≥ 1 must refuse");
assert!(err.contains("spectral-radius certificate failed"));
}
#[test]
fn margin_verdict_is_honest_about_the_gap() {
let (diag, off, dense) = fixture(4, 3, 0.08);
let truth = dense_logdet(&dense);
let e = block_preconditioned_logdet_enclosure(&diag, &off, true).expect("enclosure");
let tight = e.gap() * 0.5;
assert!(!e.decide_within_margin(tight).is_decided());
assert!(e.decide_within_margin(tight).decided_value().is_none());
let loose = e.gap() * 2.0 + 1e-9;
let verdict = e.decide_within_margin(loose);
assert!(verdict.is_decided());
let value = verdict.decided_value().expect("decided");
assert!((value - truth).abs() <= 0.5 * e.gap() + 1e-12);
}
#[test]
fn pair_absorption_preserves_truth_and_tightens() {
let (diag, off, dense) = fixture(4, 3, 0.14);
let truth = dense_logdet(&dense);
let before =
block_preconditioned_logdet_enclosure(&diag, &off, true).expect("pre-absorption");
let (mdiag, moff) = absorb_strongest_pair(&diag, &off).expect("absorb");
assert_eq!(
mdiag.len(),
diag.len() - 1,
"one fewer block after absorption"
);
let after =
block_preconditioned_logdet_enclosure(&mdiag, &moff, true).expect("post-absorption");
assert!(
after.lower <= truth && truth <= after.upper,
"absorbed enclosure [{}, {}] must still contain log|S| = {truth}",
after.lower,
after.upper
);
assert!(
after.gap() <= before.gap() + 1e-9,
"absorption must not widen the gap ({} vs {})",
after.gap(),
before.gap()
);
}
#[test]
fn refinement_ladder_closes_margin_via_absorption() {
let (diag, off, dense) = fixture(5, 2, 0.16);
let truth = dense_logdet(&dense);
let order3 =
block_preconditioned_logdet_enclosure(&diag, &off, true).expect("order-3 enclosure");
let margin = order3.gap() * 0.5;
let (enc, verdict) =
refine_logdet_enclosure_to_margin(&diag, &off, margin, 8).expect("ladder");
assert!(
verdict.is_decided(),
"ladder must close the margin via absorption"
);
assert!(
enc.lower <= truth && truth <= enc.upper,
"refined enclosure [{}, {}] must contain log|S| = {truth}",
enc.lower,
enc.upper
);
assert!(enc.gap() < margin);
}
}