oxicuda_rl/normalize/
reward_norm.rs1use crate::normalize::running_stats::RunningStats;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum RewardNormMode {
21 ReturnNorm,
23 Clip,
25 None,
27}
28
29#[derive(Debug, Clone)]
31pub struct RewardNormalizer {
32 mode: RewardNormMode,
33 gamma: f32,
34 clip: f32,
35 return_stats: RunningStats,
37 running_returns: Vec<f64>,
39 n_envs: usize,
40}
41
42impl RewardNormalizer {
43 #[must_use]
50 pub fn new(n_envs: usize, gamma: f32, clip: f32, mode: RewardNormMode) -> Self {
51 assert!(n_envs > 0, "n_envs must be > 0");
52 Self {
53 mode,
54 gamma,
55 clip,
56 return_stats: RunningStats::new(1),
57 running_returns: vec![0.0_f64; n_envs],
58 n_envs,
59 }
60 }
61
62 pub fn process(&mut self, rewards: &[f32], dones: &[f32]) -> Vec<f32> {
72 assert_eq!(rewards.len(), self.n_envs);
73 assert_eq!(dones.len(), self.n_envs);
74
75 match self.mode {
76 RewardNormMode::None => rewards.to_vec(),
77 RewardNormMode::Clip => rewards
78 .iter()
79 .map(|&r| r.clamp(-self.clip, self.clip))
80 .collect(),
81 RewardNormMode::ReturnNorm => {
82 for (i, (&r, &d)) in rewards.iter().zip(dones.iter()).enumerate() {
84 self.running_returns[i] =
85 self.gamma as f64 * self.running_returns[i] * (1.0 - d as f64) + r as f64;
86 let _ = self.return_stats.update(&[self.running_returns[i] as f32]);
88 }
89 let std = self.return_stats.std_f32()[0];
90 rewards
91 .iter()
92 .map(|&r| (r / (std + 1e-8)).clamp(-self.clip, self.clip))
93 .collect()
94 }
95 }
96 }
97
98 pub fn normalise_eval(&self, rewards: &[f32]) -> Vec<f32> {
100 match self.mode {
101 RewardNormMode::None => rewards.to_vec(),
102 RewardNormMode::Clip => rewards
103 .iter()
104 .map(|&r| r.clamp(-self.clip, self.clip))
105 .collect(),
106 RewardNormMode::ReturnNorm => {
107 let std = self.return_stats.std_f32()[0];
108 rewards
109 .iter()
110 .map(|&r| (r / (std + 1e-8)).clamp(-self.clip, self.clip))
111 .collect()
112 }
113 }
114 }
115
116 pub fn reset_returns(&mut self) {
118 self.running_returns.iter_mut().for_each(|g| *g = 0.0);
119 }
120}
121
122#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn clip_mode_clips_rewards() {
130 let mut rn = RewardNormalizer::new(2, 0.99, 1.0, RewardNormMode::Clip);
131 let r = rn.process(&[5.0, -5.0], &[0.0, 0.0]);
132 assert!((r[0] - 1.0).abs() < 1e-5, "r[0]={}", r[0]);
133 assert!((r[1] + 1.0).abs() < 1e-5, "r[1]={}", r[1]);
134 }
135
136 #[test]
137 fn none_mode_passthrough() {
138 let mut rn = RewardNormalizer::new(3, 0.99, 10.0, RewardNormMode::None);
139 let r = rn.process(&[1.0, 2.0, 3.0], &[0.0, 0.0, 0.0]);
140 assert_eq!(r, vec![1.0, 2.0, 3.0]);
141 }
142
143 #[test]
144 fn return_norm_output_within_clip() {
145 let mut rn = RewardNormalizer::new(1, 0.99, 5.0, RewardNormMode::ReturnNorm);
146 for _ in 0..200 {
148 let r = rn.process(&[1.0], &[0.0]);
149 assert!(r[0].abs() <= 5.0 + 1e-4, "clipped |r|={}", r[0].abs());
150 }
151 }
152
153 #[test]
154 fn done_resets_return() {
155 let mut rn = RewardNormalizer::new(1, 0.99, 10.0, RewardNormMode::ReturnNorm);
156 for _ in 0..10 {
157 rn.process(&[1.0], &[0.0]);
158 }
159 let g_before = rn.running_returns[0];
160 rn.process(&[1.0], &[1.0]); let g_after = rn.running_returns[0];
162 assert!(
164 g_after.abs() < g_before.abs() + 1.0 + 1e-3,
165 "done should reset return"
166 );
167 }
168
169 #[test]
170 fn normalise_eval_no_stat_change() {
171 let rn = RewardNormalizer::new(1, 0.99, 5.0, RewardNormMode::ReturnNorm);
172 let before = rn.return_stats.count();
173 let _ = rn.normalise_eval(&[1.0, 2.0]);
174 let after = rn.return_stats.count();
175 assert_eq!(before, after, "eval should not update stats");
176 }
177}