use std::collections::HashMap;
use crate::mdp::State;
use crate::strategy::learn::LearningStrategy;
pub struct QLearning {
alpha: f64,
gamma: f64,
initial_value: f64,
}
impl QLearning {
pub fn new(alpha: f64, gamma: f64, initial_value: f64) -> QLearning {
QLearning {
alpha,
gamma,
initial_value,
}
}
}
impl<S: State> LearningStrategy<S> for QLearning {
fn value(
&self,
new_action_values: &Option<&HashMap<S::A, f64>>,
old_value: &Option<&f64>,
reward_after_action: f64,
) -> f64 {
let max_next = new_action_values
.and_then(|m| m.values().max_by(|a, b| a.partial_cmp(b).unwrap()))
.unwrap_or(&self.initial_value);
old_value.map_or(self.initial_value, |x| {
x + self.alpha * (reward_after_action + self.gamma * max_next - x)
})
}
}