use crate::error::{StatsError, StatsResult};
fn validate_probability_vector(pk: &[f64], name: &str) -> StatsResult<()> {
if pk.is_empty() {
return Err(StatsError::InsufficientData(format!(
"{name}: probability vector must not be empty"
)));
}
for &p in pk {
if p < 0.0 || p.is_nan() {
return Err(StatsError::InvalidArgument(format!(
"{name}: all probabilities must be non-negative, found {p}"
)));
}
}
let s: f64 = pk.iter().sum();
if (s - 1.0).abs() > 1e-6 {
return Err(StatsError::InvalidArgument(format!(
"{name}: probabilities must sum to 1, got {s}"
)));
}
Ok(())
}
fn normalise(v: &[f64]) -> StatsResult<Vec<f64>> {
let s: f64 = v.iter().sum();
if s < f64::EPSILON {
return Err(StatsError::InvalidArgument(
"cannot normalise: sum is zero or negative".into(),
));
}
Ok(v.iter().map(|&x| x / s).collect())
}
#[inline]
fn log_base(x: f64, base: Option<f64>) -> f64 {
match base {
None => x.ln(),
Some(b) => x.log2() / b.log2(), }
}
fn joint_histogram(
x: &[f64],
y: &[f64],
bins: usize,
) -> StatsResult<(Vec<Vec<f64>>, Vec<f64>, Vec<f64>)> {
if x.is_empty() || y.is_empty() {
return Err(StatsError::InsufficientData(
"joint_histogram: data arrays must not be empty".into(),
));
}
if x.len() != y.len() {
return Err(StatsError::DimensionMismatch(format!(
"joint_histogram: x and y must have the same length, got {} vs {}",
x.len(),
y.len()
)));
}
if bins == 0 {
return Err(StatsError::InvalidArgument(
"bins must be at least 1".into(),
));
}
let n = x.len();
let x_min = x.iter().cloned().fold(f64::INFINITY, f64::min);
let x_max = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let y_min = y.iter().cloned().fold(f64::INFINITY, f64::min);
let y_max = y.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let x_range = if (x_max - x_min).abs() < f64::EPSILON { 1.0 } else { x_max - x_min };
let y_range = if (y_max - y_min).abs() < f64::EPSILON { 1.0 } else { y_max - y_min };
let mut counts = vec![vec![0usize; bins]; bins];
for (&xi, &yi) in x.iter().zip(y.iter()) {
let bx = ((((xi - x_min) / x_range) * bins as f64)
.floor() as usize)
.min(bins - 1);
let by = ((((yi - y_min) / y_range) * bins as f64)
.floor() as usize)
.min(bins - 1);
counts[bx][by] += 1;
}
let nf = n as f64;
let joint: Vec<Vec<f64>> = counts
.iter()
.map(|row| row.iter().map(|&c| c as f64 / nf).collect())
.collect();
let px: Vec<f64> = joint.iter().map(|row| row.iter().sum::<f64>()).collect();
let py: Vec<f64> = (0..bins)
.map(|j| joint.iter().map(|row| row[j]).sum::<f64>())
.collect();
Ok((joint, px, py))
}
pub fn entropy(pk: &[f64], base: Option<f64>) -> StatsResult<f64> {
validate_probability_vector(pk, "entropy")?;
if let Some(b) = base {
if b <= 0.0 || (b - 1.0).abs() < f64::EPSILON {
return Err(StatsError::InvalidArgument(
"logarithm base must be positive and ≠ 1".into(),
));
}
}
let h = pk
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * log_base(p, base))
.sum::<f64>();
Ok(h)
}
pub fn joint_entropy(x: &[f64], y: &[f64], bins: usize) -> StatsResult<f64> {
let (joint, _, _) = joint_histogram(x, y, bins)?;
let h = joint
.iter()
.flat_map(|row| row.iter())
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln())
.sum::<f64>();
Ok(h)
}
pub fn conditional_entropy(x: &[f64], y: &[f64], bins: usize) -> StatsResult<f64> {
let (joint, _, py) = joint_histogram(x, y, bins)?;
let mut h = 0.0_f64;
for bx in 0..bins {
for by in 0..bins {
let pxy = joint[bx][by];
let py_val = py[by];
if pxy > 0.0 && py_val > 0.0 {
h += pxy * (py_val / pxy).ln();
}
}
}
Ok(h.max(0.0))
}
pub fn mutual_information(x: &[f64], y: &[f64], bins: usize) -> StatsResult<f64> {
let (joint, px, py) = joint_histogram(x, y, bins)?;
let mut mi = 0.0_f64;
for bx in 0..bins {
for by in 0..bins {
let pxy = joint[bx][by];
let px_val = px[bx];
let py_val = py[by];
if pxy > 0.0 && px_val > 0.0 && py_val > 0.0 {
mi += pxy * (pxy / (px_val * py_val)).ln();
}
}
}
Ok(mi.max(0.0))
}
pub fn normalized_mutual_information(x: &[f64], y: &[f64], bins: usize) -> StatsResult<f64> {
let (joint, px, py) = joint_histogram(x, y, bins)?;
let hx: f64 = px.iter().filter(|&&p| p > 0.0).map(|&p| -p * p.ln()).sum();
let hy: f64 = py.iter().filter(|&&p| p > 0.0).map(|&p| -p * p.ln()).sum();
let mut mi = 0.0_f64;
for bx in 0..bins {
for by in 0..bins {
let pxy = joint[bx][by];
let px_val = px[bx];
let py_val = py[by];
if pxy > 0.0 && px_val > 0.0 && py_val > 0.0 {
mi += pxy * (pxy / (px_val * py_val)).ln();
}
}
}
mi = mi.max(0.0);
let denom = hx + hy;
if denom < f64::EPSILON {
return Ok(0.0); }
Ok((2.0 * mi / denom).min(1.0))
}
pub fn kl_divergence(p: &[f64], q: &[f64]) -> StatsResult<f64> {
validate_probability_vector(p, "kl_divergence(p)")?;
validate_probability_vector(q, "kl_divergence(q)")?;
if p.len() != q.len() {
return Err(StatsError::DimensionMismatch(format!(
"kl_divergence: p and q must have the same length, got {} vs {}",
p.len(),
q.len()
)));
}
let mut kl = 0.0_f64;
for (&pi, &qi) in p.iter().zip(q.iter()) {
if pi > 0.0 {
if qi <= 0.0 {
return Err(StatsError::DomainError(
"kl_divergence: Q(x) = 0 where P(x) > 0 → KL divergence is infinite".into(),
));
}
kl += pi * (pi / qi).ln();
}
}
Ok(kl)
}
pub fn js_divergence(p: &[f64], q: &[f64]) -> StatsResult<f64> {
validate_probability_vector(p, "js_divergence(p)")?;
validate_probability_vector(q, "js_divergence(q)")?;
if p.len() != q.len() {
return Err(StatsError::DimensionMismatch(format!(
"js_divergence: p and q must have the same length, got {} vs {}",
p.len(),
q.len()
)));
}
let m: Vec<f64> = p.iter().zip(q.iter()).map(|(&pi, &qi)| (pi + qi) * 0.5).collect();
let mut kl_pm = 0.0_f64;
for (&pi, &mi) in p.iter().zip(m.iter()) {
if pi > 0.0 {
kl_pm += pi * (pi / mi).ln();
}
}
let mut kl_qm = 0.0_f64;
for (&qi, &mi) in q.iter().zip(m.iter()) {
if qi > 0.0 {
kl_qm += qi * (qi / mi).ln();
}
}
Ok(0.5 * kl_pm + 0.5 * kl_qm)
}
pub fn total_variation(p: &[f64], q: &[f64]) -> StatsResult<f64> {
validate_probability_vector(p, "total_variation(p)")?;
validate_probability_vector(q, "total_variation(q)")?;
if p.len() != q.len() {
return Err(StatsError::DimensionMismatch(format!(
"total_variation: p and q must have the same length, got {} vs {}",
p.len(),
q.len()
)));
}
let tv: f64 = p.iter().zip(q.iter()).map(|(&pi, &qi)| (pi - qi).abs()).sum::<f64>() * 0.5;
Ok(tv)
}
pub fn hellinger_distance(p: &[f64], q: &[f64]) -> StatsResult<f64> {
validate_probability_vector(p, "hellinger_distance(p)")?;
validate_probability_vector(q, "hellinger_distance(q)")?;
if p.len() != q.len() {
return Err(StatsError::DimensionMismatch(format!(
"hellinger_distance: p and q must have the same length, got {} vs {}",
p.len(),
q.len()
)));
}
let sum_sq: f64 = p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (pi.sqrt() - qi.sqrt()).powi(2))
.sum();
Ok((sum_sq * 0.5).sqrt())
}
pub fn aic(log_likelihood: f64, k: usize) -> f64 {
2.0 * k as f64 - 2.0 * log_likelihood
}
pub fn bic(log_likelihood: f64, k: usize, n: usize) -> f64 {
k as f64 * (n as f64).ln() - 2.0 * log_likelihood
}
pub fn aicc(log_likelihood: f64, k: usize, n: usize) -> f64 {
let base = aic(log_likelihood, k);
let denom = n as f64 - k as f64 - 1.0;
if denom <= 0.0 {
return f64::INFINITY;
}
base + 2.0 * k as f64 * (k as f64 + 1.0) / denom
}
pub fn hqic(log_likelihood: f64, k: usize, n: usize) -> f64 {
2.0 * k as f64 * (n as f64).ln().ln() - 2.0 * log_likelihood
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_entropy_uniform_bits() {
let uniform = vec![0.25_f64; 4];
let h = entropy(&uniform, Some(2.0)).expect("ok");
assert!((h - 2.0).abs() < 1e-10, "h={h}");
}
#[test]
fn test_entropy_deterministic() {
let certain = vec![1.0_f64, 0.0, 0.0];
let h = entropy(&certain, None).expect("ok");
assert_eq!(h, 0.0);
}
#[test]
fn test_entropy_nats_binary() {
let p = vec![0.5_f64, 0.5];
let h = entropy(&p, None).expect("ok");
assert!((h - 2.0_f64.ln()).abs() < 1e-10, "h={h}");
}
#[test]
fn test_entropy_invalid_negative() {
assert!(entropy(&[-0.1, 1.1], None).is_err());
}
#[test]
fn test_entropy_invalid_sum() {
assert!(entropy(&[0.3, 0.3], None).is_err());
}
#[test]
fn test_entropy_empty() {
assert!(entropy(&[], None).is_err());
}
#[test]
fn test_joint_entropy_non_negative() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let h = joint_entropy(&x, &y, 5).expect("ok");
assert!(h >= 0.0, "h={h}");
}
#[test]
fn test_joint_entropy_identical_data() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let h = joint_entropy(&x, &x, 4).expect("ok");
assert!(h >= 0.0);
}
#[test]
fn test_joint_entropy_length_mismatch() {
assert!(joint_entropy(&[1.0, 2.0], &[1.0], 2).is_err());
}
#[test]
fn test_conditional_entropy_non_negative() {
let x = vec![1.0, 2.0, 1.0, 2.0, 3.0];
let y = vec![1.0, 1.0, 2.0, 2.0, 3.0];
let h = conditional_entropy(&x, &y, 3).expect("ok");
assert!(h >= 0.0, "h={h}");
}
#[test]
fn test_conditional_entropy_given_self() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let h = conditional_entropy(&x, &x, 4).expect("ok");
assert!(h >= 0.0);
assert!(h < 0.5, "H(X|X) too large: {h}");
}
#[test]
fn test_mutual_information_non_negative() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let mi = mutual_information(&x, &y, 5).expect("ok");
assert!(mi >= 0.0, "mi={mi}");
}
#[test]
fn test_mutual_information_self() {
let x: Vec<f64> = (1..=10).map(|i| i as f64).collect();
let mi = mutual_information(&x, &x, 10).expect("ok");
assert!(mi >= 0.0);
}
#[test]
fn test_nmi_range() {
let x: Vec<f64> = (0..20).map(|i| i as f64).collect();
let y: Vec<f64> = x.iter().map(|&v| v * 2.0).collect();
let nmi = normalized_mutual_information(&x, &y, 10).expect("ok");
assert!(nmi >= 0.0 && nmi <= 1.0 + 1e-9, "nmi={nmi}");
}
#[test]
fn test_kl_divergence_identical() {
let p = vec![0.2_f64, 0.3, 0.5];
let kl = kl_divergence(&p, &p).expect("ok");
assert!(kl.abs() < 1e-10, "kl={kl}");
}
#[test]
fn test_kl_divergence_asymmetry() {
let p = vec![0.9_f64, 0.1];
let q = vec![0.5_f64, 0.5];
let kl_pq = kl_divergence(&p, &q).expect("ok");
let kl_qp = kl_divergence(&q, &p).expect("ok");
assert!(kl_pq > 0.0);
assert!((kl_pq - kl_qp).abs() > 1e-6, "KL should be asymmetric");
}
#[test]
fn test_kl_divergence_q_zero_error() {
let p = vec![0.5_f64, 0.5];
let q = vec![1.0_f64, 0.0];
assert!(kl_divergence(&p, &q).is_err());
}
#[test]
fn test_kl_divergence_p_zero_ok() {
let p = vec![1.0_f64, 0.0];
let q = vec![0.5_f64, 0.5];
let kl = kl_divergence(&p, &q).expect("ok");
assert!(kl >= 0.0);
}
#[test]
fn test_kl_divergence_length_mismatch() {
assert!(kl_divergence(&[0.5, 0.5], &[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]).is_err());
}
#[test]
fn test_js_divergence_symmetric() {
let p = vec![0.7_f64, 0.3];
let q = vec![0.4_f64, 0.6];
let js_pq = js_divergence(&p, &q).expect("ok");
let js_qp = js_divergence(&q, &p).expect("ok");
assert!((js_pq - js_qp).abs() < 1e-10, "JS should be symmetric");
}
#[test]
fn test_js_divergence_identical_zero() {
let p = vec![0.3_f64, 0.4, 0.3];
let js = js_divergence(&p, &p).expect("ok");
assert!(js.abs() < 1e-10, "js={js}");
}
#[test]
fn test_js_divergence_max() {
let p = vec![1.0_f64, 0.0];
let q = vec![0.0_f64, 1.0];
let js = js_divergence(&p, &q).expect("ok");
assert!((js - 2.0_f64.ln()).abs() < 1e-10, "js={js}");
}
#[test]
fn test_total_variation_identical() {
let p = vec![0.5_f64, 0.5];
let tv = total_variation(&p, &p).expect("ok");
assert_eq!(tv, 0.0);
}
#[test]
fn test_total_variation_disjoint() {
let p = vec![1.0_f64, 0.0];
let q = vec![0.0_f64, 1.0];
let tv = total_variation(&p, &q).expect("ok");
assert!((tv - 1.0).abs() < 1e-10, "tv={tv}");
}
#[test]
fn test_total_variation_half() {
let p = vec![0.5_f64, 0.5];
let q = vec![1.0_f64, 0.0];
let tv = total_variation(&p, &q).expect("ok");
assert!((tv - 0.5).abs() < 1e-10, "tv={tv}");
}
#[test]
fn test_hellinger_identical() {
let p = vec![0.25_f64; 4];
let h = hellinger_distance(&p, &p).expect("ok");
assert!(h.abs() < 1e-10, "h={h}");
}
#[test]
fn test_hellinger_disjoint() {
let p = vec![1.0_f64, 0.0];
let q = vec![0.0_f64, 1.0];
let h = hellinger_distance(&p, &q).expect("ok");
assert!((h - 1.0).abs() < 1e-10, "h={h}");
}
#[test]
fn test_hellinger_range() {
let p = vec![0.7_f64, 0.3];
let q = vec![0.2_f64, 0.8];
let h = hellinger_distance(&p, &q).expect("ok");
assert!(h >= 0.0 && h <= 1.0 + 1e-10, "h={h}");
}
#[test]
fn test_aic_formula() {
let ll = -50.0_f64;
let k = 3;
assert_eq!(aic(ll, k), 2.0 * 3.0 - 2.0 * (-50.0));
}
#[test]
fn test_bic_formula() {
let ll = -50.0_f64;
let k = 3;
let n = 100;
let expected = 3.0 * (100_f64).ln() - 2.0 * (-50.0);
assert!((bic(ll, k, n) - expected).abs() < 1e-10);
}
#[test]
fn test_aicc_greater_than_aic() {
let ll = -50.0_f64;
let k = 3;
let n = 20;
assert!(aicc(ll, k, n) >= aic(ll, k) - 1e-10);
}
#[test]
fn test_aicc_converges_to_aic_large_n() {
let ll = -50.0_f64;
let k = 3;
let large_n = 1_000_000;
let correction = (aicc(ll, k, large_n) - aic(ll, k)).abs();
assert!(correction < 1e-3, "correction={correction}");
}
#[test]
fn test_hqic_formula() {
let ll = -100.0_f64;
let k = 5;
let n = 100;
let expected = 2.0 * 5.0 * (100_f64).ln().ln() - 2.0 * (-100.0);
assert!((hqic(ll, k, n) - expected).abs() < 1e-10);
}
#[test]
fn test_aic_bic_ordering() {
let ll = -100.0_f64;
let k = 5;
let n = 1000;
let a = aic(ll, k);
let b = bic(ll, k, n);
assert!(b > a, "bic={b} should exceed aic={a} for n={n}");
}
}