use crate::env::{Env, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
use crate::space::Space;
pub struct TransformAction<E, NewAct, NewActSpace, F>
where
E: Env,
NewActSpace: Space<Element = NewAct>,
F: Fn(&NewAct) -> E::Act,
{
env: E,
f: F,
act_space: NewActSpace,
}
impl<E, NewAct, NewActSpace, F> std::fmt::Debug for TransformAction<E, NewAct, NewActSpace, F>
where
E: Env + std::fmt::Debug,
NewActSpace: Space<Element = NewAct> + std::fmt::Debug,
F: Fn(&NewAct) -> E::Act,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.debug_struct("TransformAction")
.field("env", &self.env)
.field("act_space", &self.act_space)
.finish_non_exhaustive()
}
}
impl<E, NewAct, NewActSpace, F> TransformAction<E, NewAct, NewActSpace, F>
where
E: Env,
NewActSpace: Space<Element = NewAct>,
F: Fn(&NewAct) -> E::Act,
{
#[must_use]
pub const fn new(env: E, f: F, act_space: NewActSpace) -> Self {
Self { env, f, act_space }
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
}
impl<E, NewAct, NewActSpace, F> Env for TransformAction<E, NewAct, NewActSpace, F>
where
E: Env,
NewActSpace: Space<Element = NewAct>,
F: Fn(&NewAct) -> E::Act,
{
type Obs = E::Obs;
type Act = NewAct;
type ObsSpace = E::ObsSpace;
type ActSpace = NewActSpace;
fn step(&mut self, action: &NewAct) -> Result<StepResult<Self::Obs>> {
let inner_action = (self.f)(action);
self.env.step(&inner_action)
}
fn action_space(&self) -> &Self::ActSpace {
&self.act_space
}
delegate_env!(env, reset, render, close, render_mode, observation_space);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
#[test]
fn identity_transform() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let act_space = *env.action_space();
let mut env = TransformAction::new(env, |a: &i64| *a, act_space);
env.reset(Some(42)).unwrap();
let r = env.step(&0).unwrap();
assert_eq!(r.obs.len(), 4);
}
#[test]
fn flip_action() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let act_space = *env.action_space();
let mut env = TransformAction::new(env, |a: &i64| 1 - a, act_space);
env.reset(Some(42)).unwrap();
let r = env.step(&0).unwrap(); assert_eq!(r.obs.len(), 4);
}
}