Skip to main content

oxicuda_rl/normalize/
obs_norm.rs

1//! # Observation Normalizer
2//!
3//! Wraps [`RunningStats`] to provide a stateful normalizer that can be
4//! toggled on/off and clipped to prevent extreme values.
5
6use crate::error::{RlError, RlResult};
7use crate::normalize::running_stats::RunningStats;
8
9/// Observation normalizer with running statistics.
10#[derive(Debug, Clone)]
11pub struct ObservationNormalizer {
12    stats: RunningStats,
13    /// Whether normalisation is active.
14    enabled: bool,
15    /// Clip range after normalisation (symmetric: `[-clip, clip]`).
16    clip: f32,
17    /// Whether to update statistics when processing observations.
18    update_stats: bool,
19}
20
21impl ObservationNormalizer {
22    /// Create a normalizer for observations of dimension `obs_dim`.
23    ///
24    /// Default: enabled, `clip = 5.0`, `update_stats = true`.
25    #[must_use]
26    pub fn new(obs_dim: usize) -> Self {
27        Self {
28            stats: RunningStats::new(obs_dim),
29            enabled: true,
30            clip: 5.0,
31            update_stats: true,
32        }
33    }
34
35    /// Disable statistics update (useful at evaluation time).
36    #[must_use]
37    pub fn with_no_update(mut self) -> Self {
38        self.update_stats = false;
39        self
40    }
41
42    /// Set the clip range.
43    #[must_use]
44    pub fn with_clip(mut self, clip: f32) -> Self {
45        self.clip = clip;
46        self
47    }
48
49    /// Disable normalisation (pass-through mode).
50    pub fn disable(&mut self) {
51        self.enabled = false;
52    }
53
54    /// Enable normalisation.
55    pub fn enable(&mut self) {
56        self.enabled = true;
57    }
58
59    /// Number of observations seen.
60    #[must_use]
61    pub fn count(&self) -> u64 {
62        self.stats.count()
63    }
64
65    /// Process a batch of observations: optionally update stats, normalise, clip.
66    ///
67    /// * `batch` — `[B × obs_dim]` flat slice.
68    ///
69    /// Returns a normalised `[B × obs_dim]` vector.
70    ///
71    /// # Errors
72    ///
73    /// * [`RlError::DimensionMismatch`] if `batch.len() % obs_dim != 0`.
74    pub fn process(&mut self, batch: &[f32]) -> RlResult<Vec<f32>> {
75        let obs_dim = self.stats.dim();
76        if batch.len() % obs_dim != 0 {
77            return Err(RlError::DimensionMismatch {
78                expected: obs_dim,
79                got: batch.len() % obs_dim,
80            });
81        }
82        if !self.enabled {
83            return Ok(batch.to_vec());
84        }
85        if self.update_stats {
86            self.stats.update_batch(batch)?;
87        }
88        let mut out = Vec::with_capacity(batch.len());
89        for chunk in batch.chunks_exact(obs_dim) {
90            let norm = self.stats.normalise(chunk)?;
91            for v in norm {
92                out.push(v.clamp(-self.clip, self.clip));
93            }
94        }
95        Ok(out)
96    }
97
98    /// Process a single observation.
99    ///
100    /// # Errors
101    ///
102    /// * [`RlError::DimensionMismatch`] if `obs.len() != obs_dim`.
103    pub fn process_one(&mut self, obs: &[f32]) -> RlResult<Vec<f32>> {
104        let obs_dim = self.stats.dim();
105        if obs.len() != obs_dim {
106            return Err(RlError::DimensionMismatch {
107                expected: obs_dim,
108                got: obs.len(),
109            });
110        }
111        if !self.enabled {
112            return Ok(obs.to_vec());
113        }
114        if self.update_stats {
115            self.stats.update(obs)?;
116        }
117        let norm = self.stats.normalise(obs)?;
118        Ok(norm
119            .into_iter()
120            .map(|v| v.clamp(-self.clip, self.clip))
121            .collect())
122    }
123
124    /// Access the underlying running statistics.
125    #[must_use]
126    pub fn stats(&self) -> &RunningStats {
127        &self.stats
128    }
129}
130
131// ─── Tests ───────────────────────────────────────────────────────────────────
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn disabled_passthrough() {
139        let mut norm = ObservationNormalizer::new(3);
140        norm.disable();
141        let obs = vec![1.0, 2.0, 3.0];
142        let out = norm.process_one(&obs).unwrap();
143        assert_eq!(out, obs);
144    }
145
146    #[test]
147    fn normalise_clips_extreme() {
148        let mut norm = ObservationNormalizer::new(1).with_clip(2.0);
149        // Seed with many values so std is reasonable
150        for _ in 0..200 {
151            norm.process_one(&[0.0]).unwrap();
152        }
153        // Now feed an extreme value
154        let out = norm.process_one(&[1000.0]).unwrap();
155        assert!(out[0] <= 2.0 + 1e-3, "clipped={}", out[0]);
156    }
157
158    #[test]
159    fn count_increments() {
160        let mut norm = ObservationNormalizer::new(2);
161        for _ in 0..10 {
162            norm.process_one(&[1.0, 2.0]).unwrap();
163        }
164        assert_eq!(norm.count(), 10);
165    }
166
167    #[test]
168    fn batch_same_as_sequential() {
169        let mut norm_seq = ObservationNormalizer::new(2);
170        let mut norm_bat = ObservationNormalizer::new(2);
171        let obs = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3 × dim 2
172        for chunk in obs.chunks_exact(2) {
173            norm_seq.process_one(chunk).unwrap();
174        }
175        norm_bat.process(&obs).unwrap();
176        let std_seq = norm_seq.stats().std_f32();
177        let std_bat = norm_bat.stats().std_f32();
178        for (a, b) in std_seq.iter().zip(std_bat.iter()) {
179            assert!((a - b).abs() < 1e-4, "seq_std={a} bat_std={b}");
180        }
181    }
182
183    #[test]
184    fn dimension_mismatch_error() {
185        let mut norm = ObservationNormalizer::new(4);
186        assert!(norm.process_one(&[1.0, 2.0]).is_err());
187    }
188}