use crate::env::{Env, ResetResult, StepResult};
use crate::error::{Error, Result};
use crate::macros::delegate_env;
use crate::space::BoundedSpace;
#[derive(Debug)]
pub struct RescaleObservation<E>
where
E: Env<Obs = Vec<f32>, ObsSpace = BoundedSpace>,
{
env: E,
min_obs: Vec<f32>,
max_obs: Vec<f32>,
new_space: BoundedSpace,
}
impl<E> RescaleObservation<E>
where
E: Env<Obs = Vec<f32>, ObsSpace = BoundedSpace>,
{
pub fn new(env: E, min_obs: f32, max_obs: f32) -> Result<Self> {
if min_obs >= max_obs {
return Err(Error::InvalidSpace {
reason: format!("min_obs ({min_obs}) must be < max_obs ({max_obs})"),
});
}
let inner = env.observation_space();
let dim = inner.low.len();
let low = vec![min_obs; dim];
let high = vec![max_obs; dim];
let new_space =
BoundedSpace::new(low.clone(), high.clone()).map_err(|e| Error::InvalidSpace {
reason: format!("failed to create rescaled observation space: {e}"),
})?;
Ok(Self {
env,
min_obs: low,
max_obs: high,
new_space,
})
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
fn rescale(&self, obs: &[f32]) -> Vec<f32> {
let inner = self.env.observation_space();
obs.iter()
.zip(inner.low.iter().zip(inner.high.iter()))
.zip(self.min_obs.iter().zip(self.max_obs.iter()))
.map(|((&v, (&old_lo, &old_hi)), (&new_lo, &new_hi))| {
let range = old_hi - old_lo;
if range.abs() < f32::EPSILON {
new_lo
} else {
let t = (v - old_lo) / range;
t.mul_add(new_hi - new_lo, new_lo)
}
})
.collect()
}
}
impl<E> Env for RescaleObservation<E>
where
E: Env<Obs = Vec<f32>, ObsSpace = BoundedSpace>,
{
type Obs = Vec<f32>;
type Act = E::Act;
type ObsSpace = BoundedSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
let result = self.env.step(action)?;
Ok(StepResult {
obs: self.rescale(&result.obs),
reward: result.reward,
terminated: result.terminated,
truncated: result.truncated,
info: result.info,
})
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<Self::Obs>> {
let result = self.env.reset(seed)?;
Ok(ResetResult {
obs: self.rescale(&result.obs),
info: result.info,
})
}
fn observation_space(&self) -> &Self::ObsSpace {
&self.new_space
}
delegate_env!(env, render, close, render_mode, action_space);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{PendulumConfig, PendulumEnv};
use crate::space::Space;
#[test]
fn rescales_observations() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
let mut env = RescaleObservation::new(env, 0.0, 1.0).unwrap();
let r = env.reset(Some(42)).unwrap();
assert_eq!(r.obs.len(), 3);
for &v in &r.obs {
assert!(v.is_finite());
}
}
#[test]
fn observation_space_reflects_new_bounds() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
let env = RescaleObservation::new(env, -1.0, 1.0).unwrap();
let space = env.observation_space();
for &lo in &space.low {
assert!((lo - (-1.0)).abs() < f32::EPSILON);
}
for &hi in &space.high {
assert!((hi - 1.0).abs() < f32::EPSILON);
}
}
#[test]
fn rejects_invalid_range() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
assert!(RescaleObservation::new(env, 1.0, 0.0).is_err());
}
#[test]
fn step_produces_rescaled_obs() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
let mut env = RescaleObservation::new(env, -10.0, 10.0).unwrap();
env.reset(Some(42)).unwrap();
let s = env.step(&vec![0.0]).unwrap();
assert_eq!(s.obs.len(), 3);
let space = env.observation_space();
assert!(space.contains(&s.obs));
}
}