use crate::mdp::policy::Policy;
use std::error::Error;
use std::fmt;
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Debug, PartialEq, Eq)]
pub enum MDPError<'a, S: State> {
Empty,
NoAction { state: &'a S },
NoTransition { state: &'a S },
InvalidTransitionMatrix,
InvalidRewardMatrix,
}
impl<'a, S: State> Error for MDPError<'a, S> {}
impl<'a, S: State> fmt::Display for MDPError<'a, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MDPError::Empty => write!(f, "The MDP cannot be empty."),
MDPError::NoAction { state } => {
write!(f, "No action available for state {}.", state.id())
}
MDPError::NoTransition { state } => {
write!(f, "No transition is available for state {}.", state.id())
}
MDPError::InvalidTransitionMatrix => {
write!(f, "The transition matrix is invalid. Either the dimensions are incorrect or the probabilities do not sum to 1.")
}
MDPError::InvalidRewardMatrix => {
write!(f, "The reward matrix has invalid dimensions.")
}
}
}
}
#[derive(Debug)]
pub struct Episode<'a, S: State> {
pub starting_state: &'a S,
pub trajectory: Vec<&'a S>,
pub total_reward: f64,
}
pub trait State: Debug + Hash + Eq {
fn id(&self) -> usize;
}
pub trait Action: Debug + Eq {
fn id(&self) -> usize;
}
pub trait MDP<S: State, A: Action> {
fn n_states(&self) -> usize;
fn states(&self) -> &[S];
fn n_actions(&self) -> usize;
fn actions(&self) -> &[A];
fn is_terminal(&self, state: &S) -> bool;
#[inline(always)]
fn discount_factor(&self) -> f64 {
1.0
}
fn transition_probability(&self, state: &S, action: &A, next_state: &S) -> f64;
fn reward(&self, state: &S, action: &A, next_state: &S) -> f64;
fn act(&self, state: &S, action: &A) -> &S;
fn run_policy<'a>(
&'a self,
policy: &'a Policy<S, A>,
starting_state: &'a S,
maximum_steps: usize,
) -> Result<Episode<'a, S>, MDPError<'a, S>> {
let mut total_reward = 0f64;
let mut trajectory = vec![starting_state];
let mut state = starting_state;
for _ in 0..maximum_steps {
match policy.select_action(state) {
Some(action) => {
let next_state = self.act(state, action);
trajectory.push(next_state);
total_reward += self.reward(state, action, next_state);
state = next_state;
}
None => {
return Err(MDPError::NoAction { state });
}
};
if self.is_terminal(state) {
break;
}
}
Ok(Episode {
starting_state,
trajectory,
total_reward,
})
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::mdp::model::{Action, State, MDP};
use crate::mdp::policy::Policy;
use rand::Rng;
#[derive(Debug, Hash, PartialEq, Eq)]
struct S {
id: usize,
}
impl State for S {
fn id(&self) -> usize {
self.id
}
}
#[derive(Debug, PartialEq, Eq)]
enum A {
Forward,
Backward,
}
impl Action for A {
fn id(&self) -> usize {
0
}
}
struct Line {
states: Vec<S>,
actions: Vec<A>,
}
impl MDP<S, A> for Line {
fn n_states(&self) -> usize {
self.states.len()
}
fn n_actions(&self) -> usize {
self.actions.len()
}
fn states(&self) -> &[S] {
&self.states
}
fn actions(&self) -> &[A] {
&self.actions
}
fn is_terminal(&self, state: &S) -> bool {
state.id() == self.n_states() - 1
}
fn act(&self, state: &S, _: &A) -> &S {
if rand::thread_rng().gen_bool(0.5) {
if state.id() != self.n_states() - 1 {
&self.states[state.id() + 1]
} else {
&self.states[state.id()]
}
} else {
if state.id() != 0 {
&self.states[state.id() - 1]
} else {
&self.states[0]
}
}
}
fn transition_probability(&self, state: &S, action: &A, next_state: &S) -> f64 {
match action {
A::Forward if state.id() == next_state.id() - 1 => 0.5,
A::Backward if state.id() == next_state.id() + 1 => 0.5,
_ => 0.0,
}
}
#[rustfmt::skip]
fn reward(&self, _: &S, _: &A, next_state: &S) -> f64 {
if next_state.id() != self.n_states() - 1 { -1.0 } else { 0.0 }
}
}
#[test]
fn run_incomplete_policy() {
let env = Line {
states: (0..2).map(|id| S { id }).collect(),
actions: vec![A::Forward, A::Backward],
};
let incomplete_policy = Policy::new(HashMap::from([(&env.states[1], &A::Forward)]));
let episode = env.run_policy(&incomplete_policy, &env.states[0], 10);
assert!(episode.is_err());
assert!(episode
.unwrap_err()
.to_string()
.contains("No action available for state 0."));
}
#[test]
fn run_random_policy() {
let env = Line {
states: (0..10).map(|id| S { id }).collect(),
actions: vec![A::Forward, A::Backward],
};
let starting_state = env.states.iter().find(|state| state.id() == 0);
assert!(starting_state.is_some());
let policy = Policy::random(&env.states, &env.actions);
let episode = env
.run_policy(&policy, starting_state.unwrap(), 100)
.unwrap();
assert_eq!(episode.starting_state.id(), 0);
for i in 0..episode.trajectory.len() - 1 {
assert!(
episode.trajectory[i]
.id()
.abs_diff(episode.trajectory[i + 1].id())
<= 1
);
}
let actual_reward: f64 = episode.trajectory[1..] .iter()
.filter(|&&state| !env.is_terminal(state))
.map(|_| -1f64)
.sum();
assert_eq!(episode.total_reward, actual_reward);
}
}