use crate::{
domains::Transition,
fa::StateActionUpdate,
policies::Policy,
Function,
Handler,
Parameterised,
};
use rand::thread_rng;
#[derive(Clone, Debug)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct Response<R> {
td_error: f64,
qfunc_response: R,
}
#[derive(Clone, Debug, Parameterised)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct SARSA<Q, P> {
#[weights]
pub q_func: Q,
pub policy: P,
pub gamma: f64,
}
impl<'m, S, Q, P> Handler<&'m Transition<S, P::Action>> for SARSA<Q, P>
where
Q: Function<(&'m S, P::Action), Output = f64>
+ for<'a> Function<(&'m S, &'a P::Action), Output = f64>
+ Handler<StateActionUpdate<&'m S, &'m P::Action>>,
P: Policy<&'m S>,
{
type Response = Response<Q::Response>;
type Error = Q::Error;
fn handle(&mut self, t: &'m Transition<S, P::Action>) -> Result<Self::Response, Self::Error> {
let s = t.from.state();
let qsa = self.q_func.evaluate((s, &t.action));
let residual = if t.terminated() {
t.reward - qsa
} else {
let ns = t.to.state();
let na = self.policy.sample(&mut thread_rng(), ns);
let nqsna = self.q_func.evaluate((ns, na));
t.reward + self.gamma * nqsna - qsa
};
self.q_func.handle(StateActionUpdate {
state: s,
action: &t.action,
error: residual,
}).map(|r| Response {
td_error: residual,
qfunc_response: r,
})
}
}