use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Dimension, Ix1, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::collections::HashMap;
use crate::error::{MetricsError, Result};
pub fn entropy<F, S, D>(probabilities: &ArrayBase<S, D>) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + SimdUnifiedOps,
S: Data<Elem = F>,
D: Dimension,
{
if probabilities.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty probability distribution".to_string(),
));
}
let mut sum = F::zero();
for &p in probabilities.iter() {
if p < F::zero() {
return Err(MetricsError::InvalidInput(
"Probabilities must be non-negative".to_string(),
));
}
sum = sum + p;
}
let one = F::one();
let epsilon = NumCast::from(1e-6)
.ok_or_else(|| MetricsError::InvalidInput("Failed to convert epsilon".to_string()))?;
if (sum - one).abs() > epsilon {
return Err(MetricsError::InvalidInput(format!(
"Probabilities must sum to 1.0, got {sum:?}"
)));
}
let mut h = F::zero();
let ln2 = NumCast::from(std::f64::consts::LN_2)
.ok_or_else(|| MetricsError::InvalidInput("Failed to convert ln(2)".to_string()))?;
for &p in probabilities.iter() {
if p > F::zero() {
let log_p = p.ln() / ln2;
h = h - p * log_p;
}
}
Ok(h)
}
pub fn joint_entropy<F, S>(joint_probabilities: &ArrayBase<S, Ix2>) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + SimdUnifiedOps,
S: Data<Elem = F>,
{
if joint_probabilities.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty joint probability distribution".to_string(),
));
}
let mut sum = F::zero();
for &p in joint_probabilities.iter() {
if p < F::zero() {
return Err(MetricsError::InvalidInput(
"Probabilities must be non-negative".to_string(),
));
}
sum = sum + p;
}
let one = F::one();
let epsilon = NumCast::from(1e-6)
.ok_or_else(|| MetricsError::InvalidInput("Failed to convert epsilon".to_string()))?;
if (sum - one).abs() > epsilon {
return Err(MetricsError::InvalidInput(format!(
"Joint probabilities must sum to 1.0, got {sum:?}"
)));
}
let mut h = F::zero();
let ln2 = NumCast::from(std::f64::consts::LN_2)
.ok_or_else(|| MetricsError::InvalidInput("Failed to convert ln(2)".to_string()))?;
for &p in joint_probabilities.iter() {
if p > F::zero() {
let log_p = p.ln() / ln2;
h = h - p * log_p;
}
}
Ok(h)
}
pub fn conditional_entropy<F, S1, S2>(
joint_probabilities: &ArrayBase<S1, Ix2>,
marginal_x: &ArrayBase<S2, Ix1>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + SimdUnifiedOps,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
{
let h_xy = joint_entropy(joint_probabilities)?;
let h_x = entropy(marginal_x)?;
Ok(h_xy - h_x)
}
pub fn kl_divergence<F, S1, S2, D1, D2>(p: &ArrayBase<S1, D1>, q: &ArrayBase<S2, D2>) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + SimdUnifiedOps,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
{
if p.shape() != q.shape() {
return Err(MetricsError::InvalidInput(format!(
"Distributions must have the same shape: {:?} vs {:?}",
p.shape(),
q.shape()
)));
}
if p.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty distributions provided".to_string(),
));
}
let mut kl = F::zero();
let ln2 = NumCast::from(std::f64::consts::LN_2)
.ok_or_else(|| MetricsError::InvalidInput("Failed to convert ln(2)".to_string()))?;
let epsilon = NumCast::from(1e-10)
.ok_or_else(|| MetricsError::InvalidInput("Failed to convert epsilon".to_string()))?;
for (p_val, q_val) in p.iter().zip(q.iter()) {
if *p_val < F::zero() || *q_val < F::zero() {
return Err(MetricsError::InvalidInput(
"Probabilities must be non-negative".to_string(),
));
}
if *p_val > F::zero() {
if *q_val <= epsilon {
return Err(MetricsError::InvalidInput(
"Q has zero probability where P is non-zero (infinite divergence)".to_string(),
));
}
let ratio = *p_val / *q_val;
let log_ratio = ratio.ln() / ln2;
kl = kl + *p_val * log_ratio;
}
}
Ok(kl)
}
pub fn js_divergence<F, S1, S2, D1, D2>(p: &ArrayBase<S1, D1>, q: &ArrayBase<S2, D2>) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + SimdUnifiedOps,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
{
if p.shape() != q.shape() {
return Err(MetricsError::InvalidInput(format!(
"Distributions must have the same shape: {:?} vs {:?}",
p.shape(),
q.shape()
)));
}
let half = NumCast::from(0.5)
.ok_or_else(|| MetricsError::InvalidInput("Failed to convert 0.5".to_string()))?;
let m: Array1<F> = p
.iter()
.zip(q.iter())
.map(|(p_val, q_val)| (*p_val + *q_val) * half)
.collect();
let kl_pm = kl_divergence(p, &m)?;
let kl_qm = kl_divergence(q, &m)?;
Ok((kl_pm + kl_qm) * half)
}
pub fn mutual_information<F, S1, S2, S3>(
joint_probabilities: &ArrayBase<S1, Ix2>,
marginal_x: &ArrayBase<S2, Ix1>,
marginal_y: &ArrayBase<S3, Ix1>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + SimdUnifiedOps,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
S3: Data<Elem = F>,
{
let h_x = entropy(marginal_x)?;
let h_y = entropy(marginal_y)?;
let h_xy = joint_entropy(joint_probabilities)?;
Ok(h_x + h_y - h_xy)
}
pub fn mutual_information_from_labels<T, S1, S2, D1, D2>(
labels_x: &ArrayBase<S1, D1>,
labels_y: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: std::hash::Hash + std::cmp::Eq + Copy + std::fmt::Debug,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if labels_x.len() != labels_y.len() {
return Err(MetricsError::InvalidInput(format!(
"Label arrays must have the same length: {} vs {}",
labels_x.len(),
labels_y.len()
)));
}
let n = labels_x.len() as f64;
if n == 0.0 {
return Err(MetricsError::InvalidInput(
"Empty label arrays provided".to_string(),
));
}
let mut joint_counts: HashMap<(T, T), usize> = HashMap::new();
let mut marginal_x_counts: HashMap<T, usize> = HashMap::new();
let mut marginal_y_counts: HashMap<T, usize> = HashMap::new();
for (x, y) in labels_x.iter().zip(labels_y.iter()) {
*joint_counts.entry((*x, *y)).or_insert(0) += 1;
*marginal_x_counts.entry(*x).or_insert(0) += 1;
*marginal_y_counts.entry(*y).or_insert(0) += 1;
}
let mut mi = 0.0;
let ln2 = std::f64::consts::LN_2;
for ((x, y), &count_xy) in &joint_counts {
let p_xy = count_xy as f64 / n;
let count_x = marginal_x_counts
.get(x)
.ok_or_else(|| MetricsError::InvalidInput(format!("Missing count for x={x:?}")))?;
let count_y = marginal_y_counts
.get(y)
.ok_or_else(|| MetricsError::InvalidInput(format!("Missing count for y={y:?}")))?;
let p_x = *count_x as f64 / n;
let p_y = *count_y as f64 / n;
if p_xy > 0.0 && p_x > 0.0 && p_y > 0.0 {
let ratio = p_xy / (p_x * p_y);
mi += p_xy * (ratio.ln() / ln2);
}
}
Ok(mi)
}
pub fn normalized_mutual_information<T, S1, S2, D1, D2>(
labels_x: &ArrayBase<S1, D1>,
labels_y: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: std::hash::Hash + std::cmp::Eq + Copy + std::fmt::Debug,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
let mi = mutual_information_from_labels(labels_x, labels_y)?;
let n = labels_x.len() as f64;
let mut counts_x: HashMap<T, usize> = HashMap::new();
let mut counts_y: HashMap<T, usize> = HashMap::new();
for x in labels_x.iter() {
*counts_x.entry(*x).or_insert(0) += 1;
}
for y in labels_y.iter() {
*counts_y.entry(*y).or_insert(0) += 1;
}
let ln2 = std::f64::consts::LN_2;
let mut h_x = 0.0;
for &count in counts_x.values() {
if count > 0 {
let p = count as f64 / n;
h_x -= p * (p.ln() / ln2);
}
}
let mut h_y = 0.0;
for &count in counts_y.values() {
if count > 0 {
let p = count as f64 / n;
h_y -= p * (p.ln() / ln2);
}
}
if h_x <= 0.0 || h_y <= 0.0 {
return Ok(0.0);
}
Ok(mi / (h_x * h_y).sqrt())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_entropy_uniform() {
let uniform = array![0.25, 0.25, 0.25, 0.25];
let h: f64 = entropy(&uniform).expect("Failed to compute entropy");
assert_relative_eq!(h, 2.0, epsilon = 1e-10);
}
#[test]
fn test_entropy_deterministic() {
let deterministic = array![1.0, 0.0, 0.0, 0.0];
let h: f64 = entropy(&deterministic).expect("Failed to compute entropy");
assert_relative_eq!(h, 0.0, epsilon = 1e-10);
}
#[test]
fn test_kl_divergence_identical() {
let p = array![0.5, 0.5];
let q = array![0.5, 0.5];
let kl: f64 = kl_divergence(&p, &q).expect("Failed to compute KL divergence");
assert_relative_eq!(kl, 0.0, epsilon = 1e-10);
}
#[test]
fn test_js_divergence_identical() {
let p = array![0.5, 0.5];
let q = array![0.5, 0.5];
let js: f64 = js_divergence(&p, &q).expect("Failed to compute JS divergence");
assert_relative_eq!(js, 0.0, epsilon = 1e-10);
}
#[test]
fn test_js_divergence_symmetric() {
let p = array![0.6, 0.4];
let q = array![0.3, 0.7];
let js_pq: f64 = js_divergence(&p, &q).expect("Failed to compute JS divergence");
let js_qp: f64 = js_divergence(&q, &p).expect("Failed to compute JS divergence");
assert_relative_eq!(js_pq, js_qp, epsilon = 1e-10);
}
#[test]
fn test_mutual_information_independent() {
let x = array![0, 0, 1, 1, 0, 0, 1, 1];
let y = array![0, 1, 0, 1, 0, 1, 0, 1];
let mi = mutual_information_from_labels(&x, &y).expect("Failed to compute MI");
assert_relative_eq!(mi, 0.0, epsilon = 1e-10);
}
#[test]
fn test_mutual_information_identical() {
let x = array![0, 1, 0, 1, 0, 1];
let y = x.clone();
let mi = mutual_information_from_labels(&x, &y).expect("Failed to compute MI");
assert_relative_eq!(mi, 1.0, epsilon = 1e-10);
}
#[test]
fn test_nmi_bounds() {
let x = array![0, 1, 2, 0, 1, 2];
let y = array![0, 0, 1, 1, 2, 2];
let nmi = normalized_mutual_information(&x, &y).expect("Failed to compute NMI");
assert!((0.0..=1.0).contains(&nmi));
}
#[test]
fn test_entropy_invalid_sum() {
let invalid = array![0.3, 0.3, 0.3]; let result: std::result::Result<f64, _> = entropy(&invalid);
assert!(result.is_err());
}
#[test]
fn test_entropy_negative_probability() {
let invalid = array![0.5, -0.3, 0.8];
let result: std::result::Result<f64, _> = entropy(&invalid);
assert!(result.is_err());
}
#[test]
fn test_kl_divergence_infinite() {
let p = array![0.5, 0.5];
let q = array![1.0, 0.0]; let result: std::result::Result<f64, _> = kl_divergence(&p, &q);
assert!(result.is_err());
}
}