use crate::error::{SpecialError, SpecialResult};
use crate::validation::{check_finite, check_non_negative};
use scirs2_core::ndarray::{Array1, ArrayView1, ArrayViewMut1};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use std::fmt::{Debug, Display};
#[allow(dead_code)]
pub fn entr<T>(x: T) -> T
where
T: Float + FromPrimitive + Zero,
{
if x.is_zero() {
T::zero()
} else if x < T::zero() {
-T::infinity()
} else {
-x * x.ln()
}
}
#[allow(dead_code)]
pub fn rel_entr<T>(x: T, y: T) -> T
where
T: Float + FromPrimitive + Zero,
{
if x.is_zero() {
T::zero()
} else if y.is_zero() {
T::infinity()
} else {
x * (x / y).ln()
}
}
#[allow(dead_code)]
pub fn kl_div<T>(x: T, y: T) -> T
where
T: Float + FromPrimitive + Zero,
{
if x.is_zero() {
y
} else if y.is_zero() {
T::infinity()
} else {
x * (x / y).ln() - x + y
}
}
#[allow(dead_code)]
pub fn huber<T>(delta: T, r: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Display,
{
check_finite(delta, "delta value")?;
check_finite(r, "r value")?;
if delta <= T::zero() {
return Err(SpecialError::DomainError(
"huber: delta must be positive".to_string(),
));
}
let abs_r = r.abs();
if abs_r <= delta {
Ok(r * r / T::from_f64(2.0).expect("Operation failed"))
} else {
Ok(delta * abs_r - delta * delta / T::from_f64(2.0).expect("Operation failed"))
}
}
#[allow(dead_code)]
pub fn pseudo_huber<T>(delta: T, r: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Display,
{
check_finite(delta, "delta value")?;
check_finite(r, "r value")?;
if delta <= T::zero() {
return Err(SpecialError::DomainError(
"pseudo_huber: delta must be positive".to_string(),
));
}
let r_overdelta = r / delta;
let delta_squared = delta * delta;
Ok(delta_squared * ((T::one() + r_overdelta * r_overdelta).sqrt() - T::one()))
}
#[allow(dead_code)]
pub fn entr_array<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + FromPrimitive + Zero + Send + Sync,
{
x.mapv(entr)
}
#[allow(dead_code)]
pub fn entropy<T>(p: &ArrayView1<T>) -> SpecialResult<T>
where
T: Float + FromPrimitive + Zero + Display + Debug,
{
for &pi in p.iter() {
check_non_negative(pi, "probability")?;
}
let mut h = T::zero();
for &pi in p.iter() {
h = h + entr(pi);
}
Ok(h)
}
#[allow(dead_code)]
pub fn kl_divergence<T>(p: &ArrayView1<T>, q: &ArrayView1<T>) -> SpecialResult<T>
where
T: Float + FromPrimitive + Zero + Display + Debug,
{
if p.len() != q.len() {
return Err(SpecialError::ValueError(
"kl_divergence: arrays must have the same length".to_string(),
));
}
let mut kl = T::zero();
for i in 0..p.len() {
kl = kl + rel_entr(p[i], q[i]);
}
Ok(kl)
}
#[allow(dead_code)]
pub fn huber_loss<T>(
delta: T,
predictions: &ArrayView1<T>,
targets: &ArrayView1<T>,
output: &mut ArrayViewMut1<T>,
) -> SpecialResult<()>
where
T: Float + FromPrimitive + Display + Debug,
{
if predictions.len() != targets.len() || predictions.len() != output.len() {
return Err(SpecialError::ValueError(
"huber_loss: all arrays must have the same length".to_string(),
));
}
for i in 0..predictions.len() {
let residual = predictions[i] - targets[i];
output[i] = huber(delta, residual)?;
}
Ok(())
}
#[allow(dead_code)]
pub fn binary_entropy<T>(p: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Display,
{
crate::validation::check_probability(p, "p")?;
if p.is_zero() || p == T::one() {
return Ok(T::zero());
}
Ok(entr(p) + entr(T::one() - p))
}
#[allow(dead_code)]
pub fn cross_entropy<T>(p: &ArrayView1<T>, q: &ArrayView1<T>) -> SpecialResult<T>
where
T: Float + FromPrimitive + Zero + Display + Debug,
{
if p.len() != q.len() {
return Err(SpecialError::ValueError(
"cross_entropy: arrays must have the same length".to_string(),
));
}
let mut ce = T::zero();
for i in 0..p.len() {
if p[i] > T::zero() {
if q[i].is_zero() {
return Ok(T::infinity());
}
ce = ce - p[i] * q[i].ln();
}
}
Ok(ce)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::arr1;
#[test]
fn test_entr() {
assert_eq!(entr(0.0), 0.0);
assert_relative_eq!(entr(0.5), 0.34657359027997264, epsilon = 1e-10);
assert_relative_eq!(entr(1.0), 0.0, epsilon = 1e-10);
assert!(entr(-1.0).is_infinite() && entr(-1.0) < 0.0);
}
#[test]
fn test_rel_entr() {
assert_eq!(rel_entr(0.0, 1.0), 0.0);
assert!(rel_entr(1.0, 0.0).is_infinite());
assert_relative_eq!(rel_entr(0.5, 0.5), 0.0, epsilon = 1e-10);
assert_relative_eq!(rel_entr(0.7, 0.3), 0.5931085022710425, epsilon = 1e-10);
}
#[test]
fn test_kl_div() {
assert_eq!(kl_div(0.0, 1.0), 1.0);
assert!(kl_div(1.0, 0.0).is_infinite());
assert_relative_eq!(kl_div(0.5, 0.5), 0.0, epsilon = 1e-10);
}
#[test]
fn test_huber() {
let delta = 1.0;
assert_relative_eq!(
huber(delta, 0.5).expect("Operation failed"),
0.125,
epsilon = 1e-10
);
assert_relative_eq!(
huber(delta, -0.5).expect("Operation failed"),
0.125,
epsilon = 1e-10
);
assert_relative_eq!(
huber(delta, 2.0).expect("Operation failed"),
1.5,
epsilon = 1e-10
);
assert_relative_eq!(
huber(delta, -2.0).expect("Operation failed"),
1.5,
epsilon = 1e-10
);
}
#[test]
fn test_pseudo_huber() {
let delta = 1.0;
assert_relative_eq!(
pseudo_huber(delta, 0.0).expect("Operation failed"),
0.0,
epsilon = 1e-10
);
assert_relative_eq!(
pseudo_huber(delta, 1.0).expect("Operation failed"),
0.41421356237309515,
epsilon = 1e-10
);
}
#[test]
fn test_entropy() {
let uniform = arr1(&[0.25, 0.25, 0.25, 0.25]);
let h = entropy(&uniform.view()).expect("Operation failed");
assert_relative_eq!(h, 1.3862943611198906, epsilon = 1e-10);
let certain = arr1(&[1.0, 0.0, 0.0, 0.0]);
let h = entropy(&certain.view()).expect("Operation failed");
assert_relative_eq!(h, 0.0, epsilon = 1e-10);
}
#[test]
fn test_kl_divergence() {
let p = arr1(&[0.5, 0.5]);
let q = arr1(&[0.9, 0.1]);
let kl = kl_divergence(&p.view(), &q.view()).expect("Operation failed");
assert!(kl > 0.0); }
#[test]
fn test_binary_entropy() {
assert_eq!(binary_entropy(0.0).expect("Operation failed"), 0.0);
assert_eq!(binary_entropy(1.0).expect("Operation failed"), 0.0);
assert_relative_eq!(
binary_entropy(0.5).expect("Operation failed"),
std::f64::consts::LN_2,
epsilon = 1e-10
); }
}