Skip to main content

oxicuda_rl/normalize/
reward_norm.rs

1//! # Reward Normalizer
2//!
3//! Normalizes rewards by a running estimate of the return's standard deviation,
4//! following the convention in PPO and similar algorithms:
5//!
6//! ```text
7//! G_t = γ G_{t-1} + r_t   (discounted running return)
8//! normalised_r = r / std(G) + clip
9//! ```
10//!
11//! Two modes are supported:
12//! * **Return normalisation** (PPO default): maintain running stats on the
13//!   discounted return `G` and divide rewards by `std(G)`.
14//! * **Reward clipping** (simple): clip raw rewards to `[-clip, clip]`.
15
16use crate::normalize::running_stats::RunningStats;
17
18/// Reward normalization mode.
19#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum RewardNormMode {
21    /// Divide by running std of discounted return.
22    ReturnNorm,
23    /// Clip rewards to `[-clip, clip]` (no normalisation).
24    Clip,
25    /// No normalisation or clipping.
26    None,
27}
28
29/// Reward normalizer with optional return-based normalisation and clipping.
30#[derive(Debug, Clone)]
31pub struct RewardNormalizer {
32    mode: RewardNormMode,
33    gamma: f32,
34    clip: f32,
35    /// Running statistics on the discounted return.
36    return_stats: RunningStats,
37    /// Current discounted return accumulator (per environment, if vectorized).
38    running_returns: Vec<f64>,
39    n_envs: usize,
40}
41
42impl RewardNormalizer {
43    /// Create a reward normalizer.
44    ///
45    /// * `n_envs` — number of parallel environments.
46    /// * `gamma`  — discount factor for return accumulation.
47    /// * `clip`   — symmetric clip range.
48    /// * `mode`   — normalisation mode.
49    #[must_use]
50    pub fn new(n_envs: usize, gamma: f32, clip: f32, mode: RewardNormMode) -> Self {
51        assert!(n_envs > 0, "n_envs must be > 0");
52        Self {
53            mode,
54            gamma,
55            clip,
56            return_stats: RunningStats::new(1),
57            running_returns: vec![0.0_f64; n_envs],
58            n_envs,
59        }
60    }
61
62    /// Process rewards for a step across `n_envs` parallel environments.
63    ///
64    /// Updates the running return estimates and normalizes.
65    ///
66    /// Returns the normalised / clipped rewards.
67    ///
68    /// # Panics
69    ///
70    /// Panics if `rewards.len() != n_envs` or `dones.len() != n_envs`.
71    pub fn process(&mut self, rewards: &[f32], dones: &[f32]) -> Vec<f32> {
72        assert_eq!(rewards.len(), self.n_envs);
73        assert_eq!(dones.len(), self.n_envs);
74
75        match self.mode {
76            RewardNormMode::None => rewards.to_vec(),
77            RewardNormMode::Clip => rewards
78                .iter()
79                .map(|&r| r.clamp(-self.clip, self.clip))
80                .collect(),
81            RewardNormMode::ReturnNorm => {
82                // Update running returns: G_t = γ G_{t-1} + r_t
83                for (i, (&r, &d)) in rewards.iter().zip(dones.iter()).enumerate() {
84                    self.running_returns[i] =
85                        self.gamma as f64 * self.running_returns[i] * (1.0 - d as f64) + r as f64;
86                    // Update stats with current return value
87                    let _ = self.return_stats.update(&[self.running_returns[i] as f32]);
88                }
89                let std = self.return_stats.std_f32()[0];
90                rewards
91                    .iter()
92                    .map(|&r| (r / (std + 1e-8)).clamp(-self.clip, self.clip))
93                    .collect()
94            }
95        }
96    }
97
98    /// Normalise a batch of rewards without updating running statistics (evaluation).
99    pub fn normalise_eval(&self, rewards: &[f32]) -> Vec<f32> {
100        match self.mode {
101            RewardNormMode::None => rewards.to_vec(),
102            RewardNormMode::Clip => rewards
103                .iter()
104                .map(|&r| r.clamp(-self.clip, self.clip))
105                .collect(),
106            RewardNormMode::ReturnNorm => {
107                let std = self.return_stats.std_f32()[0];
108                rewards
109                    .iter()
110                    .map(|&r| (r / (std + 1e-8)).clamp(-self.clip, self.clip))
111                    .collect()
112            }
113        }
114    }
115
116    /// Reset running returns (call at episode start).
117    pub fn reset_returns(&mut self) {
118        self.running_returns.iter_mut().for_each(|g| *g = 0.0);
119    }
120}
121
122// ─── Tests ───────────────────────────────────────────────────────────────────
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn clip_mode_clips_rewards() {
130        let mut rn = RewardNormalizer::new(2, 0.99, 1.0, RewardNormMode::Clip);
131        let r = rn.process(&[5.0, -5.0], &[0.0, 0.0]);
132        assert!((r[0] - 1.0).abs() < 1e-5, "r[0]={}", r[0]);
133        assert!((r[1] + 1.0).abs() < 1e-5, "r[1]={}", r[1]);
134    }
135
136    #[test]
137    fn none_mode_passthrough() {
138        let mut rn = RewardNormalizer::new(3, 0.99, 10.0, RewardNormMode::None);
139        let r = rn.process(&[1.0, 2.0, 3.0], &[0.0, 0.0, 0.0]);
140        assert_eq!(r, vec![1.0, 2.0, 3.0]);
141    }
142
143    #[test]
144    fn return_norm_output_within_clip() {
145        let mut rn = RewardNormalizer::new(1, 0.99, 5.0, RewardNormMode::ReturnNorm);
146        // Feed 200 steps to build running stats
147        for _ in 0..200 {
148            let r = rn.process(&[1.0], &[0.0]);
149            assert!(r[0].abs() <= 5.0 + 1e-4, "clipped |r|={}", r[0].abs());
150        }
151    }
152
153    #[test]
154    fn done_resets_return() {
155        let mut rn = RewardNormalizer::new(1, 0.99, 10.0, RewardNormMode::ReturnNorm);
156        for _ in 0..10 {
157            rn.process(&[1.0], &[0.0]);
158        }
159        let g_before = rn.running_returns[0];
160        rn.process(&[1.0], &[1.0]); // done=1 → reset return
161        let g_after = rn.running_returns[0];
162        // After done, G = 0*G_before*0 + r = 1.0
163        assert!(
164            g_after.abs() < g_before.abs() + 1.0 + 1e-3,
165            "done should reset return"
166        );
167    }
168
169    #[test]
170    fn normalise_eval_no_stat_change() {
171        let rn = RewardNormalizer::new(1, 0.99, 5.0, RewardNormMode::ReturnNorm);
172        let before = rn.return_stats.count();
173        let _ = rn.normalise_eval(&[1.0, 2.0]);
174        let after = rn.return_stats.count();
175        assert_eq!(before, after, "eval should not update stats");
176    }
177}