use crate::mdp::model::{Action, MDPError, State, MDP};
use crate::mdp::policy::Policy;
use rand::Rng;
use std::collections::HashMap;
pub trait Optimizer<'a, S: State, A: Action, M: MDP<S, A>> {
fn find_optimal_policy(&self, mdp: &'a M) -> Result<Policy<'a, S, A>, MDPError<'a, S>>;
}
pub struct PolicyIteration {
pub theta: f64,
pub max_iterations: usize,
}
impl<'a, S: State, A: Action, M: MDP<S, A>> Optimizer<'a, S, A, M> for PolicyIteration {
fn find_optimal_policy(&self, mdp: &'a M) -> Result<Policy<'a, S, A>, MDPError<'a, S>> {
let mut delta;
let mut values = vec![0.0; mdp.n_states()];
let mut rng = rand::thread_rng();
let mut mapping: HashMap<&S, &A> = mdp
.states()
.iter()
.map(|state| (state, &mdp.actions()[rng.gen_range(0..mdp.n_actions())]))
.collect();
loop {
for _ in 0..self.max_iterations {
delta = 0f64;
for state in mdp.states() {
let value = values[state.id()];
match mapping.get(state) {
Some(&action) => {
let new_value = mdp.states().iter().fold(0.0, |v, next_state| {
let r = mdp.reward(state, action, next_state);
let p = mdp.transition_probability(state, action, next_state);
v + p * (r + mdp.discount_factor() * values[next_state.id()])
});
delta = delta.max((value - new_value).abs());
values[state.id()] = new_value;
}
None => return Err(MDPError::NoAction { state }),
}
}
if delta < self.theta {
break;
}
}
let mut stable = true;
for state in mdp.states() {
match mapping.get(state) {
Some(&prev_action) => {
let mut best_action = prev_action;
let mut best_value = f64::NEG_INFINITY;
for action in mdp.actions() {
let v = mdp.states().iter().fold(0.0, |v, s| {
let r = mdp.reward(state, action, s);
let p = mdp.transition_probability(state, action, s);
v + p * (r + mdp.discount_factor() * values[s.id()])
});
if v > best_value {
best_value = v;
best_action = action;
}
}
stable &= best_action == prev_action;
mapping.insert(state, best_action);
}
None => return Err(MDPError::NoAction { state }),
}
}
if stable {
return Ok(Policy::new(mapping));
}
}
}
}
pub struct ValueIteration {
pub theta: f64,
pub max_iterations: usize,
}
impl<'a, S: State, A: Action, M: MDP<S, A>> Optimizer<'a, S, A, M> for ValueIteration {
fn find_optimal_policy(&self, mdp: &'a M) -> Result<Policy<'a, S, A>, MDPError<'a, S>> {
let mut delta;
let mut values = vec![0.0; mdp.n_states()];
for _ in 0..self.max_iterations {
delta = 0f64;
for state in mdp.states() {
let value = values[state.id()];
values[state.id()] =
mdp.actions()
.iter()
.fold(f64::NEG_INFINITY, |max_v, action| {
let x = mdp.states().iter().fold(0.0, |v, next_state| {
let r = mdp.reward(state, action, next_state);
let p = mdp.transition_probability(state, action, next_state);
v + p * (r + mdp.discount_factor() * values[next_state.id()])
});
max_v.max(x)
});
delta = delta.max((value - values[state.id()]).abs());
}
if delta < self.theta {
break;
}
}
let mut mapping = HashMap::with_capacity(mdp.n_states());
for state in mdp.states() {
let mut best_action = &mdp.actions()[0];
let mut best_value = f64::NEG_INFINITY;
for action in mdp.actions() {
let v = mdp.states().iter().fold(0.0, |v, s| {
let r = mdp.reward(state, action, s);
let p = mdp.transition_probability(state, action, s);
v + p * (r + mdp.discount_factor() * values[s.id()])
});
if v > best_value {
best_value = v;
best_action = action;
}
}
mapping.insert(state, best_action);
}
Ok(Policy::new(mapping))
}
}
#[cfg(test)]
mod tests {
use crate::mdp::environment::{GridWorld, Move};
use crate::mdp::model::{State, MDP};
use crate::mdp::optimizer::{Optimizer, PolicyIteration, ValueIteration};
#[test]
fn test_policy_iteration() {
let grid = GridWorld::from(
3,
4,
|s| s.id() == 5, |a| match a {
Move::North => |d| match d {
Move::North => 0.8,
Move::South => 0.0,
Move::East => 0.1,
Move::West => 0.1,
},
Move::South => |d| match d {
Move::North => 0.0,
Move::South => 0.8,
Move::East => 0.1,
Move::West => 0.1,
},
Move::East => |d| match d {
Move::North => 0.1,
Move::South => 0.1,
Move::East => 0.8,
Move::West => 0.0,
},
Move::West => |d| match d {
Move::North => 0.1,
Move::South => 0.1,
Move::East => 0.0,
Move::West => 0.8,
},
},
|s| {
if s.id() == 3 {
1.0
} else if s.id() == 7 {
-1.0
} else {
-0.5
}
},
|s| s.id() == 3 || s.id() == 7, )
.unwrap();
let optimal_policy = PolicyIteration {
theta: 1e-6,
max_iterations: 100000,
}
.find_optimal_policy(&grid)
.unwrap();
assert_eq!(
optimal_policy.select_action(&grid.states()[0]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[1]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[2]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[4]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[6]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[8]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[9]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[10]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[11]),
Some(&Move::North)
);
}
#[test]
fn test_value_iteration() {
let grid = GridWorld::from(
3,
4,
|s| s.id() == 5, |a| match a {
Move::North => |d| match d {
Move::North => 0.8,
Move::South => 0.0,
Move::East => 0.1,
Move::West => 0.1,
},
Move::South => |d| match d {
Move::North => 0.0,
Move::South => 0.8,
Move::East => 0.1,
Move::West => 0.1,
},
Move::East => |d| match d {
Move::North => 0.1,
Move::South => 0.1,
Move::East => 0.8,
Move::West => 0.0,
},
Move::West => |d| match d {
Move::North => 0.1,
Move::South => 0.1,
Move::East => 0.0,
Move::West => 0.8,
},
},
|s| {
if s.id() == 3 {
1.0
} else if s.id() == 7 {
-1.0
} else {
-0.5
}
},
|s| s.id() == 3 || s.id() == 7, )
.unwrap();
let optimal_policy = ValueIteration {
theta: 1e-6,
max_iterations: 100000,
}
.find_optimal_policy(&grid)
.unwrap();
assert_eq!(
optimal_policy.select_action(&grid.states()[0]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[1]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[2]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[4]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[6]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[8]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[9]),
Some(&Move::East)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[10]),
Some(&Move::North)
);
assert_eq!(
optimal_policy.select_action(&grid.states()[11]),
Some(&Move::North)
);
}
}