use crate::{
Handler, OutputOf,
domains::Transition,
fa::linear::basis::Basis,
params::*,
policies::{Policy, EnumerablePolicy},
prediction::{ValuePredictor, ActionValuePredictor},
traces::Trace,
utils::argmax_first,
};
use ndarray::{Array1, Ix1, linalg::Dot};
use std::f64;
#[derive(Parameterised)]
pub struct TOQLambda<B: Space, P, T> {
pub basis: B,
#[weights] pub theta: Array1<f64>,
pub trace: B::Value,
pub policy: P,
pub alpha: f64,
pub gamma: f64,
pub lambda: f64,
q_old: f64,
}
impl<B: Space, P, T> TOQLambda<B, P, T>
where
B::Value: Buffer<Dim = Ix1>,
{
pub fn new(
basis: B,
theta: Array1<f64>,
policy: P,
alpha: f64,
gamma: f64,
lambda: f64,
) -> Self {
let trace = B::Value::zeros(theta.len());
TOQLambda {
basis,
theta,
trace,
policy,
alpha,
gamma,
lambda,
q_old: 0.0,
}
}
pub fn zeros(
basis: B,
policy: P,
alpha: f64,
gamma: f64,
lambda: f64,
) -> Self {
let n: usize = basis.dim().into();
TOQLambda::new(
basis, Array1::zeros(n),
policy, alpha, gamma, lambda,
)
}
}
impl<'m, S, B, P, T> Handler<&'m Transition<S, P::Action>> for TOQLambda<B, P, T>
where
B: Basis<(&'m S, &'m P::Action)> + Basis<(&'m S, P::Action)>,
P: EnumerablePolicy<&'m S>,
T: Trace<Buffer = B::Value>,
B::Value: BufferMut<Dim = Ix1> + Dot<Array1<f64>, Output = f64> + Dot<B::Value, Output = f64>,
OutputOf<P, (&'m S,)>: std::ops::Index<usize, Output = f64> + IntoIterator<Item = f64>,
<OutputOf<P, (&'m S,)> as IntoIterator>::IntoIter: ExactSizeIterator,
{
type Response = ();
type Error = ();
fn handle(&mut self, t: &'m Transition<S, P::Action>) -> Result<(), ()> {
let s = t.from.state();
let phi_s: Vec<_> = (0..self.policy.len((s,)))
.into_iter()
.map(|a| self.basis.project((s, a)).unwrap())
.collect();
let phi_s_a = &phi_s[t.action];
let qs: Vec<_> = phi_s.iter().map(|f| f.dot(&self.theta)).collect();
let qsa = qs[t.action];
let (amax, _) = argmax_first(qs);
if t.action == amax {
let a = self.alpha;
let c = self.lambda * self.gamma;
let dotted = self.trace.dot(phi_s_a);
self.trace.merge_inplace(&phi_s_a, move |x, y| {
c * x + (1.0 - a * c * dotted) * y
});
} else {
self.trace.merge_inplace(&phi_s_a, |_, y| y);
}
if t.terminated() {
self.trace.scaled_addto(self.alpha * (t.reward - self.q_old), &mut self.theta);
phi_s_a.scaled_addto(self.alpha * (self.q_old - qsa), &mut self.theta);
self.q_old = 0.0;
self.trace.reset();
} else {
let ns = t.to.state();
let phi_ns_na = self.basis.project((ns, 0)).unwrap();
let qnsna = phi_ns_na.dot(&self.theta);
let (phi_ns_na, qnsna) = (1..self.policy.len((ns,)))
.into_iter()
.fold((phi_ns_na, qnsna), |acc, a| {
let phi = self.basis.project((s, a)).unwrap();
let val = phi.dot(&self.theta);
if val - acc.1 > 1e-7 { (phi, val) } else { acc }
});
let residual = t.reward + self.gamma * qnsna - self.q_old;
self.trace.scaled_addto(self.alpha * residual, &mut self.theta);
phi_ns_na.scaled_addto(self.alpha * (self.q_old - qsa), &mut self.theta);
self.q_old = qnsna;
if t.action != amax {
self.trace.reset();
}
}
Ok(())
}
}
impl<S, B, P, T> ValuePredictor<S> for TOQLambda<B, P, T>
where
B: for<'s> Basis<(&'s S, usize)>,
P: for<'s> EnumerablePolicy<&'s S>,
B::Value: Dot<Array1<f64>, Output = f64>,
for<'s> OutputOf<P, (&'s S,)>: std::ops::Index<usize, Output = f64> + IntoIterator<Item = f64>,
for<'s> <OutputOf<P, (&'s S,)> as IntoIterator>::IntoIter: ExactSizeIterator,
{
fn predict_v(&self, s: S) -> f64 {
(0..self.policy.len((&s,)))
.into_iter()
.map(|a| self.basis.project((&s, a)).unwrap().dot(&self.theta))
.fold(f64::MIN, |acc, x| if x - acc > 1e-7 { x } else { acc })
}
}
impl<S, B, P, T> ActionValuePredictor<S, P::Action> for TOQLambda<B, P, T>
where
B: Basis<(S, P::Action)>,
P: Policy<S>,
B::Value: Dot<Array1<f64>, Output = f64>,
{
fn predict_q(&self, s: S, a: P::Action) -> f64 {
self.basis.project((s, a)).unwrap().dot(&self.theta)
}
}