1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
//! Critics for an actor-critic agent.
#![allow(clippy::use_self)] // false positive with serde derives
mod opt;
mod rtg;

pub use opt::{ValuesOpt, ValuesOptConfig};
pub use rtg::{RewardToGo, RewardToGoConfig};

use super::features::HistoryFeatures;
use crate::logging::StatsLogger;
use crate::torch::modules::SeqPacked;
use crate::torch::packed::PackedTensor;
use serde::{Deserialize, Serialize};
use tch::Device;

/// A critic for an [actor-critic agent][super::ActorCriticAgent].
///
/// Estimates the value (minus any per-state baseline) of each selected action in a collection of
/// experience. Learns from collected experience.
pub trait Critic {
    /// Value estimates of the selected actions offset by a baseline function of state.
    ///
    /// Formally, a estimate of `Q(a_t; o_0, ..., o_t) - b(o_0, ..., o_t)` for each time step `t`
    /// where `Q` is the discounted state-action value function and `b` is any baseline function
    /// that does not depend on the value of `a_t`. The environment is assumed to be partially
    /// observable so state is represented by the observation history.
    ///
    /// The returned values are suitable for use in REINFORCE policy gradient updates.
    ///
    /// # Design Note
    /// These "advantages" are more general than standard advantages, which require `b` to be the
    /// state value function. As far as I am aware, there is no name for "state-action value
    /// function with a state baseline" and the name "generalized advantages" is
    /// [taken][AdvantageFn::Gae]. "Advantages" was chosen as a reasonably evocative short name
    /// despite being technically incorrect.
    fn advantages(&self, features: &dyn HistoryFeatures) -> PackedTensor;

    /// Update the critic given a collection of experience features.
    fn update(&mut self, features: &dyn HistoryFeatures, logger: &mut dyn StatsLogger);
}

/// Build a [`Critic`].
pub trait BuildCritic {
    type Critic: Critic;

    fn build_critic(&self, in_dim: usize, discount_factor: f64, device: Device) -> Self::Critic;
}

/// Estimate baselined advantages from state value estimates and history features.
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum AdvantageFn {
    /// Generalized Advantage Estimation
    ///
    /// # Reference
    /// "[High-Dimensional Continuous Control Using Generalized Advantage Estimation][gae]"
    /// by Schulman et al.
    ///
    /// The `gamma` parameter is implemented as [`ValuesOptConfig::max_discount_factor`] so that
    /// the learned value module uses the same discount factor.
    ///
    /// [gae]: https://arxiv.org/abs/1506.02438
    Gae {
        /// Shaping parameter in `[0, 1]` prioritizing sampled reward-to-go over the value module.
        ///
        /// Selects the degree to which advantage estimates rely on the value module estimates
        /// (low values) versus the empirical reward-to-go (high values).
        /// * `lambda = 0` is the 1-step temporal difference: `r_t +  γ * V(s_{t+1}) - V(s_t)`
        /// * `lambda = 1` is the reward-to-go with baseline: `sum_l(γ^l r_{t+l}) - V(s_t)`
        ///
        /// Lower values reduce variance but increase bias when the value function module is
        /// incorrect.
        lambda: f32,
    },
}

impl Default for AdvantageFn {
    fn default() -> Self {
        Self::Gae { lambda: 0.95 }
    }
}

impl AdvantageFn {
    /// Estimate baselined advantages of selected actions given a state value function module.
    pub fn advantages<M: SeqPacked + ?Sized>(
        &self,
        state_value_fn: &M,
        discount_factor: f32,
        features: &dyn HistoryFeatures,
    ) -> PackedTensor {
        match self {
            Self::Gae { lambda } => gae(state_value_fn, discount_factor, *lambda, features),
        }
    }
}

/// Discounted reward-to-go
///
/// # Args:
/// * `discount_factor` - Discount factor on future rewards. In `[0, 1]`.
/// * `features` - Experience features.
pub fn reward_to_go(discount_factor: f32, features: &dyn HistoryFeatures) -> PackedTensor {
    features
        .rewards()
        .discounted_cumsum_from_end(discount_factor)
}

/// Apply a state value function to `HistoryFeatures::extended_observations`.
///
/// # Args:
/// * `state_value_fn` - State value function estimator using past & present episode observations.
/// * `discount_factor` - Discount factor on future rewards. In `[0, 1]`.
/// * `features` - Experience features.
///
/// Returns a [`PackedTensor`] of sequences that are one longer than the episode lengths.
/// The final value is `0` if the episode ended and the next state value if interrupted.
pub fn eval_extended_state_values<M: SeqPacked + ?Sized>(
    state_value_fn: &M,
    features: &dyn HistoryFeatures,
) -> PackedTensor {
    let (extended_observation_features, is_invalid) = features.extended_observation_features();

    // Packed estimated values of the observed states
    let mut extended_estimated_values = state_value_fn
        .seq_packed(extended_observation_features)
        .batch_map(|t| t.squeeze_dim(-1));
    let _ = extended_estimated_values
        .tensor_mut()
        .masked_fill_(is_invalid.tensor(), 0.0);

    extended_estimated_values
}

/// One-step targets of a state value function.
///
/// # Args:
/// * `state_value_fn` - State value function estimator using past & present episode observations.
/// * `discount_factor` - Discount factor on future rewards. In `[0, 1]`.
/// * `features` - Experience features.
pub fn one_step_values<M: SeqPacked + ?Sized>(
    state_value_fn: &M,
    discount_factor: f32,
    features: &dyn HistoryFeatures,
) -> PackedTensor {
    // Estimated value for each of `step.next.into_inner().observation`
    let estimated_next_values =
        eval_extended_state_values(state_value_fn, features).view_trim_start(1);
    features
        .rewards()
        .batch_map_ref(|rewards| rewards + discount_factor * estimated_next_values.tensor())
}

/// One-step temporal difference residuals of a state value function
///
/// # Args:
/// * `state_value_fn` - State value function estimator using past & present episode observations.
/// * `discount_factor` - Discount factor on future rewards. In `[0, 1]`.
/// * `features` - Experience features.
pub fn temporal_differences<M: SeqPacked + ?Sized>(
    state_value_fn: &M,
    discount_factor: f32,
    features: &dyn HistoryFeatures,
) -> PackedTensor {
    let extended_state_values = eval_extended_state_values(state_value_fn, features);

    // Estimated values for each of `step.observation`
    let estimated_values = extended_state_values.trim_end(1);

    // Estimated value for each of `step.next.into_inner().observation`
    let estimated_next_values = extended_state_values.view_trim_start(1);

    features.rewards().batch_map_ref(|rewards| {
        rewards + discount_factor * estimated_next_values.tensor() - estimated_values.tensor()
    })
}

/// Generalized advantage estimation
///
/// # Reference
/// "[High-Dimensional Continuous Control Using Generalized Advantage Estimation][gae]"
/// by Schulman et al.
///
/// [gae]: https://arxiv.org/abs/1506.02438
///
/// # Args:
/// * `state_value_fn` - State value function estimator using past & present episode observations.
/// * `discount_factor` - Discount factor on future rewards. In `[0, 1]`.
/// * `lambda` - Parameter prioritizing sampled reward-to-go over the value module. In `[0, 1]`.
///     See [`AdvantageFn::Gae::lambda`].
/// * `features` - Experience features.
pub fn gae<M: SeqPacked + ?Sized>(
    state_value_fn: &M,
    discount_factor: f32,
    lambda: f32,
    features: &dyn HistoryFeatures,
) -> PackedTensor {
    let residuals = temporal_differences(state_value_fn, discount_factor, features);

    residuals.discounted_cumsum_from_end(lambda * discount_factor)
}

/// Target function for per-step selected-action value estimates.
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum StepValueTarget {
    /// The empirical reward-to-go: discounted sum of future rewards.
    RewardToGo,
    /// One-step temporal-difference targets: `r_i + γ * V(s_{i+1})`
    OneStepTd,
}

impl Default for StepValueTarget {
    fn default() -> Self {
        Self::RewardToGo
    }
}

impl StepValueTarget {
    /// Generate state value targets for each state in a collection of experience.
    pub fn targets<M: SeqPacked + ?Sized>(
        &self,
        state_value_fn: &M,
        discount_factor: f32,
        features: &dyn HistoryFeatures,
    ) -> PackedTensor {
        match self {
            Self::RewardToGo => reward_to_go(discount_factor, features),
            Self::OneStepTd => one_step_values(state_value_fn, discount_factor, features),
        }
    }
}