madepro/environments/gridworld/
env.rs1use crate::models::{Action, Sampler, State, MDP};
2
3use super::{END_TRANSITION_REWARD, NO_OP_TRANSITION_REWARD};
4
5#[derive(PartialEq, Eq, Hash, Debug, Clone)]
7pub struct GridworldState {
8 i: usize,
9 j: usize,
10}
11
12impl GridworldState {
13 pub const fn new(i: usize, j: usize) -> Self {
15 Self { i, j }
16 }
17}
18
19impl State for GridworldState {}
20
21#[derive(PartialEq, Eq, Hash, Debug, Clone)]
23pub enum GridworldAction {
24 Down,
25 Left,
26 Right,
27 Up,
28}
29
30impl Action for GridworldAction {}
31
32#[derive(Debug, PartialEq)]
34pub enum Cell {
35 Air,
36 Wall,
37 End,
38}
39
40pub struct Gridworld {
42 cell_grid: Vec<Vec<Cell>>,
43 states: Sampler<GridworldState>,
44 actions: Sampler<GridworldAction>,
45}
46
47impl Gridworld {
48 pub fn new(
50 cell_grid: Vec<Vec<Cell>>,
51 states: Vec<GridworldState>,
52 actions: Vec<GridworldAction>,
53 ) -> Self {
54 Self {
55 cell_grid,
56 states: states.into(),
57 actions: actions.into(),
58 }
59 }
60
61 fn get_grid_size(&self) -> (usize, usize) {
63 (self.cell_grid.len(), self.cell_grid[0].len())
64 }
65}
66
67impl MDP for Gridworld {
68 type State = GridworldState;
69 type Action = GridworldAction;
70
71 fn get_states(&self) -> &Sampler<Self::State> {
72 &self.states
73 }
74
75 fn get_actions(&self) -> &Sampler<Self::Action> {
76 &self.actions
77 }
78
79 fn is_state_terminal(&self, state: &Self::State) -> bool {
80 let cell = &self.cell_grid[state.i][state.j];
81 *cell == Cell::End
82 }
83
84 fn transition(&self, state: &Self::State, action: &Self::Action) -> (Self::State, f64) {
85 let cell = &self.cell_grid[state.i][state.j];
86
87 if (*cell) == Cell::End || (*cell) == Cell::Wall {
90 return (state.clone(), 0.0);
91 }
92
93 let (i, j) = (state.i as i32, state.j as i32);
95 let (i_, j_) = match action {
96 Self::Action::Up => (i - 1, j),
97 Self::Action::Down => (i + 1, j),
98 Self::Action::Left => (i, j - 1),
99 Self::Action::Right => (i, j + 1),
100 };
101
102 let (n, m) = self.get_grid_size();
104 let (n, m) = (n as i32, m as i32);
105 if i_ < 0 || i_ >= n || j_ < 0 || j_ >= m {
106 return (state.clone(), NO_OP_TRANSITION_REWARD);
107 }
108
109 let (i_, j_) = (i_ as usize, j_ as usize);
111 let cell_ = &self.cell_grid[i_][j_];
112 match cell_ {
113 Cell::Air => (Self::State::new(i_, j_), NO_OP_TRANSITION_REWARD),
114 Cell::Wall => (state.clone(), NO_OP_TRANSITION_REWARD),
115 Cell::End => (Self::State::new(i_, j_), END_TRANSITION_REWARD),
116 }
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::environments::gridworld::{
124 get_gridworld, BOTTOM_RIGHT, DOWN, LEFT, RIGHT, TOP_LEFT, TOP_RIGHT, UP,
125 };
126
127 #[test]
128 fn is_not_terminal() {
129 let mdp = get_gridworld();
130 assert!(!mdp.is_state_terminal(&TOP_LEFT));
131 assert!(!mdp.is_state_terminal(&TOP_RIGHT));
132 }
133
134 #[test]
135 fn is_terminal() {
136 let mdp = get_gridworld();
137 assert!(mdp.is_state_terminal(&BOTTOM_RIGHT));
138 }
139
140 #[test]
141 fn transition_to_boundaries() {
142 let mdp = get_gridworld();
143 assert_eq!(
144 mdp.transition(&TOP_LEFT, &LEFT),
145 (TOP_LEFT.clone(), NO_OP_TRANSITION_REWARD)
146 );
147 }
148
149 #[test]
150 fn transition_to_air() {
151 let mdp = get_gridworld();
152 assert_eq!(
153 mdp.transition(&TOP_LEFT, &RIGHT),
154 (TOP_RIGHT.clone(), NO_OP_TRANSITION_REWARD)
155 );
156 }
157
158 #[test]
159 fn transition_to_wall() {
160 let mdp = get_gridworld();
161 assert_eq!(
162 mdp.transition(&TOP_LEFT, &DOWN),
163 (TOP_LEFT.clone(), NO_OP_TRANSITION_REWARD)
164 );
165 }
166
167 #[test]
168 fn transition_to_end() {
169 let mdp = get_gridworld();
170 assert_eq!(
171 mdp.transition(&TOP_RIGHT, &DOWN),
172 (BOTTOM_RIGHT.clone(), END_TRANSITION_REWARD)
173 );
174 }
175
176 #[test]
177 fn transition_from_terminal() {
178 let mdp = get_gridworld();
179 assert_eq!(
180 mdp.transition(&BOTTOM_RIGHT, &UP),
181 (BOTTOM_RIGHT.clone(), 0.0)
182 );
183 }
184}