oxicuda_rl/estimator/
vtrace.rs1use crate::error::{RlError, RlResult};
26
27#[derive(Debug, Clone, Copy)]
29pub struct VtraceConfig {
30 pub gamma: f32,
32 pub c_bar: f32,
34 pub rho_bar: f32,
36}
37
38impl Default for VtraceConfig {
39 fn default() -> Self {
40 Self {
41 gamma: 0.99,
42 c_bar: 1.0,
43 rho_bar: 1.0,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct VtraceOutput {
51 pub vs: Vec<f32>,
53 pub advantages: Vec<f32>,
55}
56
57pub fn compute_vtrace(
72 rewards: &[f32],
73 values: &[f32],
74 dones: &[f32],
75 log_probs_new: &[f32],
76 log_probs_old: &[f32],
77 cfg: VtraceConfig,
78) -> RlResult<VtraceOutput> {
79 let t = rewards.len();
80 if values.len() != t + 1
81 || dones.len() != t
82 || log_probs_new.len() != t
83 || log_probs_old.len() != t
84 {
85 return Err(RlError::DimensionMismatch {
86 expected: t,
87 got: t.wrapping_sub(1),
88 });
89 }
90
91 let rho: Vec<f32> = log_probs_new
93 .iter()
94 .zip(log_probs_old.iter())
95 .map(|(&lp_new, &lp_old)| {
96 let ratio = (lp_new - lp_old).exp();
97 ratio.clamp(0.0, 1e6) })
99 .collect();
100
101 let c_vals: Vec<f32> = rho.iter().map(|&r| r.min(cfg.c_bar)).collect();
102 let rho_vals: Vec<f32> = rho.iter().map(|&r| r.min(cfg.rho_bar)).collect();
103
104 let deltas: Vec<f32> = (0..t)
106 .map(|i| {
107 let mask = 1.0 - dones[i];
108 rho_vals[i] * (rewards[i] + cfg.gamma * values[i + 1] * mask - values[i])
109 })
110 .collect();
111
112 let mut vs = vec![0.0_f32; t];
114 let mut acc = values[t]; for i in (0..t).rev() {
117 let mask = 1.0 - dones[i];
118 acc = values[i] + deltas[i] + cfg.gamma * c_vals[i] * mask * (acc - values[i + 1]);
119 vs[i] = acc;
120 }
121
122 let advantages: Vec<f32> = (0..t)
124 .map(|i| {
125 let mask = 1.0 - dones[i];
126 let v_next = if i + 1 < t { vs[i + 1] } else { values[t] };
128 rho_vals[i] * (rewards[i] + cfg.gamma * v_next * mask - values[i])
129 })
130 .collect();
131
132 Ok(VtraceOutput { vs, advantages })
133}
134
135#[cfg(test)]
138mod tests {
139 use super::*;
140
141 fn run_vtrace(t: usize, rho_same: bool) -> VtraceOutput {
142 let r = vec![1.0_f32; t];
143 let v = vec![0.5_f32; t + 1];
144 let d = vec![0.0_f32; t];
145 let lp_new = vec![0.0_f32; t]; let lp_old = if rho_same {
147 vec![0.0_f32; t] } else {
149 vec![-1.0_f32; t] };
151 compute_vtrace(&r, &v, &d, &lp_new, &lp_old, VtraceConfig::default()).unwrap()
152 }
153
154 #[test]
155 fn vtrace_output_length() {
156 let out = run_vtrace(10, true);
157 assert_eq!(out.vs.len(), 10);
158 assert_eq!(out.advantages.len(), 10);
159 }
160
161 #[test]
162 fn vtrace_on_policy_matches_td() {
163 let cfg = VtraceConfig {
165 gamma: 0.99,
166 c_bar: 1.0,
167 rho_bar: 1.0,
168 };
169 let r = vec![1.0_f32; 3];
170 let v = vec![0.0_f32; 4];
171 let d = vec![0.0_f32; 3];
172 let lp = vec![0.0_f32; 3];
173 let out = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
174 assert!(
176 out.vs[2] > 0.5,
177 "v_trace target should be > v=0, vs[2]={}",
178 out.vs[2]
179 );
180 }
181
182 #[test]
183 fn vtrace_advantages_non_nan() {
184 let out = run_vtrace(5, false);
185 for (i, &a) in out.advantages.iter().enumerate() {
186 assert!(a.is_finite(), "advantage[{i}] is NaN/inf: {a}");
187 }
188 }
189
190 #[test]
191 fn vtrace_dimension_mismatch() {
192 let r = vec![1.0_f32; 3];
193 let v = vec![0.5_f32; 3]; let d = vec![0.0_f32; 3];
195 let lp = vec![0.0_f32; 3];
196 assert!(compute_vtrace(&r, &v, &d, &lp, &lp, VtraceConfig::default()).is_err());
197 }
198
199 #[test]
200 fn vtrace_done_stops_accumulation() {
201 let cfg = VtraceConfig::default();
202 let r = vec![1.0, 1.0, 1.0];
203 let v = vec![0.0_f32; 4];
204 let d = vec![0.0, 1.0, 0.0];
205 let lp = vec![0.0_f32; 3];
206 let out = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
207 assert!(out.vs[1].is_finite());
209 }
210
211 #[test]
212 fn vtrace_clipping_reduces_large_rho() {
213 let cfg = VtraceConfig {
215 gamma: 0.99,
216 c_bar: 1.0,
217 rho_bar: 1.0,
218 };
219 let r = vec![1.0_f32; 2];
220 let v = vec![0.5_f32; 3];
221 let d = vec![0.0_f32; 2];
222 let lp_new = vec![0.0_f32; 2];
223 let lp_old = vec![-100.0_f32; 2]; let out = compute_vtrace(&r, &v, &d, &lp_new, &lp_old, cfg).unwrap();
225 let lp = vec![0.0_f32; 2];
227 let out_on = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
228 assert!(
229 (out.vs[0] - out_on.vs[0]).abs() < 1e-4,
230 "clipped rho should match on-policy: {} vs {}",
231 out.vs[0],
232 out_on.vs[0]
233 );
234 }
235}