use super::*;
use gam_linalg::lanczos::{symmetric_lanczos_eigenpairs, SymmetricLanczosOptions};
use gam_linalg::utils::splitmix64;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[derive(Debug, Clone, Copy)]
pub struct SlqLogDet {
pub estimate: f64,
pub std_err: f64,
}
const RITZ_LN_FLOOR: f64 = 1e-300;
fn rademacher_into(z: &mut Array1<f64>, probe_seed: u64) {
let mut state = probe_seed;
let mut bits: u64 = 0;
let mut remaining: u32 = 0;
for value in z.iter_mut() {
if remaining == 0 {
bits = splitmix64(&mut state);
remaining = 64;
}
*value = if bits & 1 == 1 { 1.0 } else { -1.0 };
bits >>= 1;
remaining -= 1;
}
}
pub fn slq_logdet(
dim: usize,
matvec: impl Fn(ArrayView1<f64>) -> Array1<f64> + Sync,
num_probes: usize,
lanczos_steps: usize,
seed: u64,
) -> SlqLogDet {
if dim == 0 {
return SlqLogDet {
estimate: 0.0,
std_err: 0.0,
};
}
let num_probes = num_probes.max(1);
let steps = lanczos_steps.max(1).min(dim);
let norm_sq = dim as f64;
let lanczos_options = SymmetricLanczosOptions {
max_steps: steps,
residual_tol: 0.0,
local_reorthogonalize: false,
full_reorthogonalize: true,
};
let matvec = &matvec;
let contributions: Vec<f64> = (0..num_probes)
.into_par_iter()
.map(|probe| {
let probe_seed = seed.wrapping_add(probe as u64);
let mut z = Array1::<f64>::zeros(dim);
rademacher_into(&mut z, probe_seed);
let mut in_buf = Array1::<f64>::zeros(dim);
let mut apply = |x: &[f64], out: &mut [f64]| -> Result<(), String> {
in_buf
.as_slice_mut()
.expect("contiguous probe input buffer")
.copy_from_slice(x);
let y = matvec(in_buf.view());
if y.len() != dim {
return Err(format!(
"slq_logdet matvec returned length {}, expected {dim}",
y.len()
));
}
out.copy_from_slice(y.as_slice().expect("contiguous matvec output"));
Ok(())
};
let start = z.as_slice().expect("contiguous probe vector");
match symmetric_lanczos_eigenpairs(dim, start, lanczos_options, &mut apply) {
Ok(pairs) => {
norm_sq * clamped_log_quadrature(&pairs.eigenvalues, &pairs.eigenvectors)
}
Err(_) => 0.0,
}
})
.collect();
let n = contributions.len() as f64;
let mean = contributions.iter().sum::<f64>() / n;
let std_err = if contributions.len() > 1 {
let var = contributions
.iter()
.map(|c| {
let d = c - mean;
d * d
})
.sum::<f64>()
/ (n - 1.0);
(var / n).sqrt()
} else {
0.0
};
SlqLogDet {
estimate: mean,
std_err,
}
}
fn clamped_log_quadrature(eigenvalues: &Array1<f64>, eigenvectors: &Array2<f64>) -> f64 {
let mut quad = 0.0_f64;
for i in 0..eigenvalues.len() {
let tau0 = eigenvectors[[0, i]];
let weight = tau0 * tau0;
let lambda = eigenvalues[i].max(RITZ_LN_FLOOR);
quad += weight * lambda.ln();
}
quad
}
#[cfg(test)]
mod tests {
use super::*;
fn next_uniform(state: &mut u64, lo: f64, hi: f64) -> f64 {
let bits = splitmix64(state) >> 11;
let unit = (bits as f64) / ((1u64 << 53) as f64);
lo + (hi - lo) * unit
}
fn random_spd(dim: usize, m_rows: usize, delta: f64, seed: u64) -> Array2<f64> {
let mut state = seed;
let mut m = Array2::<f64>::zeros((m_rows, dim));
for value in m.iter_mut() {
*value = next_uniform(&mut state, -1.0, 1.0);
}
let mut a = m.t().dot(&m);
for i in 0..dim {
a[[i, i]] += delta;
}
for i in 0..dim {
for j in (i + 1)..dim {
let avg = 0.5 * (a[[i, j]] + a[[j, i]]);
a[[i, j]] = avg;
a[[j, i]] = avg;
}
}
a
}
fn exact_logdet(a: &Array2<f64>) -> f64 {
let (evals, _) = a.eigh(Side::Lower).expect("SPD eigendecomposition");
evals.iter().map(|&l| l.max(RITZ_LN_FLOOR).ln()).sum()
}
fn condition_number(a: &Array2<f64>) -> f64 {
let (evals, _) = a.eigh(Side::Lower).expect("SPD eigendecomposition");
let max = evals.iter().cloned().fold(f64::MIN, f64::max);
let min = evals.iter().cloned().fold(f64::MAX, f64::min);
max / min
}
#[test]
fn slq_matches_exact_logdet_well_conditioned() {
for (dim, seed) in [(60usize, 1u64), (120, 2), (200, 3)] {
let a = random_spd(dim, dim + 40, 5.0, seed);
let exact = exact_logdet(&a);
let cond = condition_number(&a);
let result = slq_logdet(dim, |v| a.dot(&v), 48, 70, 0xA5A5_0000 ^ seed);
let rel_err = (result.estimate - exact).abs() / exact.abs();
eprintln!(
"well-conditioned dim={dim} cond={cond:.2e} exact={exact:.6} \
est={:.6} rel_err={rel_err:.4e} std_err={:.4e}",
result.estimate, result.std_err
);
assert!(
rel_err < 0.05,
"dim={dim}: SLQ relative error {rel_err:.4e} exceeds 5% \
(exact={exact}, est={})",
result.estimate
);
assert!(
(result.estimate - exact).abs() < 3.0 * result.std_err + 0.05 * exact.abs(),
"dim={dim}: estimate not within ~3 std_err of exact \
(|Δ|={:.4e}, std_err={:.4e})",
(result.estimate - exact).abs(),
result.std_err
);
}
}
#[test]
fn slq_handles_moderately_ill_conditioned() {
let dim = 150usize;
let a = random_spd(dim, dim + 5, 0.05, 7);
let exact = exact_logdet(&a);
let cond = condition_number(&a);
assert!(
cond > 1e3,
"test fixture should be moderately ill-conditioned, got cond={cond:.2e}"
);
let result = slq_logdet(dim, |v| a.dot(&v), 40, 110, 0xC0FFEE);
let rel_err = (result.estimate - exact).abs() / exact.abs();
eprintln!(
"ill-conditioned dim={dim} cond={cond:.2e} exact={exact:.6} \
est={:.6} rel_err={rel_err:.4e} std_err={:.4e}",
result.estimate, result.std_err
);
assert!(
rel_err < 0.10,
"ill-conditioned dim={dim}: SLQ relative error {rel_err:.4e} \
exceeds 10% (cond={cond:.2e}, exact={exact}, est={})",
result.estimate
);
}
#[test]
fn slq_is_deterministic_for_fixed_seed() {
let dim = 80usize;
let a = random_spd(dim, dim + 20, 2.0, 11);
let r1 = slq_logdet(dim, |v| a.dot(&v), 24, 50, 99);
let r2 = slq_logdet(dim, |v| a.dot(&v), 24, 50, 99);
assert_eq!(
r1.estimate, r2.estimate,
"SLQ must be bit-reproducible for a fixed seed"
);
assert_eq!(r1.std_err, r2.std_err);
}
#[test]
fn slq_diagonal_operator_matches_closed_form() {
let dim = 100usize;
let mut state = 123u64;
let diag: Vec<f64> = (0..dim).map(|_| next_uniform(&mut state, 0.5, 4.0)).collect();
let exact: f64 = diag.iter().map(|d| d.ln()).sum();
let diag_clone = diag.clone();
let result = slq_logdet(
dim,
move |v| {
let mut out = v.to_owned();
for (o, d) in out.iter_mut().zip(diag_clone.iter()) {
*o *= d;
}
out
},
32,
60,
7,
);
let rel_err = (result.estimate - exact).abs() / exact.abs();
eprintln!(
"diagonal dim={dim} exact={exact:.6} est={:.6} rel_err={rel_err:.4e}",
result.estimate
);
assert!(
rel_err < 0.05,
"diagonal operator: relative error {rel_err:.4e} exceeds 5%"
);
}
#[test]
fn slq_empty_operator_is_zero() {
let result = slq_logdet(0, |v| v.to_owned(), 8, 8, 1);
assert_eq!(result.estimate, 0.0);
assert_eq!(result.std_err, 0.0);
}
#[test]
fn std_err_shrinks_with_more_probes() {
let dim = 120usize;
let a = random_spd(dim, dim + 30, 3.0, 21);
let few = slq_logdet(dim, |v| a.dot(&v), 6, 60, 5);
let many = slq_logdet(dim, |v| a.dot(&v), 96, 60, 5);
eprintln!(
"std_err few(6)={:.4e} many(96)={:.4e}",
few.std_err, many.std_err
);
assert!(
many.std_err < few.std_err,
"more probes should reduce std_err (few={:.4e}, many={:.4e})",
few.std_err,
many.std_err
);
}
}