Skip to main content

rlx_diamond/
value.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Value function estimators (Diamond Maps Proposition 4.1).
17
18const NEG_INF: f32 = -1e30;
19
20/// log (1/K Σ exp(r_k)) — stable log-mean-exp.
21pub fn log_mean_exp(rewards: &[f32]) -> f32 {
22    assert!(!rewards.is_empty());
23    let max_r = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
24    let sum: f32 = rewards.iter().map(|r| (r - max_r).exp()).sum();
25    max_r + (sum / rewards.len() as f32).ln()
26}
27
28/// Softmax weights exp(r_k) / Σ exp(r_j).
29pub fn softmax_weights(rewards: &[f32]) -> Vec<f32> {
30    assert!(!rewards.is_empty());
31    let max_r = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
32    let exp: Vec<f32> = rewards.iter().map(|r| (r - max_r).exp()).collect();
33    let sum: f32 = exp.iter().sum();
34    exp.iter().map(|e| e / sum).collect()
35}
36
37/// Online log-sum-exp accumulation (reference `make_guidance_value_and_grad_fn`).
38#[derive(Debug, Clone)]
39pub struct LogSumExpGrad {
40    logsumexp: f32,
41    grad: Vec<f32>,
42}
43
44impl LogSumExpGrad {
45    pub fn new(dim: usize) -> Self {
46        Self {
47            logsumexp: NEG_INF,
48            grad: vec![0.0; dim],
49        }
50    }
51
52    /// Incorporate one particle: reward value and ∂r/∂x_t.
53    pub fn accumulate(&mut self, reward: f32, grad: &[f32]) {
54        assert_eq!(self.grad.len(), grad.len());
55        if !self.logsumexp.is_finite() || self.logsumexp <= NEG_INF / 2.0 {
56            self.logsumexp = reward;
57            self.grad.copy_from_slice(grad);
58            return;
59        }
60        let logsumexp_next = logaddexp(self.logsumexp, reward);
61        let w_prev = (self.logsumexp - logsumexp_next).exp();
62        let w_curr = (reward - logsumexp_next).exp();
63        for (g, &dg) in self.grad.iter_mut().zip(grad.iter()) {
64            *g = *g * w_prev + dg * w_curr;
65        }
66        self.logsumexp = logsumexp_next;
67    }
68
69    pub fn value(&self) -> f32 {
70        self.logsumexp
71    }
72
73    pub fn grad(&self) -> &[f32] {
74        &self.grad
75    }
76}
77
78#[inline]
79pub fn logaddexp(a: f32, b: f32) -> f32 {
80    if a > b {
81        a + (1.0 + (b - a).exp()).ln()
82    } else {
83        b + (1.0 + (a - b).exp()).ln()
84    }
85}
86
87/// Aggregate particle gradients: Σ softmax(r)_k ∇_{x_t} r(z^k).
88pub fn softmax_grad_aggregate(rewards: &[f32], grads: &[Vec<f32>]) -> Vec<f32> {
89    assert!(!rewards.is_empty());
90    assert_eq!(rewards.len(), grads.len());
91    let weights = softmax_weights(rewards);
92    let dim = grads[0].len();
93    let mut out = vec![0.0f32; dim];
94    for (w, g) in weights.iter().zip(grads.iter()) {
95        for (o, &gi) in out.iter_mut().zip(g.iter()) {
96            *o += w * gi;
97        }
98    }
99    out
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn log_mean_exp_two_particles() {
108        let r = [0.0f32, 1.0];
109        let v = log_mean_exp(&r);
110        let expected = (0.5f32 * (1.0 + std::f32::consts::E)).ln();
111        assert!((v - expected).abs() < 1e-4);
112    }
113
114    #[test]
115    fn online_matches_batch_softmax() {
116        let rewards = [0.1f32, 0.5, 0.2];
117        let grads = [vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
118        let batch = softmax_grad_aggregate(&rewards, &grads);
119        let mut online = LogSumExpGrad::new(2);
120        for (r, g) in rewards.iter().zip(grads.iter()) {
121            online.accumulate(*r, g);
122        }
123        for i in 0..2 {
124            assert!((batch[i] - online.grad()[i]).abs() < 1e-5);
125        }
126    }
127}