use super::{validate_distribution, xlogy, InfoTheoryError, InfoTheoryResult};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use std::f64::consts::E;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LogBase {
Bits,
Nats,
Hartleys,
Custom(f64),
}
impl LogBase {
fn convert(&self, ln_value: f64) -> f64 {
match self {
LogBase::Bits => ln_value / 2_f64.ln(),
LogBase::Nats => ln_value,
LogBase::Hartleys => ln_value / 10_f64.ln(),
LogBase::Custom(base) => {
if *base <= 0.0 || *base == 1.0 {
f64::NAN
} else {
ln_value / base.ln()
}
}
}
}
}
pub fn shannon_entropy(probs: &Array1<f64>, base: LogBase) -> Result<f64, NumRs2Error> {
validate_distribution(&probs.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let normalized =
super::normalize_distribution(probs).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let mut entropy = 0.0;
for &p in normalized.iter() {
entropy -= xlogy(p, p);
}
Ok(base.convert(entropy))
}
pub fn joint_entropy(joint_probs: &Array2<f64>, base: LogBase) -> Result<f64, NumRs2Error> {
if joint_probs.is_empty() {
return Err(NumRs2Error::InvalidInput(
"Joint probability array is empty".to_string(),
));
}
let flat = joint_probs.iter().cloned().collect::<Vec<_>>();
let flat_array = Array1::from_vec(flat);
validate_distribution(&flat_array.view())
.map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let normalized = super::normalize_distribution(&flat_array)
.map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let mut entropy = 0.0;
for &p in normalized.iter() {
entropy -= xlogy(p, p);
}
Ok(base.convert(entropy))
}
pub fn conditional_entropy(joint_probs: &Array2<f64>, base: LogBase) -> Result<f64, NumRs2Error> {
let h_xy = joint_entropy(joint_probs, base)?;
let marginal_y: Array1<f64> = joint_probs.sum_axis(Axis(0));
let h_y = shannon_entropy(&marginal_y, base)?;
Ok(h_xy - h_y)
}
pub fn cross_entropy(p: &Array1<f64>, q: &Array1<f64>, base: LogBase) -> Result<f64, NumRs2Error> {
if p.len() != q.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Probability arrays must have same length: {} vs {}",
p.len(),
q.len()
)));
}
validate_distribution(&p.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
validate_distribution(&q.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let p_norm =
super::normalize_distribution(p).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let q_norm =
super::normalize_distribution(q).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let mut cross_ent = 0.0;
for i in 0..p_norm.len() {
cross_ent -= xlogy(p_norm[i], q_norm[i]);
}
Ok(base.convert(cross_ent))
}
pub fn renyi_entropy(probs: &Array1<f64>, alpha: f64, base: LogBase) -> Result<f64, NumRs2Error> {
if alpha < 0.0 {
return Err(NumRs2Error::ValueError(format!(
"Alpha must be non-negative, got {}",
alpha
)));
}
if (alpha - 1.0).abs() < 1e-10 {
return shannon_entropy(probs, base);
}
validate_distribution(&probs.view()).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
let normalized =
super::normalize_distribution(probs).map_err(|e| NumRs2Error::ValueError(e.to_string()))?;
if alpha == 0.0 {
let n = normalized.iter().filter(|&&p| p > 0.0).count() as f64;
return Ok(base.convert(n.ln()));
}
if alpha.is_infinite() {
let max_p = normalized.iter().cloned().fold(0.0, f64::max);
if max_p <= 0.0 {
return Err(NumRs2Error::NumericalError(
"All probabilities are zero".to_string(),
));
}
return Ok(base.convert(-max_p.ln()));
}
let sum_p_alpha: f64 = normalized.iter().map(|&p| p.powf(alpha)).sum();
if sum_p_alpha <= 0.0 || !sum_p_alpha.is_finite() {
return Err(NumRs2Error::NumericalError(format!(
"Invalid sum of p^α: {}",
sum_p_alpha
)));
}
let renyi = sum_p_alpha.ln() / (1.0 - alpha);
Ok(base.convert(renyi))
}
pub fn differential_entropy(
pdf_values: &Array1<f64>,
dx: f64,
base: LogBase,
) -> Result<f64, NumRs2Error> {
if pdf_values.is_empty() {
return Err(NumRs2Error::InvalidInput(
"PDF values array is empty".to_string(),
));
}
if dx <= 0.0 {
return Err(NumRs2Error::ValueError(format!(
"dx must be positive, got {}",
dx
)));
}
for &f in pdf_values.iter() {
if f < 0.0 {
return Err(NumRs2Error::ValueError(format!(
"PDF value cannot be negative: {}",
f
)));
}
if !f.is_finite() {
return Err(NumRs2Error::ValueError(format!(
"PDF value must be finite: {}",
f
)));
}
}
let mut entropy = 0.0;
for &f in pdf_values.iter() {
entropy -= xlogy(f, f) * dx;
}
Ok(base.convert(entropy))
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
#[test]
fn test_log_base_conversion() {
let ln_val = 2_f64.ln();
assert!((LogBase::Bits.convert(ln_val) - 1.0).abs() < EPSILON);
assert!((LogBase::Nats.convert(ln_val) - 2_f64.ln()).abs() < EPSILON);
assert!((LogBase::Hartleys.convert(ln_val) - 2_f64.log10()).abs() < EPSILON);
assert!((LogBase::Custom(2.0).convert(ln_val) - 1.0).abs() < EPSILON);
}
#[test]
fn test_shannon_entropy_uniform() {
let uniform = Array1::from_vec(vec![0.25, 0.25, 0.25, 0.25]);
let h = shannon_entropy(&uniform, LogBase::Bits).expect("entropy failed");
assert!((h - 2.0).abs() < EPSILON); }
#[test]
fn test_shannon_entropy_deterministic() {
let deterministic = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
let h = shannon_entropy(&deterministic, LogBase::Bits).expect("entropy failed");
assert!(h.abs() < EPSILON);
}
#[test]
fn test_shannon_entropy_binary() {
let p = 0.3;
let probs = Array1::from_vec(vec![p, 1.0 - p]);
let h = shannon_entropy(&probs, LogBase::Bits).expect("entropy failed");
let expected = -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
assert!((h - expected).abs() < EPSILON);
}
#[test]
fn test_shannon_entropy_normalization() {
let unnormalized = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0]);
let h = shannon_entropy(&unnormalized, LogBase::Bits).expect("entropy failed");
assert!((h - 2.0).abs() < EPSILON); }
#[test]
fn test_joint_entropy() {
let joint = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.25, 0.25])
.expect("array creation failed");
let h = joint_entropy(&joint, LogBase::Bits).expect("entropy failed");
assert!((h - 2.0).abs() < EPSILON); }
#[test]
fn test_conditional_entropy_independent() {
let joint = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.25, 0.25])
.expect("array creation failed");
let h_cond = conditional_entropy(&joint, LogBase::Bits).expect("entropy failed");
assert!((h_cond - 1.0).abs() < EPSILON); }
#[test]
fn test_conditional_entropy_dependent() {
let joint = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 0.5])
.expect("array creation failed");
let h_cond = conditional_entropy(&joint, LogBase::Bits).expect("entropy failed");
assert!(h_cond.abs() < EPSILON); }
#[test]
fn test_cross_entropy() {
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.5, 0.5]);
let h_self = cross_entropy(&p, &q, LogBase::Bits).expect("cross entropy failed");
let h_shannon = shannon_entropy(&p, LogBase::Bits).expect("shannon entropy failed");
assert!((h_self - h_shannon).abs() < EPSILON);
let q2 = Array1::from_vec(vec![0.4, 0.6]);
let h_cross = cross_entropy(&p, &q2, LogBase::Bits).expect("cross entropy failed");
assert!(h_cross > h_shannon);
}
#[test]
fn test_renyi_entropy_alpha_2() {
let probs = Array1::from_vec(vec![0.5, 0.3, 0.2]);
let h2 = renyi_entropy(&probs, 2.0, LogBase::Bits).expect("renyi entropy failed");
let sum_p2: f64 = probs.iter().map(|&p| p * p).sum();
let expected = -sum_p2.log2();
assert!((h2 - expected).abs() < EPSILON);
}
#[test]
fn test_renyi_entropy_alpha_0() {
let probs = Array1::from_vec(vec![0.5, 0.3, 0.2]);
let h0 = renyi_entropy(&probs, 0.0, LogBase::Bits).expect("renyi entropy failed");
let expected = (3.0_f64).log2();
assert!((h0 - expected).abs() < EPSILON);
}
#[test]
fn test_renyi_entropy_alpha_infinity() {
let probs = Array1::from_vec(vec![0.5, 0.3, 0.2]);
let h_inf =
renyi_entropy(&probs, f64::INFINITY, LogBase::Bits).expect("renyi entropy failed");
let expected = -(0.5_f64).log2();
assert!((h_inf - expected).abs() < EPSILON);
}
#[test]
fn test_renyi_entropy_limit_to_shannon() {
let probs = Array1::from_vec(vec![0.5, 0.3, 0.2]);
let h_shannon = shannon_entropy(&probs, LogBase::Bits).expect("shannon entropy failed");
let h_renyi_low =
renyi_entropy(&probs, 0.9999, LogBase::Bits).expect("renyi entropy failed");
let h_renyi_high =
renyi_entropy(&probs, 1.0001, LogBase::Bits).expect("renyi entropy failed");
assert!((h_renyi_low - h_shannon).abs() < 0.01);
assert!((h_renyi_high - h_shannon).abs() < 0.01);
}
#[test]
fn test_differential_entropy_uniform() {
let n = 1000;
let pdf = Array1::from_elem(n, 1.0); let dx = 1.0 / (n as f64);
let h = differential_entropy(&pdf, dx, LogBase::Nats).expect("differential entropy failed");
assert!(h.abs() < 0.01); }
#[test]
fn test_differential_entropy_gaussian() {
let n = 1000;
let sigma = 1.0;
let x_vals: Vec<f64> = (0..n)
.map(|i| -5.0 + 10.0 * (i as f64) / (n as f64))
.collect();
let pdf: Array1<f64> = Array1::from_vec(
x_vals
.iter()
.map(|&x| {
let z = x / sigma;
(1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt())) * (-0.5 * z * z).exp()
})
.collect(),
);
let dx = 10.0 / (n as f64);
let h = differential_entropy(&pdf, dx, LogBase::Nats).expect("differential entropy failed");
let expected = 0.5 * (2.0 * std::f64::consts::PI * E).ln();
assert!((h - expected).abs() < 0.1); }
#[test]
fn test_entropy_errors() {
let empty: Array1<f64> = Array1::from_vec(vec![]);
assert!(shannon_entropy(&empty, LogBase::Bits).is_err());
let negative = Array1::from_vec(vec![0.5, -0.1, 0.6]);
assert!(shannon_entropy(&negative, LogBase::Bits).is_err());
let p = Array1::from_vec(vec![0.5, 0.5]);
let q = Array1::from_vec(vec![0.3, 0.3, 0.4]);
assert!(cross_entropy(&p, &q, LogBase::Bits).is_err());
let probs = Array1::from_vec(vec![0.5, 0.5]);
assert!(renyi_entropy(&probs, -1.0, LogBase::Bits).is_err());
let pdf = Array1::from_vec(vec![1.0, 1.0]);
assert!(differential_entropy(&pdf, -0.1, LogBase::Bits).is_err());
}
}