const NEG_INF: f32 = -1e30;
pub fn log_mean_exp(rewards: &[f32]) -> f32 {
assert!(!rewards.is_empty());
let max_r = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum: f32 = rewards.iter().map(|r| (r - max_r).exp()).sum();
max_r + (sum / rewards.len() as f32).ln()
}
pub fn softmax_weights(rewards: &[f32]) -> Vec<f32> {
assert!(!rewards.is_empty());
let max_r = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = rewards.iter().map(|r| (r - max_r).exp()).collect();
let sum: f32 = exp.iter().sum();
exp.iter().map(|e| e / sum).collect()
}
#[derive(Debug, Clone)]
pub struct LogSumExpGrad {
logsumexp: f32,
grad: Vec<f32>,
}
impl LogSumExpGrad {
pub fn new(dim: usize) -> Self {
Self {
logsumexp: NEG_INF,
grad: vec![0.0; dim],
}
}
pub fn accumulate(&mut self, reward: f32, grad: &[f32]) {
assert_eq!(self.grad.len(), grad.len());
if !self.logsumexp.is_finite() || self.logsumexp <= NEG_INF / 2.0 {
self.logsumexp = reward;
self.grad.copy_from_slice(grad);
return;
}
let logsumexp_next = logaddexp(self.logsumexp, reward);
let w_prev = (self.logsumexp - logsumexp_next).exp();
let w_curr = (reward - logsumexp_next).exp();
for (g, &dg) in self.grad.iter_mut().zip(grad.iter()) {
*g = *g * w_prev + dg * w_curr;
}
self.logsumexp = logsumexp_next;
}
pub fn value(&self) -> f32 {
self.logsumexp
}
pub fn grad(&self) -> &[f32] {
&self.grad
}
}
#[inline]
pub fn logaddexp(a: f32, b: f32) -> f32 {
if a > b {
a + (1.0 + (b - a).exp()).ln()
} else {
b + (1.0 + (a - b).exp()).ln()
}
}
pub fn softmax_grad_aggregate(rewards: &[f32], grads: &[Vec<f32>]) -> Vec<f32> {
assert!(!rewards.is_empty());
assert_eq!(rewards.len(), grads.len());
let weights = softmax_weights(rewards);
let dim = grads[0].len();
let mut out = vec![0.0f32; dim];
for (w, g) in weights.iter().zip(grads.iter()) {
for (o, &gi) in out.iter_mut().zip(g.iter()) {
*o += w * gi;
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn log_mean_exp_two_particles() {
let r = [0.0f32, 1.0];
let v = log_mean_exp(&r);
let expected = (0.5f32 * (1.0 + std::f32::consts::E)).ln();
assert!((v - expected).abs() < 1e-4);
}
#[test]
fn online_matches_batch_softmax() {
let rewards = [0.1f32, 0.5, 0.2];
let grads = [vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let batch = softmax_grad_aggregate(&rewards, &grads);
let mut online = LogSumExpGrad::new(2);
for (r, g) in rewards.iter().zip(grads.iter()) {
online.accumulate(*r, g);
}
for i in 0..2 {
assert!((batch[i] - online.grad()[i]).abs() < 1e-5);
}
}
}