Skip to main content

rsrl/prediction/
mc.rs

1use crate::{
2    domains::Trajectory,
3    fa::StateUpdate,
4    Function,
5    Handler,
6    Parameterised,
7};
8
9#[derive(Clone, Copy, Debug)]
10#[cfg_attr(
11    feature = "serde",
12    derive(Serialize, Deserialize),
13    serde(crate = "serde_crate")
14)]
15pub struct Error<E> {
16    pub timestep: usize,
17    pub error: E,
18}
19
20#[derive(Clone, Debug, Parameterised)]
21#[cfg_attr(
22    feature = "serde",
23    derive(Serialize, Deserialize),
24    serde(crate = "serde_crate")
25)]
26pub struct GradientMC<V> {
27    #[weights]
28    pub v_func: V,
29
30    pub gamma: f64,
31}
32
33impl<'m, S, A, V> Handler<&'m Trajectory<S, A>> for GradientMC<V>
34where V: Function<(&'m S,), Output = f64> + Handler<StateUpdate<&'m S, f64>>
35{
36    type Response = Vec<V::Response>;
37    type Error = Error<V::Error>;
38
39    fn handle(&mut self, traj: &'m Trajectory<S, A>) -> Result<Self::Response, Self::Error> {
40        let n = traj.n_transitions();
41        let mut sum = 0.0;
42
43        traj.iter().rev().enumerate().map(|(t, transition)| {
44            sum = transition.reward + self.gamma * sum;
45
46            let from = transition.from.state();
47            let pred = self.v_func.evaluate((from,));
48
49            self.v_func.handle(StateUpdate {
50                state: from,
51                error: sum - pred,
52            }).map_err(|e| Error {
53                timestep: n - t - 1,
54                error: e,
55            })
56        }).collect()
57    }
58}