pub mod coding;
pub mod divergence;
pub mod entropy;
pub mod mutual_information;
pub use coding::{
aic, bic, binary_symmetric_channel_capacity, entropy_rate, mdl, rate_distortion_binary,
};
pub use divergence::{
bhattacharyya_coefficient, bhattacharyya_distance, chi_squared_divergence, hellinger_distance,
jensen_shannon_divergence, kl_divergence, total_variation_distance,
};
pub use entropy::{
conditional_entropy, cross_entropy, differential_entropy, joint_entropy, renyi_entropy,
shannon_entropy, LogBase,
};
pub use mutual_information::{
adjusted_mutual_information, conditional_mutual_information, mutual_information,
normalized_mutual_information, pointwise_mutual_information, variation_of_information,
NormalizationType,
};
use crate::error::NumRs2Error;
use std::fmt;
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum InfoTheoryError {
#[error("Invalid probability distribution: {0}")]
InvalidDistribution(String),
#[error("Empty input: {0}")]
EmptyInput(String),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("Numerical error: {0}")]
NumericalError(String),
#[error("Dimension mismatch: {0}")]
DimensionMismatch(String),
#[error("Computation error: {0}")]
ComputationError(String),
}
impl From<InfoTheoryError> for NumRs2Error {
fn from(err: InfoTheoryError) -> Self {
match err {
InfoTheoryError::InvalidDistribution(msg) => NumRs2Error::ValueError(msg),
InfoTheoryError::EmptyInput(msg) => NumRs2Error::InvalidInput(msg),
InfoTheoryError::InvalidParameter(msg) => NumRs2Error::ValueError(msg),
InfoTheoryError::NumericalError(msg) => NumRs2Error::NumericalError(msg),
InfoTheoryError::DimensionMismatch(msg) => NumRs2Error::DimensionMismatch(msg),
InfoTheoryError::ComputationError(msg) => NumRs2Error::ComputationError(msg),
}
}
}
pub type InfoTheoryResult<T> = Result<T, InfoTheoryError>;
pub(crate) fn validate_distribution(
probs: &scirs2_core::ndarray::ArrayView1<f64>,
) -> InfoTheoryResult<()> {
if probs.is_empty() {
return Err(InfoTheoryError::EmptyInput(
"Probability array is empty".to_string(),
));
}
for &p in probs.iter() {
if p < 0.0 {
return Err(InfoTheoryError::InvalidDistribution(format!(
"Negative probability: {}",
p
)));
}
if !p.is_finite() {
return Err(InfoTheoryError::InvalidDistribution(format!(
"Non-finite probability: {}",
p
)));
}
}
Ok(())
}
pub(crate) fn normalize_distribution(
probs: &scirs2_core::ndarray::Array1<f64>,
) -> InfoTheoryResult<scirs2_core::ndarray::Array1<f64>> {
use scirs2_core::ndarray::Array1;
let sum: f64 = probs.iter().sum();
if sum <= 0.0 || !sum.is_finite() {
return Err(InfoTheoryError::InvalidDistribution(format!(
"Invalid distribution sum: {}",
sum
)));
}
if (sum - 1.0).abs() < 1e-10 {
Ok(probs.clone())
} else {
Ok(probs / sum)
}
}
pub(crate) fn xlogy(x: f64, y: f64) -> f64 {
if x == 0.0 {
0.0
} else if y > 0.0 {
x * y.ln()
} else {
f64::NEG_INFINITY
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_validate_distribution() {
let valid = Array1::from_vec(vec![0.5, 0.3, 0.2]);
assert!(validate_distribution(&valid.view()).is_ok());
let negative = Array1::from_vec(vec![0.5, -0.1, 0.6]);
assert!(validate_distribution(&negative.view()).is_err());
let infinite = Array1::from_vec(vec![0.5, f64::INFINITY, 0.2]);
assert!(validate_distribution(&infinite.view()).is_err());
let empty: Array1<f64> = Array1::from_vec(vec![]);
assert!(validate_distribution(&empty.view()).is_err());
}
#[test]
fn test_normalize_distribution() {
let unnormalized = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let normalized = normalize_distribution(&unnormalized).expect("normalization failed");
let sum: f64 = normalized.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
assert!((normalized[0] - 1.0 / 6.0).abs() < 1e-10);
assert!((normalized[1] - 2.0 / 6.0).abs() < 1e-10);
assert!((normalized[2] - 3.0 / 6.0).abs() < 1e-10);
let already_normalized = Array1::from_vec(vec![0.25, 0.25, 0.5]);
let result = normalize_distribution(&already_normalized).expect("normalization failed");
let sum2: f64 = result.iter().sum();
assert!((sum2 - 1.0).abs() < 1e-10);
}
#[test]
fn test_xlogy() {
assert_eq!(xlogy(0.0, 0.0), 0.0);
assert_eq!(xlogy(0.0, 1.0), 0.0);
let result = xlogy(2.0, std::f64::consts::E);
assert!((result - 2.0).abs() < 1e-10);
let result2 = xlogy(1.0, 0.5);
assert!(result2 < 0.0);
}
}