#![forbid(unsafe_code)]
#![warn(missing_docs)]
use thiserror::Error;
pub mod distprop;
mod ksg;
pub use ksg::{mutual_information_ksg, KsgVariant};
use core::f64::consts::LN_2;
pub fn kl_divergence_gaussians(
mu1: &[f64],
std1: &[f64],
mu2: &[f64],
std2: &[f64],
) -> Result<f64> {
ensure_same_len(mu1, std1)?;
ensure_same_len(mu1, mu2)?;
ensure_same_len(mu1, std2)?;
let mut kl = 0.0;
for (((&m1, &s1), &m2), &s2) in mu1.iter().zip(std1).zip(mu2).zip(std2) {
if s1 <= 0.0 || s2 <= 0.0 {
return Err(Error::Domain("standard deviation must be positive"));
}
let v1 = s1 * s1;
let v2 = s2 * s2;
kl += (v1 / v2) + (m2 - m1).powi(2) / v2 - 1.0 + 2.0 * (s2.ln() - s1.ln());
}
Ok(0.5 * kl)
}
#[derive(Debug, Error)]
pub enum Error {
#[error("length mismatch: {0} vs {1}")]
LengthMismatch(usize, usize),
#[error("empty input")]
Empty,
#[error("non-finite entry at index {idx}: {value}")]
NonFinite {
idx: usize,
value: f64,
},
#[error("negative entry at index {idx}: {value}")]
Negative {
idx: usize,
value: f64,
},
#[error("not normalized (expected sum≈1): sum={sum}")]
NotNormalized {
sum: f64,
},
#[error("invalid alpha: {alpha} (must be finite and not equal to {forbidden})")]
InvalidAlpha {
alpha: f64,
forbidden: f64,
},
#[error("domain error: {0}")]
Domain(&'static str),
}
pub type Result<T> = core::result::Result<T, Error>;
fn ensure_nonempty(x: &[f64]) -> Result<()> {
if x.is_empty() {
return Err(Error::Empty);
}
Ok(())
}
fn ensure_same_len(a: &[f64], b: &[f64]) -> Result<()> {
if a.len() != b.len() {
return Err(Error::LengthMismatch(a.len(), b.len()));
}
Ok(())
}
fn ensure_nonnegative(x: &[f64]) -> Result<()> {
for (i, &v) in x.iter().enumerate() {
if !v.is_finite() {
return Err(Error::NonFinite { idx: i, value: v });
}
if v < 0.0 {
return Err(Error::Negative { idx: i, value: v });
}
}
Ok(())
}
fn sum(x: &[f64]) -> f64 {
x.iter().sum()
}
pub fn validate_simplex(p: &[f64], tol: f64) -> Result<()> {
ensure_nonempty(p)?;
ensure_nonnegative(p)?;
let s = sum(p);
if (s - 1.0).abs() > tol {
return Err(Error::NotNormalized { sum: s });
}
Ok(())
}
pub fn normalize_in_place(p: &mut [f64]) -> Result<f64> {
ensure_nonempty(p)?;
ensure_nonnegative(p)?;
let s = sum(p);
if s <= 0.0 {
return Err(Error::Domain("cannot normalize: sum <= 0"));
}
for v in p.iter_mut() {
*v /= s;
}
Ok(s)
}
pub fn entropy_nats(p: &[f64], tol: f64) -> Result<f64> {
validate_simplex(p, tol)?;
let mut h = 0.0;
for &pi in p {
if pi > 0.0 {
h -= pi * pi.ln();
}
}
Ok(h)
}
pub fn entropy_bits(p: &[f64], tol: f64) -> Result<f64> {
Ok(entropy_nats(p, tol)? / LN_2)
}
#[inline]
pub fn entropy_unchecked(p: &[f64]) -> f64 {
let mut h = 0.0;
for &pi in p {
if pi > 0.0 {
h -= pi * pi.ln();
}
}
h
}
pub fn renyi_entropy(p: &[f64], alpha: f64, tol: f64) -> Result<f64> {
validate_simplex(p, tol)?;
if !alpha.is_finite() || alpha < 0.0 {
return Err(Error::InvalidAlpha {
alpha,
forbidden: f64::NAN,
});
}
if (alpha - 1.0).abs() < 1e-12 {
return entropy_nats(p, tol);
}
let mut s = 0.0;
for &pi in p {
if pi > 0.0 {
s += pi.powf(alpha);
}
}
if s <= 0.0 {
return Err(Error::Domain("renyi_entropy: sum of p_i^alpha <= 0"));
}
Ok(s.ln() / (1.0 - alpha))
}
pub fn tsallis_entropy(p: &[f64], alpha: f64, tol: f64) -> Result<f64> {
validate_simplex(p, tol)?;
if !alpha.is_finite() || alpha < 0.0 {
return Err(Error::InvalidAlpha {
alpha,
forbidden: f64::NAN,
});
}
if (alpha - 1.0).abs() < 1e-12 {
return entropy_nats(p, tol);
}
let mut s = 0.0;
for &pi in p {
if pi > 0.0 {
s += pi.powf(alpha);
}
}
Ok((1.0 - s) / (alpha - 1.0))
}
pub fn cross_entropy_nats(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let mut h = 0.0;
for (&pi, &qi) in p.iter().zip(q.iter()) {
if pi == 0.0 {
continue;
}
if qi <= 0.0 {
return Err(Error::Domain("cross-entropy undefined: q_i=0 while p_i>0"));
}
h -= pi * qi.ln();
}
Ok(h)
}
pub fn kl_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let mut d = 0.0;
for (&pi, &qi) in p.iter().zip(q.iter()) {
if pi == 0.0 {
continue;
}
if qi <= 0.0 {
return Err(Error::Domain("KL undefined: q_i=0 while p_i>0"));
}
d += pi * (pi / qi).ln();
}
Ok(d)
}
pub fn jensen_shannon_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let mut m = vec![0.0; p.len()];
for i in 0..p.len() {
m[i] = 0.5 * (p[i] + q[i]);
}
Ok(0.5 * kl_divergence(p, &m, tol)? + 0.5 * kl_divergence(q, &m, tol)?)
}
pub fn jensen_shannon_weighted(p: &[f64], q: &[f64], pi1: f64, tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
if !(0.0..=1.0).contains(&pi1) || !pi1.is_finite() {
return Err(Error::Domain("pi1 must be in [0, 1]"));
}
let pi2 = 1.0 - pi1;
let mut m = vec![0.0; p.len()];
for i in 0..p.len() {
m[i] = pi1 * p[i] + pi2 * q[i];
}
let kl_p = if pi1 > 0.0 {
kl_divergence(p, &m, tol)?
} else {
0.0
};
let kl_q = if pi2 > 0.0 {
kl_divergence(q, &m, tol)?
} else {
0.0
};
Ok(pi1 * kl_p + pi2 * kl_q)
}
pub fn mutual_information(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
if n_x == 0 || n_y == 0 {
return Err(Error::Domain(
"mutual_information: n_x and n_y must be >= 1",
));
}
if p_xy.len() != n_x * n_y {
return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
}
validate_simplex(p_xy, tol)?;
let mut p_x = vec![0.0; n_x];
let mut p_y = vec![0.0; n_y];
for i in 0..n_x {
for j in 0..n_y {
let p = p_xy[i * n_y + j];
p_x[i] += p;
p_y[j] += p;
}
}
let mut mi = 0.0;
for i in 0..n_x {
for j in 0..n_y {
let pxy = p_xy[i * n_y + j];
if pxy > 0.0 {
let px = p_x[i];
let py = p_y[j];
if px <= 0.0 || py <= 0.0 {
return Err(Error::Domain(
"mutual_information: p(x)=0 or p(y)=0 while p(x,y)>0",
));
}
mi += pxy * (pxy / (px * py)).ln();
}
}
}
Ok(mi)
}
#[cfg(feature = "ndarray")]
pub fn mutual_information_ndarray(p_xy: &ndarray::Array2<f64>, tol: f64) -> Result<f64> {
let (n_x, n_y) = p_xy.dim();
let flat: Vec<f64> = p_xy.iter().copied().collect();
mutual_information(&flat, n_x, n_y, tol)
}
pub fn conditional_entropy(p_xy: &[f64], n_x: usize, n_y: usize, tol: f64) -> Result<f64> {
let mi = mutual_information(p_xy, n_x, n_y, tol)?;
let mut p_x = vec![0.0; n_x];
for i in 0..n_x {
for j in 0..n_y {
p_x[i] += p_xy[i * n_y + j];
}
}
let h_x = entropy_nats(&p_x, tol)?;
Ok(h_x - mi)
}
pub fn normalized_mutual_information(
p_xy: &[f64],
n_x: usize,
n_y: usize,
tol: f64,
) -> Result<f64> {
if n_x == 0 || n_y == 0 {
return Err(Error::Domain("nmi: n_x and n_y must be >= 1"));
}
if p_xy.len() != n_x * n_y {
return Err(Error::LengthMismatch(p_xy.len(), n_x * n_y));
}
validate_simplex(p_xy, tol)?;
let mut p_x = vec![0.0; n_x];
let mut p_y = vec![0.0; n_y];
for i in 0..n_x {
for j in 0..n_y {
let p = p_xy[i * n_y + j];
p_x[i] += p;
p_y[j] += p;
}
}
let h_x = entropy_nats(&p_x, tol)?;
let h_y = entropy_nats(&p_y, tol)?;
let denom = h_x + h_y;
if denom <= 0.0 {
return Ok(0.0);
}
let mi = mutual_information(p_xy, n_x, n_y, tol)?;
Ok(2.0 * mi / denom)
}
pub fn pmi(pxy: f64, px: f64, py: f64) -> Result<f64> {
if pxy > 0.0 && px == 0.0 {
return Err(Error::Domain("pmi: p(x,y)>0 but p(x)=0 is impossible"));
}
if pxy > 0.0 && py == 0.0 {
return Err(Error::Domain("pmi: p(x,y)>0 but p(y)=0 is impossible"));
}
if pxy <= 0.0 || px <= 0.0 || py <= 0.0 {
Ok(0.0)
} else {
Ok((pxy / (px * py)).ln())
}
}
#[inline]
pub fn log_sum_exp(values: &[f64]) -> f64 {
if values.is_empty() {
return f64::NEG_INFINITY;
}
let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max.is_infinite() {
return max;
}
let sum: f64 = values.iter().map(|v| (v - max).exp()).sum();
max + sum.ln()
}
#[inline]
pub fn log_sum_exp2(a: f64, b: f64) -> f64 {
let max = a.max(b);
if max.is_infinite() {
return max;
}
max + ((a - max).exp() + (b - max).exp()).ln()
}
#[inline]
pub fn log_sum_exp_iter(iter: impl Iterator<Item = f64>) -> f64 {
let mut max = f64::NEG_INFINITY;
let mut sum_exp = 0.0;
for v in iter {
if v > max {
if max.is_finite() {
sum_exp *= (max - v).exp();
}
max = v;
}
sum_exp += (v - max).exp();
}
if max.is_infinite() {
return max; }
max + sum_exp.ln()
}
pub fn digamma(mut x: f64) -> f64 {
if x <= 0.0 {
return f64::NAN;
}
let mut result = 0.0;
while x < 10.0 {
result -= 1.0 / x;
x += 1.0;
}
let r = 1.0 / x;
result += x.ln() - 0.5 * r;
let r2 = r * r;
result -= r2 * (1.0 / 12.0 - r2 * (1.0 / 120.0 - r2 * (1.0 / 252.0 - r2 / 240.0)));
result
}
pub fn bhattacharyya_coeff(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let bc: f64 = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| pi.sqrt() * qi.sqrt())
.sum();
Ok(bc)
}
pub fn bhattacharyya_distance(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
let bc = bhattacharyya_coeff(p, q, tol)?;
if bc == 0.0 {
return Err(Error::Domain("Bhattacharyya distance is infinite (BC=0)"));
}
Ok(-bc.ln())
}
pub fn hellinger_squared(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let h2: f64 = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
let diff = pi.sqrt() - qi.sqrt();
diff * diff
})
.sum();
Ok((0.5 * h2).max(0.0))
}
pub fn hellinger(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
Ok(hellinger_squared(p, q, tol)?.sqrt())
}
fn pow_nonneg(x: f64, a: f64) -> Result<f64> {
if x < 0.0 || !x.is_finite() || !a.is_finite() {
return Err(Error::Domain("pow_nonneg: invalid input"));
}
if x == 0.0 {
if a == 0.0 {
return Ok(1.0);
}
if a > 0.0 {
return Ok(0.0);
}
return Err(Error::Domain("0^a for a<0 is infinite"));
}
Ok(x.powf(a))
}
pub fn rho_alpha(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
if !alpha.is_finite() {
return Err(Error::InvalidAlpha {
alpha,
forbidden: f64::NAN,
});
}
let mut s = 0.0;
for (&pi, &qi) in p.iter().zip(q.iter()) {
let a = pow_nonneg(pi, alpha)?;
let b = pow_nonneg(qi, 1.0 - alpha)?;
s += a * b;
}
Ok(s)
}
pub fn renyi_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
if (alpha - 1.0).abs() < 1e-12 {
return kl_divergence(p, q, tol);
}
let rho = rho_alpha(p, q, alpha, tol)?;
if rho <= 0.0 {
return Err(Error::Domain("rho_alpha <= 0"));
}
Ok(rho.ln() / (alpha - 1.0))
}
pub fn tsallis_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
if (alpha - 1.0).abs() < 1e-12 {
return kl_divergence(p, q, tol);
}
Ok((rho_alpha(p, q, alpha, tol)? - 1.0) / (alpha - 1.0))
}
pub fn amari_alpha_divergence(p: &[f64], q: &[f64], alpha: f64, tol: f64) -> Result<f64> {
if !alpha.is_finite() {
return Err(Error::InvalidAlpha {
alpha,
forbidden: f64::NAN,
});
}
let eps = tol.sqrt();
if (alpha + 1.0).abs() <= eps {
return kl_divergence(p, q, tol);
}
if (alpha - 1.0).abs() <= eps {
return kl_divergence(q, p, tol);
}
let t = (1.0 - alpha) / 2.0;
let rho = rho_alpha(p, q, t, tol)?;
Ok((4.0 / (1.0 - alpha * alpha)) * (1.0 - rho))
}
pub fn csiszar_f_divergence(p: &[f64], q: &[f64], f: impl Fn(f64) -> f64, tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let mut d = 0.0;
for (&pi, &qi) in p.iter().zip(q.iter()) {
if qi == 0.0 {
if pi == 0.0 {
continue;
}
return Err(Error::Domain("f-divergence undefined: q_i=0 while p_i>0"));
}
d += qi * f(pi / qi);
}
Ok(d)
}
pub fn total_variation(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let tv: f64 = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (pi - qi).abs())
.sum();
Ok(0.5 * tv)
}
pub fn chi_squared_divergence(p: &[f64], q: &[f64], tol: f64) -> Result<f64> {
ensure_same_len(p, q)?;
validate_simplex(p, tol)?;
validate_simplex(q, tol)?;
let mut d = 0.0;
for (&pi, &qi) in p.iter().zip(q.iter()) {
if qi == 0.0 {
if pi == 0.0 {
continue;
}
return Err(Error::Domain("chi-squared undefined: q_i=0 while p_i>0"));
}
let diff = pi - qi;
d += diff * diff / qi;
}
Ok(d)
}
pub trait BregmanGenerator {
fn f(&self, x: &[f64]) -> Result<f64>;
fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()>;
}
pub fn bregman_divergence(gen: &impl BregmanGenerator, p: &[f64], q: &[f64]) -> Result<f64> {
ensure_nonempty(p)?;
ensure_same_len(p, q)?;
let mut grad_q = vec![0.0; q.len()];
gen.grad_into(q, &mut grad_q)?;
let fp = gen.f(p)?;
let fq = gen.f(q)?;
let mut inner = 0.0;
for i in 0..p.len() {
inner += (p[i] - q[i]) * grad_q[i];
}
Ok(fp - fq - inner)
}
pub fn total_bregman_divergence(gen: &impl BregmanGenerator, p: &[f64], q: &[f64]) -> Result<f64> {
ensure_nonempty(p)?;
ensure_same_len(p, q)?;
let mut grad_q = vec![0.0; q.len()];
gen.grad_into(q, &mut grad_q)?;
let fp = gen.f(p)?;
let fq = gen.f(q)?;
let mut inner = 0.0;
for i in 0..p.len() {
inner += (p[i] - q[i]) * grad_q[i];
}
let b = fp - fq - inner;
let grad_norm_sq: f64 = grad_q.iter().map(|&x| x * x).sum();
Ok(b / (1.0 + grad_norm_sq).sqrt())
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SquaredL2;
impl BregmanGenerator for SquaredL2 {
fn f(&self, x: &[f64]) -> Result<f64> {
ensure_nonempty(x)?;
Ok(0.5 * x.iter().map(|&v| v * v).sum::<f64>())
}
fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
ensure_nonempty(x)?;
if out.len() != x.len() {
return Err(Error::LengthMismatch(out.len(), x.len()));
}
out.copy_from_slice(x);
Ok(())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NegEntropy;
impl BregmanGenerator for NegEntropy {
fn f(&self, x: &[f64]) -> Result<f64> {
ensure_nonempty(x)?;
let mut s = 0.0;
for &xi in x {
if xi < 0.0 {
return Err(Error::Domain("NegEntropy: input must be nonnegative"));
}
if xi > 0.0 {
s += xi * xi.ln();
}
}
Ok(s)
}
fn grad_into(&self, x: &[f64], out: &mut [f64]) -> Result<()> {
ensure_nonempty(x)?;
if out.len() != x.len() {
return Err(Error::LengthMismatch(out.len(), x.len()));
}
for (o, &xi) in out.iter_mut().zip(x.iter()) {
if xi <= 0.0 {
return Err(Error::Domain("NegEntropy grad: input must be positive"));
}
*o = 1.0 + xi.ln();
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
const TOL: f64 = 1e-9;
fn simplex_vec(len: usize) -> impl Strategy<Value = Vec<f64>> {
prop::collection::vec(0.0f64..10.0, len).prop_map(|mut v| {
let s: f64 = v.iter().sum();
if s == 0.0 {
v[0] = 1.0;
return v;
}
for x in v.iter_mut() {
*x /= s;
}
v
})
}
fn simplex_vec_pos(len: usize, eps: f64) -> impl Strategy<Value = Vec<f64>> {
prop::collection::vec(0.0f64..10.0, len).prop_map(move |mut v| {
for x in v.iter_mut() {
*x += eps;
}
let s: f64 = v.iter().sum();
for x in v.iter_mut() {
*x /= s;
}
v
})
}
fn random_partition(n: usize) -> impl Strategy<Value = Vec<usize>> {
prop::collection::vec(0usize..n, n).prop_map(|labels| {
use std::collections::BTreeMap;
let mut map = BTreeMap::<usize, usize>::new();
let mut next = 0usize;
labels
.into_iter()
.map(|l| {
*map.entry(l).or_insert_with(|| {
let id = next;
next += 1;
id
})
})
.collect::<Vec<_>>()
})
}
fn coarse_grain(p: &[f64], labels: &[usize]) -> Vec<f64> {
let k = labels.iter().copied().max().unwrap_or(0) + 1;
let mut out = vec![0.0; k];
for (i, &lab) in labels.iter().enumerate() {
out[lab] += p[i];
}
out
}
fn l1(p: &[f64], q: &[f64]) -> f64 {
p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum()
}
#[test]
fn test_entropy_unchecked() {
let p = [0.5, 0.5];
let h = entropy_unchecked(&p);
assert!((h - core::f64::consts::LN_2).abs() < 1e-12);
}
#[test]
fn js_is_bounded_by_ln2() {
let p = [1.0, 0.0];
let q = [0.0, 1.0];
let js = jensen_shannon_divergence(&p, &q, TOL).unwrap();
assert!(js <= core::f64::consts::LN_2 + 1e-12);
assert!(js >= 0.0);
}
#[test]
fn mutual_information_independent_is_zero() {
let p_x = [0.5, 0.5];
let p_y = [0.25, 0.75];
let p_xy = [
p_x[0] * p_y[0],
p_x[0] * p_y[1],
p_x[1] * p_y[0],
p_x[1] * p_y[1],
];
let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
assert!(mi.abs() < 1e-12, "mi={}", mi);
}
#[test]
fn mutual_information_perfect_correlation_is_ln2() {
let p_xy = [0.5, 0.0, 0.0, 0.5]; let mi = mutual_information(&p_xy, 2, 2, TOL).unwrap();
assert!((mi - core::f64::consts::LN_2).abs() < 1e-12, "mi={}", mi);
}
#[test]
fn bregman_squared_l2_matches_half_l2() {
let gen = SquaredL2;
let p = [1.0, 2.0, 3.0];
let q = [1.5, 1.5, 2.5];
let b = bregman_divergence(&gen, &p, &q).unwrap();
let expected = 0.5
* p.iter()
.zip(q.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f64>();
assert!((b - expected).abs() < 1e-12);
}
#[test]
fn entropy_nats_uniform_is_ln_n() {
for n in [2, 4, 8, 16] {
let p: Vec<f64> = vec![1.0 / n as f64; n];
let h = entropy_nats(&p, TOL).unwrap();
let expected = (n as f64).ln();
assert!(
(h - expected).abs() < 1e-12,
"n={n}: h={h} expected={expected}"
);
}
}
#[test]
fn entropy_nats_singleton_is_zero() {
let h = entropy_nats(&[1.0], TOL).unwrap();
assert!(h.abs() < 1e-15);
}
#[test]
fn entropy_bits_converts_correctly() {
let p = [0.25, 0.75];
let nats = entropy_nats(&p, TOL).unwrap();
let bits = entropy_bits(&p, TOL).unwrap();
assert!((bits - nats / core::f64::consts::LN_2).abs() < 1e-12);
}
#[test]
fn cross_entropy_identity_h_pq_eq_h_p_plus_kl() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let h_pq = cross_entropy_nats(&p, &q, TOL).unwrap();
let h_p = entropy_nats(&p, TOL).unwrap();
let kl = kl_divergence(&p, &q, TOL).unwrap();
assert!((h_pq - (h_p + kl)).abs() < 1e-12);
}
#[test]
fn cross_entropy_rejects_zero_q_with_positive_p() {
let p = [0.5, 0.5];
let q = [1.0, 0.0]; assert!(cross_entropy_nats(&p, &q, TOL).is_err());
}
#[test]
fn validate_simplex_accepts_valid() {
assert!(validate_simplex(&[0.3, 0.7], TOL).is_ok());
assert!(validate_simplex(&[1.0], TOL).is_ok());
}
#[test]
fn validate_simplex_rejects_bad_sum() {
assert!(validate_simplex(&[0.3, 0.6], TOL).is_err()); }
#[test]
fn validate_simplex_rejects_negative() {
assert!(validate_simplex(&[1.5, -0.5], TOL).is_err());
}
#[test]
fn validate_simplex_rejects_empty() {
assert!(validate_simplex(&[], TOL).is_err());
}
#[test]
fn normalize_in_place_works() {
let mut v = vec![2.0, 3.0];
let s = normalize_in_place(&mut v).unwrap();
assert!((s - 5.0).abs() < 1e-12);
assert!((v[0] - 0.4).abs() < 1e-12);
assert!((v[1] - 0.6).abs() < 1e-12);
}
#[test]
fn normalize_in_place_rejects_zero_sum() {
let mut v = vec![0.0, 0.0];
assert!(normalize_in_place(&mut v).is_err());
}
#[test]
fn hellinger_identical_is_zero() {
let p = [0.25, 0.75];
let h = hellinger(&p, &p, TOL).unwrap();
assert!(h.abs() < 1e-12);
}
#[test]
fn hellinger_squared_in_unit_interval() {
let p = [0.1, 0.9];
let q = [0.9, 0.1];
let h2 = hellinger_squared(&p, &q, TOL).unwrap();
assert!((-1e-12..=1.0 + 1e-12).contains(&h2), "h2={h2}");
}
#[test]
fn bhattacharyya_coeff_identical_is_one() {
let p = [0.3, 0.7];
let bc = bhattacharyya_coeff(&p, &p, TOL).unwrap();
assert!((bc - 1.0).abs() < 1e-12);
}
#[test]
fn bhattacharyya_distance_identical_is_zero() {
let p = [0.5, 0.5];
let d = bhattacharyya_distance(&p, &p, TOL).unwrap();
assert!(d.abs() < 1e-12);
}
#[test]
fn renyi_alpha_half_on_simple_case() {
let p = [0.5, 0.5];
let q = [0.25, 0.75];
let r = renyi_divergence(&p, &q, 0.5, TOL).unwrap();
assert!(r >= -1e-12, "renyi={r}");
}
#[test]
fn renyi_identical_is_zero() {
let p = [0.3, 0.7];
let r = renyi_divergence(&p, &p, 2.0, TOL).unwrap();
assert!(r.abs() < 1e-12, "renyi(p,p)={r}");
}
#[test]
fn tsallis_identical_is_zero() {
let p = [0.4, 0.6];
let t = tsallis_divergence(&p, &p, 2.0, TOL).unwrap();
assert!(t.abs() < 1e-12, "tsallis(p,p)={t}");
}
#[test]
fn digamma_at_one_is_neg_euler_mascheroni() {
let psi1 = digamma(1.0);
assert!(
(psi1 - (-0.57721566490153286)).abs() < 1e-12,
"psi(1)={psi1}"
);
}
#[test]
fn digamma_recurrence_relation() {
for &x in &[1.0, 2.0, 3.5, 10.0] {
let lhs = digamma(x + 1.0);
let rhs = digamma(x) + 1.0 / x;
assert!(
(lhs - rhs).abs() < 1e-12,
"recurrence at x={x}: {lhs} vs {rhs}"
);
}
}
#[test]
fn pmi_independent_is_zero() {
let pmi_val = pmi(0.06, 0.3, 0.2).unwrap(); assert!(
pmi_val.abs() < 1e-10,
"PMI of independent events should be 0: {pmi_val}"
);
}
#[test]
fn pmi_positive_for_correlated() {
let pmi_val = pmi(0.4, 0.5, 0.5).unwrap(); assert!(
pmi_val > 0.0,
"correlated events should have positive PMI: {pmi_val}"
);
}
#[test]
fn renyi_approaches_kl_as_alpha_to_one() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let tol = 1e-9;
let kl = kl_divergence(&p, &q, tol).unwrap();
let r099 = renyi_divergence(&p, &q, 0.99, tol).unwrap();
let r0999 = renyi_divergence(&p, &q, 0.999, tol).unwrap();
assert!((r099 - kl).abs() < 0.01, "Renyi(0.99)={r099}, KL={kl}");
assert!((r0999 - kl).abs() < 0.001, "Renyi(0.999)={r0999}, KL={kl}");
}
#[test]
fn amari_alpha_neg1_is_kl_forward() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let tol = 1e-9;
let kl_pq = kl_divergence(&p, &q, tol).unwrap();
let amari = amari_alpha_divergence(&p, &q, -1.0, tol).unwrap();
assert!(
(amari - kl_pq).abs() < 1e-6,
"Amari(-1)={amari}, KL(p||q)={kl_pq}"
);
}
#[test]
fn amari_alpha_pos1_is_kl_reverse() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let tol = 1e-9;
let kl_qp = kl_divergence(&q, &p, tol).unwrap();
let amari = amari_alpha_divergence(&p, &q, 1.0, tol).unwrap();
assert!(
(amari - kl_qp).abs() < 1e-6,
"Amari(1)={amari}, KL(q||p)={kl_qp}"
);
}
#[test]
fn csiszar_with_kl_generator_matches_kl() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let tol = 1e-9;
let kl = kl_divergence(&p, &q, tol).unwrap();
let cs = csiszar_f_divergence(&p, &q, |t| t * t.ln(), tol).unwrap();
assert!((cs - kl).abs() < 1e-6, "Csiszar(t*ln(t))={cs}, KL={kl}");
}
#[test]
fn mutual_information_deterministic_equals_entropy() {
let p_xy = [0.3, 0.0, 0.0, 0.7]; let mi = mutual_information(&p_xy, 2, 2, 1e-9).unwrap();
let h_x = entropy_nats(&[0.3, 0.7], 1e-9).unwrap();
assert!((mi - h_x).abs() < 1e-6, "MI={mi}, H(X)={h_x}");
}
proptest! {
#[test]
fn kl_is_nonnegative(p in simplex_vec_pos(8, 1e-6), q in simplex_vec_pos(8, 1e-6)) {
let d = kl_divergence(&p, &q, 1e-6).unwrap();
prop_assert!(d >= -1e-12);
}
#[test]
fn js_is_bounded(p in simplex_vec(16), q in simplex_vec(16)) {
let js = jensen_shannon_divergence(&p, &q, 1e-6).unwrap();
prop_assert!(js >= -1e-12);
prop_assert!(js <= core::f64::consts::LN_2 + 1e-9);
}
#[test]
fn prop_kl_gaussians_is_nonnegative(
mu1 in prop::collection::vec(-10.0f64..10.0, 1..16),
std1 in prop::collection::vec(0.1f64..5.0, 1..16),
mu2 in prop::collection::vec(-10.0f64..10.0, 1..16),
std2 in prop::collection::vec(0.1f64..5.0, 1..16),
) {
let n = mu1.len().min(std1.len()).min(mu2.len()).min(std2.len());
let d = kl_divergence_gaussians(&mu1[..n], &std1[..n], &mu2[..n], &std2[..n]).unwrap();
prop_assert!(d >= -1e-12);
}
#[test]
fn prop_kl_gaussians_is_zero_for_identical(
mu in prop::collection::vec(-10.0f64..10.0, 1..16),
std in prop::collection::vec(0.1f64..5.0, 1..16),
) {
let n = mu.len().min(std.len());
let d = kl_divergence_gaussians(&mu[..n], &std[..n], &mu[..n], &std[..n]).unwrap();
prop_assert!(d.abs() < 1e-12);
}
#[test]
fn f_divergence_monotone_under_coarse_graining(
p in simplex_vec_pos(12, 1e-6),
q in simplex_vec_pos(12, 1e-6),
labels in random_partition(12),
) {
let f = |t: f64| if t == 0.0 { 0.0 } else { t * t.ln() };
let d_f = csiszar_f_divergence(&p, &q, f, 1e-6).unwrap();
let pc = coarse_grain(&p, &labels);
let qc = coarse_grain(&q, &labels);
let d_fc = csiszar_f_divergence(&pc, &qc, f, 1e-6).unwrap();
prop_assert!(d_fc <= d_f + 1e-9);
}
}
proptest! {
#![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
#[test]
fn pinsker_kl_lower_bounds_l1_squared(
p in simplex_vec_pos(16, 1e-6),
q in simplex_vec_pos(16, 1e-6),
) {
let kl = kl_divergence(&p, &q, 1e-6).unwrap();
let d1 = l1(&p, &q);
prop_assert!(kl + 1e-9 >= 0.5 * d1 * d1, "kl={kl} l1={d1}");
}
#[test]
fn sqrt_js_satisfies_triangle_inequality(
p in simplex_vec(12),
q in simplex_vec(12),
r in simplex_vec(12),
) {
let js_pq = jensen_shannon_divergence(&p, &q, 1e-6).unwrap().max(0.0).sqrt();
let js_qr = jensen_shannon_divergence(&q, &r, 1e-6).unwrap().max(0.0).sqrt();
let js_pr = jensen_shannon_divergence(&p, &r, 1e-6).unwrap().max(0.0).sqrt();
prop_assert!(js_pr <= js_pq + js_qr + 1e-7, "js_pr={js_pr} js_pq+js_qr={}", js_pq+js_qr);
}
#[test]
fn mutual_information_equals_kl_to_product(
p_xy in simplex_vec_pos(16, 1e-6),
nx in 2usize..=4,
ny in 2usize..=4,
) {
let n = nx * ny;
let mut joint = p_xy;
joint.truncate(n);
let _ = normalize_in_place(&mut joint).unwrap();
let mi = mutual_information(&joint, nx, ny, 1e-6).unwrap();
let mut p_x = vec![0.0; nx];
let mut p_y = vec![0.0; ny];
for i in 0..nx {
for j in 0..ny {
let p = joint[i * ny + j];
p_x[i] += p;
p_y[j] += p;
}
}
let mut prod = vec![0.0; n];
for i in 0..nx {
for j in 0..ny {
prod[i * ny + j] = p_x[i] * p_y[j];
}
}
let kl = kl_divergence(&joint, &prod, 1e-6).unwrap();
prop_assert!((mi - kl).abs() < 1e-9, "mi={mi} kl={kl}");
}
#[test]
fn hellinger_satisfies_triangle_inequality(
p in simplex_vec(8),
q in simplex_vec(8),
r in simplex_vec(8),
) {
let h_pq = hellinger(&p, &q, 1e-6).unwrap();
let h_qr = hellinger(&q, &r, 1e-6).unwrap();
let h_pr = hellinger(&p, &r, 1e-6).unwrap();
prop_assert!(h_pr <= h_pq + h_qr + 1e-7, "h_pr={h_pr} h_pq+h_qr={}", h_pq + h_qr);
}
}
#[test]
fn total_bregman_le_bregman() {
let gen = SquaredL2;
let p = [1.0, 2.0, 3.0];
let q = [4.0, 5.0, 6.0];
let b = bregman_divergence(&gen, &p, &q).unwrap();
let tb = total_bregman_divergence(&gen, &p, &q).unwrap();
assert!(tb <= b + 1e-12, "total_bregman={tb} > bregman={b}");
assert!(tb >= 0.0);
}
#[test]
fn total_bregman_is_zero_for_identical() {
let gen = SquaredL2;
let p = [1.0, 2.0];
let tb = total_bregman_divergence(&gen, &p, &p).unwrap();
assert!(tb.abs() < 1e-15);
}
#[test]
fn rho_alpha_self_is_one() {
let p = [0.1, 0.2, 0.3, 0.4];
for alpha in [0.0, 0.25, 0.5, 0.75, 1.0, 2.0, -1.0] {
let r = rho_alpha(&p, &p, alpha, TOL).unwrap();
assert!((r - 1.0).abs() < 1e-10, "rho_alpha(p,p,{alpha})={r}");
}
}
#[test]
fn digamma_nonpositive_is_nan() {
assert!(digamma(0.0).is_nan());
assert!(digamma(-1.0).is_nan());
assert!(digamma(-100.0).is_nan());
}
#[test]
fn pmi_zero_joint_returns_zero() {
assert_eq!(pmi(0.0, 0.5, 0.5).unwrap(), 0.0);
}
#[test]
fn pmi_zero_marginal_with_zero_joint_returns_zero() {
assert_eq!(pmi(0.0, 0.0, 0.5).unwrap(), 0.0);
assert_eq!(pmi(0.0, 0.5, 0.0).unwrap(), 0.0);
}
#[test]
fn pmi_all_zero_returns_zero() {
assert_eq!(pmi(0.0, 0.0, 0.0).unwrap(), 0.0);
}
#[test]
fn digamma_at_dlmf_reference_values() {
let gamma = 0.57721566490153286;
let expected_half = -gamma - 2.0 * core::f64::consts::LN_2;
let psi_half = digamma(0.5);
assert!(
(psi_half - expected_half).abs() < 1e-12,
"psi(0.5)={psi_half} expected={expected_half}"
);
let expected_2 = 1.0 - gamma;
let psi_2 = digamma(2.0);
assert!(
(psi_2 - expected_2).abs() < 1e-12,
"psi(2)={psi_2} expected={expected_2}"
);
let expected_3 = 1.5 - gamma;
let psi_3 = digamma(3.0);
assert!(
(psi_3 - expected_3).abs() < 1e-12,
"psi(3)={psi_3} expected={expected_3}"
);
let expected_4 = 1.0 + 0.5 + 1.0 / 3.0 - gamma;
let psi_4 = digamma(4.0);
assert!(
(psi_4 - expected_4).abs() < 1e-12,
"psi(4)={psi_4} expected={expected_4}"
);
}
#[test]
fn tsallis_approaches_kl_as_alpha_to_one() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let tol = 1e-9;
let kl = kl_divergence(&p, &q, tol).unwrap();
let t099 = tsallis_divergence(&p, &q, 0.99, tol).unwrap();
let t0999 = tsallis_divergence(&p, &q, 0.999, tol).unwrap();
let t101 = tsallis_divergence(&p, &q, 1.01, tol).unwrap();
let t1001 = tsallis_divergence(&p, &q, 1.001, tol).unwrap();
assert!((t099 - kl).abs() < 0.01, "Tsallis(0.99)={t099}, KL={kl}");
assert!(
(t0999 - kl).abs() < 0.001,
"Tsallis(0.999)={t0999}, KL={kl}"
);
assert!((t101 - kl).abs() < 0.01, "Tsallis(1.01)={t101}, KL={kl}");
assert!(
(t1001 - kl).abs() < 0.001,
"Tsallis(1.001)={t1001}, KL={kl}"
);
}
#[test]
fn renyi_approaches_kl_from_above() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let tol = 1e-9;
let kl = kl_divergence(&p, &q, tol).unwrap();
let r101 = renyi_divergence(&p, &q, 1.01, tol).unwrap();
let r1001 = renyi_divergence(&p, &q, 1.001, tol).unwrap();
assert!((r101 - kl).abs() < 0.01, "Renyi(1.01)={r101}, KL={kl}");
assert!((r1001 - kl).abs() < 0.001, "Renyi(1.001)={r1001}, KL={kl}");
}
#[test]
fn renyi_at_half_equals_neg2_ln_bc() {
let p = [0.2, 0.3, 0.5];
let q = [0.4, 0.4, 0.2];
let tol = 1e-9;
let renyi_half = renyi_divergence(&p, &q, 0.5, tol).unwrap();
let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
let expected = -2.0 * bc.ln();
assert!(
(renyi_half - expected).abs() < 1e-10,
"Renyi(0.5)={renyi_half}, -2*ln(BC)={expected}"
);
}
#[test]
fn hellinger_squared_equals_one_minus_bc() {
let p = [0.1, 0.4, 0.5];
let q = [0.3, 0.3, 0.4];
let tol = 1e-9;
let h2 = hellinger_squared(&p, &q, tol).unwrap();
let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
assert!(
(h2 - (1.0 - bc)).abs() < 1e-12,
"H^2={h2}, 1-BC={}",
1.0 - bc
);
}
#[test]
fn csiszar_hellinger_generator_matches_twice_hellinger_squared() {
let p = [0.2, 0.3, 0.5];
let q = [0.4, 0.4, 0.2];
let tol = 1e-9;
let h2 = hellinger_squared(&p, &q, tol).unwrap();
let cs = csiszar_f_divergence(&p, &q, |t| (t.sqrt() - 1.0).powi(2), tol).unwrap();
assert!(
(cs - 2.0 * h2).abs() < 1e-10,
"Csiszar(Hellinger)={cs}, 2*H^2={}",
2.0 * h2
);
}
#[test]
fn csiszar_chi_squared_generator_is_nonneg() {
let p = [0.2, 0.3, 0.5];
let q = [0.4, 0.4, 0.2];
let tol = 1e-9;
let chi2 = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), tol).unwrap();
assert!(chi2 >= 0.0, "chi2={chi2}");
let chi2_self = csiszar_f_divergence(&p, &p, |t| (t - 1.0).powi(2), tol).unwrap();
assert!(chi2_self.abs() < 1e-12, "chi2(p,p)={chi2_self}");
}
#[test]
fn near_boundary_inputs_no_nan() {
let tiny = 1e-300;
let p = [tiny, 1.0 - tiny];
let q = [tiny * 2.0, 1.0 - tiny * 2.0];
let tol = 1e-6;
let kl = kl_divergence(&p, &q, tol).unwrap();
assert!(kl.is_finite(), "kl={kl}");
assert!(kl >= -1e-12, "kl negative: {kl}");
let js = jensen_shannon_divergence(&p, &q, tol).unwrap();
assert!(js.is_finite(), "js={js}");
let h = hellinger(&p, &q, tol).unwrap();
assert!(h.is_finite(), "hellinger={h}");
let bc = bhattacharyya_coeff(&p, &q, tol).unwrap();
assert!(bc.is_finite(), "bc={bc}");
let ent = entropy_nats(&p, tol).unwrap();
assert!(ent.is_finite(), "entropy={ent}");
}
proptest! {
#![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
#[test]
fn entropy_is_concave(
p in simplex_vec(8),
q in simplex_vec(8),
lambda in 0.0f64..=1.0,
) {
let mix: Vec<f64> = p.iter().zip(q.iter())
.map(|(&pi, &qi)| lambda * pi + (1.0 - lambda) * qi)
.collect();
let h_mix = entropy_nats(&mix, 1e-6).unwrap();
let h_p = entropy_nats(&p, 1e-6).unwrap();
let h_q = entropy_nats(&q, 1e-6).unwrap();
let rhs = lambda * h_p + (1.0 - lambda) * h_q;
prop_assert!(h_mix + 1e-10 >= rhs, "h_mix={h_mix} rhs={rhs}");
}
#[test]
fn renyi_monotone_in_alpha(
p in simplex_vec_pos(8, 1e-6),
q in simplex_vec_pos(8, 1e-6),
) {
let alphas = [0.1, 0.25, 0.5, 0.75, 0.99];
let vals: Vec<f64> = alphas.iter()
.map(|&a| renyi_divergence(&p, &q, a, 1e-6).unwrap())
.collect();
for i in 1..vals.len() {
prop_assert!(
vals[i] + 1e-9 >= vals[i - 1],
"Renyi not monotone: D({})={} < D({})={}",
alphas[i], vals[i], alphas[i - 1], vals[i - 1]
);
}
}
#[test]
fn cross_entropy_decomposition(
p in simplex_vec_pos(8, 1e-6),
q in simplex_vec_pos(8, 1e-6),
) {
let h_pq = cross_entropy_nats(&p, &q, 1e-6).unwrap();
let h_p = entropy_nats(&p, 1e-6).unwrap();
let kl = kl_divergence(&p, &q, 1e-6).unwrap();
prop_assert!(
(h_pq - (h_p + kl)).abs() < 1e-9,
"H(p,q)={h_pq} != H(p)+KL={}", h_p + kl
);
}
#[test]
fn bhattacharyya_renyi_consistency(
p in simplex_vec_pos(8, 1e-6),
q in simplex_vec_pos(8, 1e-6),
) {
let renyi_half = renyi_divergence(&p, &q, 0.5, 1e-6).unwrap();
let bc = bhattacharyya_coeff(&p, &q, 1e-6).unwrap();
let expected = -2.0 * bc.ln();
prop_assert!(
(renyi_half - expected).abs() < 1e-8,
"Renyi(0.5)={renyi_half}, -2*ln(BC)={expected}"
);
}
#[test]
fn csiszar_hellinger_consistency(
p in simplex_vec_pos(8, 1e-6),
q in simplex_vec_pos(8, 1e-6),
) {
let h2 = hellinger_squared(&p, &q, 1e-6).unwrap();
let cs = csiszar_f_divergence(&p, &q, |t| (t.sqrt() - 1.0).powi(2), 1e-6).unwrap();
prop_assert!(
(cs - 2.0 * h2).abs() < 1e-8,
"Csiszar(Hellinger)={cs}, 2*H^2={}", 2.0 * h2
);
}
#[test]
fn pinsker_tightness_for_nearby_distributions(
p in simplex_vec_pos(8, 1e-6),
) {
let n = p.len();
let q: Vec<f64> = p.iter().map(|&pi| 0.99 * pi + 0.01 / n as f64).collect();
let kl = kl_divergence(&p, &q, 1e-6).unwrap();
let d1: f64 = p.iter().zip(q.iter()).map(|(&a, &b)| (a - b).abs()).sum();
let pinsker_rhs = 0.5 * d1 * d1;
prop_assert!(kl + 1e-12 >= pinsker_rhs, "kl={kl} pinsker_rhs={pinsker_rhs}");
if pinsker_rhs > 1e-20 {
let ratio = kl / pinsker_rhs;
prop_assert!(ratio < 1000.0, "Pinsker ratio too large: {ratio}");
}
}
#[test]
fn total_variation_satisfies_triangle(
p in simplex_vec(8),
q in simplex_vec(8),
r in simplex_vec(8),
) {
let tv_pq = total_variation(&p, &q, 1e-6).unwrap();
let tv_qr = total_variation(&q, &r, 1e-6).unwrap();
let tv_pr = total_variation(&p, &r, 1e-6).unwrap();
prop_assert!(tv_pr <= tv_pq + tv_qr + 1e-10);
}
#[test]
fn chi_squared_matches_csiszar(
p in simplex_vec_pos(8, 1e-6),
q in simplex_vec_pos(8, 1e-6),
) {
let chi2 = chi_squared_divergence(&p, &q, 1e-6).unwrap();
let cs = csiszar_f_divergence(&p, &q, |t| (t - 1.0).powi(2), 1e-6).unwrap();
prop_assert!(
(chi2 - cs).abs() < 1e-8,
"chi2={chi2}, csiszar={cs}"
);
}
#[test]
fn renyi_entropy_monotone_in_alpha(
p in simplex_vec_pos(8, 1e-6),
) {
let alphas = [0.1, 0.25, 0.5, 0.75, 0.99];
let vals: Vec<f64> = alphas.iter()
.map(|&a| renyi_entropy(&p, a, 1e-6).unwrap())
.collect();
for i in 1..vals.len() {
prop_assert!(
vals[i] <= vals[i - 1] + 1e-9,
"H_alpha not monotone: H({})={} > H({})={}",
alphas[i], vals[i], alphas[i - 1], vals[i - 1]
);
}
}
}
#[test]
fn weighted_js_at_extreme_weights() {
let p = [0.3, 0.7];
let q = [0.5, 0.5];
let js0 = jensen_shannon_weighted(&p, &q, 0.0, TOL).unwrap();
assert!(js0.abs() < 1e-12, "JS(pi=0)={js0}");
}
#[test]
fn conditional_entropy_chain_rule() {
let p_xy = [0.2, 0.1, 0.3, 0.4]; let h_xy = entropy_nats(&p_xy, TOL).unwrap();
let p_y = [p_xy[0] + p_xy[2], p_xy[1] + p_xy[3]];
let h_y = entropy_nats(&p_y, TOL).unwrap();
let h_x_given_y = conditional_entropy(&p_xy, 2, 2, TOL).unwrap();
assert!(
(h_x_given_y - (h_xy - h_y)).abs() < 1e-10,
"H(X|Y)={h_x_given_y}, H(X,Y)-H(Y)={}",
h_xy - h_y
);
}
#[test]
fn conditional_entropy_nonnegative() {
let p_xy = [0.1, 0.2, 0.3, 0.4];
let h = conditional_entropy(&p_xy, 2, 2, TOL).unwrap();
assert!(h >= -1e-12, "H(X|Y) negative: {h}");
}
#[test]
fn nmi_bounds() {
let p_xy = [0.1, 0.2, 0.3, 0.4];
let nmi = normalized_mutual_information(&p_xy, 2, 2, TOL).unwrap();
assert!((-1e-12..=1.0 + 1e-12).contains(&nmi), "nmi={nmi}");
}
#[test]
fn total_variation_self_is_zero() {
let p = [0.3, 0.7];
assert!(total_variation(&p, &p, TOL).unwrap().abs() < 1e-15);
}
#[test]
fn total_variation_disjoint_is_one() {
let a = [1.0, 0.0];
let b = [0.0, 1.0];
assert!((total_variation(&a, &b, TOL).unwrap() - 1.0).abs() < 1e-12);
}
#[test]
fn chi_squared_self_is_zero() {
let p = [0.3, 0.7];
assert!(chi_squared_divergence(&p, &p, TOL).unwrap().abs() < 1e-15);
}
#[test]
fn chi_squared_upper_bounds_kl() {
let p = [0.2, 0.3, 0.5];
let q = [0.4, 0.4, 0.2];
let kl = kl_divergence(&p, &q, TOL).unwrap();
let chi2 = chi_squared_divergence(&p, &q, TOL).unwrap();
assert!(
kl <= (1.0 + chi2).ln() + 1e-10,
"kl={kl} > ln(1+chi2)={}",
(1.0 + chi2).ln()
);
}
#[test]
fn renyi_entropy_uniform_is_ln_n() {
let p = [0.25, 0.25, 0.25, 0.25];
for alpha in [0.5, 2.0, 3.0, 10.0] {
let h = renyi_entropy(&p, alpha, TOL).unwrap();
let expected = 4.0_f64.ln();
assert!(
(h - expected).abs() < 1e-12,
"H_{alpha}(uniform) = {h}, expected {expected}"
);
}
}
#[test]
fn tsallis_entropy_delta_is_zero() {
let delta = [1.0, 0.0, 0.0];
for alpha in [0.5, 2.0, 3.0] {
let s = tsallis_entropy(&delta, alpha, TOL).unwrap();
assert!(s.abs() < 1e-12, "Tsallis({alpha}) of delta = {s}");
}
}
#[test]
fn renyi_entropy_collision() {
let p = [0.3, 0.7];
let h2 = renyi_entropy(&p, 2.0, TOL).unwrap();
let expected = -(0.3_f64.powi(2) + 0.7_f64.powi(2)).ln();
assert!(
(h2 - expected).abs() < 1e-12,
"H_2={h2} expected={expected}"
);
}
#[test]
fn neg_entropy_bregman_matches_kl_on_simplex() {
let p = [0.2, 0.3, 0.5];
let q = [0.4, 0.4, 0.2];
let kl = kl_divergence(&p, &q, TOL).unwrap();
let gen = NegEntropy;
let breg = bregman_divergence(&gen, &p, &q).unwrap();
assert!(
(breg - kl).abs() < 1e-10,
"Bregman(NegEntropy)={breg}, KL={kl}"
);
}
#[test]
fn neg_entropy_bregman_self_is_zero() {
let p = [0.3, 0.7];
let gen = NegEntropy;
let breg = bregman_divergence(&gen, &p, &p).unwrap();
assert!(breg.abs() < 1e-14, "Bregman(p,p)={breg}");
}
#[test]
fn log_sum_exp_iter_matches_slice() {
let values = [1.0, 2.0, 3.0, -1.0, 0.5];
let lse_slice = log_sum_exp(&values);
let lse_iter = log_sum_exp_iter(values.iter().copied());
assert!(
(lse_slice - lse_iter).abs() < 1e-12,
"slice={lse_slice} iter={lse_iter}"
);
}
#[test]
fn log_sum_exp_iter_empty() {
assert_eq!(log_sum_exp_iter(std::iter::empty()), f64::NEG_INFINITY);
}
#[test]
fn log_sum_exp_iter_single() {
assert_eq!(log_sum_exp_iter(std::iter::once(42.0)), 42.0);
}
#[test]
fn log_sum_exp_iter_large_values() {
let lse = log_sum_exp_iter([1000.0, 0.0].iter().copied());
assert!((lse - 1000.0).abs() < 1e-10);
}
#[test]
fn data_processing_inequality_mi() {
let p_xy = [0.3, 0.1, 0.05, 0.05, 0.1, 0.2, 0.05, 0.15]; let n_x = 2;
let n_y = 4;
let mi_full = mutual_information(&p_xy, n_x, n_y, TOL).unwrap();
let mut p_coarse = [0.0; 4]; for i in 0..n_x {
p_coarse[i * 2] = p_xy[i * n_y] + p_xy[i * n_y + 1];
p_coarse[i * 2 + 1] = p_xy[i * n_y + 2] + p_xy[i * n_y + 3];
}
let mi_coarse = mutual_information(&p_coarse, n_x, 2, TOL).unwrap();
assert!(
mi_coarse <= mi_full + 1e-10,
"DPI violated: MI(coarse)={mi_coarse} > MI(full)={mi_full}"
);
}
#[test]
fn weighted_js_bounded_by_entropy_of_weights() {
let p = [0.1, 0.9];
let q = [0.9, 0.1];
let pi1 = 0.3;
let jsw = jensen_shannon_weighted(&p, &q, pi1, TOL).unwrap();
let pi2 = 1.0 - pi1;
let h_pi = -(pi1 * pi1.ln() + pi2 * pi2.ln());
assert!(jsw <= h_pi + 1e-10, "JS_pi={jsw} > H(pi)={h_pi}");
}
#[test]
fn renyi_entropy_approaches_shannon_as_alpha_to_one() {
let p = [0.2, 0.3, 0.5];
let h_shannon = entropy_nats(&p, TOL).unwrap();
let h_099 = renyi_entropy(&p, 0.99, TOL).unwrap();
let h_0999 = renyi_entropy(&p, 0.999, TOL).unwrap();
let h_101 = renyi_entropy(&p, 1.01, TOL).unwrap();
assert!((h_099 - h_shannon).abs() < 0.01);
assert!((h_0999 - h_shannon).abs() < 0.001);
assert!((h_101 - h_shannon).abs() < 0.01);
}
#[test]
fn tsallis_entropy_approaches_shannon_as_alpha_to_one() {
let p = [0.2, 0.3, 0.5];
let h_shannon = entropy_nats(&p, TOL).unwrap();
let s_099 = tsallis_entropy(&p, 0.99, TOL).unwrap();
let s_0999 = tsallis_entropy(&p, 0.999, TOL).unwrap();
assert!((s_099 - h_shannon).abs() < 0.01);
assert!((s_0999 - h_shannon).abs() < 0.001);
}
#[test]
fn bhattacharyya_precision_near_identical() {
let p = [0.5 + 1e-15, 0.5 - 1e-15];
let q = [0.5, 0.5];
let bc = bhattacharyya_coeff(&p, &q, 1e-6).unwrap();
assert!(
(bc - 1.0).abs() < 1e-14,
"BC should be very close to 1.0: {bc}"
);
let h2 = hellinger_squared(&p, &q, 1e-6).unwrap();
assert!(h2 < 1e-14, "h2 should be tiny: {h2}");
assert!(h2.is_finite(), "h2 should be finite");
}
#[test]
fn renyi_alpha_sweep_continuity() {
let p = [0.2, 0.3, 0.5];
let tol = 1e-9;
let mut prev_renyi = renyi_entropy(&p, 0.5, tol).unwrap();
let mut prev_tsallis = tsallis_entropy(&p, 0.5, tol).unwrap();
let mut alpha = 0.6;
while alpha <= 2.0 + 1e-9 {
let r = renyi_entropy(&p, alpha, tol).unwrap();
let t = tsallis_entropy(&p, alpha, tol).unwrap();
let jump_r = (r - prev_renyi).abs();
let jump_t = (t - prev_tsallis).abs();
assert!(
jump_r < 0.5,
"Renyi discontinuity at alpha={alpha}: jump={jump_r}"
);
assert!(
jump_t < 0.5,
"Tsallis discontinuity at alpha={alpha}: jump={jump_t}"
);
prev_renyi = r;
prev_tsallis = t;
alpha += 0.1;
}
}
#[test]
fn ksg_ties_finite() {
let x: Vec<Vec<f64>> = (0..50).map(|i| vec![(i % 5) as f64]).collect();
let y: Vec<Vec<f64>> = (0..50).map(|i| vec![(i % 3) as f64]).collect();
let mi1 = mutual_information_ksg(&x, &y, 3, KsgVariant::Alg1).unwrap();
let mi2 = mutual_information_ksg(&x, &y, 3, KsgVariant::Alg2).unwrap();
assert!(
mi1.is_finite(),
"KSG Alg1 with ties returned NaN/Inf: {mi1}"
);
assert!(
mi2.is_finite(),
"KSG Alg2 with ties returned NaN/Inf: {mi2}"
);
}
proptest! {
#![proptest_config(ProptestConfig { cases: 64, .. ProptestConfig::default() })]
#[test]
fn bregman_nonnegative(
p in simplex_vec_pos(8, 1e-6),
q in simplex_vec_pos(8, 1e-6),
) {
let gen = NegEntropy;
let b = bregman_divergence(&gen, &p, &q).unwrap();
prop_assert!(b >= -1e-12, "Bregman(NegEntropy) negative: {b}");
}
#[test]
fn renyi_divergence_alpha1_equals_kl(
p in simplex_vec_pos(8, 1e-6),
q in simplex_vec_pos(8, 1e-6),
) {
let tol = 1e-6;
let kl = kl_divergence(&p, &q, tol).unwrap();
let r1 = renyi_divergence(&p, &q, 1.0, tol).unwrap();
prop_assert!(
(r1 - kl).abs() < 1e-9,
"renyi(alpha=1)={r1} != kl={kl}"
);
}
}
#[test]
fn pmi_impossible_input_errors() {
assert!(pmi(0.1, 0.0, 0.5).is_err());
assert!(pmi(0.1, 0.5, 0.0).is_err());
}
}