rl_examples/agents/
mcts.rs

1use crate::{ environment::Environment, store::Store };
2
3use super::selector::Selector;
4
5// TODO: add store visit count
6// TODO: CHANGE TO STRUCT ARG
7// TODO: ADD EPSILON DECAY
8pub struct AgentMcts<T: Environment, U: Selector, S: Store> {
9    environment: T,
10    selector: U,
11    q_store: S,
12    state_value_store: S,
13    store_action_count: S,
14    store_state_count: S,
15    total_actions_taken: usize,
16}
17
18type StateActionValue = (String, usize, f64);
19
20impl<T: Environment, U: Selector, S: Store> AgentMcts<T, U, S> {
21    pub fn new(
22        environment: T,
23        selector: U,
24        q_store: S,
25        state_value_store: S,
26        store_action_count: S,
27        store_state_count: S
28    ) -> AgentMcts<T, U, S> {
29        AgentMcts {
30            environment,
31            selector,
32            q_store: q_store,
33            state_value_store: state_value_store,
34            store_action_count,
35            store_state_count,
36            total_actions_taken: 0,
37        }
38    }
39
40    pub fn select_action(&mut self) -> usize {
41        self.selector.select_action(&mut self.environment, &self.q_store, &self.store_action_count)
42    }
43
44    pub fn get_number_of_possible_states(&self) -> usize {
45        self.environment.get_number_of_possible_states()
46    }
47
48    pub fn take_action(&mut self, action: usize) -> f64 {
49        // record action taken
50        let current_state = self.environment.get_state();
51        let id = self.store_action_count.generate_id(current_state.clone(), Some(action));
52        let current_count = self.store_action_count.get_float(&id);
53        self.store_action_count.store_float(id, current_count + 1.0);
54        // record visit to state
55        let state_id = self.store_state_count.generate_id(current_state, None);
56        let current_state_count = self.store_state_count.get_float(&state_id);
57        self.store_state_count.store_float(state_id, current_state_count + 1.0);
58        // take step
59        self.environment.step(action)
60    }
61
62    fn update_q_estimate(&mut self, state: String, action: usize, reward: f64) {
63        let new_estimate = self.selector.get_new_q_estimate(
64            &mut self.environment,
65            &mut self.q_store,
66            &mut self.store_action_count,
67            state.clone(),
68            action,
69            reward
70        );
71        let id: String = self.q_store.generate_id(state, Some(action));
72        self.q_store.store_float(id, new_estimate);
73    }
74
75    fn update_state_value_estimate(&mut self, state: String, reward: f64) {
76        let new_estimate = self.selector.get_new_value_estimate(
77            &mut self.environment,
78            &self.state_value_store,
79            &self.store_state_count,
80            state.clone(),
81            reward
82        );
83        let id: String = self.state_value_store.generate_id(state, None);
84        self.state_value_store.store_float(id, new_estimate);
85    }
86
87    pub fn get_state_value_estimate(&self, state: String) -> f64 {
88        let id: String = self.state_value_store.generate_id(state, None);
89        self.state_value_store.get_float(&id)
90    }
91
92    pub fn get_state_visit_count(&self, state: String) -> f64 {
93        let id: String = self.store_state_count.generate_id(state, None);
94        self.store_state_count.get_float(&id)
95    }
96
97    pub fn all_possible_states(&self) -> Vec<String> {
98        self.environment.all_possible_states()
99    }
100
101    /// Run an episode of the environment and update q/value estimates
102    /// Returns a vector of state action values
103    pub fn run_episode(&mut self) -> f64 {
104        let mut state_action_values: Vec<StateActionValue> = Vec::new();
105        let mut reward: f64;
106        loop {
107            let state = self.environment.get_state();
108            let action = self.select_action();
109            reward = self.take_action(action);
110            state_action_values.push((state, action, reward));
111            self.total_actions_taken += 1;
112            if self.environment.is_terminal() {
113                break;
114            }
115        }
116        // now update q estimates for each state action pair
117        // rewards should be summed from time t to end of episode
118        let mut total_reward = 0.0;
119        for (state, action, reward) in state_action_values.iter().rev() {
120            println!("State: {}, Action: {}, Reward: {}", state, action, reward);
121            total_reward += reward;
122            self.update_q_estimate(state.clone(), *action, total_reward);
123            self.update_state_value_estimate(state.clone(), total_reward);
124        }
125        // reset the environment
126        self.environment.reset();
127        // return the total reward
128        total_reward
129    }
130}