pub trait Regularizer {
fn omega(&self, p: &[f64]) -> f64;
fn predict(&self, theta: &[f64]) -> Vec<f64>;
fn conjugate(&self, theta: &[f64]) -> f64;
fn loss(&self, theta: &[f64], y: &[f64]) -> f64 {
if theta.len() != y.len() {
return f64::NAN;
}
let inner: f64 = theta.iter().zip(y).map(|(t, yi)| t * yi).sum();
self.conjugate(theta) - inner + self.omega(y)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Shannon;
impl Regularizer for Shannon {
fn omega(&self, p: &[f64]) -> f64 {
p.iter()
.filter(|&&pi| pi > 0.0)
.map(|&pi| pi * pi.ln())
.sum()
}
fn predict(&self, theta: &[f64]) -> Vec<f64> {
softmax(theta)
}
fn conjugate(&self, theta: &[f64]) -> f64 {
log_sum_exp(theta)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SquaredL2;
impl Regularizer for SquaredL2 {
fn omega(&self, p: &[f64]) -> f64 {
0.5 * p.iter().map(|&x| x * x).sum::<f64>()
}
fn predict(&self, theta: &[f64]) -> Vec<f64> {
sparsemax(theta)
}
fn conjugate(&self, theta: &[f64]) -> f64 {
let p = sparsemax(theta);
let inner: f64 = theta.iter().zip(&p).map(|(&t, &pi)| t * pi).sum();
let norm_sq: f64 = p.iter().map(|&x| x * x).sum();
inner - 0.5 * norm_sq
}
fn loss(&self, theta: &[f64], y: &[f64]) -> f64 {
sparsemax_loss(theta, y)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Tsallis {
pub alpha: f64,
}
impl Tsallis {
pub fn new(alpha: f64) -> Self {
assert!(alpha > 0.0, "alpha must be positive");
Self { alpha }
}
pub fn entmax15() -> Self {
Self::new(1.5)
}
}
impl Default for Tsallis {
fn default() -> Self {
Self::new(1.5)
}
}
impl Regularizer for Tsallis {
fn omega(&self, p: &[f64]) -> f64 {
if (self.alpha - 1.0).abs() < 1e-10 {
return Shannon.omega(p);
}
let sum_powers: f64 = p.iter().map(|&pi| pi.powf(self.alpha)).sum();
(sum_powers - 1.0) / (self.alpha * (self.alpha - 1.0))
}
fn predict(&self, theta: &[f64]) -> Vec<f64> {
if (self.alpha - 1.0).abs() < 1e-10 {
return softmax(theta);
}
if (self.alpha - 2.0).abs() < 1e-10 {
return sparsemax(theta);
}
entmax(theta, self.alpha)
}
fn conjugate(&self, theta: &[f64]) -> f64 {
if (self.alpha - 1.0).abs() < 1e-10 {
return Shannon.conjugate(theta);
}
let p = self.predict(theta);
let inner: f64 = theta.iter().zip(&p).map(|(t, pi)| t * pi).sum();
inner - self.omega(&p)
}
}
pub fn softmax(theta: &[f64]) -> Vec<f64> {
if theta.is_empty() {
return vec![];
}
let max = theta.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exps: Vec<f64> = theta.iter().map(|&t| (t - max).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
fn log_sum_exp(theta: &[f64]) -> f64 {
if theta.is_empty() {
return f64::NEG_INFINITY;
}
let max = theta.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
if max == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
max + theta.iter().map(|&t| (t - max).exp()).sum::<f64>().ln()
}
pub fn sparsemax(theta: &[f64]) -> Vec<f64> {
if theta.is_empty() {
return vec![];
}
let mut sorted: Vec<f64> = theta.to_vec();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut k = 0;
for (j, &s) in sorted.iter().enumerate() {
cumsum += s;
if 1.0 + (j + 1) as f64 * s > cumsum {
k = j + 1;
}
}
let tau = (sorted[..k].iter().sum::<f64>() - 1.0) / k as f64;
theta.iter().map(|&t| (t - tau).max(0.0)).collect()
}
pub fn sparsemax_loss(theta: &[f64], y: &[f64]) -> f64 {
if theta.len() != y.len() || theta.is_empty() {
return f64::NAN;
}
let p = sparsemax(theta);
let y_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
let p_sq: f64 = p.iter().map(|&pi| pi * pi).sum();
let diff_inner: f64 = p
.iter()
.zip(y)
.zip(theta)
.map(|((&pi, &yi), &ti)| (pi - yi) * ti)
.sum();
0.5 * (y_sq - p_sq) + diff_inner
}
pub fn entmax(theta: &[f64], alpha: f64) -> Vec<f64> {
if theta.is_empty() {
return vec![];
}
if (alpha - 1.0).abs() < 1e-10 {
return softmax(theta);
}
if (alpha - 2.0).abs() < 1e-10 {
return sparsemax(theta);
}
let max_theta = theta.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut tau_lo = max_theta - 10.0;
let mut tau_hi = max_theta;
let inv_alpha_m1 = 1.0 / (alpha - 1.0);
for _ in 0..50 {
let tau = (tau_lo + tau_hi) / 2.0;
let sum: f64 = theta
.iter()
.map(|&t| ((t - tau).max(0.0)).powf(inv_alpha_m1))
.sum();
if sum < 1.0 {
tau_hi = tau;
} else {
tau_lo = tau;
}
}
let tau = (tau_lo + tau_hi) / 2.0;
let mut result: Vec<f64> = theta
.iter()
.map(|&t| ((t - tau).max(0.0)).powf(inv_alpha_m1))
.collect();
let sum: f64 = result.iter().sum();
if sum > 0.0 {
for r in &mut result {
*r /= sum;
}
}
result
}
pub fn fy_loss<R: Regularizer>(reg: &R, theta: &[f64], y: &[f64]) -> f64 {
reg.loss(theta, y)
}
pub fn softmax_with_temperature(theta: &[f64], temperature: f64) -> Vec<f64> {
if temperature <= 0.0 || !temperature.is_finite() {
if theta.is_empty() {
return vec![];
}
let n = theta.len() as f64;
return vec![1.0 / n; theta.len()];
}
let max = theta
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b / temperature));
let inv_temp = 1.0 / temperature;
let exps: Vec<f64> = theta.iter().map(|&t| (t * inv_temp - max).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
pub fn entropy_nats(p: &[f64]) -> f64 {
p.iter()
.filter(|&&pi| pi > 0.0)
.map(|&pi| -pi * pi.ln())
.sum()
}
pub fn entropy_bits(p: &[f64]) -> f64 {
entropy_nats(p) / std::f64::consts::LN_2
}
pub fn softmax_loss(theta: &[f64], y: &[f64]) -> f64 {
Shannon.loss(theta, y)
}
pub fn entmax15_loss(theta: &[f64], y: &[f64]) -> f64 {
Tsallis::entmax15().loss(theta, y)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_sums_to_one() {
let theta = [2.0, 1.0, 0.1, -1.0];
let p = softmax(&theta);
let sum: f64 = p.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_sparsemax_sums_to_one() {
let theta = [2.0, 1.0, 0.1, -1.0];
let p = sparsemax(&theta);
let sum: f64 = p.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_sparsemax_is_sparse() {
let theta = [2.0, 1.0, 0.1, -1.0];
let p = sparsemax(&theta);
let zeros = p.iter().filter(|&&x| x == 0.0).count();
assert!(zeros > 0, "sparsemax should produce zeros");
}
#[test]
fn test_entmax_interpolates() {
let theta = [2.0, 1.0, 0.1];
let p1 = entmax(&theta, 1.0);
assert!(p1.iter().all(|&x| x > 0.0), "α=1 should be dense");
let p2 = entmax(&theta, 2.0);
let sparse2 = sparsemax(&theta);
for (a, b) in p2.iter().zip(&sparse2) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_shannon_loss_equals_cross_entropy() {
let theta = [2.0, 1.0, 0.1];
let y = [1.0, 0.0, 0.0];
let loss = Shannon.loss(&theta, &y);
let p = softmax(&theta);
let ce: f64 = -y.iter().zip(&p).map(|(&yi, &pi)| yi * pi.ln()).sum::<f64>();
assert!((loss - ce).abs() < 1e-6, "loss={}, ce={}", loss, ce);
}
#[test]
fn test_fy_loss_nonnegative() {
let theta = [2.0, 1.0, 0.1];
let y = [0.5, 0.3, 0.2];
assert!(Shannon.loss(&theta, &y) >= -1e-10);
assert!(SquaredL2.loss(&theta, &y) >= -1e-10);
assert!(Tsallis::entmax15().loss(&theta, &y) >= -1e-10);
}
#[test]
fn test_softmax_temperature_scaling() {
let theta = [2.0, 1.0, 0.1];
let cold = softmax_with_temperature(&theta, 0.5);
let normal = softmax_with_temperature(&theta, 1.0);
let hot = softmax_with_temperature(&theta, 2.0);
assert!((cold.iter().sum::<f64>() - 1.0).abs() < 1e-10);
assert!((normal.iter().sum::<f64>() - 1.0).abs() < 1e-10);
assert!((hot.iter().sum::<f64>() - 1.0).abs() < 1e-10);
let regular = softmax(&theta);
for (a, b) in normal.iter().zip(®ular) {
assert!((a - b).abs() < 1e-10);
}
let h_cold = entropy_bits(&cold);
let h_normal = entropy_bits(&normal);
let h_hot = entropy_bits(&hot);
assert!(h_cold < h_normal, "cold={}, normal={}", h_cold, h_normal);
assert!(h_normal < h_hot, "normal={}, hot={}", h_normal, h_hot);
}
#[test]
fn test_entropy_uniform() {
let p = [0.25, 0.25, 0.25, 0.25];
let h = entropy_bits(&p);
assert!((h - 2.0).abs() < 1e-10);
}
#[test]
fn test_fy_loss_zero_at_prediction() {
let theta = [2.0, 1.0, 0.1];
let y_shannon = Shannon.predict(&theta);
let loss_shannon = Shannon.loss(&theta, &y_shannon);
assert!(
loss_shannon.abs() < 1e-6,
"Shannon loss at prediction: {}",
loss_shannon
);
let y_sparse = SquaredL2.predict(&theta);
let loss_sparse = SquaredL2.loss(&theta, &y_sparse);
assert!(
loss_sparse.abs() < 1e-6,
"Sparsemax loss at prediction: {}",
loss_sparse
);
}
#[test]
fn test_squared_l2_conjugate_definition() {
let theta = [2.0, 1.0, 0.1, -1.0];
let p_star = sparsemax(&theta);
let inner: f64 = theta.iter().zip(&p_star).map(|(&t, &p)| t * p).sum();
let omega_p = 0.5 * p_star.iter().map(|&x| x * x).sum::<f64>();
let expected = inner - omega_p;
let actual = SquaredL2.conjugate(&theta);
assert!(
(actual - expected).abs() < 1e-10,
"conjugate={}, expected={}",
actual,
expected
);
}
}