forger/env/
lineworld.rs

1use crate::env::Env;
2
3// ┌──────────────────────────────────────────────────────────┐
4//  Line World
5// └──────────────────────────────────────────────────────────┘
6#[derive(Debug, Clone)]
7pub struct LineWorld {
8    num_rows: usize,
9    init_state: usize,
10    goal_state: usize,
11    terminal_state: Vec<usize>,
12}
13
14impl LineWorld {
15    pub fn new(
16        num_rows: usize,
17        init_state: usize,
18        goal_state: usize,
19        terminal_state: Vec<usize>,
20    ) -> Self {
21        Self {
22            num_rows,
23            init_state,
24            goal_state,
25            terminal_state,
26        }
27    }
28
29    pub fn get_init_state(&self) -> usize {
30        self.init_state
31    }
32
33    pub fn get_goal_state(&self) -> usize {
34        self.goal_state
35    }
36
37    pub fn get_terminal_state(&self) -> &Vec<usize> {
38        &self.terminal_state
39    }
40}
41
42#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
43pub enum LineWorldAction {
44    Up,
45    Down,
46}
47
48impl Env<usize, LineWorldAction> for LineWorld {
49    fn is_terminal(&self, state: &usize) -> bool {
50        self.terminal_state.contains(state)
51    }
52
53    fn is_goal(&self, state: &usize) -> bool {
54        *state == self.goal_state
55    }
56
57    fn transition(&self, state: &usize, action: &Option<LineWorldAction>) -> (Option<usize>, f64) {
58        if self.is_terminal(state) {
59            (None, -1.0)
60        } else if self.is_goal(state) {
61            (None, 1.0)
62        } else {
63            let action = action.as_ref().unwrap();
64            match action {
65                LineWorldAction::Up => (Some(*state + 1), 0.0),
66                LineWorldAction::Down => (Some(*state - 1), 0.0),
67            }
68        }
69    }
70
71    fn available_actions(&self, state: &usize) -> Vec<LineWorldAction> {
72        match state {
73            0 => vec![LineWorldAction::Up],
74            r if *r == self.num_rows - 1 => vec![LineWorldAction::Down],
75            _ => vec![LineWorldAction::Up, LineWorldAction::Down],
76        }
77    }
78}