use super::{validate_distribution, xlogy, InfoTheoryError, InfoTheoryResult};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{Array1, ArrayView1};
pub fn kl_divergence(p: &Array1<f64>, q: &Array1<f64>) -> Result<f64, NumRs2Error> {
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Probability arrays must have same length: {} vs {}",
p.len(),
q.len()
)));
}
validate_distribution(&p.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
validate_distribution(&q.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let p_norm =
super::normalize_distribution(p).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let q_norm =
super::normalize_distribution(q).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let mut divergence = 0.0;
for i in 0..p_norm.len() {
let pi = p_norm[i];
let qi = q_norm[i];
if pi > 0.0 {
if qi == 0.0 {
return Ok(f64::INFINITY);
}
divergence += pi * (pi / qi).ln();
}
}
Ok(divergence)
}
pub fn jensen_shannon_divergence(p: &Array1<f64>, q: &Array1<f64>) -> Result<f64, NumRs2Error> {
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Probability arrays must have same length: {} vs {}",
p.len(),
q.len()
)));
}
let p_norm =
super::normalize_distribution(p).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let q_norm =
super::normalize_distribution(q).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let m = (&p_norm + &q_norm) / 2.0;
let d_pm = kl_divergence(&p_norm, &m)?;
let d_qm = kl_divergence(&q_norm, &m)?;
Ok((d_pm + d_qm) / 2.0)
}
pub fn bhattacharyya_coefficient(p: &Array1<f64>, q: &Array1<f64>) -> Result<f64, NumRs2Error> {
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Probability arrays must have same length: {} vs {}",
p.len(),
q.len()
)));
}
validate_distribution(&p.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
validate_distribution(&q.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let p_norm =
super::normalize_distribution(p).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let q_norm =
super::normalize_distribution(q).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let bc: f64 = p_norm
.iter()
.zip(q_norm.iter())
.map(|(&pi, &qi)| (pi * qi).sqrt())
.sum();
Ok(bc)
}
pub fn bhattacharyya_distance(p: &Array1<f64>, q: &Array1<f64>) -> Result<f64, NumRs2Error> {
let bc = bhattacharyya_coefficient(p, q)?;
if bc <= 0.0 {
Ok(f64::INFINITY)
} else {
Ok(-bc.ln())
}
}
pub fn hellinger_distance(p: &Array1<f64>, q: &Array1<f64>) -> Result<f64, NumRs2Error> {
let bc = bhattacharyya_coefficient(p, q)?;
let h_squared = (1.0 - bc).max(0.0); Ok(h_squared.sqrt())
}
pub fn total_variation_distance(p: &Array1<f64>, q: &Array1<f64>) -> Result<f64, NumRs2Error> {
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Probability arrays must have same length: {} vs {}",
p.len(),
q.len()
)));
}
validate_distribution(&p.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
validate_distribution(&q.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let p_norm =
super::normalize_distribution(p).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let q_norm =
super::normalize_distribution(q).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let tv: f64 = p_norm
.iter()
.zip(q_norm.iter())
.map(|(&pi, &qi)| (pi - qi).abs())
.sum::<f64>()
/ 2.0;
Ok(tv)
}
pub fn chi_squared_divergence(p: &Array1<f64>, q: &Array1<f64>) -> Result<f64, NumRs2Error> {
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Probability arrays must have same length: {} vs {}",
p.len(),
q.len()
)));
}
validate_distribution(&p.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
validate_distribution(&q.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let p_norm =
super::normalize_distribution(p).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let q_norm =
super::normalize_distribution(q).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let mut chi2 = 0.0;
for i in 0..p_norm.len() {
let pi = p_norm[i];
let qi = q_norm[i];
if qi == 0.0 && pi > 0.0 {
return Ok(f64::INFINITY);
}
if qi > 0.0 {
let diff = pi - qi;
chi2 += (diff * diff) / qi;
}
}
Ok(chi2)
}
pub fn f_divergence<F>(p: &Array1<f64>, q: &Array1<f64>, f: F) -> Result<f64, NumRs2Error>
where
F: Fn(f64) -> f64,
{
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Probability arrays must have same length: {} vs {}",
p.len(),
q.len()
)));
}
validate_distribution(&p.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
validate_distribution(&q.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let p_norm =
super::normalize_distribution(p).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let q_norm =
super::normalize_distribution(q).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let mut divergence = 0.0;
for i in 0..p_norm.len() {
let pi = p_norm[i];
let qi = q_norm[i];
if qi > 0.0 {
let ratio = pi / qi;
divergence += qi * f(ratio);
} else if pi > 0.0 {
return Ok(f64::INFINITY);
}
}
Ok(divergence)
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
#[test]
fn test_kl_divergence_identical() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let d = kl_divergence(&p, &q).expect("kl divergence failed");
assert!(d.abs() < EPSILON); }
#[test]
fn test_kl_divergence_different() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.6, 0.4]);
let d = kl_divergence(&p, &q).expect("kl divergence failed");
assert!(d > 0.0);
let d_reverse = kl_divergence(&q, &p).expect("kl divergence failed");
assert!((d - d_reverse).abs() > 1e-6);
}
#[test]
fn test_kl_divergence_disjoint() {
let p = Array1::from_vec(vec![1.0, 0.0]);
let q = Array1::from_vec(vec![0.0, 1.0]);
let d = kl_divergence(&p, &q).expect("kl divergence failed");
assert!(d.is_infinite()); }
#[test]
fn test_jensen_shannon_divergence() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let jsd = jensen_shannon_divergence(&p, &q).expect("jsd failed");
assert!(jsd.abs() < EPSILON);
let q2 = Array1::from_vec(vec![0.6, 0.4]);
let jsd2 = jensen_shannon_divergence(&p, &q2).expect("jsd failed");
assert!(jsd2 > 0.0);
assert!(jsd2 < 2_f64.ln());
let jsd_reverse = jensen_shannon_divergence(&q2, &p).expect("jsd failed");
assert!((jsd2 - jsd_reverse).abs() < EPSILON);
}
#[test]
fn test_bhattacharyya_coefficient() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let bc = bhattacharyya_coefficient(&p, &q).expect("bc failed");
assert!((bc - 1.0).abs() < EPSILON);
let p2 = Array1::from_vec(vec![1.0, 0.0]);
let q2 = Array1::from_vec(vec![0.0, 1.0]);
let bc2 = bhattacharyya_coefficient(&p2, &q2).expect("bc failed");
assert!(bc2.abs() < EPSILON);
let bc_reverse = bhattacharyya_coefficient(&q2, &p2).expect("bc failed");
assert!((bc2 - bc_reverse).abs() < EPSILON);
}
#[test]
fn test_bhattacharyya_distance() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let d = bhattacharyya_distance(&p, &q).expect("bhattacharyya distance failed");
assert!(d.abs() < EPSILON);
let p2 = Array1::from_vec(vec![1.0, 0.0]);
let q2 = Array1::from_vec(vec![0.0, 1.0]);
let d2 = bhattacharyya_distance(&p2, &q2).expect("bhattacharyya distance failed");
assert!(d2.is_infinite()); }
#[test]
fn test_hellinger_distance() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let h = hellinger_distance(&p, &q).expect("hellinger distance failed");
assert!(h.abs() < EPSILON);
let p2 = Array1::from_vec(vec![1.0, 0.0]);
let q2 = Array1::from_vec(vec![0.0, 1.0]);
let h2 = hellinger_distance(&p2, &q2).expect("hellinger distance failed");
assert!((h2 - 1.0).abs() < EPSILON);
let h_reverse = hellinger_distance(&q2, &p2).expect("hellinger distance failed");
assert!((h2 - h_reverse).abs() < EPSILON);
let p3 = Array1::from_vec(vec![0.7, 0.3]);
let q3 = Array1::from_vec(vec![0.2, 0.8]);
let h3 = hellinger_distance(&p3, &q3).expect("hellinger distance failed");
assert!((0.0..=1.0).contains(&h3));
}
#[test]
fn test_total_variation_distance() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let tv = total_variation_distance(&p, &q).expect("tv distance failed");
assert!(tv.abs() < EPSILON);
let p2 = Array1::from_vec(vec![1.0, 0.0]);
let q2 = Array1::from_vec(vec![0.0, 1.0]);
let tv2 = total_variation_distance(&p2, &q2).expect("tv distance failed");
assert!((tv2 - 1.0).abs() < EPSILON);
let tv_reverse = total_variation_distance(&q2, &p2).expect("tv distance failed");
assert!((tv2 - tv_reverse).abs() < EPSILON);
let p3 = Array1::from_vec(vec![0.7, 0.3]);
let q3 = Array1::from_vec(vec![0.2, 0.8]);
let tv3 = total_variation_distance(&p3, &q3).expect("tv distance failed");
assert!((0.0..=1.0).contains(&tv3));
}
#[test]
fn test_chi_squared_divergence() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let chi2 = chi_squared_divergence(&p, &q).expect("chi2 divergence failed");
assert!(chi2.abs() < EPSILON);
let p2 = Array1::from_vec(vec![0.6, 0.4]);
let q2 = Array1::from_vec(vec![0.5, 0.5]);
let chi2_2 = chi_squared_divergence(&p2, &q2).expect("chi2 divergence failed");
assert!(chi2_2 > 0.0);
assert!((chi2_2 - 0.04).abs() < EPSILON);
}
#[test]
fn test_f_divergence_kl() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.6, 0.4]);
let d_kl = f_divergence(&p, &q, |t| if t > 0.0 { t * t.ln() } else { 0.0 })
.expect("f-divergence failed");
let d_kl_direct = kl_divergence(&p, &q).expect("kl divergence failed");
assert!((d_kl - d_kl_direct).abs() < EPSILON);
}
#[test]
fn test_f_divergence_chi_squared() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.6, 0.4]);
let d_chi2 = f_divergence(&p, &q, |t| (t - 1.0).powi(2)).expect("f-divergence failed");
let d_chi2_direct = chi_squared_divergence(&p, &q).expect("chi2 divergence failed");
assert!((d_chi2 - d_chi2_direct).abs() < EPSILON);
}
#[test]
fn test_divergence_errors() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.3, 0.3, 0.4]);
assert!(kl_divergence(&p, &q).is_err());
assert!(jensen_shannon_divergence(&p, &q).is_err());
assert!(hellinger_distance(&p, &q).is_err());
let negative = Array1::from_vec(vec![0.5, -0.1]);
let valid = Array1::from_vec(vec![0.5, 0.5]);
assert!(kl_divergence(&negative, &valid).is_err());
}
}