use crate::error::{RlError, RlResult};
#[derive(Debug, Clone, Copy)]
pub struct VtraceConfig {
pub gamma: f32,
pub c_bar: f32,
pub rho_bar: f32,
}
impl Default for VtraceConfig {
fn default() -> Self {
Self {
gamma: 0.99,
c_bar: 1.0,
rho_bar: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct VtraceOutput {
pub vs: Vec<f32>,
pub advantages: Vec<f32>,
}
pub fn compute_vtrace(
rewards: &[f32],
values: &[f32],
dones: &[f32],
log_probs_new: &[f32],
log_probs_old: &[f32],
cfg: VtraceConfig,
) -> RlResult<VtraceOutput> {
let t = rewards.len();
if values.len() != t + 1
|| dones.len() != t
|| log_probs_new.len() != t
|| log_probs_old.len() != t
{
return Err(RlError::DimensionMismatch {
expected: t,
got: t.wrapping_sub(1),
});
}
let rho: Vec<f32> = log_probs_new
.iter()
.zip(log_probs_old.iter())
.map(|(&lp_new, &lp_old)| {
let ratio = (lp_new - lp_old).exp();
ratio.clamp(0.0, 1e6) })
.collect();
let c_vals: Vec<f32> = rho.iter().map(|&r| r.min(cfg.c_bar)).collect();
let rho_vals: Vec<f32> = rho.iter().map(|&r| r.min(cfg.rho_bar)).collect();
let deltas: Vec<f32> = (0..t)
.map(|i| {
let mask = 1.0 - dones[i];
rho_vals[i] * (rewards[i] + cfg.gamma * values[i + 1] * mask - values[i])
})
.collect();
let mut vs = vec![0.0_f32; t];
let mut acc = values[t];
for i in (0..t).rev() {
let mask = 1.0 - dones[i];
acc = values[i] + deltas[i] + cfg.gamma * c_vals[i] * mask * (acc - values[i + 1]);
vs[i] = acc;
}
let advantages: Vec<f32> = (0..t)
.map(|i| {
let mask = 1.0 - dones[i];
let v_next = if i + 1 < t { vs[i + 1] } else { values[t] };
rho_vals[i] * (rewards[i] + cfg.gamma * v_next * mask - values[i])
})
.collect();
Ok(VtraceOutput { vs, advantages })
}
#[cfg(test)]
mod tests {
use super::*;
fn run_vtrace(t: usize, rho_same: bool) -> VtraceOutput {
let r = vec![1.0_f32; t];
let v = vec![0.5_f32; t + 1];
let d = vec![0.0_f32; t];
let lp_new = vec![0.0_f32; t]; let lp_old = if rho_same {
vec![0.0_f32; t] } else {
vec![-1.0_f32; t] };
compute_vtrace(&r, &v, &d, &lp_new, &lp_old, VtraceConfig::default()).unwrap()
}
#[test]
fn vtrace_output_length() {
let out = run_vtrace(10, true);
assert_eq!(out.vs.len(), 10);
assert_eq!(out.advantages.len(), 10);
}
#[test]
fn vtrace_on_policy_matches_td() {
let cfg = VtraceConfig {
gamma: 0.99,
c_bar: 1.0,
rho_bar: 1.0,
};
let r = vec![1.0_f32; 3];
let v = vec![0.0_f32; 4];
let d = vec![0.0_f32; 3];
let lp = vec![0.0_f32; 3];
let out = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
assert!(
out.vs[2] > 0.5,
"v_trace target should be > v=0, vs[2]={}",
out.vs[2]
);
}
#[test]
fn vtrace_advantages_non_nan() {
let out = run_vtrace(5, false);
for (i, &a) in out.advantages.iter().enumerate() {
assert!(a.is_finite(), "advantage[{i}] is NaN/inf: {a}");
}
}
#[test]
fn vtrace_dimension_mismatch() {
let r = vec![1.0_f32; 3];
let v = vec![0.5_f32; 3]; let d = vec![0.0_f32; 3];
let lp = vec![0.0_f32; 3];
assert!(compute_vtrace(&r, &v, &d, &lp, &lp, VtraceConfig::default()).is_err());
}
#[test]
fn vtrace_done_stops_accumulation() {
let cfg = VtraceConfig::default();
let r = vec![1.0, 1.0, 1.0];
let v = vec![0.0_f32; 4];
let d = vec![0.0, 1.0, 0.0];
let lp = vec![0.0_f32; 3];
let out = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
assert!(out.vs[1].is_finite());
}
#[test]
fn vtrace_clipping_reduces_large_rho() {
let cfg = VtraceConfig {
gamma: 0.99,
c_bar: 1.0,
rho_bar: 1.0,
};
let r = vec![1.0_f32; 2];
let v = vec![0.5_f32; 3];
let d = vec![0.0_f32; 2];
let lp_new = vec![0.0_f32; 2];
let lp_old = vec![-100.0_f32; 2]; let out = compute_vtrace(&r, &v, &d, &lp_new, &lp_old, cfg).unwrap();
let lp = vec![0.0_f32; 2];
let out_on = compute_vtrace(&r, &v, &d, &lp, &lp, cfg).unwrap();
assert!(
(out.vs[0] - out_on.vs[0]).abs() < 1e-4,
"clipped rho should match on-policy: {} vs {}",
out.vs[0],
out_on.vs[0]
);
}
}