use crate::error::{SpecialError, SpecialResult};
use std::f64::consts::LN_2;
fn validate_probs(probs: &[f64], name: &str) -> SpecialResult<()> {
for (i, &p) in probs.iter().enumerate() {
if !p.is_finite() {
return Err(SpecialError::ValueError(format!(
"{name}[{i}] = {p} is not finite"
)));
}
if p < 0.0 {
return Err(SpecialError::ValueError(format!(
"{name}[{i}] = {p} is negative"
)));
}
}
Ok(())
}
fn same_length(a: &[f64], b: &[f64]) -> SpecialResult<()> {
if a.len() != b.len() {
Err(SpecialError::ValueError(format!(
"slice lengths differ: {} vs {}",
a.len(),
b.len()
)))
} else {
Ok(())
}
}
pub fn binary_entropy(p: f64) -> f64 {
let p = p.clamp(0.0, 1.0);
if p == 0.0 || p == 1.0 {
return 0.0;
}
-(p * p.log2() + (1.0 - p) * (1.0 - p).log2())
}
pub fn binary_entropy_nats(p: f64) -> f64 {
binary_entropy(p) * LN_2
}
pub fn entropy(probs: &[f64], base: f64) -> SpecialResult<f64> {
if base <= 0.0 || base == 1.0 {
return Err(SpecialError::ValueError(
"entropy: base must be > 0 and ≠ 1".to_string(),
));
}
validate_probs(probs, "probs")?;
let ln_base = base.ln();
let h = probs
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln() / ln_base)
.sum();
Ok(h)
}
pub fn kl_divergence(p: &[f64], q: &[f64]) -> SpecialResult<f64> {
same_length(p, q)?;
validate_probs(p, "p")?;
validate_probs(q, "q")?;
let mut kl = 0.0;
for (&pi, &qi) in p.iter().zip(q.iter()) {
if pi == 0.0 {
continue;
}
if qi == 0.0 {
return Ok(f64::INFINITY);
}
kl += pi * (pi / qi).ln();
}
Ok(kl)
}
pub fn js_divergence(p: &[f64], q: &[f64]) -> SpecialResult<f64> {
same_length(p, q)?;
validate_probs(p, "p")?;
validate_probs(q, "q")?;
let m: Vec<f64> = p.iter().zip(q.iter()).map(|(&pi, &qi)| 0.5 * (pi + qi)).collect();
let kl_pm = kl_divergence(p, &m)?;
let kl_qm = kl_divergence(q, &m)?;
Ok(0.5 * (kl_pm + kl_qm))
}
pub fn renyi_entropy(probs: &[f64], alpha: f64) -> SpecialResult<f64> {
if alpha < 0.0 {
return Err(SpecialError::DomainError(
"renyi_entropy: alpha must be ≥ 0".to_string(),
));
}
validate_probs(probs, "probs")?;
if (alpha - 1.0).abs() < 1e-12 {
return Ok(probs
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln())
.sum());
}
if alpha == 0.0 {
let support = probs.iter().filter(|&&p| p > 0.0).count();
return Ok((support as f64).ln());
}
if alpha == f64::INFINITY {
let max_p = probs.iter().cloned().fold(0.0f64, f64::max);
if max_p == 0.0 {
return Err(SpecialError::DomainError(
"renyi_entropy: all probabilities are zero".to_string(),
));
}
return Ok(-max_p.ln());
}
let sum_pow: f64 = probs.iter().map(|&p| p.powf(alpha)).sum();
if sum_pow == 0.0 {
return Err(SpecialError::DomainError(
"renyi_entropy: Σ p_i^α = 0".to_string(),
));
}
Ok(sum_pow.ln() / (1.0 - alpha))
}
pub fn tsallis_entropy(probs: &[f64], q: f64) -> SpecialResult<f64> {
if q < 0.0 {
return Err(SpecialError::DomainError(
"tsallis_entropy: q must be ≥ 0".to_string(),
));
}
validate_probs(probs, "probs")?;
if (q - 1.0).abs() < 1e-12 {
return Ok(probs
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln())
.sum());
}
let sum_pow: f64 = probs.iter().map(|&p| p.powf(q)).sum();
Ok((1.0 - sum_pow) / (q - 1.0))
}
pub fn sigmoid(x: f64) -> f64 {
if x >= 0.0 {
let e = (-x).exp();
1.0 / (1.0 + e)
} else {
let e = x.exp();
e / (1.0 + e)
}
}
pub fn sigmoid_derivative(x: f64) -> f64 {
let s = sigmoid(x);
s * (1.0 - s)
}
pub fn log_sum_exp(xs: &[f64]) -> f64 {
if xs.is_empty() {
return f64::NEG_INFINITY;
}
let max = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max.is_infinite() {
return max;
}
let sum: f64 = xs.iter().map(|&x| (x - max).exp()).sum();
max + sum.ln()
}
pub fn softplus(x: f64) -> f64 {
if x > 20.0 {
x } else if x < -20.0 {
x.exp() } else {
(1.0 + x.exp()).ln()
}
}
pub fn softmax(xs: &[f64]) -> Vec<f64> {
if xs.is_empty() {
return Vec::new();
}
let max = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = xs.iter().map(|&x| (x - max).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.into_iter().map(|e| e / sum).collect()
}
pub fn log_softmax(xs: &[f64]) -> Vec<f64> {
if xs.is_empty() {
return Vec::new();
}
let lse = log_sum_exp(xs);
xs.iter().map(|&x| x - lse).collect()
}
pub fn gumbel_softmax(logits: &[f64], temperature: f64, seed: u64) -> SpecialResult<Vec<f64>> {
if logits.is_empty() {
return Err(SpecialError::ValueError(
"gumbel_softmax: logits must not be empty".to_string(),
));
}
if temperature <= 0.0 {
return Err(SpecialError::DomainError(
"gumbel_softmax: temperature must be > 0".to_string(),
));
}
let gumbel_samples: Vec<f64> = {
let mut state = seed.wrapping_add(0x853c49e6748fea9b);
(0..logits.len())
.map(|_| {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let u = (state as f64) / (u64::MAX as f64);
let u_clamped = u.clamp(1e-38, 1.0 - 1e-15);
-(-u_clamped.ln()).ln()
})
.collect()
};
let perturbed: Vec<f64> = logits
.iter()
.zip(gumbel_samples.iter())
.map(|(&l, &g)| (l + g) / temperature)
.collect();
Ok(softmax(&perturbed))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_binary_entropy_bits() {
assert_relative_eq!(binary_entropy(0.5), 1.0, epsilon = 1e-12);
assert_eq!(binary_entropy(0.0), 0.0);
assert_eq!(binary_entropy(1.0), 0.0);
assert!(binary_entropy(0.3) > 0.0);
assert!(binary_entropy(0.3) < 1.0);
}
#[test]
fn test_binary_entropy_nats() {
use std::f64::consts::LN_2;
assert_relative_eq!(binary_entropy_nats(0.5), LN_2, epsilon = 1e-12);
}
#[test]
fn test_entropy_uniform() {
let h_bits = entropy(&[0.25, 0.25, 0.25, 0.25], 2.0).expect("ok");
assert_relative_eq!(h_bits, 2.0, epsilon = 1e-12);
let h_nats = entropy(&[0.25, 0.25, 0.25, 0.25], std::f64::consts::E).expect("ok");
assert_relative_eq!(h_nats, (4.0f64).ln(), epsilon = 1e-12);
}
#[test]
fn test_entropy_certain() {
let h = entropy(&[1.0, 0.0, 0.0], 2.0).expect("ok");
assert_relative_eq!(h, 0.0, epsilon = 1e-14);
}
#[test]
fn test_entropy_invalid_base() {
assert!(entropy(&[0.5, 0.5], 1.0).is_err());
assert!(entropy(&[0.5, 0.5], -2.0).is_err());
}
#[test]
fn test_kl_divergence_zero() {
let d = kl_divergence(&[0.5, 0.5], &[0.5, 0.5]).expect("ok");
assert_relative_eq!(d, 0.0, epsilon = 1e-12);
}
#[test]
fn test_kl_divergence_degenerate() {
let d = kl_divergence(&[1.0, 0.0], &[1.0, 0.0]).expect("ok");
assert_relative_eq!(d, 0.0, epsilon = 1e-12);
}
#[test]
fn test_kl_divergence_nonneg() {
let p = [0.7, 0.3];
let q = [0.4, 0.6];
let kl = kl_divergence(&p, &q).expect("ok");
assert!(kl >= 0.0);
}
#[test]
fn test_kl_divergence_infinity() {
let kl = kl_divergence(&[0.5, 0.5], &[1.0, 0.0]).expect("ok");
assert!(kl.is_infinite());
}
#[test]
fn test_js_divergence_symmetric() {
let p = [0.7, 0.3];
let q = [0.4, 0.6];
let pq = js_divergence(&p, &q).expect("ok");
let qp = js_divergence(&q, &p).expect("ok");
assert_relative_eq!(pq, qp, epsilon = 1e-12);
assert!(pq >= 0.0);
assert!(pq <= 1.0); }
#[test]
fn test_js_divergence_identical() {
let p = [0.3, 0.4, 0.3];
let jsd = js_divergence(&p, &p).expect("ok");
assert_relative_eq!(jsd, 0.0, epsilon = 1e-12);
}
#[test]
fn test_renyi_entropy_uniform() {
let probs = [0.25, 0.25, 0.25, 0.25];
for &alpha in &[0.5, 2.0, 3.0] {
let h = renyi_entropy(&probs, alpha).expect("ok");
assert_relative_eq!(h, (4.0f64).ln(), epsilon = 1e-10);
}
}
#[test]
fn test_renyi_limit_shannon() {
let probs = [0.5, 0.3, 0.2];
let h_shannon = entropy(&probs, std::f64::consts::E).expect("ok");
let h_renyi = renyi_entropy(&probs, 1.0).expect("ok");
assert_relative_eq!(h_shannon, h_renyi, epsilon = 1e-12);
}
#[test]
fn test_tsallis_entropy() {
let h = tsallis_entropy(&[0.5, 0.5], 2.0).expect("ok");
assert_relative_eq!(h, 0.5, epsilon = 1e-12);
}
#[test]
fn test_tsallis_limit_shannon() {
let probs = [0.5, 0.3, 0.2];
let h_shannon = entropy(&probs, std::f64::consts::E).expect("ok");
let h_tsallis = tsallis_entropy(&probs, 1.0).expect("ok");
assert_relative_eq!(h_shannon, h_tsallis, epsilon = 1e-12);
}
#[test]
fn test_sigmoid() {
assert_relative_eq!(sigmoid(0.0), 0.5, epsilon = 1e-14);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
assert_relative_eq!(sigmoid(-2.0), 1.0 - sigmoid(2.0), epsilon = 1e-14);
}
#[test]
fn test_sigmoid_derivative() {
assert_relative_eq!(sigmoid_derivative(0.0), 0.25, epsilon = 1e-14);
let s = sigmoid(1.5);
assert_relative_eq!(sigmoid_derivative(1.5), s * (1.0 - s), epsilon = 1e-14);
}
#[test]
fn test_log_sum_exp() {
let v = log_sum_exp(&[1.0, 2.0]);
let expected = (1.0f64.exp() + 2.0f64.exp()).ln();
assert_relative_eq!(v, expected, epsilon = 1e-12);
assert_relative_eq!(log_sum_exp(&[5.0]), 5.0, epsilon = 1e-14);
assert!(log_sum_exp(&[]).is_infinite() && log_sum_exp(&[]).is_sign_negative());
}
#[test]
fn test_softplus() {
assert_relative_eq!(softplus(0.0), LN_2, epsilon = 1e-12);
assert_relative_eq!(softplus(50.0), 50.0, epsilon = 0.001);
assert!(softplus(-50.0) < 1e-10);
}
#[test]
fn test_softmax() {
let s = softmax(&[1.0, 2.0, 3.0]);
let sum: f64 = s.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-12);
assert!(s.iter().all(|&v| v > 0.0));
assert!(s[2] > s[1] && s[1] > s[0]);
}
#[test]
fn test_softmax_empty() {
let s = softmax(&[]);
assert!(s.is_empty());
}
#[test]
fn test_log_softmax() {
let xs = [1.0, 2.0, 3.0];
let ls = log_softmax(&xs);
let s = softmax(&xs);
for (l, sv) in ls.iter().zip(s.iter()) {
assert_relative_eq!(*l, sv.ln(), epsilon = 1e-12);
}
}
#[test]
fn test_gumbel_softmax_valid() {
let gs = gumbel_softmax(&[1.0, 2.0, 3.0], 1.0, 42).expect("ok");
assert_eq!(gs.len(), 3);
let sum: f64 = gs.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-12);
assert!(gs.iter().all(|&v| v > 0.0));
}
#[test]
fn test_gumbel_softmax_low_temp() {
let gs = gumbel_softmax(&[0.0, 10.0, 0.0], 0.01, 7).expect("ok");
assert!(gs[1] > 0.99);
}
#[test]
fn test_gumbel_softmax_errors() {
assert!(gumbel_softmax(&[], 1.0, 0).is_err());
assert!(gumbel_softmax(&[1.0], 0.0, 0).is_err());
assert!(gumbel_softmax(&[1.0], -1.0, 0).is_err());
}
}