use crate::inference::psis::pareto_smooth_weights;
use ndarray::{Array1, Array2};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RhoCertificate {
PlugInCertified,
ImportanceCorrect,
Escalate,
}
impl RhoCertificate {
fn from_k_hat(k_hat: f64) -> Self {
if !k_hat.is_finite() || k_hat > 0.7 {
RhoCertificate::Escalate
} else if k_hat < 0.5 {
RhoCertificate::PlugInCertified
} else {
RhoCertificate::ImportanceCorrect
}
}
}
#[derive(Debug, Clone)]
pub struct RhoPosteriorCertificate {
pub k_hat: f64,
pub certificate: RhoCertificate,
pub n_samples: usize,
pub weights: Array1<f64>,
pub effective_sample_size: f64,
}
const DEFAULT_M: usize = 64;
const CERTIFICATE_SEED: u64 = 0x9E37_79B9_7F4A_7C15;
struct DetNormal {
state: u64,
}
impl DetNormal {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn uniform(&mut self) -> f64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^= z >> 31;
(((z >> 11) as f64) + 0.5) / ((1u64 << 53) as f64)
}
fn normal(&mut self) -> f64 {
let u1 = self.uniform().max(1e-300);
let u2 = self.uniform();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
}
fn cholesky_lower(a: &Array2<f64>) -> Option<Array2<f64>> {
let n = a.nrows();
if n == 0 || a.ncols() != n {
return None;
}
let scale = (0..n)
.map(|i| a[[i, i]].abs())
.fold(0.0_f64, f64::max)
.max(1.0);
let jitter = 1e-10 * scale;
let mut l = Array2::<f64>::zeros((n, n));
for j in 0..n {
let mut d = a[[j, j]] + jitter;
for k in 0..j {
d -= l[[j, k]] * l[[j, k]];
}
if !(d.is_finite() && d > 0.0) {
return None;
}
let ljj = d.sqrt();
l[[j, j]] = ljj;
for i in (j + 1)..n {
let mut s = a[[i, j]];
for k in 0..j {
s -= l[[i, k]] * l[[j, k]];
}
l[[i, j]] = s / ljj;
}
}
Some(l)
}
fn whitening_factor_from_outer_hessian(outer_hessian: &Array2<f64>) -> Option<Array2<f64>> {
let r = cholesky_lower(outer_hessian)?;
let n = r.nrows();
let mut r_inv = Array2::<f64>::zeros((n, n));
for col in 0..n {
let mut x = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = if i == col { 1.0 } else { 0.0 };
for k in 0..i {
acc -= r[[i, k]] * x[k];
}
let rii = r[[i, i]];
if !(rii.is_finite() && rii.abs() > 0.0) {
return None;
}
x[i] = acc / rii;
}
for i in 0..n {
r_inv[[i, col]] = x[i];
}
}
let mut l_inv = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
l_inv[[i, j]] = r_inv[[j, i]];
}
}
Some(l_inv)
}
pub fn rho_posterior_certificate<F>(
rho_hat: &Array1<f64>,
outer_hessian: &Array2<f64>,
criterion: F,
n_samples: Option<usize>,
) -> Option<RhoPosteriorCertificate>
where
F: Fn(&Array1<f64>) -> Option<f64>,
{
let k = rho_hat.len();
if k == 0 || outer_hessian.nrows() != k || outer_hessian.ncols() != k {
return None;
}
let cost_hat = criterion(rho_hat)?;
if !cost_hat.is_finite() {
return None;
}
let l_inv = whitening_factor_from_outer_hessian(outer_hessian)?;
let m = n_samples
.unwrap_or(DEFAULT_M)
.max(2 * crate::inference::psis::MIN_TAIL_COUNT);
let mut rng = DetNormal::new(CERTIFICATE_SEED);
let mut raw_weights: Vec<f64> = Vec::with_capacity(m);
for _ in 0..m {
let z: Array1<f64> = Array1::from_iter((0..k).map(|_| rng.normal()));
let mut rho_m = rho_hat.clone();
for i in 0..k {
let mut acc = 0.0;
for j in 0..k {
acc += l_inv[[i, j]] * z[j];
}
rho_m[i] += acc;
}
let half_norm_sq = 0.5 * z.iter().map(|&v| v * v).sum::<f64>();
let log_w = match criterion(&rho_m) {
Some(c) if c.is_finite() => -c + cost_hat + half_norm_sq,
_ => f64::NEG_INFINITY,
};
raw_weights.push(log_w);
}
let max_lw = raw_weights
.iter()
.copied()
.filter(|v| v.is_finite())
.fold(f64::NEG_INFINITY, f64::max);
if !max_lw.is_finite() {
return None;
}
let weights: Vec<f64> = raw_weights
.iter()
.map(|&lw| {
if lw.is_finite() {
(lw - max_lw).exp()
} else {
0.0
}
})
.collect();
let psis = pareto_smooth_weights(&weights)?;
let k_hat = psis.k_hat;
let total: f64 = psis.smoothed.iter().sum();
if !(total.is_finite() && total > 0.0) {
return None;
}
let normalized: Array1<f64> = Array1::from_iter(psis.smoothed.iter().map(|&w| w / total));
let sum_sq: f64 = normalized.iter().map(|&w| w * w).sum();
let ess = if sum_sq > 0.0 { 1.0 / sum_sq } else { 0.0 };
Some(RhoPosteriorCertificate {
k_hat,
certificate: RhoCertificate::from_k_hat(k_hat),
n_samples: m,
weights: normalized,
effective_sample_size: ess,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn exact_gaussian_target_certifies_plug_in() {
let rho_hat = array![0.3, -0.7];
let h = array![[2.0, 0.5], [0.5, 1.5]];
let crit = |rho: &Array1<f64>| {
let d = rho - &rho_hat;
let mut q = 0.0;
for i in 0..2 {
for j in 0..2 {
q += d[i] * h[[i, j]] * d[j];
}
}
Some(0.5 * q)
};
let cert = rho_posterior_certificate(&rho_hat, &h, crit, Some(256)).expect("certificate");
assert!(
(cert.effective_sample_size - cert.n_samples as f64).abs() < 1e-6,
"uniform weights must give ESS == M: ess={} M={}",
cert.effective_sample_size,
cert.n_samples
);
assert!(
cert.k_hat < 0.5,
"exact-Gaussian target must yield small k̂, got {}",
cert.k_hat
);
assert_eq!(cert.certificate, RhoCertificate::PlugInCertified);
}
#[test]
fn heavy_tailed_target_refuses_to_certify() {
let rho_hat = array![0.0];
let h = array![[4.0]]; let crit = |rho: &Array1<f64>| {
let r = rho[0];
Some((1.0 + r * r).ln())
};
let cert = rho_posterior_certificate(&rho_hat, &h, crit, Some(512)).expect("certificate");
assert!(
cert.k_hat > 0.5,
"heavy-tailed target must raise k̂ above 0.5, got {}",
cert.k_hat
);
assert_ne!(cert.certificate, RhoCertificate::PlugInCertified);
}
#[test]
fn weights_are_normalized_and_deterministic() {
let rho_hat = array![1.0];
let h = array![[1.0]];
let crit = |rho: &Array1<f64>| {
let d = rho[0] - 1.0;
Some(0.5 * d * d)
};
let a = rho_posterior_certificate(&rho_hat, &h, crit, Some(64)).expect("a");
let b = rho_posterior_certificate(&rho_hat, &h, crit, Some(64)).expect("b");
let s: f64 = a.weights.iter().sum();
assert!((s - 1.0).abs() < 1e-10, "weights must sum to 1, got {s}");
assert_eq!(a.k_hat.to_bits(), b.k_hat.to_bits());
}
#[test]
fn empty_rho_returns_none() {
let rho_hat: Array1<f64> = array![];
let h = Array2::<f64>::zeros((0, 0));
assert!(rho_posterior_certificate(&rho_hat, &h, |_| Some(0.0), None).is_none());
}
}