1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
use super::{
reward_to_go, BuildCritic, Critic, Device, HistoryFeatures, PackedTensor, StatsLogger,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct RewardToGoConfig;
impl BuildCritic for RewardToGoConfig {
type Critic = RewardToGo;
fn build_critic(&self, _in_dim: usize, discount_factor: f64, _device: Device) -> Self::Critic {
#[allow(clippy::cast_possible_truncation)]
RewardToGo {
discount_factor: discount_factor as f32,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct RewardToGo {
discount_factor: f32,
}
impl Critic for RewardToGo {
fn advantages(&self, features: &dyn HistoryFeatures) -> PackedTensor {
reward_to_go(self.discount_factor, features)
}
fn update(&mut self, _features: &dyn HistoryFeatures, _logger: &mut dyn StatsLogger) {}
}