1pub trait LatentReward: Send + Sync {
20 fn reward(&self, z: &[f32]) -> f32;
21 fn grad_wrt_z(&self, z: &[f32]) -> Vec<f32>;
23}
24
25#[derive(Debug, Clone, Copy, Default)]
27pub struct BluenessReward {
28 pub scale: f32,
29}
30
31impl LatentReward for BluenessReward {
32 fn reward(&self, z: &[f32]) -> f32 {
33 let s: f32 = z.chunks(3).map(|c| c.get(2).copied().unwrap_or(0.0)).sum();
34 self.scale * s
35 }
36
37 fn grad_wrt_z(&self, z: &[f32]) -> Vec<f32> {
38 let mut g = vec![0.0f32; z.len()];
39 for (i, chunk) in z.chunks(3).enumerate() {
40 if chunk.len() > 2 {
41 g[i * 3 + 2] = self.scale;
42 }
43 }
44 g
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct LinearMeasurementReward {
51 pub matrix: Vec<f32>,
52 pub measurement: Vec<f32>,
53 pub rows: usize,
54 pub cols: usize,
55}
56
57impl LinearMeasurementReward {
58 pub fn new(matrix: Vec<f32>, measurement: Vec<f32>, rows: usize, cols: usize) -> Self {
59 assert_eq!(matrix.len(), rows * cols);
60 assert_eq!(measurement.len(), rows);
61 Self {
62 matrix,
63 measurement,
64 rows,
65 cols,
66 }
67 }
68
69 fn matvec(&self, z: &[f32]) -> Vec<f32> {
70 let mut out = vec![0.0f32; self.rows];
71 for r in 0..self.rows {
72 let mut acc = 0.0f32;
73 for c in 0..self.cols.min(z.len()) {
74 acc += self.matrix[r * self.cols + c] * z[c];
75 }
76 out[r] = acc;
77 }
78 out
79 }
80}
81
82impl LatentReward for LinearMeasurementReward {
83 fn reward(&self, z: &[f32]) -> f32 {
84 let pred = self.matvec(z);
85 let err: f32 = pred
86 .iter()
87 .zip(self.measurement.iter())
88 .map(|(p, m)| (p - m).powi(2))
89 .sum();
90 -err
91 }
92
93 fn grad_wrt_z(&self, z: &[f32]) -> Vec<f32> {
94 let pred = self.matvec(z);
95 let mut g = vec![0.0f32; z.len()];
96 for r in 0..self.rows {
97 let residual = 2.0 * (pred[r] - self.measurement[r]);
98 for c in 0..self.cols.min(z.len()) {
99 g[c] -= residual * self.matrix[r * self.cols + c];
100 }
101 }
102 g
103 }
104}
105
106pub fn grad_xt_via_z(grad_z: &[f32]) -> Vec<f32> {
108 grad_z.to_vec()
109}
110
111pub fn spsa_grad(
113 x_t: &[f32],
114 mut eval_v: impl FnMut(&[f32]) -> f32,
115 eps: f32,
116 num_dirs: usize,
117 seed: u64,
118) -> Vec<f32> {
119 let dim = x_t.len();
120 let mut grad = vec![0.0f32; dim];
121 let mut state = seed;
122 for _ in 0..num_dirs {
123 let mut delta = vec![0.0f32; dim];
124 for d in &mut delta {
125 state ^= state << 13;
126 state ^= state >> 7;
127 state ^= state << 17;
128 *d = if state & 1 == 0 { 1.0 } else { -1.0 };
129 }
130 let mut xp = x_t.to_vec();
131 let mut xm = x_t.to_vec();
132 for i in 0..dim {
133 xp[i] += eps * delta[i];
134 xm[i] -= eps * delta[i];
135 }
136 let vp = eval_v(&xp);
137 let vm = eval_v(&xm);
138 let scale = (vp - vm) / (2.0 * eps);
139 for i in 0..dim {
140 grad[i] += scale * delta[i];
141 }
142 }
143 let inv = 1.0 / num_dirs as f32;
144 grad.iter_mut().for_each(|g| *g *= inv);
145 grad
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn blueness_grad_sparsity() {
154 let r = BluenessReward { scale: 1.0 };
155 let z = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
156 let g = r.grad_wrt_z(&z);
157 assert!((g[2] - 1.0).abs() < 1e-6);
158 assert!((g[5] - 1.0).abs() < 1e-6);
159 assert!(g[0].abs() < 1e-6);
160 }
161}