Skip to main content

rsrl/control/mc/
reinforce.rs

1use crate::{domains::Batch, fa::StateActionUpdate, policies::Policy, Handler};
2
3#[derive(Clone, Debug, Parameterised)]
4#[cfg_attr(
5    feature = "serde",
6    derive(Serialize, Deserialize),
7    serde(crate = "serde_crate")
8)]
9pub struct REINFORCE<P> {
10    #[weights]
11    pub policy: P,
12
13    pub alpha: f64,
14    pub gamma: f64,
15}
16
17impl<P> REINFORCE<P> {
18    pub fn new(policy: P, alpha: f64, gamma: f64) -> Self {
19        REINFORCE {
20            policy,
21
22            alpha,
23            gamma,
24        }
25    }
26}
27
28impl<'m, S, P> Handler<&'m Batch<S, P::Action>> for REINFORCE<P>
29where P: Policy<S> + Handler<StateActionUpdate<&'m S, &'m <P as Policy<S>>::Action>>
30{
31    type Response = Vec<P::Response>;
32    type Error = P::Error;
33
34    fn handle(&mut self, batch: &'m Batch<S, P::Action>) -> Result<Self::Response, Self::Error> {
35        let mut ret = 0.0;
36
37        batch.iter().map(|t| {
38            ret = t.reward + self.gamma * ret;
39
40            self.policy.handle(StateActionUpdate {
41                state: t.from.state(),
42                action: &t.action,
43                error: self.alpha * ret,
44            })
45        }).collect()
46    }
47}