rl_examples/agents/
mcts.rs1use crate::{ environment::Environment, store::Store };
2
3use super::selector::Selector;
4
5pub struct AgentMcts<T: Environment, U: Selector, S: Store> {
9 environment: T,
10 selector: U,
11 q_store: S,
12 state_value_store: S,
13 store_action_count: S,
14 store_state_count: S,
15 total_actions_taken: usize,
16}
17
18type StateActionValue = (String, usize, f64);
19
20impl<T: Environment, U: Selector, S: Store> AgentMcts<T, U, S> {
21 pub fn new(
22 environment: T,
23 selector: U,
24 q_store: S,
25 state_value_store: S,
26 store_action_count: S,
27 store_state_count: S
28 ) -> AgentMcts<T, U, S> {
29 AgentMcts {
30 environment,
31 selector,
32 q_store: q_store,
33 state_value_store: state_value_store,
34 store_action_count,
35 store_state_count,
36 total_actions_taken: 0,
37 }
38 }
39
40 pub fn select_action(&mut self) -> usize {
41 self.selector.select_action(&mut self.environment, &self.q_store, &self.store_action_count)
42 }
43
44 pub fn get_number_of_possible_states(&self) -> usize {
45 self.environment.get_number_of_possible_states()
46 }
47
48 pub fn take_action(&mut self, action: usize) -> f64 {
49 let current_state = self.environment.get_state();
51 let id = self.store_action_count.generate_id(current_state.clone(), Some(action));
52 let current_count = self.store_action_count.get_float(&id);
53 self.store_action_count.store_float(id, current_count + 1.0);
54 let state_id = self.store_state_count.generate_id(current_state, None);
56 let current_state_count = self.store_state_count.get_float(&state_id);
57 self.store_state_count.store_float(state_id, current_state_count + 1.0);
58 self.environment.step(action)
60 }
61
62 fn update_q_estimate(&mut self, state: String, action: usize, reward: f64) {
63 let new_estimate = self.selector.get_new_q_estimate(
64 &mut self.environment,
65 &mut self.q_store,
66 &mut self.store_action_count,
67 state.clone(),
68 action,
69 reward
70 );
71 let id: String = self.q_store.generate_id(state, Some(action));
72 self.q_store.store_float(id, new_estimate);
73 }
74
75 fn update_state_value_estimate(&mut self, state: String, reward: f64) {
76 let new_estimate = self.selector.get_new_value_estimate(
77 &mut self.environment,
78 &self.state_value_store,
79 &self.store_state_count,
80 state.clone(),
81 reward
82 );
83 let id: String = self.state_value_store.generate_id(state, None);
84 self.state_value_store.store_float(id, new_estimate);
85 }
86
87 pub fn get_state_value_estimate(&self, state: String) -> f64 {
88 let id: String = self.state_value_store.generate_id(state, None);
89 self.state_value_store.get_float(&id)
90 }
91
92 pub fn get_state_visit_count(&self, state: String) -> f64 {
93 let id: String = self.store_state_count.generate_id(state, None);
94 self.store_state_count.get_float(&id)
95 }
96
97 pub fn all_possible_states(&self) -> Vec<String> {
98 self.environment.all_possible_states()
99 }
100
101 pub fn run_episode(&mut self) -> f64 {
104 let mut state_action_values: Vec<StateActionValue> = Vec::new();
105 let mut reward: f64;
106 loop {
107 let state = self.environment.get_state();
108 let action = self.select_action();
109 reward = self.take_action(action);
110 state_action_values.push((state, action, reward));
111 self.total_actions_taken += 1;
112 if self.environment.is_terminal() {
113 break;
114 }
115 }
116 let mut total_reward = 0.0;
119 for (state, action, reward) in state_action_values.iter().rev() {
120 println!("State: {}, Action: {}, Reward: {}", state, action, reward);
121 total_reward += reward;
122 self.update_q_estimate(state.clone(), *action, total_reward);
123 self.update_state_value_estimate(state.clone(), total_reward);
124 }
125 self.environment.reset();
127 total_reward
129 }
130}