Skip to main content

oxicuda_rl/estimator/
vtrace.rs

1//! # V-trace Off-Policy Return Estimation (IMPALA)
2//!
3//! Espeholt et al. (2018), "IMPALA: Scalable Distributed Deep-RL with
4//! Importance Weighted Actor-Learner Architectures", ICML 2018.
5//!
6//! ## Algorithm
7//!
8//! Given IS ratios `ρ_t = π(a_t|s_t) / μ(a_t|s_t)` between the current
9//! policy `π` and the behaviour policy `μ`:
10//!
11//! ```text
12//! c_t    = min(c̄, ρ_t)
13//! ρ̄_t   = min(ρ̄, ρ_t)
14//!
15//! δ_t    = ρ̄_t [r_t + γ V(s_{t+1})(1-done_t) - V(s_t)]
16//!
17//! v_s = V(s_s) + Σ_{t=s}^{s+n-1} γ^{t-s} (Π_{i=s}^{t-1} c_i) δ_t
18//! ```
19//!
20//! The advantage for the policy gradient is:
21//! ```text
22//! A_t = ρ̄_t [r_t + γ v_{t+1}(1-done_t) - V(s_t)]
23//! ```
24
25use crate::error::{RlError, RlResult};
26
27/// V-trace configuration.
28#[derive(Debug, Clone, Copy)]
29pub struct VtraceConfig {
30    /// Discount factor γ.
31    pub gamma: f32,
32    /// Clipping threshold c̄ for importance weights in the v-trace target.
33    pub c_bar: f32,
34    /// Clipping threshold ρ̄ for IS weights in TD errors.
35    pub rho_bar: f32,
36}
37
38impl Default for VtraceConfig {
39    fn default() -> Self {
40        Self {
41            gamma: 0.99,
42            c_bar: 1.0,
43            rho_bar: 1.0,
44        }
45    }
46}
47
48/// V-trace output.
49#[derive(Debug, Clone)]
50pub struct VtraceOutput {
51    /// V-trace value targets `v_t`.
52    pub vs: Vec<f32>,
53    /// Policy gradient advantages `A_t`.
54    pub advantages: Vec<f32>,
55}
56
57/// Compute V-trace returns and policy gradient advantages.
58///
59/// # Arguments
60///
61/// * `rewards`          — `[T]` rewards.
62/// * `values`           — `[T+1]` value estimates (including bootstrap `v_T`).
63/// * `dones`            — `[T]` done flags.
64/// * `log_probs_new`    — `[T]` log-probs under current policy `π`.
65/// * `log_probs_old`    — `[T]` log-probs under behaviour policy `μ`.
66/// * `cfg`              — V-trace hyperparameters.
67///
68/// # Errors
69///
70/// * [`RlError::DimensionMismatch`] if slice lengths are inconsistent.
71pub fn compute_vtrace(
72    rewards: &[f32],
73    values: &[f32],
74    dones: &[f32],
75    log_probs_new: &[f32],
76    log_probs_old: &[f32],
77    cfg: VtraceConfig,
78) -> RlResult<VtraceOutput> {
79    let t = rewards.len();
80    if values.len() != t + 1
81        || dones.len() != t
82        || log_probs_new.len() != t
83        || log_probs_old.len() != t
84    {
85        return Err(RlError::DimensionMismatch {
86            expected: t,
87            got: t.wrapping_sub(1),
88        });
89    }
90
91    // Compute IS ratios
92    let rho: Vec<f32> = log_probs_new
93        .iter()
94        .zip(log_probs_old.iter())
95        .map(|(&lp_new, &lp_old)| {
96            let ratio = (lp_new - lp_old).exp();
97            ratio.clamp(0.0, 1e6) // prevent extreme ratios
98        })
99        .collect();
100
101    let c_vals: Vec<f32> = rho.iter().map(|&r| r.min(cfg.c_bar)).collect();
102    let rho_vals: Vec<f32> = rho.iter().map(|&r| r.min(cfg.rho_bar)).collect();
103
104    // Compute TD errors δ_t = ρ̄_t [r_t + γ v_{t+1} (1-done_t) - v_t]
105    let deltas: Vec<f32> = (0..t)
106        .map(|i| {
107            let mask = 1.0 - dones[i];
108            rho_vals[i] * (rewards[i] + cfg.gamma * values[i + 1] * mask - values[i])
109        })
110        .collect();
111
112    // Compute v_s via backward scan
113    let mut vs = vec![0.0_f32; t];
114    let mut acc = values[t]; // v_T = V(s_T) for bootstrap
115
116    for i in (0..t).rev() {
117        let mask = 1.0 - dones[i];
118        acc = values[i] + deltas[i] + cfg.gamma * c_vals[i] * mask * (acc - values[i + 1]);
119        vs[i] = acc;
120    }
121
122    // Policy gradient advantages: A_t = ρ̄_t [r_t + γ v_{t+1} (1-done) - V(s_t)]
123    let advantages: Vec<f32> = (0..t)
124        .map(|i| {
125            let mask = 1.0 - dones[i];
126            // v_{t+1} is vs[t+1] if t+1 < T, else values[T]
127            let v_next = if i + 1 < t { vs[i + 1] } else { values[t] };
128            rho_vals[i] * (rewards[i] + cfg.gamma * v_next * mask - values[i])
129        })
130        .collect();
131
132    Ok(VtraceOutput { vs, advantages })
133}
134
135// ─── Tests ───────────────────────────────────────────────────────────────────
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    fn run_vtrace(t: usize, rho_same: bool) -> VtraceOutput {
142        let r = vec![1.0_f32; t];
143        let v = vec![0.5_f32; t + 1];
144        let d = vec![0.0_f32; t];
145        let lp_new = vec![0.0_f32; t]; // log π = 0
146        let lp_old = if rho_same {
147            vec![0.0_f32; t] // IS ratio = exp(0-0) = 1
148        } else {
149            vec![-1.0_f32; t] // IS ratio = exp(0-(-1)) = e > 1, clipped
150        };
151        compute_vtrace(&r, &v, &d, &lp_new, &lp_old, VtraceConfig::default()).unwrap()
152    }
153
154    #[test]
155    fn vtrace_output_length() {
156        let out = run_vtrace(10, true);
157        assert_eq!(out.vs.len(), 10);
158        assert_eq!(out.advantages.len(), 10);
159    }
160
161    #[test]
162    fn vtrace_on_policy_matches_td() {
163        // When π = μ (ρ = 1, both bars ≥ 1), V-trace should match TD with c̄=ρ̄=1
164        let cfg = VtraceConfig {
165            gamma: 0.99,
166            c_bar: 1.0,
167            rho_bar: 1.0,
168        };
169        let r = vec![1.0_f32; 3];
170        let v = vec![0.0_f32; 4];
171        let d = vec![0.0_f32; 3];
172        let lp = vec![0.0_f32; 3];
173        let out = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
174        // vs[2] ≈ 0 + (1 + 0.99*0 - 0) = 1.0 (δ = ρ * (r + γ*v_next - v))
175        assert!(
176            out.vs[2] > 0.5,
177            "v_trace target should be > v=0, vs[2]={}",
178            out.vs[2]
179        );
180    }
181
182    #[test]
183    fn vtrace_advantages_non_nan() {
184        let out = run_vtrace(5, false);
185        for (i, &a) in out.advantages.iter().enumerate() {
186            assert!(a.is_finite(), "advantage[{i}] is NaN/inf: {a}");
187        }
188    }
189
190    #[test]
191    fn vtrace_dimension_mismatch() {
192        let r = vec![1.0_f32; 3];
193        let v = vec![0.5_f32; 3]; // wrong: should be 4
194        let d = vec![0.0_f32; 3];
195        let lp = vec![0.0_f32; 3];
196        assert!(compute_vtrace(&r, &v, &d, &lp, &lp, VtraceConfig::default()).is_err());
197    }
198
199    #[test]
200    fn vtrace_done_stops_accumulation() {
201        let cfg = VtraceConfig::default();
202        let r = vec![1.0, 1.0, 1.0];
203        let v = vec![0.0_f32; 4];
204        let d = vec![0.0, 1.0, 0.0];
205        let lp = vec![0.0_f32; 3];
206        let out = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
207        // At done step, mask=0 so no future contributions
208        assert!(out.vs[1].is_finite());
209    }
210
211    #[test]
212    fn vtrace_clipping_reduces_large_rho() {
213        // Large rho (off-policy) should be clipped to c̄/ρ̄
214        let cfg = VtraceConfig {
215            gamma: 0.99,
216            c_bar: 1.0,
217            rho_bar: 1.0,
218        };
219        let r = vec![1.0_f32; 2];
220        let v = vec![0.5_f32; 3];
221        let d = vec![0.0_f32; 2];
222        let lp_new = vec![0.0_f32; 2];
223        let lp_old = vec![-100.0_f32; 2]; // huge IS ratio, clipped to ρ̄=1
224        let out = compute_vtrace(&r, &v, &d, &lp_new, &lp_old, cfg).unwrap();
225        // Clipped to ρ̄=1, so same as on-policy
226        let lp = vec![0.0_f32; 2];
227        let out_on = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
228        assert!(
229            (out.vs[0] - out_on.vs[0]).abs() < 1e-4,
230            "clipped rho should match on-policy: {} vs {}",
231            out.vs[0],
232            out_on.vs[0]
233        );
234    }
235}