use super::entropy::{conditional_entropy, joint_entropy, shannon_entropy, LogBase};
use super::{validate_distribution, InfoTheoryError, InfoTheoryResult};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView2, Axis};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum NormalizationType {
Arithmetic,
Geometric,
Max,
Min,
}
pub fn mutual_information(joint_probs: &Array2<f64>) -> Result<f64, NumRs2Error> {
let marginal_x: Array1<f64> = joint_probs.sum_axis(Axis(1));
let marginal_y: Array1<f64> = joint_probs.sum_axis(Axis(0));
let h_x = shannon_entropy(&marginal_x, LogBase::Bits)?;
let h_y = shannon_entropy(&marginal_y, LogBase::Bits)?;
let h_xy = joint_entropy(joint_probs, LogBase::Bits)?;
let mi = h_x + h_y - h_xy;
Ok(mi.max(0.0))
}
pub fn conditional_mutual_information(joint_probs: &Array3<f64>) -> Result<f64, NumRs2Error> {
if joint_probs.is_empty() {
return Err(NumRs2Error::InvalidInput(
"Joint probability array is empty".to_string(),
));
}
let shape = joint_probs.shape();
let (nx, ny, nz) = (shape[0], shape[1], shape[2]);
let flat = joint_probs.iter().cloned().collect::<Vec<_>>();
let flat_array = Array1::from_vec(flat);
let h_xyz = shannon_entropy(&flat_array, LogBase::Bits)?;
let mut joint_xz = Array2::zeros((nx, nz));
for i in 0..nx {
for k in 0..nz {
for j in 0..ny {
joint_xz[[i, k]] += joint_probs[[i, j, k]];
}
}
}
let h_xz = joint_entropy(&joint_xz, LogBase::Bits)?;
let mut joint_yz = Array2::zeros((ny, nz));
for j in 0..ny {
for k in 0..nz {
for i in 0..nx {
joint_yz[[j, k]] += joint_probs[[i, j, k]];
}
}
}
let h_yz = joint_entropy(&joint_yz, LogBase::Bits)?;
let mut marginal_z = Array1::zeros(nz);
for k in 0..nz {
for i in 0..nx {
for j in 0..ny {
marginal_z[k] += joint_probs[[i, j, k]];
}
}
}
let h_z = shannon_entropy(&marginal_z, LogBase::Bits)?;
let cmi = h_xz + h_yz - h_z - h_xyz;
Ok(cmi.max(0.0))
}
pub fn normalized_mutual_information(
joint_probs: &Array2<f64>,
norm_type: NormalizationType,
) -> Result<f64, NumRs2Error> {
let mi = mutual_information(joint_probs)?;
let marginal_x: Array1<f64> = joint_probs.sum_axis(Axis(1));
let marginal_y: Array1<f64> = joint_probs.sum_axis(Axis(0));
let h_x = shannon_entropy(&marginal_x, LogBase::Bits)?;
let h_y = shannon_entropy(&marginal_y, LogBase::Bits)?;
if h_x == 0.0 && h_y == 0.0 {
return Ok(1.0);
}
let nmi = match norm_type {
NormalizationType::Arithmetic => {
let avg = (h_x + h_y) / 2.0;
if avg == 0.0 {
0.0
} else {
mi / avg
}
}
NormalizationType::Geometric => {
let geo_mean = (h_x * h_y).sqrt();
if geo_mean == 0.0 {
0.0
} else {
mi / geo_mean
}
}
NormalizationType::Max => {
let max_h = h_x.max(h_y);
if max_h == 0.0 {
0.0
} else {
mi / max_h
}
}
NormalizationType::Min => {
let min_h = h_x.min(h_y);
if min_h == 0.0 {
0.0
} else {
mi / min_h
}
}
};
Ok(nmi.clamp(0.0, 1.0))
}
pub fn pointwise_mutual_information(joint_probs: &Array2<f64>) -> Result<Array2<f64>, NumRs2Error> {
if joint_probs.is_empty() {
return Err(NumRs2Error::InvalidInput(
"Joint probability array is empty".to_string(),
));
}
let shape = joint_probs.shape();
let (nx, ny) = (shape[0], shape[1]);
let marginal_x: Array1<f64> = joint_probs.sum_axis(Axis(1));
let marginal_y: Array1<f64> = joint_probs.sum_axis(Axis(0));
let mut pmi = Array2::zeros((nx, ny));
for i in 0..nx {
for j in 0..ny {
let p_xy = joint_probs[[i, j]];
let p_x = marginal_x[i];
let p_y = marginal_y[j];
if p_xy > 0.0 && p_x > 0.0 && p_y > 0.0 {
pmi[[i, j]] = (p_xy / (p_x * p_y)).ln();
} else {
pmi[[i, j]] = f64::NEG_INFINITY;
}
}
}
Ok(pmi)
}
pub fn variation_of_information(joint_probs: &Array2<f64>) -> Result<f64, NumRs2Error> {
let h_xy = joint_entropy(joint_probs, LogBase::Bits)?;
let mi = mutual_information(joint_probs)?;
Ok(h_xy - mi)
}
pub fn adjusted_mutual_information(joint_probs: &Array2<f64>) -> Result<f64, NumRs2Error> {
if joint_probs.is_empty() {
return Err(NumRs2Error::InvalidInput(
"Joint probability array is empty".to_string(),
));
}
let mi = mutual_information(joint_probs)?;
let marginal_x: Array1<f64> = joint_probs.sum_axis(Axis(1));
let marginal_y: Array1<f64> = joint_probs.sum_axis(Axis(0));
let h_x = shannon_entropy(&marginal_x, LogBase::Bits)?;
let h_y = shannon_entropy(&marginal_y, LogBase::Bits)?;
let expected_mi = compute_expected_mi_hypergeometric(joint_probs)?;
let avg_h = (h_x + h_y) / 2.0;
let denominator = avg_h - expected_mi;
if denominator.abs() < 1e-10 {
if (mi - expected_mi).abs() < 1e-10 {
return Ok(0.0);
} else {
return Ok(1.0);
}
}
let ami = (mi - expected_mi) / denominator;
Ok(ami)
}
fn compute_expected_mi_hypergeometric(joint_probs: &Array2<f64>) -> Result<f64, NumRs2Error> {
let shape = joint_probs.shape();
let (nr, nc) = (shape[0], shape[1]);
let n: i64 = 1000;
let n_f = n as f64;
let marginal_x: Array1<f64> = joint_probs.sum_axis(Axis(1));
let marginal_y: Array1<f64> = joint_probs.sum_axis(Axis(0));
let mut a: Vec<i64> = marginal_x
.iter()
.map(|&p| (p * n_f).round() as i64)
.collect();
let mut b: Vec<i64> = marginal_y
.iter()
.map(|&p| (p * n_f).round() as i64)
.collect();
adjust_counts_to_sum(&mut a, n);
adjust_counts_to_sum(&mut b, n);
let log_fact = precompute_log_factorials(n as usize);
let mut emi = 0.0;
let log_n = n_f.ln();
for i in 0..nr {
let ai = a[i];
if ai == 0 {
continue;
}
for j in 0..nc {
let bj = b[j];
if bj == 0 {
continue;
}
let nij_min = 0_i64.max(ai + bj - n);
let nij_max = ai.min(bj);
for nij in nij_min..=nij_max {
if nij == 0 {
continue;
}
let term1 =
log_fact[ai as usize] - log_fact[nij as usize] - log_fact[(ai - nij) as usize];
let term2 = log_fact[(n - ai) as usize]
- log_fact[(bj - nij) as usize]
- log_fact[(n - ai - bj + nij) as usize];
let term3 =
log_fact[n as usize] - log_fact[bj as usize] - log_fact[(n - bj) as usize];
let log_prob = term1 + term2 - term3;
let log_term = log_n + (nij as f64).ln() - (ai as f64).ln() - (bj as f64).ln();
let log2_term = log_term / std::f64::consts::LN_2;
let nij_over_n = nij as f64 / n_f;
let prob = log_prob.exp();
emi += nij_over_n * log2_term * prob;
}
}
}
Ok(emi.max(0.0))
}
fn precompute_log_factorials(n: usize) -> Vec<f64> {
let mut log_fact = vec![0.0_f64; n + 1];
for k in 1..=n {
log_fact[k] = log_fact[k - 1] + (k as f64).ln();
}
log_fact
}
fn adjust_counts_to_sum(counts: &mut [i64], target: i64) {
let current_sum: i64 = counts.iter().sum();
let mut diff = target - current_sum;
if diff == 0 {
return;
}
let mut indices: Vec<usize> = (0..counts.len()).collect();
indices.sort_by(|&a, &b| counts[b].cmp(&counts[a]));
let step = if diff > 0 { 1 } else { -1 };
let mut idx = 0;
while diff != 0 {
let i = indices[idx % indices.len()];
if counts[i] + step >= 0 {
counts[i] += step;
diff -= step;
}
idx += 1;
if idx > counts.len() * 2 {
break;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
#[test]
fn test_mutual_information_independent() {
let joint = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.25, 0.25])
.expect("array creation failed");
let mi = mutual_information(&joint).expect("mi failed");
assert!(mi.abs() < EPSILON);
}
#[test]
fn test_mutual_information_correlated() {
let joint = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 0.5])
.expect("array creation failed");
let mi = mutual_information(&joint).expect("mi failed");
assert!((mi - 1.0).abs() < EPSILON);
}
#[test]
fn test_mutual_information_partial() {
let joint = Array2::from_shape_vec((2, 2), vec![0.4, 0.1, 0.1, 0.4])
.expect("array creation failed");
let mi = mutual_information(&joint).expect("mi failed");
assert!(mi > 0.0);
assert!(mi < 1.0); }
#[test]
fn test_conditional_mutual_information_independent() {
let joint = Array3::from_shape_vec(
(2, 2, 2),
vec![
0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, ],
)
.expect("array creation failed");
let cmi = conditional_mutual_information(&joint).expect("cmi failed");
assert!(cmi.abs() < EPSILON);
}
#[test]
fn test_normalized_mutual_information_independent() {
let joint = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.25, 0.25])
.expect("array creation failed");
let nmi_arith = normalized_mutual_information(&joint, NormalizationType::Arithmetic)
.expect("nmi failed");
assert!(nmi_arith.abs() < EPSILON);
let nmi_geo = normalized_mutual_information(&joint, NormalizationType::Geometric)
.expect("nmi failed");
assert!(nmi_geo.abs() < EPSILON);
}
#[test]
fn test_normalized_mutual_information_correlated() {
let joint = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 0.5])
.expect("array creation failed");
let nmi_arith = normalized_mutual_information(&joint, NormalizationType::Arithmetic)
.expect("nmi failed");
assert!((nmi_arith - 1.0).abs() < EPSILON);
let nmi_max =
normalized_mutual_information(&joint, NormalizationType::Max).expect("nmi failed");
assert!((nmi_max - 1.0).abs() < EPSILON);
}
#[test]
fn test_pointwise_mutual_information() {
let joint = Array2::from_shape_vec((2, 2), vec![0.4, 0.1, 0.1, 0.4])
.expect("array creation failed");
let pmi = pointwise_mutual_information(&joint).expect("pmi failed");
let expected_00 = (0.4_f64 / 0.25_f64).ln();
assert!((pmi[[0, 0]] - expected_00).abs() < EPSILON);
let expected_01 = (0.1_f64 / 0.25_f64).ln();
assert!((pmi[[0, 1]] - expected_01).abs() < EPSILON);
}
#[test]
fn test_variation_of_information_independent() {
let joint = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.25, 0.25])
.expect("array creation failed");
let vi = variation_of_information(&joint).expect("vi failed");
assert!((vi - 2.0).abs() < EPSILON);
}
#[test]
fn test_variation_of_information_correlated() {
let joint = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 0.5])
.expect("array creation failed");
let vi = variation_of_information(&joint).expect("vi failed");
assert!(vi.abs() < EPSILON);
}
#[test]
fn test_adjusted_mutual_information() {
let joint = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 0.5])
.expect("array creation failed");
let ami = adjusted_mutual_information(&joint).expect("ami failed");
assert!(ami > 0.9);
let joint2 = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.25, 0.25])
.expect("array creation failed");
let ami2 = adjusted_mutual_information(&joint2).expect("ami failed");
assert!(ami2.abs() < 0.1);
}
#[test]
fn test_mutual_information_symmetry() {
let joint = Array2::from_shape_vec((2, 3), vec![0.1, 0.15, 0.05, 0.2, 0.3, 0.2])
.expect("array creation failed");
let mi = mutual_information(&joint).expect("mi failed");
let joint_t = joint.t().to_owned();
let mi_t = mutual_information(&joint_t).expect("mi failed");
assert!((mi - mi_t).abs() < EPSILON);
}
#[test]
fn test_mi_bounds() {
let joint = Array2::from_shape_vec(
(3, 3),
vec![0.2, 0.05, 0.05, 0.05, 0.2, 0.05, 0.05, 0.05, 0.2],
)
.expect("array creation failed");
let mi = mutual_information(&joint).expect("mi failed");
let marginal_x: Array1<f64> = joint.sum_axis(Axis(1));
let marginal_y: Array1<f64> = joint.sum_axis(Axis(0));
let h_x = shannon_entropy(&marginal_x, LogBase::Bits).expect("entropy failed");
let h_y = shannon_entropy(&marginal_y, LogBase::Bits).expect("entropy failed");
assert!(mi <= h_x + EPSILON);
assert!(mi <= h_y + EPSILON);
}
}