1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
use super::super::Pomdp;
use super::Wrapped;
use rand::rngs::StdRng;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StepLimit {
pub max_steps_per_episode: u64,
}
impl StepLimit {
pub const fn new(max_steps_per_episode: u64) -> Self {
Self {
max_steps_per_episode,
}
}
}
impl Default for StepLimit {
fn default() -> Self {
Self {
max_steps_per_episode: 100,
}
}
}
pub type WithStepLimit<E> = Wrapped<E, StepLimit>;
impl<E: Pomdp> Pomdp for Wrapped<E, StepLimit> {
type State = (E::State, u64);
type Observation = E::Observation;
type Action = E::Action;
fn initial_state(&self, rng: &mut StdRng) -> Self::State {
(self.inner.initial_state(rng), 0)
}
fn observe(&self, state: &Self::State, rng: &mut StdRng) -> Self::Observation {
self.inner.observe(&state.0, rng)
}
fn step(
&self,
state: Self::State,
action: &Self::Action,
rng: &mut StdRng,
) -> (Option<Self::State>, f64, bool) {
let (inner_state, mut current_steps) = state;
let (next_inner_state, reward, mut episode_done) =
self.inner.step(inner_state, action, rng);
current_steps += 1;
let next_state = next_inner_state.map(|s| (s, current_steps));
if current_steps >= self.wrapper.max_steps_per_episode {
episode_done = true;
}
(next_state, reward, episode_done)
}
}
#[cfg(test)]
mod tests {
use super::super::super::{chain::Move, testing, BuildEnv, Chain, PomdpEnv};
use super::*;
use rand::SeedableRng;
#[test]
fn run_default() {
testing::run_pomdp(WithStepLimit::<Chain>::default(), 1000, 119);
}
#[test]
fn build() {
let config = WithStepLimit::<Chain>::default();
let _env: PomdpEnv<WithStepLimit<Chain>> = config.build_env(0).unwrap();
}
#[test]
fn step_limit() {
let mut rng = StdRng::seed_from_u64(110);
let env = WithStepLimit::new(Chain::default(), StepLimit::new(2));
let state = env.initial_state(&mut rng);
let (opt_state, _, episode_done) = env.step(state, &Move::Left, &mut rng);
assert!(!episode_done);
let state = opt_state.unwrap();
let (state, _, episode_done) = env.step(state, &Move::Left, &mut rng);
assert!(episode_done);
assert!(state.is_some());
}
}