use crate::error::{RlError, RlResult};
#[derive(Debug, Clone, Copy)]
pub struct RetraceConfig {
pub gamma: f32,
pub lambda: f32,
}
impl Default for RetraceConfig {
fn default() -> Self {
Self {
gamma: 0.99,
lambda: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct RetraceOutput {
pub q_targets: Vec<f32>,
pub td_errors: Vec<f32>,
}
pub fn compute_retrace(
rewards: &[f32],
q_values: &[f32],
values: &[f32],
dones: &[f32],
log_probs_new: &[f32],
log_probs_old: &[f32],
cfg: RetraceConfig,
) -> RlResult<RetraceOutput> {
let t = rewards.len();
if q_values.len() != t
|| values.len() != t + 1
|| dones.len() != t
|| log_probs_new.len() != t
|| log_probs_old.len() != t
{
return Err(RlError::DimensionMismatch {
expected: t,
got: q_values.len(),
});
}
let c: Vec<f32> = log_probs_new
.iter()
.zip(log_probs_old.iter())
.map(|(&lp_new, &lp_old)| {
let rho = (lp_new - lp_old).exp().clamp(0.0, 1e6);
cfg.lambda * rho.min(1.0)
})
.collect();
let td_errors: Vec<f32> = (0..t)
.map(|i| {
let mask = 1.0 - dones[i];
rewards[i] + cfg.gamma * values[i + 1] * mask - q_values[i]
})
.collect();
let mut q_targets = vec![0.0_f32; t];
let mut ret_next = values[t];
for i in (0..t).rev() {
let mask = 1.0 - dones[i];
let c_next = if i + 1 < t { c[i + 1] } else { 0.0 };
let q_next = if i + 1 < t {
q_values[i + 1]
} else {
values[t]
};
q_targets[i] = q_values[i] + td_errors[i] + cfg.gamma * mask * c_next * (ret_next - q_next);
ret_next = q_targets[i];
}
Ok(RetraceOutput {
q_targets,
td_errors,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_retrace(t: usize, same_policy: bool) -> RetraceOutput {
let r = vec![1.0_f32; t];
let q = vec![0.5_f32; t];
let v = vec![0.5_f32; t + 1];
let d = vec![0.0_f32; t];
let lp_new = vec![0.0_f32; t];
let lp_old = if same_policy {
vec![0.0_f32; t]
} else {
vec![-0.5_f32; t]
};
compute_retrace(&r, &q, &v, &d, &lp_new, &lp_old, RetraceConfig::default()).unwrap()
}
#[test]
fn retrace_output_length() {
let out = simple_retrace(5, true);
assert_eq!(out.q_targets.len(), 5);
assert_eq!(out.td_errors.len(), 5);
}
#[test]
fn retrace_td_errors_finite() {
let out = simple_retrace(4, false);
for (i, &d) in out.td_errors.iter().enumerate() {
assert!(d.is_finite(), "td_error[{i}]={d}");
}
}
#[test]
fn retrace_on_policy_q_targets_finite() {
let out = simple_retrace(4, true);
for (i, &q) in out.q_targets.iter().enumerate() {
assert!(q.is_finite(), "q_target[{i}]={q}");
}
}
#[test]
fn retrace_dimension_mismatch() {
let r = vec![1.0_f32; 3];
let q = vec![0.5_f32; 3];
let v = vec![0.5_f32; 3]; let d = vec![0.0_f32; 3];
let lp = vec![0.0_f32; 3];
assert!(compute_retrace(&r, &q, &v, &d, &lp, &lp, RetraceConfig::default()).is_err());
}
#[test]
fn retrace_done_stops_accumulation() {
let cfg = RetraceConfig::default();
let r = vec![1.0, 1.0, 1.0];
let q = vec![0.0_f32; 3];
let v = vec![0.0_f32; 4];
let d = vec![0.0, 1.0, 0.0];
let lp = vec![0.0_f32; 3];
let out = compute_retrace(&r, &q, &v, &d, &lp, &lp, cfg).unwrap();
assert!(out.q_targets.iter().all(|&q| q.is_finite()));
}
#[test]
fn retrace_lambda_zero_is_bellman() {
let cfg = RetraceConfig {
gamma: 0.99,
lambda: 0.0,
};
let r = vec![1.0_f32; 3];
let q = vec![0.5_f32; 3];
let v = vec![0.5_f32; 4];
let d = vec![0.0_f32; 3];
let lp = vec![0.0_f32; 3];
let out = compute_retrace(&r, &q, &v, &d, &lp, &lp, cfg).unwrap();
let expected = 1.0 + 0.99 * 0.5;
for (i, &qt) in out.q_targets.iter().enumerate() {
assert!(
(qt - expected).abs() < 1e-4,
"Q_target[{i}]={qt} vs {expected}"
);
}
}
}