1const NEG_INF: f32 = -1e30;
19
20pub fn log_mean_exp(rewards: &[f32]) -> f32 {
22 assert!(!rewards.is_empty());
23 let max_r = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
24 let sum: f32 = rewards.iter().map(|r| (r - max_r).exp()).sum();
25 max_r + (sum / rewards.len() as f32).ln()
26}
27
28pub fn softmax_weights(rewards: &[f32]) -> Vec<f32> {
30 assert!(!rewards.is_empty());
31 let max_r = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
32 let exp: Vec<f32> = rewards.iter().map(|r| (r - max_r).exp()).collect();
33 let sum: f32 = exp.iter().sum();
34 exp.iter().map(|e| e / sum).collect()
35}
36
37#[derive(Debug, Clone)]
39pub struct LogSumExpGrad {
40 logsumexp: f32,
41 grad: Vec<f32>,
42}
43
44impl LogSumExpGrad {
45 pub fn new(dim: usize) -> Self {
46 Self {
47 logsumexp: NEG_INF,
48 grad: vec![0.0; dim],
49 }
50 }
51
52 pub fn accumulate(&mut self, reward: f32, grad: &[f32]) {
54 assert_eq!(self.grad.len(), grad.len());
55 if !self.logsumexp.is_finite() || self.logsumexp <= NEG_INF / 2.0 {
56 self.logsumexp = reward;
57 self.grad.copy_from_slice(grad);
58 return;
59 }
60 let logsumexp_next = logaddexp(self.logsumexp, reward);
61 let w_prev = (self.logsumexp - logsumexp_next).exp();
62 let w_curr = (reward - logsumexp_next).exp();
63 for (g, &dg) in self.grad.iter_mut().zip(grad.iter()) {
64 *g = *g * w_prev + dg * w_curr;
65 }
66 self.logsumexp = logsumexp_next;
67 }
68
69 pub fn value(&self) -> f32 {
70 self.logsumexp
71 }
72
73 pub fn grad(&self) -> &[f32] {
74 &self.grad
75 }
76}
77
78#[inline]
79pub fn logaddexp(a: f32, b: f32) -> f32 {
80 if a > b {
81 a + (1.0 + (b - a).exp()).ln()
82 } else {
83 b + (1.0 + (a - b).exp()).ln()
84 }
85}
86
87pub fn softmax_grad_aggregate(rewards: &[f32], grads: &[Vec<f32>]) -> Vec<f32> {
89 assert!(!rewards.is_empty());
90 assert_eq!(rewards.len(), grads.len());
91 let weights = softmax_weights(rewards);
92 let dim = grads[0].len();
93 let mut out = vec![0.0f32; dim];
94 for (w, g) in weights.iter().zip(grads.iter()) {
95 for (o, &gi) in out.iter_mut().zip(g.iter()) {
96 *o += w * gi;
97 }
98 }
99 out
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn log_mean_exp_two_particles() {
108 let r = [0.0f32, 1.0];
109 let v = log_mean_exp(&r);
110 let expected = (0.5f32 * (1.0 + std::f32::consts::E)).ln();
111 assert!((v - expected).abs() < 1e-4);
112 }
113
114 #[test]
115 fn online_matches_batch_softmax() {
116 let rewards = [0.1f32, 0.5, 0.2];
117 let grads = [vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
118 let batch = softmax_grad_aggregate(&rewards, &grads);
119 let mut online = LogSumExpGrad::new(2);
120 for (r, g) in rewards.iter().zip(grads.iter()) {
121 online.accumulate(*r, g);
122 }
123 for i in 0..2 {
124 assert!((batch[i] - online.grad()[i]).abs() < 1e-5);
125 }
126 }
127}