1use crate::{
2 domains::Trajectory,
3 fa::StateUpdate,
4 Function,
5 Handler,
6 Parameterised,
7};
8
9#[derive(Clone, Copy, Debug)]
10#[cfg_attr(
11 feature = "serde",
12 derive(Serialize, Deserialize),
13 serde(crate = "serde_crate")
14)]
15pub struct Error<E> {
16 pub timestep: usize,
17 pub error: E,
18}
19
20#[derive(Clone, Debug, Parameterised)]
21#[cfg_attr(
22 feature = "serde",
23 derive(Serialize, Deserialize),
24 serde(crate = "serde_crate")
25)]
26pub struct GradientMC<V> {
27 #[weights]
28 pub v_func: V,
29
30 pub gamma: f64,
31}
32
33impl<'m, S, A, V> Handler<&'m Trajectory<S, A>> for GradientMC<V>
34where V: Function<(&'m S,), Output = f64> + Handler<StateUpdate<&'m S, f64>>
35{
36 type Response = Vec<V::Response>;
37 type Error = Error<V::Error>;
38
39 fn handle(&mut self, traj: &'m Trajectory<S, A>) -> Result<Self::Response, Self::Error> {
40 let n = traj.n_transitions();
41 let mut sum = 0.0;
42
43 traj.iter().rev().enumerate().map(|(t, transition)| {
44 sum = transition.reward + self.gamma * sum;
45
46 let from = transition.from.state();
47 let pred = self.v_func.evaluate((from,));
48
49 self.v_func.handle(StateUpdate {
50 state: from,
51 error: sum - pred,
52 }).map_err(|e| Error {
53 timestep: n - t - 1,
54 error: e,
55 })
56 }).collect()
57 }
58}