use crate::env::{Env, ResetResult, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
#[derive(Debug)]
pub struct NormalizeObservation<E: Env<Obs = Vec<f32>>> {
env: E,
count: f64,
mean: Vec<f64>,
var: Vec<f64>,
epsilon: f64,
}
impl<E: Env<Obs = Vec<f32>>> NormalizeObservation<E> {
#[must_use]
pub const fn new(env: E, epsilon: f64) -> Self {
Self {
env,
count: 0.0,
mean: Vec::new(),
var: Vec::new(),
epsilon,
}
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
#[allow(clippy::cast_possible_truncation)]
fn normalize(&mut self, obs: &[f32]) -> Vec<f32> {
if self.mean.is_empty() {
self.mean = vec![0.0; obs.len()];
self.var = vec![1.0; obs.len()];
}
self.count += 1.0;
for (i, &v) in obs.iter().enumerate() {
let v = f64::from(v);
let delta = v - self.mean[i];
self.mean[i] += delta / self.count;
let delta2 = v - self.mean[i];
self.var[i] += delta * delta2;
}
obs.iter()
.enumerate()
.map(|(i, &v)| {
let std = (self.var[i] / self.count + self.epsilon).sqrt();
((f64::from(v) - self.mean[i]) / std) as f32
})
.collect()
}
}
impl<E: Env<Obs = Vec<f32>>> Env for NormalizeObservation<E> {
type Obs = Vec<f32>;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Vec<f32>>> {
let mut result = self.env.step(action)?;
result.obs = self.normalize(&result.obs);
Ok(result)
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<Vec<f32>>> {
let mut result = self.env.reset(seed)?;
result.obs = self.normalize(&result.obs);
Ok(result)
}
delegate_env!(env);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
#[test]
fn observations_are_normalized() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = NormalizeObservation::new(env, 1e-8);
env.reset(Some(42)).unwrap();
for _ in 0..50 {
let r = env.step(&0).unwrap();
for &v in &r.obs {
assert!(v.is_finite(), "observation should be finite: {v}");
}
if r.terminated {
env.reset(None).unwrap();
}
}
}
#[test]
fn first_observation_is_finite() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = NormalizeObservation::new(env, 1e-8);
let r = env.reset(Some(0)).unwrap();
for &v in &r.obs {
assert!(v.is_finite());
}
}
}