use crate::env::{Env, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
pub struct TransformReward<E: Env, F: Fn(f64) -> f64> {
env: E,
f: F,
}
impl<E: Env + std::fmt::Debug, F: Fn(f64) -> f64> std::fmt::Debug for TransformReward<E, F> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.debug_struct("TransformReward")
.field("env", &self.env)
.finish_non_exhaustive()
}
}
impl<E: Env, F: Fn(f64) -> f64> TransformReward<E, F> {
#[must_use]
pub const fn new(env: E, f: F) -> Self {
Self { env, f }
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
}
impl<E: Env, F: Fn(f64) -> f64> Env for TransformReward<E, F> {
type Obs = E::Obs;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
let mut result = self.env.step(action)?;
result.reward = (self.f)(result.reward);
Ok(result)
}
delegate_env!(
env,
reset,
render,
close,
render_mode,
observation_space,
action_space
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
#[test]
fn scales_reward() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = TransformReward::new(env, |r| r * 0.1);
env.reset(Some(42)).unwrap();
let r = env.step(&0).unwrap();
assert!((r.reward - 0.1).abs() < 1e-9);
}
#[test]
fn negates_reward() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = TransformReward::new(env, |r| -r);
env.reset(Some(42)).unwrap();
let r = env.step(&0).unwrap();
assert!((r.reward + 1.0).abs() < 1e-9);
}
}