Skip to main content

rlx_diamond/
reward.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pluggable reward functions for Diamond guidance.
17
18/// Reward on clean latent / decoded state z (data at t=1).
19pub trait LatentReward: Send + Sync {
20    fn reward(&self, z: &[f32]) -> f32;
21    /// ∂r/∂z (same length as z).
22    fn grad_wrt_z(&self, z: &[f32]) -> Vec<f32>;
23}
24
25/// Proxy “blueness” on packed latent: maximize channel index 2 in each 3-group.
26#[derive(Debug, Clone, Copy, Default)]
27pub struct BluenessReward {
28    pub scale: f32,
29}
30
31impl LatentReward for BluenessReward {
32    fn reward(&self, z: &[f32]) -> f32 {
33        let s: f32 = z.chunks(3).map(|c| c.get(2).copied().unwrap_or(0.0)).sum();
34        self.scale * s
35    }
36
37    fn grad_wrt_z(&self, z: &[f32]) -> Vec<f32> {
38        let mut g = vec![0.0f32; z.len()];
39        for (i, chunk) in z.chunks(3).enumerate() {
40            if chunk.len() > 2 {
41                g[i * 3 + 2] = self.scale;
42            }
43        }
44        g
45    }
46}
47
48/// Linear measurement y = A z + noise; reward = -||y - A z||².
49#[derive(Debug, Clone)]
50pub struct LinearMeasurementReward {
51    pub matrix: Vec<f32>,
52    pub measurement: Vec<f32>,
53    pub rows: usize,
54    pub cols: usize,
55}
56
57impl LinearMeasurementReward {
58    pub fn new(matrix: Vec<f32>, measurement: Vec<f32>, rows: usize, cols: usize) -> Self {
59        assert_eq!(matrix.len(), rows * cols);
60        assert_eq!(measurement.len(), rows);
61        Self {
62            matrix,
63            measurement,
64            rows,
65            cols,
66        }
67    }
68
69    fn matvec(&self, z: &[f32]) -> Vec<f32> {
70        let mut out = vec![0.0f32; self.rows];
71        for r in 0..self.rows {
72            let mut acc = 0.0f32;
73            for c in 0..self.cols.min(z.len()) {
74                acc += self.matrix[r * self.cols + c] * z[c];
75            }
76            out[r] = acc;
77        }
78        out
79    }
80}
81
82impl LatentReward for LinearMeasurementReward {
83    fn reward(&self, z: &[f32]) -> f32 {
84        let pred = self.matvec(z);
85        let err: f32 = pred
86            .iter()
87            .zip(self.measurement.iter())
88            .map(|(p, m)| (p - m).powi(2))
89            .sum();
90        -err
91    }
92
93    fn grad_wrt_z(&self, z: &[f32]) -> Vec<f32> {
94        let pred = self.matvec(z);
95        let mut g = vec![0.0f32; z.len()];
96        for r in 0..self.rows {
97            let residual = 2.0 * (pred[r] - self.measurement[r]);
98            for c in 0..self.cols.min(z.len()) {
99                g[c] -= residual * self.matrix[r * self.cols + c];
100            }
101        }
102        g
103    }
104}
105
106/// Chain rule proxy: ∂r/∂x_t ≈ ∂r/∂z when z is a posterior sample near x_t.
107pub fn grad_xt_via_z(grad_z: &[f32]) -> Vec<f32> {
108    grad_z.to_vec()
109}
110
111/// SPSA estimate of ∂V/∂x_t using random Rademacher directions.
112pub fn spsa_grad(
113    x_t: &[f32],
114    mut eval_v: impl FnMut(&[f32]) -> f32,
115    eps: f32,
116    num_dirs: usize,
117    seed: u64,
118) -> Vec<f32> {
119    let dim = x_t.len();
120    let mut grad = vec![0.0f32; dim];
121    let mut state = seed;
122    for _ in 0..num_dirs {
123        let mut delta = vec![0.0f32; dim];
124        for d in &mut delta {
125            state ^= state << 13;
126            state ^= state >> 7;
127            state ^= state << 17;
128            *d = if state & 1 == 0 { 1.0 } else { -1.0 };
129        }
130        let mut xp = x_t.to_vec();
131        let mut xm = x_t.to_vec();
132        for i in 0..dim {
133            xp[i] += eps * delta[i];
134            xm[i] -= eps * delta[i];
135        }
136        let vp = eval_v(&xp);
137        let vm = eval_v(&xm);
138        let scale = (vp - vm) / (2.0 * eps);
139        for i in 0..dim {
140            grad[i] += scale * delta[i];
141        }
142    }
143    let inv = 1.0 / num_dirs as f32;
144    grad.iter_mut().for_each(|g| *g *= inv);
145    grad
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn blueness_grad_sparsity() {
154        let r = BluenessReward { scale: 1.0 };
155        let z = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
156        let g = r.grad_wrt_z(&z);
157        assert!((g[2] - 1.0).abs() < 1e-6);
158        assert!((g[5] - 1.0).abs() < 1e-6);
159        assert!(g[0].abs() < 1e-6);
160    }
161}