1use crate::env::Env;
2
3#[derive(Debug, Clone)]
7pub struct LineWorld {
8 num_rows: usize,
9 init_state: usize,
10 goal_state: usize,
11 terminal_state: Vec<usize>,
12}
13
14impl LineWorld {
15 pub fn new(
16 num_rows: usize,
17 init_state: usize,
18 goal_state: usize,
19 terminal_state: Vec<usize>,
20 ) -> Self {
21 Self {
22 num_rows,
23 init_state,
24 goal_state,
25 terminal_state,
26 }
27 }
28
29 pub fn get_init_state(&self) -> usize {
30 self.init_state
31 }
32
33 pub fn get_goal_state(&self) -> usize {
34 self.goal_state
35 }
36
37 pub fn get_terminal_state(&self) -> &Vec<usize> {
38 &self.terminal_state
39 }
40}
41
42#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
43pub enum LineWorldAction {
44 Up,
45 Down,
46}
47
48impl Env<usize, LineWorldAction> for LineWorld {
49 fn is_terminal(&self, state: &usize) -> bool {
50 self.terminal_state.contains(state)
51 }
52
53 fn is_goal(&self, state: &usize) -> bool {
54 *state == self.goal_state
55 }
56
57 fn transition(&self, state: &usize, action: &Option<LineWorldAction>) -> (Option<usize>, f64) {
58 if self.is_terminal(state) {
59 (None, -1.0)
60 } else if self.is_goal(state) {
61 (None, 1.0)
62 } else {
63 let action = action.as_ref().unwrap();
64 match action {
65 LineWorldAction::Up => (Some(*state + 1), 0.0),
66 LineWorldAction::Down => (Some(*state - 1), 0.0),
67 }
68 }
69 }
70
71 fn available_actions(&self, state: &usize) -> Vec<LineWorldAction> {
72 match state {
73 0 => vec![LineWorldAction::Up],
74 r if *r == self.num_rows - 1 => vec![LineWorldAction::Down],
75 _ => vec![LineWorldAction::Up, LineWorldAction::Down],
76 }
77 }
78}