rsrl 0.8.1

A fast, extensible reinforcement learning framework in Rust
Documentation
use crate::{
    domains::Batch,
    fa::StateActionUpdate,
    policies::Policy,
    Function,
    Handler,
};

#[derive(Clone, Debug, Parameterised)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub struct BaselineREINFORCE<B, P> {
    #[weights]
    pub policy: P,
    pub baseline: B,

    pub alpha: f64,
    pub gamma: f64,
}

impl<B, P> BaselineREINFORCE<B, P> {
    pub fn new(policy: P, baseline: B, alpha: f64, gamma: f64) -> Self {
        BaselineREINFORCE {
            policy,
            baseline,

            alpha,
            gamma,
        }
    }
}

impl<'m, S, B, P> Handler<&'m Batch<S, P::Action>> for BaselineREINFORCE<B, P>
where
    P: Policy<S> + Handler<StateActionUpdate<&'m S, &'m <P as Policy<S>>::Action>>,
    B: Function<(&'m S, &'m P::Action), Output = f64>,
{
    type Response = Vec<P::Response>;
    type Error = P::Error;

    fn handle(&mut self, batch: &'m Batch<S, P::Action>) -> Result<Self::Response, Self::Error> {
        let mut ret = 0.0;

        batch.iter().map(|t| {
            let s = t.from.state();
            let baseline = self.baseline.evaluate((s, &t.action));

            ret = t.reward + self.gamma * ret;

            self.policy.handle(StateActionUpdate {
                state: t.from.state(),
                action: &t.action,
                error: self.alpha * (ret - baseline),
            })
        }).collect()
    }
}