rsrl/control/mc/
reinforce.rs1use crate::{domains::Batch, fa::StateActionUpdate, policies::Policy, Handler};
2
3#[derive(Clone, Debug, Parameterised)]
4#[cfg_attr(
5 feature = "serde",
6 derive(Serialize, Deserialize),
7 serde(crate = "serde_crate")
8)]
9pub struct REINFORCE<P> {
10 #[weights]
11 pub policy: P,
12
13 pub alpha: f64,
14 pub gamma: f64,
15}
16
17impl<P> REINFORCE<P> {
18 pub fn new(policy: P, alpha: f64, gamma: f64) -> Self {
19 REINFORCE {
20 policy,
21
22 alpha,
23 gamma,
24 }
25 }
26}
27
28impl<'m, S, P> Handler<&'m Batch<S, P::Action>> for REINFORCE<P>
29where P: Policy<S> + Handler<StateActionUpdate<&'m S, &'m <P as Policy<S>>::Action>>
30{
31 type Response = Vec<P::Response>;
32 type Error = P::Error;
33
34 fn handle(&mut self, batch: &'m Batch<S, P::Action>) -> Result<Self::Response, Self::Error> {
35 let mut ret = 0.0;
36
37 batch.iter().map(|t| {
38 ret = t.reward + self.gamma * ret;
39
40 self.policy.handle(StateActionUpdate {
41 state: t.from.state(),
42 action: &t.action,
43 error: self.alpha * ret,
44 })
45 }).collect()
46 }
47}