use logp::{mutual_information_ksg, KsgVariant};
fn main() {
let dx = 5; let dy = 5; let rho = 0.5_f64; let k = 6;
let d = dx; let theoretical_mi = -0.5 * d as f64 * (1.0 - rho * rho).ln();
println!(
"KSG Mutual Information: Multivariate Gaussian ({dx}D + {dy}D = {}D joint)",
dx + dy
);
println!(" cross-correlation rho = {rho}");
println!(" theoretical MI = {theoretical_mi:.4} nats");
println!(" k = {k} neighbors");
println!();
let sample_sizes = [200, 500, 1000, 2000, 5000];
println!(
"{:>6} {:>12} {:>12} {:>10} {:>10}",
"n", "MI (theory)", "MI (KSG)", "abs_err", "rel_err%"
);
println!("{}", "-".repeat(54));
for &n in &sample_sizes {
let (xs, ys) = generate_correlated_gaussian(n, dx, dy, rho, 42);
let mi = mutual_information_ksg(&xs, &ys, k, KsgVariant::Alg1).unwrap();
let abs_err = (mi - theoretical_mi).abs();
let rel_err = abs_err / theoretical_mi * 100.0;
println!(
"{:>6} {:>12.4} {:>12.4} {:>10.4} {:>9.1}%",
n, theoretical_mi, mi, abs_err, rel_err
);
}
println!();
println!("The estimate converges toward the true value as n increases.");
println!("A histogram estimator with even 10 bins per dimension would need");
println!(
"10^{} bins for the joint space -- infeasible at these sample sizes.",
dx + dy
);
}
fn generate_correlated_gaussian(
n: usize,
dx: usize,
dy: usize,
rho: f64,
seed: u64,
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
let mut state = seed;
let mut next_uniform = || -> f64 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(state >> 11) as f64 / (1u64 << 53) as f64
};
let mut next_normal = || -> f64 {
let u1 = next_uniform().max(1e-15);
let u2 = next_uniform();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
};
let scale = (1.0 - rho * rho).sqrt();
let mut xs = Vec::with_capacity(n);
let mut ys = Vec::with_capacity(n);
for _ in 0..n {
let x: Vec<f64> = (0..dx).map(|_| next_normal()).collect();
let y: Vec<f64> = (0..dy)
.map(|j| {
let z = next_normal();
if j < dx {
rho * x[j] + scale * z
} else {
z }
})
.collect();
xs.push(x);
ys.push(y);
}
(xs, ys)
}