#[inline]
pub fn shannon_entropy(probs: &[f32]) -> f32 {
let mut entropy = 0.0;
for &p in probs {
if p > 0.0 {
entropy -= p * p.ln();
}
}
entropy
}
#[inline]
pub fn shannon_entropy_from_logits(logits: &[f32]) -> f32 {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum_exp = 0.0;
for &logit in logits {
sum_exp += (logit - max_logit).exp();
}
let log_z = max_logit + sum_exp.ln();
let mut entropy = 0.0;
for &logit in logits {
let log_p = logit - log_z;
let p = log_p.exp();
entropy -= p * log_p;
}
entropy
}
#[inline]
pub fn logit_variance(logits: &[f32]) -> f32 {
let n = logits.len() as f32;
if n == 0.0 { return 0.0; }
let mean = logits.iter().sum::<f32>() / n;
let sum_sq_diff: f32 = logits.iter().map(|&x| (x - mean).powi(2)).sum();
sum_sq_diff / n
}
#[inline]
pub fn max_probability_from_logits(logits: &[f32]) -> f32 {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum_exp = 0.0;
for &logit in logits {
sum_exp += (logit - max_logit).exp();
}
1.0 / sum_exp
}
#[inline]
pub fn logit_l2_norm(logits: &[f32]) -> f32 {
let sum_sq: f32 = logits.iter().map(|&x| x * x).sum();
sum_sq.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_entropy_uniform() {
let logits = vec![0.0, 0.0];
let h = shannon_entropy_from_logits(&logits);
assert!((h - 0.693147).abs() < 1e-4);
}
#[test]
fn test_variance() {
let logits = vec![1.0, 2.0, 3.0];
let v = logit_variance(&logits);
assert!((v - 0.666666).abs() < 1e-4);
}
#[test]
fn test_max_prob() {
let logits = vec![0.0, 100.0];
let p = max_probability_from_logits(&logits);
assert!(p > 0.9999);
assert!(p <= 1.0);
}
#[test]
fn test_l2_norm() {
let logits = vec![3.0, 4.0];
let n = logit_l2_norm(&logits);
assert!((n - 5.0).abs() < 1e-4);
}
}