use crate::prelude::*;
use beet_core::prelude::*;
pub struct QTableTrainer<S: RlSessionTypes> {
pub table: S::QLearnPolicy,
pub env: Readonly<S::Env>,
pub params: Readonly<QLearnParams>,
initial_state: S::State,
}
impl<S: RlSessionTypes> QTableTrainer<S> {
pub fn new(
env: S::Env,
table: S::QLearnPolicy,
params: QLearnParams,
initial_state: S::State,
) -> Self {
Self {
table,
env: Readonly::new(env),
params: Readonly::new(params),
initial_state,
}
}
}
impl<S: RlSessionTypes> QPolicy for QTableTrainer<S> {
type Action = S::Action;
type State = S::State;
fn greedy_policy(&self, state: &Self::State) -> (Self::Action, QValue) {
self.table.greedy_policy(state)
}
fn get_actions(
&self,
state: &Self::State,
) -> impl Iterator<Item = (&Self::Action, &QValue)> {
self.table.get_actions(state)
}
fn get_q(&self, state: &Self::State, action: &Self::Action) -> QValue {
self.table.get_q(state, action)
}
fn set_q(
&mut self,
state: &Self::State,
action: &Self::Action,
value: QValue,
) {
self.table.set_q(state, action, value)
}
}
impl<S: RlSessionTypes> QTrainer for QTableTrainer<S>
where
S::State: Clone,
{
fn train(&mut self, rng: &mut impl Rng) {
let params = &self.params;
for episode in 0..params.n_training_episodes {
let epsilon = params.epsilon(episode);
let mut env = self.env.clone();
let mut state = self.initial_state.clone();
'step: for _step in 0..params.max_steps {
let (action, _) =
self.table.epsilon_greedy_policy(&state, epsilon, rng);
let outcome = env.step(&state, &action);
self.table.set_discounted_reward(
params,
&action,
outcome.reward,
&state,
&outcome.state,
);
if outcome.done {
break 'step;
}
state = outcome.state;
}
}
}
fn evaluate(&self) -> Evaluation {
let mut rewards = Vec::new();
let mut total_steps = 0;
let params = &self.params;
for _episode in 0..params.n_eval_episodes {
let mut env = self.env.clone();
let mut state = self.initial_state.clone();
let mut total_reward = 0.0;
for _step in 0..self.params.max_steps {
total_steps += 1;
let (action, _) = self.table.greedy_policy(&state);
let outcome = env.step(&state, &action);
total_reward += outcome.reward;
state = outcome.state;
if outcome.done {
break;
}
}
rewards.push(total_reward);
}
Evaluation::new(rewards, total_steps)
}
}
#[cfg(test)]
mod test {
use crate::prelude::*;
use beet_core::prelude::*;
#[test]
fn works() {
let mut policy_rng = RandomSource::from_seed(0);
let map = FrozenLakeMap::default_four_by_four();
let initial_state = map.agent_position();
let env = QTableEnv::new(map.transition_outcomes());
let params = QLearnParams::default();
let mut trainer = QTableTrainer::<FrozenLakeQTableSession>::new(
env.clone(),
QTable::default(),
params,
initial_state,
);
let now = Instant::now();
trainer.train(&mut policy_rng.0);
let elapsed = now.elapsed();
elapsed.xpect_greater_than(Duration::from_millis(2));
let eval = trainer.evaluate();
eval.mean.xpect_eq(1.);
eval.std.xpect_eq(0.);
eval.total_steps.xpect_eq(600);
}
}