use super::super::{EnvStructure, Environment, Successor};
use super::{StructurePreservingWrapper, Wrapped};
use crate::logging::StatsLogger;
use crate::spaces::IntervalSpace;
use crate::Prng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LatentStepLimit {
pub max_steps_per_episode: u64,
}
impl LatentStepLimit {
#[must_use]
#[inline]
pub const fn new(max_steps_per_episode: u64) -> Self {
assert!(max_steps_per_episode > 0, "step limit must be positive");
Self {
max_steps_per_episode,
}
}
}
impl Default for LatentStepLimit {
#[inline]
fn default() -> Self {
Self {
max_steps_per_episode: 100,
}
}
}
pub type WithLatentStepLimit<E> = Wrapped<E, LatentStepLimit>;
impl StructurePreservingWrapper for LatentStepLimit {}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StepLimitState<T> {
pub inner: T,
pub steps_remaining: u64,
}
impl<E: Environment> Environment for Wrapped<E, LatentStepLimit> {
type State = StepLimitState<E::State>;
type Observation = E::Observation;
type Action = E::Action;
type Feedback = E::Feedback;
fn initial_state(&self, rng: &mut Prng) -> Self::State {
StepLimitState {
inner: self.inner.initial_state(rng),
steps_remaining: self.wrapper.max_steps_per_episode,
}
}
fn observe(&self, state: &Self::State, rng: &mut Prng) -> Self::Observation {
self.inner.observe(&state.inner, rng)
}
fn step(
&self,
state: Self::State,
action: &Self::Action,
rng: &mut Prng,
logger: &mut dyn StatsLogger,
) -> (Successor<Self::State>, Self::Feedback) {
assert!(
state.steps_remaining > 0,
"invalid step from a state with no remaining steps"
);
let (inner_successor, feedback) = self.inner.step(state.inner, action, rng, logger);
let successor = inner_successor
.map(|inner| StepLimitState {
inner,
steps_remaining: state.steps_remaining - 1,
})
.then_interrupt_if(|next_state| next_state.steps_remaining == 0);
(successor, feedback)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct VisibleStepLimit {
pub max_steps_per_episode: u64,
}
impl VisibleStepLimit {
#[must_use]
#[inline]
pub const fn new(max_steps_per_episode: u64) -> Self {
assert!(max_steps_per_episode > 0, "step limit must be positive");
Self {
max_steps_per_episode,
}
}
}
impl Default for VisibleStepLimit {
#[inline]
fn default() -> Self {
Self {
max_steps_per_episode: 100,
}
}
}
pub type WithVisibleStepLimit<E> = Wrapped<E, VisibleStepLimit>;
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct StepLimitObs<T> {
pub inner: T,
pub remaining: f64,
}
#[derive(Debug, Copy, Clone, PartialEq, ProductSpace, Serialize, Deserialize)]
#[element(StepLimitObs<T::Element>)]
pub struct StepLimitObsSpace<T> {
pub inner: T,
pub remaining: IntervalSpace<f64>,
}
impl<T> From<T> for StepLimitObsSpace<T> {
fn from(inner: T) -> Self {
StepLimitObsSpace {
inner,
remaining: IntervalSpace::new(0.0, 1.0),
}
}
}
impl<T: Default> Default for StepLimitObsSpace<T> {
fn default() -> Self {
StepLimitObsSpace {
inner: T::default(),
remaining: IntervalSpace::new(0.0, 1.0),
}
}
}
impl<T: EnvStructure> EnvStructure for Wrapped<T, VisibleStepLimit> {
type ObservationSpace = StepLimitObsSpace<T::ObservationSpace>;
type ActionSpace = T::ActionSpace;
type FeedbackSpace = T::FeedbackSpace;
fn observation_space(&self) -> Self::ObservationSpace {
self.inner.observation_space().into()
}
fn action_space(&self) -> Self::ActionSpace {
self.inner.action_space()
}
fn feedback_space(&self) -> Self::FeedbackSpace {
self.inner.feedback_space()
}
fn discount_factor(&self) -> f64 {
self.inner.discount_factor()
}
}
impl<E: Environment> Environment for Wrapped<E, VisibleStepLimit> {
type State = StepLimitState<E::State>;
type Observation = StepLimitObs<E::Observation>;
type Action = E::Action;
type Feedback = E::Feedback;
fn initial_state(&self, rng: &mut Prng) -> Self::State {
StepLimitState {
inner: self.inner.initial_state(rng),
steps_remaining: self.wrapper.max_steps_per_episode,
}
}
fn observe(&self, state: &Self::State, rng: &mut Prng) -> Self::Observation {
let remaining = state.steps_remaining as f64 / self.wrapper.max_steps_per_episode as f64;
StepLimitObs {
inner: self.inner.observe(&state.inner, rng),
remaining,
}
}
fn step(
&self,
state: Self::State,
action: &Self::Action,
rng: &mut Prng,
logger: &mut dyn StatsLogger,
) -> (Successor<Self::State>, Self::Feedback) {
assert!(
state.steps_remaining > 0,
"invalid step from a state with no remaining steps"
);
let (inner_successor, feedback) = self.inner.step(state.inner, action, rng, logger);
let successor = inner_successor
.map(|inner| StepLimitState {
inner,
steps_remaining: state.steps_remaining - 1,
})
.then_interrupt_if(|next_state| next_state.steps_remaining == 0);
(successor, feedback)
}
}
#[cfg(test)]
mod latent {
use super::super::super::{chain::Move, testing, BuildEnv, Chain};
use super::*;
use rand::SeedableRng;
#[test]
fn run_default() {
testing::check_structured_env(&WithLatentStepLimit::<Chain>::default(), 1000, 119);
}
#[test]
fn build() {
let config = WithLatentStepLimit::<Chain>::default();
let _env = config.build_env(&mut Prng::seed_from_u64(0)).unwrap();
}
#[test]
fn step_limit() {
let mut rng = Prng::seed_from_u64(110);
let env = WithLatentStepLimit::new(Chain::default(), LatentStepLimit::new(2));
let state = env.initial_state(&mut rng);
let (successor, _) = env.step(state, &Move::Left, &mut rng, &mut ());
assert!(matches!(successor, Successor::Continue(_)));
let state = successor.into_continue().unwrap();
let (successor, _) = env.step(state, &Move::Left, &mut rng, &mut ());
assert!(matches!(successor, Successor::Interrupt(_)));
}
}
#[cfg(test)]
mod visible {
use super::super::super::{chain::Move, testing, BuildEnv, Chain};
use super::super::Wrap;
use super::*;
use rand::SeedableRng;
#[test]
fn run_default() {
testing::check_structured_env(
&Chain::default().wrap(VisibleStepLimit::default()),
1000,
119,
);
}
#[test]
fn build() {
let config = Chain::default().wrap(VisibleStepLimit::default());
let _env = config.build_env(&mut Prng::seed_from_u64(0)).unwrap();
}
#[test]
#[allow(clippy::float_cmp)] fn step_limit() {
let mut rng = Prng::seed_from_u64(110);
let env = WithVisibleStepLimit::new(Chain::default(), VisibleStepLimit::new(2));
let state = env.initial_state(&mut rng);
assert_eq!(env.observe(&state, &mut rng).remaining, 1.0);
let (successor, _) = env.step(state, &Move::Left, &mut rng, &mut ());
assert!(matches!(successor, Successor::Continue(_)));
let state = successor.into_continue().unwrap();
assert_eq!(env.observe(&state, &mut rng).remaining, 0.5);
let (successor, _) = env.step(state, &Move::Left, &mut rng, &mut ());
assert!(matches!(successor, Successor::Interrupt(_)));
assert_eq!(successor.into_interrupt().unwrap().steps_remaining, 0);
}
}