madepro/environments/gridworld/
env.rs

1use crate::models::{Action, Sampler, State, MDP};
2
3use super::{END_TRANSITION_REWARD, NO_OP_TRANSITION_REWARD};
4
5/// A gridworld state (i, j)
6#[derive(PartialEq, Eq, Hash, Debug, Clone)]
7pub struct GridworldState {
8    i: usize,
9    j: usize,
10}
11
12impl GridworldState {
13    /// Creates a new gridworld state with the specified coordinates
14    pub const fn new(i: usize, j: usize) -> Self {
15        Self { i, j }
16    }
17}
18
19impl State for GridworldState {}
20
21/// A gridworld action
22#[derive(PartialEq, Eq, Hash, Debug, Clone)]
23pub enum GridworldAction {
24    Down,
25    Left,
26    Right,
27    Up,
28}
29
30impl Action for GridworldAction {}
31
32/// A gridworld cell
33#[derive(Debug, PartialEq)]
34pub enum Cell {
35    Air,
36    Wall,
37    End,
38}
39
40/// A gridworld
41pub struct Gridworld {
42    cell_grid: Vec<Vec<Cell>>,
43    states: Sampler<GridworldState>,
44    actions: Sampler<GridworldAction>,
45}
46
47impl Gridworld {
48    /// Creates a new gridworld with the specified cell grid, states, and actions
49    pub fn new(
50        cell_grid: Vec<Vec<Cell>>,
51        states: Vec<GridworldState>,
52        actions: Vec<GridworldAction>,
53    ) -> Self {
54        Self {
55            cell_grid,
56            states: states.into(),
57            actions: actions.into(),
58        }
59    }
60
61    /// Returns the grid's width and height
62    fn get_grid_size(&self) -> (usize, usize) {
63        (self.cell_grid.len(), self.cell_grid[0].len())
64    }
65}
66
67impl MDP for Gridworld {
68    type State = GridworldState;
69    type Action = GridworldAction;
70
71    fn get_states(&self) -> &Sampler<Self::State> {
72        &self.states
73    }
74
75    fn get_actions(&self) -> &Sampler<Self::Action> {
76        &self.actions
77    }
78
79    fn is_state_terminal(&self, state: &Self::State) -> bool {
80        let cell = &self.cell_grid[state.i][state.j];
81        *cell == Cell::End
82    }
83
84    fn transition(&self, state: &Self::State, action: &Self::Action) -> (Self::State, f64) {
85        let cell = &self.cell_grid[state.i][state.j];
86
87        // Edge cases
88        // In theory the Cell::Wall case should never happen
89        if (*cell) == Cell::End || (*cell) == Cell::Wall {
90            return (state.clone(), 0.0);
91        }
92
93        // Tentative position
94        let (i, j) = (state.i as i32, state.j as i32);
95        let (i_, j_) = match action {
96            Self::Action::Up => (i - 1, j),
97            Self::Action::Down => (i + 1, j),
98            Self::Action::Left => (i, j - 1),
99            Self::Action::Right => (i, j + 1),
100        };
101
102        // Check out of bounds
103        let (n, m) = self.get_grid_size();
104        let (n, m) = (n as i32, m as i32);
105        if i_ < 0 || i_ >= n || j_ < 0 || j_ >= m {
106            return (state.clone(), NO_OP_TRANSITION_REWARD);
107        }
108
109        // Result
110        let (i_, j_) = (i_ as usize, j_ as usize);
111        let cell_ = &self.cell_grid[i_][j_];
112        match cell_ {
113            Cell::Air => (Self::State::new(i_, j_), NO_OP_TRANSITION_REWARD),
114            Cell::Wall => (state.clone(), NO_OP_TRANSITION_REWARD),
115            Cell::End => (Self::State::new(i_, j_), END_TRANSITION_REWARD),
116        }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::environments::gridworld::{
124        get_gridworld, BOTTOM_RIGHT, DOWN, LEFT, RIGHT, TOP_LEFT, TOP_RIGHT, UP,
125    };
126
127    #[test]
128    fn is_not_terminal() {
129        let mdp = get_gridworld();
130        assert!(!mdp.is_state_terminal(&TOP_LEFT));
131        assert!(!mdp.is_state_terminal(&TOP_RIGHT));
132    }
133
134    #[test]
135    fn is_terminal() {
136        let mdp = get_gridworld();
137        assert!(mdp.is_state_terminal(&BOTTOM_RIGHT));
138    }
139
140    #[test]
141    fn transition_to_boundaries() {
142        let mdp = get_gridworld();
143        assert_eq!(
144            mdp.transition(&TOP_LEFT, &LEFT),
145            (TOP_LEFT.clone(), NO_OP_TRANSITION_REWARD)
146        );
147    }
148
149    #[test]
150    fn transition_to_air() {
151        let mdp = get_gridworld();
152        assert_eq!(
153            mdp.transition(&TOP_LEFT, &RIGHT),
154            (TOP_RIGHT.clone(), NO_OP_TRANSITION_REWARD)
155        );
156    }
157
158    #[test]
159    fn transition_to_wall() {
160        let mdp = get_gridworld();
161        assert_eq!(
162            mdp.transition(&TOP_LEFT, &DOWN),
163            (TOP_LEFT.clone(), NO_OP_TRANSITION_REWARD)
164        );
165    }
166
167    #[test]
168    fn transition_to_end() {
169        let mdp = get_gridworld();
170        assert_eq!(
171            mdp.transition(&TOP_RIGHT, &DOWN),
172            (BOTTOM_RIGHT.clone(), END_TRANSITION_REWARD)
173        );
174    }
175
176    #[test]
177    fn transition_from_terminal() {
178        let mdp = get_gridworld();
179        assert_eq!(
180            mdp.transition(&BOTTOM_RIGHT, &UP),
181            (BOTTOM_RIGHT.clone(), 0.0)
182        );
183    }
184}