madepro/solvers/
temporal_difference.rs1use crate::models::{ActionValue, Config, MDP};
6
7fn sarsa_q_learning<M>(
8 mdp: &M,
9 config: &Config,
10 q_learning: bool,
11) -> ActionValue<M::State, M::Action>
12where
13 M: MDP,
14{
15 let states = mdp.get_states();
16 let actions = mdp.get_actions();
17 let mut action_value = ActionValue::new(states, actions);
18 for _ in 0..config.num_episodes {
19 let mut state = states.get_random().clone();
20 let mut action = action_value
21 .epsilon_greedy(actions, &state, config.exploration_rate)
22 .clone();
23 for _ in 0..config.max_num_steps {
24 let (next_state, reward) = mdp.transition(&state, &action);
25 let next_action = action_value
26 .epsilon_greedy(actions, &next_state, config.exploration_rate)
27 .clone();
28 let current = action_value.get(&state, &action);
30 let q_value = if q_learning {
31 action_value.get(&next_state, action_value.greedy(&next_state))
32 } else {
33 action_value.get(&next_state, &next_action)
34 };
35 let target = reward + config.discount_factor * q_value;
36 action_value.insert(
37 &state,
38 &action,
39 current + config.learning_rate * (target - current),
40 );
41 state = next_state;
42 action = next_action;
43 if mdp.is_state_terminal(&state) {
44 break;
45 }
46 }
47 }
48 action_value
49}
50
51pub fn sarsa<M>(mdp: &M, config: &Config) -> ActionValue<M::State, M::Action>
60where
61 M: MDP,
62{
63 sarsa_q_learning(mdp, config, false)
64}
65
66pub fn q_learning<M>(mdp: &M, config: &Config) -> ActionValue<M::State, M::Action>
77where
78 M: MDP,
79{
80 sarsa_q_learning(mdp, config, true)
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use crate::environments::gridworld::{assert_policy_optimal, get_gridworld, get_test_config};
87
88 #[test]
89 fn test_sarsa() {
90 let mdp = get_gridworld();
91 let config = get_test_config();
92 let action_value = sarsa(&mdp, &config);
93 let policy = action_value.greedy_policy(mdp.get_states(), mdp.get_actions());
94 assert_policy_optimal(&policy);
95 }
96
97 #[test]
98 fn test_q_learning() {
99 let mdp = get_gridworld();
100 let config = get_test_config();
101 let action_value = q_learning(&mdp, &config);
102 let policy = action_value.greedy_policy(mdp.get_states(), mdp.get_actions());
103 assert_policy_optimal(&policy);
104 }
105}