use crate::{digamma, Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KsgVariant {
Alg1,
Alg2,
}
pub fn mutual_information_ksg(
x: &[Vec<f64>],
y: &[Vec<f64>],
k: usize,
variant: KsgVariant,
) -> Result<f64> {
let n = x.len();
if n != y.len() {
return Err(Error::LengthMismatch(n, y.len()));
}
if n <= k {
return Err(Error::Domain("Sample size must be greater than k"));
}
if k == 0 {
return Err(Error::Domain("k must be >= 1"));
}
let dx = x.first().map(|v| v.len()).unwrap_or(0);
let dy = y.first().map(|v| v.len()).unwrap_or(0);
if dx == 0 || dy == 0 {
return Err(Error::Domain("x and y must have non-empty feature vectors"));
}
if x.iter().any(|v| v.len() != dx) {
return Err(Error::Domain("x has inconsistent sample dimensionality"));
}
if y.iter().any(|v| v.len() != dy) {
return Err(Error::Domain("y has inconsistent sample dimensionality"));
}
let mut nx = vec![0usize; n];
let mut ny = vec![0usize; n];
for i in 0..n {
let mut joint_dists = Vec::with_capacity(n);
for j in 0..n {
if i == j {
continue;
}
let dx = dist_inf(&x[i], &x[j]);
let dy = dist_inf(&y[i], &y[j]);
joint_dists.push(dx.max(dy));
}
joint_dists.sort_by(|a, b| a.total_cmp(b));
let eps = joint_dists[k - 1];
match variant {
KsgVariant::Alg1 => {
nx[i] = x
.iter()
.enumerate()
.filter(|(j, _)| i != *j && dist_inf(&x[i], &x[*j]) < eps)
.count();
ny[i] = y
.iter()
.enumerate()
.filter(|(j, _)| i != *j && dist_inf(&y[i], &y[*j]) < eps)
.count();
}
KsgVariant::Alg2 => {
nx[i] = x
.iter()
.enumerate()
.filter(|(j, _)| dist_inf(&x[i], &x[*j]) <= eps || i == *j)
.count();
ny[i] = y
.iter()
.enumerate()
.filter(|(j, _)| dist_inf(&y[i], &y[*j]) <= eps || i == *j)
.count();
}
}
}
match variant {
KsgVariant::Alg1 => {
let avg_psi: f64 = nx
.iter()
.zip(ny.iter())
.map(|(&nxi, &nyi)| digamma(nxi as f64 + 1.0) + digamma(nyi as f64 + 1.0))
.sum::<f64>()
/ n as f64;
Ok(digamma(k as f64) - avg_psi + digamma(n as f64))
}
KsgVariant::Alg2 => {
let avg_psi: f64 = nx
.iter()
.zip(ny.iter())
.map(|(&nxi, &nyi)| digamma(nxi as f64) + digamma(nyi as f64))
.sum::<f64>()
/ n as f64;
Ok(digamma(k as f64) - 1.0 / k as f64 - avg_psi + digamma(n as f64))
}
}
}
fn dist_inf(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(ai, bi)| (ai - bi).abs())
.fold(0.0, f64::max)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ksg_independent_alg1() {
let x = vec![vec![0.1], vec![0.5], vec![0.9], vec![0.2], vec![0.8]];
let y = vec![vec![0.9], vec![0.2], vec![0.1], vec![0.8], vec![0.5]];
let mi = mutual_information_ksg(&x, &y, 2, KsgVariant::Alg1).unwrap();
assert!(mi.is_finite());
assert!(mi.abs() < 2.0);
}
#[test]
fn test_ksg_independent_alg2() {
let x = vec![vec![0.1], vec![0.5], vec![0.9], vec![0.2], vec![0.8]];
let y = vec![vec![0.9], vec![0.2], vec![0.1], vec![0.8], vec![0.5]];
let mi = mutual_information_ksg(&x, &y, 2, KsgVariant::Alg2).unwrap();
assert!(mi.is_finite());
assert!(mi.abs() < 2.0);
}
#[test]
fn test_ksg_correlated_is_larger_than_shuffled() {
let x: Vec<Vec<f64>> = (0..40).map(|i| vec![i as f64 / 40.0]).collect();
let y_corr = x.clone();
let y_shuf: Vec<Vec<f64>> = (0..40).rev().map(|i| vec![i as f64 / 40.0]).collect();
let mi_corr = mutual_information_ksg(&x, &y_corr, 3, KsgVariant::Alg1).unwrap();
let mi_shuf = mutual_information_ksg(&x, &y_shuf, 3, KsgVariant::Alg1).unwrap();
assert!(mi_corr.is_finite() && mi_shuf.is_finite());
assert!(mi_corr > mi_shuf);
}
#[test]
fn test_ksg_rejects_inconsistent_dims() {
let x = vec![vec![0.1], vec![0.2, 0.3]];
let y = vec![vec![0.1], vec![0.2]];
let err = mutual_information_ksg(&x, &y, 1, KsgVariant::Alg1).unwrap_err();
assert!(matches!(err, Error::Domain(_)));
}
#[test]
fn test_ksg_gaussian_ground_truth_alg1() {
let rho: f64 = 0.8;
let mi_true = -0.5_f64 * (1.0 - rho * rho).ln();
let n = 2000;
let (x, y) = correlated_gaussian_samples(n, rho, 12345);
let mi_est = mutual_information_ksg(&x, &y, 5, KsgVariant::Alg1).unwrap();
assert!(mi_est.is_finite(), "MI estimate is not finite: {mi_est}");
let rel_err = (mi_est - mi_true).abs() / mi_true;
assert!(
rel_err < 0.30,
"KSG Alg1 Gaussian ground-truth: mi_est={mi_est:.4}, mi_true={mi_true:.4}, rel_err={rel_err:.3}"
);
}
#[test]
fn test_ksg_gaussian_ground_truth_alg2() {
let rho: f64 = 0.8;
let mi_true = -0.5_f64 * (1.0 - rho * rho).ln();
let n = 4000;
let (x, y) = correlated_gaussian_samples(n, rho, 54321);
let mi_est = mutual_information_ksg(&x, &y, 5, KsgVariant::Alg2).unwrap();
assert!(mi_est.is_finite(), "MI estimate is not finite: {mi_est}");
let rel_err = (mi_est - mi_true).abs() / mi_true;
assert!(
rel_err < 0.40,
"KSG Alg2 Gaussian ground-truth: mi_est={mi_est:.4}, mi_true={mi_true:.4}, rel_err={rel_err:.3}"
);
}
#[test]
fn test_ksg_independent_near_zero() {
let n = 500;
let mut state: u64 = 99999;
let mut next_uniform = || -> f64 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(state >> 11) as f64 / (1u64 << 53) as f64
};
let x: Vec<Vec<f64>> = (0..n).map(|_| vec![next_uniform()]).collect();
let y: Vec<Vec<f64>> = (0..n).map(|_| vec![next_uniform()]).collect();
let mi = mutual_information_ksg(&x, &y, 5, KsgVariant::Alg1).unwrap();
assert!(mi.abs() < 0.15, "Independent MI should be near 0, got {mi}");
}
fn correlated_gaussian_samples(
n: usize,
rho: f64,
seed: u64,
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
let mut state: u64 = seed;
let mut next_uniform = || -> f64 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let u = ((state >> 11) as f64 + 0.5) / ((1u64 << 53) as f64);
u
};
let mut next_normal = || -> f64 {
let u1 = next_uniform();
let u2 = next_uniform();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
};
let mut x = Vec::with_capacity(n);
let mut y = Vec::with_capacity(n);
for _ in 0..n {
let z1 = next_normal();
let z2 = next_normal();
let yi = rho * z1 + (1.0 - rho * rho).sqrt() * z2;
x.push(vec![z1]);
y.push(vec![yi]);
}
(x, y)
}
}