use crate::env::{Env, ResetResult, StepResult};
use crate::error::Result;
use crate::macros::delegate_env;
#[derive(Debug)]
pub struct StickyAction<E: Env>
where
E::Act: Clone,
{
env: E,
repeat_probability: f64,
last_action: Option<E::Act>,
rng: crate::rng::Rng,
}
impl<E: Env> StickyAction<E>
where
E::Act: Clone,
{
pub fn new(env: E, repeat_probability: f64) -> Result<Self> {
if !(0.0..1.0).contains(&repeat_probability) {
return Err(crate::error::Error::InvalidAction {
reason: format!("repeat_probability must be in [0, 1), got {repeat_probability}"),
});
}
Ok(Self {
env,
repeat_probability,
last_action: None,
rng: crate::rng::create_rng(None),
})
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
}
impl<E: Env> Env for StickyAction<E>
where
E::Act: Clone,
{
type Obs = E::Obs;
type Act = E::Act;
type ObsSpace = E::ObsSpace;
type ActSpace = E::ActSpace;
fn step(&mut self, action: &Self::Act) -> Result<StepResult<Self::Obs>> {
use rand::RngExt;
let effective_action = if let Some(ref prev) = self.last_action {
let r: f64 = self.rng.random();
if r < self.repeat_probability {
prev.clone()
} else {
action.clone()
}
} else {
action.clone()
};
self.last_action = Some(effective_action.clone());
self.env.step(&effective_action)
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<Self::Obs>> {
self.last_action = None;
if let Some(s) = seed {
self.rng = crate::rng::create_rng(Some(s));
}
self.env.reset(seed)
}
delegate_env!(
env,
render,
close,
render_mode,
observation_space,
action_space
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
#[test]
fn zero_probability_passes_through() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = StickyAction::new(env, 0.0).unwrap();
env.reset(Some(42)).unwrap();
let r = env.step(&1).unwrap();
assert_eq!(r.obs.len(), 4);
}
#[test]
fn high_probability_repeats() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
let mut env = StickyAction::new(env, 0.99).unwrap();
env.reset(Some(42)).unwrap();
env.step(&0).unwrap();
for _ in 0..10 {
let r = env.step(&1).unwrap();
assert_eq!(r.obs.len(), 4);
if r.terminated {
env.reset(None).unwrap();
}
}
}
#[test]
fn rejects_invalid_probability() {
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
assert!(StickyAction::new(env, 1.0).is_err());
let env = CartPoleEnv::new(CartPoleConfig::default()).unwrap();
assert!(StickyAction::new(env, -0.1).is_err());
}
}