rsrl 0.8.1

A fast, extensible reinforcement learning framework in Rust
Documentation
use crate::{
    domains::Trajectory,
    fa::StateUpdate,
    Function,
    Handler,
    Parameterised,
};

#[derive(Clone, Copy, Debug)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub struct Error<E> {
    pub timestep: usize,
    pub error: E,
}

#[derive(Clone, Debug, Parameterised)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub struct GradientMC<V> {
    #[weights]
    pub v_func: V,

    pub gamma: f64,
}

impl<'m, S, A, V> Handler<&'m Trajectory<S, A>> for GradientMC<V>
where V: Function<(&'m S,), Output = f64> + Handler<StateUpdate<&'m S, f64>>
{
    type Response = Vec<V::Response>;
    type Error = Error<V::Error>;

    fn handle(&mut self, traj: &'m Trajectory<S, A>) -> Result<Self::Response, Self::Error> {
        let n = traj.n_transitions();
        let mut sum = 0.0;

        traj.iter().rev().enumerate().map(|(t, transition)| {
            sum = transition.reward + self.gamma * sum;

            let from = transition.from.state();
            let pred = self.v_func.evaluate((from,));

            self.v_func.handle(StateUpdate {
                state: from,
                error: sum - pred,
            }).map_err(|e| Error {
                timestep: n - t - 1,
                error: e,
            })
        }).collect()
    }
}