rsrl 0.8.1

A fast, extensible reinforcement learning framework in Rust
Documentation
use crate::{
    domains::Transition,
    fa::ScaledGradientUpdate,
    params::BufferMut,
    Differentiable,
    Handler,
};

#[derive(Clone, Debug)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub struct Response<T, W> {
    pub res_theta: T,
    pub res_w: W,
    pub td_error: f64,
}

#[derive(Clone, Debug, Parameterised)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub struct GTD2<F> {
    #[weights]
    pub fa_theta: F,
    pub fa_w: F,

    pub gamma: f64,
}

type SGU<'m, S, F> = ScaledGradientUpdate<<F as Differentiable<(&'m S,)>>::Jacobian>;
type SGURef<'m, 'j, S, F> = ScaledGradientUpdate<&'j <F as Differentiable<(&'m S,)>>::Jacobian>;

impl<'m, S, A, F> Handler<&'m Transition<S, A>> for GTD2<F>
where
    F: Differentiable<(&'m S,), Output = f64> + Handler<SGU<'m, S, F>>,
    F: for<'j> Handler<
        SGURef<'m, 'j, S, F>,
        Response = <F as Handler<SGU<'m, S, F>>>::Response,
        Error = <F as Handler<SGU<'m, S, F>>>::Error,
    >,
{
    type Response =
        Response<<F as Handler<SGU<'m, S, F>>>::Response, <F as Handler<SGU<'m, S, F>>>::Response>;
    type Error = <F as Handler<SGU<'m, S, F>>>::Error;

    fn handle(&mut self, t: &'m Transition<S, A>) -> Result<Self::Response, Self::Error> {
        let (s, ns) = t.states();

        let w_s = self.fa_w.evaluate((s,));
        let theta_s = self.fa_theta.evaluate((s,));
        let theta_ns = self.fa_theta.evaluate((ns,));

        let td_error = if t.terminated() {
            t.reward - theta_s
        } else {
            t.reward + self.gamma * theta_ns - theta_s
        };

        let mut grad_s = self.fa_theta.grad((s,));

        let res_w = self.fa_w.handle(ScaledGradientUpdate {
            alpha: td_error - w_s,
            jacobian: &grad_s,
        })?;

        let grad_ns = self.fa_theta.grad((ns,));

        grad_s.merge_inplace(&grad_ns, |x, y| x - self.gamma * y);

        let res_theta = self.fa_theta.handle(ScaledGradientUpdate {
            alpha: w_s,
            jacobian: grad_s,
        })?;

        Ok(Response {
            res_theta,
            res_w,
            td_error,
        })
    }
}