use num_traits::Float;
pub fn softmax<T: Float>(x: &[T]) -> Vec<T> {
let max = x.iter().cloned().fold(T::neg_infinity(), T::max);
let exps: Vec<T> = x.iter().map(|&v| (v - max).exp()).collect();
let sum = exps.iter().fold(T::zero(), |a, &b| a + b);
exps.into_iter().map(|v| v / sum).collect()
}
pub fn log_sum_exp<T: Float>(x: &[T]) -> T {
let max = x.iter().cloned().fold(T::neg_infinity(), T::max);
let sum = x.iter().map(|&v| (v - max).exp()).fold(T::zero(), |a, b| a + b);
max + sum.ln()
}
#[cfg(test)]
mod tests {
use super::*;
fn naive_softmax(x: &[f64]) -> Vec<f64> {
let exps: Vec<f64> = x.iter().map(|v| v.exp()).collect();
let sum: f64 = exps.iter().sum();
exps.into_iter().map(|v| v / sum).collect()
}
fn naive_log_sum_exp(x: &[f64]) -> f64 {
x.iter().map(|v| v.exp()).sum::<f64>().ln()
}
#[test]
fn matches_naive_for_small_values() {
let x = [1.0, 2.0, 3.0];
let stable = softmax(&x);
let naive = naive_softmax(&x);
for (s, n) in stable.iter().zip(naive.iter()) {
assert!((s - n).abs() < 1e-12, "stable={s}, naive={n}");
}
}
#[test]
fn sums_to_one() {
let x = [-3.0, 0.5, 2.0, 7.0];
let result = softmax(&x);
let sum: f64 = result.iter().sum();
assert!((sum - 1.0).abs() < 1e-9);
}
#[test]
fn handles_large_values_without_overflow() {
let x = [1000.0, 1000.0, 1000.0];
let result = softmax(&x);
for v in &result {
assert!(v.is_finite(), "expected finite, got {v}");
assert!((v - 1.0 / 3.0).abs() < 1e-9);
}
}
#[test]
fn log_sum_exp_matches_naive_for_small_values() {
let x = [1.0, 2.0, 3.0];
let stable = log_sum_exp(&x);
let naive = naive_log_sum_exp(&x);
assert!((stable - naive).abs() < 1e-12, "stable={stable}, naive={naive}");
}
#[test]
fn log_sum_exp_handles_large_values() {
let x = [1000.0, 1000.0, 1000.0];
let result = log_sum_exp(&x);
assert!(result.is_finite());
let expected = 1000.0 + 3.0_f64.ln();
assert!((result - expected).abs() < 1e-9);
}
}