1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use super::{CloneBuild, EnvStructure, Environment, Successor};
use crate::feedback::Reward;
use crate::logging::StatsLogger;
use crate::spaces::{IndexSpace, IntervalSpace};
use crate::Prng;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MemoryGame {
pub num_actions: usize,
pub history_len: usize,
}
impl CloneBuild for MemoryGame {}
impl Default for MemoryGame {
fn default() -> Self {
Self {
num_actions: 2,
history_len: 1,
}
}
}
impl MemoryGame {
#[must_use]
pub const fn new(num_actions: usize, history_len: usize) -> Self {
Self {
num_actions,
history_len,
}
}
}
impl EnvStructure for MemoryGame {
type ObservationSpace = IndexSpace;
type ActionSpace = IndexSpace;
type FeedbackSpace = IntervalSpace<Reward>;
fn observation_space(&self) -> Self::ObservationSpace {
IndexSpace::new(self.num_actions + self.history_len)
}
fn action_space(&self) -> Self::ActionSpace {
IndexSpace::new(self.num_actions)
}
fn feedback_space(&self) -> Self::FeedbackSpace {
IntervalSpace::new(Reward(-1.0), Reward(1.0))
}
fn discount_factor(&self) -> f64 {
1.0
}
}
impl Environment for MemoryGame {
type State = (usize, usize);
type Observation = usize;
type Action = usize;
type Feedback = Reward;
fn initial_state(&self, rng: &mut Prng) -> Self::State {
let state = rng.gen_range(0..self.num_actions);
(state, state)
}
fn observe(&self, state: &Self::State, _rng: &mut Prng) -> Self::Observation {
let (current_state, _initial_state) = *state;
current_state
}
fn step(
&self,
state: Self::State,
action: &Self::Action,
_: &mut Prng,
_: &mut dyn StatsLogger,
) -> (Successor<Self::State>, Self::Feedback) {
let (current_state, initial_state) = state;
if current_state == self.num_actions + self.history_len - 1 {
let reward = if *action == initial_state { 1.0 } else { -1.0 };
(Successor::Terminate, Reward(reward))
} else {
let new_state = if current_state < self.num_actions {
self.num_actions
} else {
current_state + 1
};
(Successor::Continue((new_state, initial_state)), Reward(0.0))
}
}
}
#[cfg(test)]
mod tests {
use super::super::testing;
use super::*;
#[test]
fn run_default() {
testing::check_structured_env(&MemoryGame::default(), 1000, 0);
}
}