extern crate chrono;
extern crate rand;
use std::fmt;
use std::f64;
use std::hash;
use std::collections::HashMap;
use std::collections::HashSet;
use std::cmp;
use self::chrono::prelude::*;
use policy;
use softmax;
#[derive(Debug,Clone)]
pub struct SearchCriteria {
pub _lambda: f64,
pub _gamma: f64,
pub _alpha: f64,
pub _stop_limit: StopCondition,
pub _policy_select_method: PolicySelectMethod,
}
impl SearchCriteria {
pub fn check( & self ) -> Result< (), & 'static str > {
if self._lambda < 0. || self._lambda > 1. ||
self._gamma < 0. || self._gamma > 1. ||
self._alpha <= 0.
{
Err( "search criteria out of range" )
} else {
match self._policy_select_method {
PolicySelectMethod::EpsilonGreedy( x ) => { if x < 0. || x > 1.0 { return Err( "search criteria out of range" ) } },
_ => {},
}
Ok( () )
}
}
}
#[derive(Debug, Clone)]
pub enum StopCondition {
TimeMicro(f64), EpisodeIter(usize), }
#[derive(Debug, Clone)]
pub enum PolicySelectMethod {
EpsilonGreedy( f64 ),
Softmax,
}
pub struct Reward(pub f64);
pub trait Game< State, Action > where State: Clone + cmp::Eq + hash::Hash, Action: Clone + cmp::Eq + hash::Hash {
fn gen_initial_state( & mut self ) -> State;
fn gen_possible_actions( & mut self, & State ) -> Vec< Action >;
fn do_action( & mut self, & State, & Action ) -> (Reward, State);
fn is_state_terminal( & mut self, s: & State ) -> bool;
fn get_state_history( & self ) -> Vec< ( State, Action ) >;
fn set_state_history( & mut self, h: & [ (State, Action) ] );
}
pub fn search< G, State, Action >( criteria: & SearchCriteria,
g: & mut G )
-> Result< ( HashMap< ( State, Action ), f64 >,
HashMap< State, Vec< ( Action, f64 ) > >,
HashMap< State, f64 > ), & 'static str >
where G: Game< State, Action >, State: Clone + cmp::Eq + hash::Hash + fmt::Debug, Action: Clone + cmp::Eq + hash::Hash + fmt::Debug {
criteria.check()?;
let mut policy_values : HashMap< (State, Action), f64 > = HashMap::new();
let t0 = Local::now();
let mut iter = 0;
let state_init = g.gen_initial_state();
'outer_loop: loop {
let mut eligibility_trace : HashMap< (State, Action), f64 > = HashMap::new();
let mut state_episode = g.gen_initial_state();
if g.is_state_terminal( & state_episode ) {
break;
}
let mut action : Action = {
let possible_actions = g.gen_possible_actions( & state_init );
match criteria._policy_select_method {
PolicySelectMethod::EpsilonGreedy( epsilon ) => {
let action_greedy = policy::get_greedy_action_at_state( & policy_values, & state_episode );
policy::e_greedy_select( epsilon, possible_actions.as_slice(), & action_greedy )
},
PolicySelectMethod::Softmax => {
let mut vals = softmax::Distr(vec![]);
for (_k,i) in possible_actions.iter().enumerate() {
let val = match policy_values.get( &( state_init.clone(), i.clone() ) ) {
Some( x ) => *x,
None => 0.,
};
vals.0.push(val);
}
policy::softmax_select( possible_actions.as_slice(), & vals )
},
}
};
loop { if g.is_state_terminal( & state_episode ) {
break;
}
let ( reward, state_next ) = g.do_action( & state_episode, & action );
let action_next : Action = {
let possible_actions = g.gen_possible_actions( & state_next );
match criteria._policy_select_method {
PolicySelectMethod::EpsilonGreedy( epsilon ) => {
let action_greedy = policy::get_greedy_action_at_state( & policy_values, & state_next );
policy::e_greedy_select( epsilon, possible_actions.as_slice(), & action_greedy )
},
PolicySelectMethod::Softmax => {
let mut vals = softmax::Distr(vec![]);
for (_k,i) in possible_actions.iter().enumerate() {
let val = match policy_values.get( &( state_next.clone(), i.clone() ) ) {
Some( x ) => *x,
None => 0.,
};
vals.0.push(val);
}
policy::softmax_select( possible_actions.as_slice(), & vals )
},
}
};
let td_error = {
let q_next = policy_values.get( &( state_next.clone(), action_next.clone() ) ).unwrap_or(&0.);
let q = policy_values.get( &( state_episode.clone(), action.clone() ) ).unwrap_or(&0.);
reward.0 + criteria._gamma * q_next - q
};
{
let eligibility = eligibility_trace.entry( ( state_episode.clone(), action.clone() ) ).or_insert( 0. );
*eligibility = 1.;
}
let mut loop_detector = HashMap::new();
let mut items_in_path = HashSet::new();
let mut items_in_loops = HashSet::new();
let trace = g.get_state_history();
for i in 0..trace.len() {
let t = & trace[i];
let exists = match loop_detector.get( t ) {
None => { false },
Some(_) => { true },
};
if exists {
let index = *loop_detector.get( t ).unwrap();
for j in index..i {
items_in_path.remove( &j );
items_in_loops.insert( j );
}
loop_detector.insert( t.clone(), i );
items_in_path.insert( i );
items_in_loops.remove( &i );
} else {
loop_detector.insert( t.clone(), i );
items_in_path.insert( i );
items_in_loops.remove( &i );
}
}
for i in items_in_loops.iter() {
let t = & trace[ *i ];
let v = eligibility_trace.get_mut( t ).unwrap();
*v = criteria._gamma * criteria._lambda * *v;
}
let normalized_policies = normalized_policy_actions( & policy_values );
for i in items_in_path.iter() {
let t = & trace[ *i ];
let qq = *policy_values.get( t ).unwrap_or(& 0.);
let v = eligibility_trace.get_mut( t ).unwrap();
let alpha_adjust = match normalized_policies.get( t ) {
Some(x) => { ( 1. - if x.is_nan() { 0. } else { *x } ) * criteria._alpha },
None => { criteria._alpha },
};
let n = qq + alpha_adjust * td_error * *v;
policy_values.insert( ( t.0.clone(), t.1.clone() ), n );
*v = criteria._gamma * criteria._lambda * *v;
}
let mut sorted_index = items_in_path.iter().cloned().collect::<Vec<usize> >();
sorted_index.sort();
let history_filtered : Vec< (State, Action) > = sorted_index.iter().map( |x| trace[*x].clone() ).collect();
g.set_state_history( history_filtered.as_slice() );
state_episode = state_next;
action = action_next;
let t1 = Local::now();
match criteria._stop_limit {
StopCondition::TimeMicro(t) => {
let t_delta = t1.signed_duration_since(t0).num_microseconds().unwrap() as f64;
if t_delta >= t {
break 'outer_loop;
}
},
_ => {},
}
}
let t1 = Local::now();
match criteria._stop_limit {
StopCondition::TimeMicro(t) => {
let t_delta = t1.signed_duration_since(t0).num_microseconds().unwrap() as f64;
if t_delta >= t {
break;
}
},
StopCondition::EpisodeIter(n) => {
if iter >= n {
break;
}
},
}
iter += 1;
}
let policy_normalized = normalized_policy_actions_array( & policy_values );
let expect = get_expectation_policy( & policy_values );
Ok( ( policy_values, policy_normalized, expect ) )
}
fn get_expectation_policy< State, Action >( policy_map: & HashMap< (State, Action), f64 > )
-> HashMap< State, f64 >
where State: Clone + cmp::Eq + hash::Hash + fmt::Debug,
Action: Clone + cmp::Eq + hash::Hash + fmt::Debug {
let mut h : HashMap< State, f64 > = HashMap::new();
for i in policy_map.iter() {
let x = h.entry( (i.0).0.clone() ).or_insert( 0. );
*x += *i.1;
}
h
}
fn normalized_policy_actions< State, Action >( policy_map: & HashMap< (State, Action), f64 > ) -> HashMap< (State, Action), f64 >
where State: Clone + cmp::Eq + hash::Hash + fmt::Debug,
Action: Clone + cmp::Eq + hash::Hash + fmt::Debug {
let mut h : HashMap< State, Vec< ( Action, f64 ) > > = HashMap::new();
for i in policy_map.iter() {
let v = h.entry( (i.0).0.clone() ).or_insert( vec![] );
v.push( ( (i.0).1.clone(), *i.1 ) );
}
for i in h.iter_mut() {
let bounds = i.1.iter()
.fold( ( f64::MAX, f64::MIN ), |accum, x| {
let mut b_l = accum.0;
let mut b_h = accum.1;
if x.1 <= b_l { b_l = x.1 }
if x.1 >= b_h { b_h = x.1 }
( b_l, b_h )
} );
let total = i.1.iter().fold( 0., |accum, x| {
accum + x.1 - bounds.0
} );
*i.1 = i.1.iter().cloned().map( |x| ( x.0, ( x.1 - bounds.0 ) / total ) ).collect();
}
let mut ret = HashMap::new();
for i in h.iter() {
for j in i.1.iter() {
ret.insert( ( i.0.clone(), j.0.clone() ), j.1.clone() );
}
}
ret
}
fn normalized_policy_actions_array< State, Action >( policy_map: & HashMap< (State, Action), f64 > ) -> HashMap< State, Vec< ( Action, f64 ) > >
where State: Clone + cmp::Eq + hash::Hash + fmt::Debug,
Action: Clone + cmp::Eq + hash::Hash + fmt::Debug {
let mut h : HashMap< State, Vec< ( Action, f64 ) > > = HashMap::new();
for i in policy_map.iter() {
let v = h.entry( (i.0).0.clone() ).or_insert( vec![] );
v.push( ( (i.0).1.clone(), *i.1 ) );
}
for i in h.iter_mut() {
let bounds = i.1.iter()
.fold( ( f64::MAX, f64::MIN ), |accum, x| {
let mut b_l = accum.0;
let mut b_h = accum.1;
if x.1 <= b_l { b_l = x.1 }
if x.1 >= b_h { b_h = x.1 }
( b_l, b_h )
} );
let total = i.1.iter().fold( 0., |accum, x| {
accum + x.1 - bounds.0
} );
*i.1 = i.1.iter().cloned().map( |x| ( x.0, ( x.1 - bounds.0 ) / total ) ).collect();
}
h
}