oxicuda_rl/normalize/
obs_norm.rs1use crate::error::{RlError, RlResult};
7use crate::normalize::running_stats::RunningStats;
8
9#[derive(Debug, Clone)]
11pub struct ObservationNormalizer {
12 stats: RunningStats,
13 enabled: bool,
15 clip: f32,
17 update_stats: bool,
19}
20
21impl ObservationNormalizer {
22 #[must_use]
26 pub fn new(obs_dim: usize) -> Self {
27 Self {
28 stats: RunningStats::new(obs_dim),
29 enabled: true,
30 clip: 5.0,
31 update_stats: true,
32 }
33 }
34
35 #[must_use]
37 pub fn with_no_update(mut self) -> Self {
38 self.update_stats = false;
39 self
40 }
41
42 #[must_use]
44 pub fn with_clip(mut self, clip: f32) -> Self {
45 self.clip = clip;
46 self
47 }
48
49 pub fn disable(&mut self) {
51 self.enabled = false;
52 }
53
54 pub fn enable(&mut self) {
56 self.enabled = true;
57 }
58
59 #[must_use]
61 pub fn count(&self) -> u64 {
62 self.stats.count()
63 }
64
65 pub fn process(&mut self, batch: &[f32]) -> RlResult<Vec<f32>> {
75 let obs_dim = self.stats.dim();
76 if batch.len() % obs_dim != 0 {
77 return Err(RlError::DimensionMismatch {
78 expected: obs_dim,
79 got: batch.len() % obs_dim,
80 });
81 }
82 if !self.enabled {
83 return Ok(batch.to_vec());
84 }
85 if self.update_stats {
86 self.stats.update_batch(batch)?;
87 }
88 let mut out = Vec::with_capacity(batch.len());
89 for chunk in batch.chunks_exact(obs_dim) {
90 let norm = self.stats.normalise(chunk)?;
91 for v in norm {
92 out.push(v.clamp(-self.clip, self.clip));
93 }
94 }
95 Ok(out)
96 }
97
98 pub fn process_one(&mut self, obs: &[f32]) -> RlResult<Vec<f32>> {
104 let obs_dim = self.stats.dim();
105 if obs.len() != obs_dim {
106 return Err(RlError::DimensionMismatch {
107 expected: obs_dim,
108 got: obs.len(),
109 });
110 }
111 if !self.enabled {
112 return Ok(obs.to_vec());
113 }
114 if self.update_stats {
115 self.stats.update(obs)?;
116 }
117 let norm = self.stats.normalise(obs)?;
118 Ok(norm
119 .into_iter()
120 .map(|v| v.clamp(-self.clip, self.clip))
121 .collect())
122 }
123
124 #[must_use]
126 pub fn stats(&self) -> &RunningStats {
127 &self.stats
128 }
129}
130
131#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn disabled_passthrough() {
139 let mut norm = ObservationNormalizer::new(3);
140 norm.disable();
141 let obs = vec![1.0, 2.0, 3.0];
142 let out = norm.process_one(&obs).unwrap();
143 assert_eq!(out, obs);
144 }
145
146 #[test]
147 fn normalise_clips_extreme() {
148 let mut norm = ObservationNormalizer::new(1).with_clip(2.0);
149 for _ in 0..200 {
151 norm.process_one(&[0.0]).unwrap();
152 }
153 let out = norm.process_one(&[1000.0]).unwrap();
155 assert!(out[0] <= 2.0 + 1e-3, "clipped={}", out[0]);
156 }
157
158 #[test]
159 fn count_increments() {
160 let mut norm = ObservationNormalizer::new(2);
161 for _ in 0..10 {
162 norm.process_one(&[1.0, 2.0]).unwrap();
163 }
164 assert_eq!(norm.count(), 10);
165 }
166
167 #[test]
168 fn batch_same_as_sequential() {
169 let mut norm_seq = ObservationNormalizer::new(2);
170 let mut norm_bat = ObservationNormalizer::new(2);
171 let obs = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; for chunk in obs.chunks_exact(2) {
173 norm_seq.process_one(chunk).unwrap();
174 }
175 norm_bat.process(&obs).unwrap();
176 let std_seq = norm_seq.stats().std_f32();
177 let std_bat = norm_bat.stats().std_f32();
178 for (a, b) in std_seq.iter().zip(std_bat.iter()) {
179 assert!((a - b).abs() < 1e-4, "seq_std={a} bat_std={b}");
180 }
181 }
182
183 #[test]
184 fn dimension_mismatch_error() {
185 let mut norm = ObservationNormalizer::new(4);
186 assert!(norm.process_one(&[1.0, 2.0]).is_err());
187 }
188}