use crate::env::{Env, ResetResult, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
#[derive(Debug)]
pub struct NormalizeReward<E: Env> {
env: E,
gamma: f64,
epsilon: f64,
ret: f64,
count: f64,
mean: f64,
var: f64,
}
impl<E: Env> NormalizeReward<E> {
#[must_use]
pub const fn new(env: E, gamma: f64, epsilon: f64) -> Self {
Self {
env,
gamma,
epsilon,
ret: 0.0,
count: 0.0,
mean: 0.0,
var: 1.0,
}
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
fn normalize_reward(&mut self, reward: f64) -> f64 {
self.ret = self.ret.mul_add(self.gamma, reward);
self.count += 1.0;
let delta = self.ret - self.mean;
self.mean += delta / self.count;
let delta2 = self.ret - self.mean;
self.var += delta * delta2;
let std = (self.var / self.count + self.epsilon).sqrt();
reward / std
}
}
impl<E: Env> Env for NormalizeReward<E> {
type Obs = E::Obs;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
let mut result = self.env.step(action)?;
result.reward = self.normalize_reward(result.reward);
if result.terminated || result.truncated {
self.ret = 0.0;
}
Ok(result)
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<Self::Obs>> {
self.ret = 0.0;
self.env.reset(seed)
}
delegate_env!(env);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
#[test]
fn rewards_are_finite() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = NormalizeReward::new(env, 0.99, 1e-8);
env.reset(Some(42)).unwrap();
for _ in 0..100 {
let r = env.step(&0).unwrap();
assert!(r.reward.is_finite(), "reward should be finite");
if r.terminated {
env.reset(None).unwrap();
}
}
}
#[test]
fn reset_clears_return() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = NormalizeReward::new(env, 0.99, 1e-8);
env.reset(Some(0)).unwrap();
env.step(&0).unwrap();
env.reset(Some(1)).unwrap();
assert!((env.ret - 0.0).abs() < f64::EPSILON);
}
}