use crate::error::{RlError, RlResult};
#[derive(Debug, Clone, Copy)]
pub struct TdConfig {
pub gamma: f32,
pub lambda: f32,
}
impl Default for TdConfig {
fn default() -> Self {
Self {
gamma: 0.99,
lambda: 0.95,
}
}
}
pub fn compute_td_lambda(
rewards: &[f32],
values: &[f32],
dones: &[f32],
cfg: TdConfig,
) -> RlResult<Vec<f32>> {
let t = rewards.len();
if values.len() != t + 1 {
return Err(RlError::DimensionMismatch {
expected: t + 1,
got: values.len(),
});
}
if dones.len() != t {
return Err(RlError::DimensionMismatch {
expected: t,
got: dones.len(),
});
}
let mut returns = vec![0.0_f32; t];
let mut g = values[t];
for i in (0..t).rev() {
let mask = 1.0 - dones[i];
let mixed = (1.0 - cfg.lambda) * values[i + 1] + cfg.lambda * g;
g = rewards[i] + cfg.gamma * mask * mixed;
returns[i] = g;
}
Ok(returns)
}
pub fn compute_td_advantages(
rewards: &[f32],
values: &[f32],
dones: &[f32],
cfg: TdConfig,
) -> RlResult<Vec<f32>> {
let returns = compute_td_lambda(rewards, values, dones, cfg)?;
let advantages: Vec<f32> = returns
.iter()
.zip(&values[..rewards.len()])
.map(|(&g, &v)| g - v)
.collect();
Ok(advantages)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn td_lambda_zero_equals_one_step() {
let cfg = TdConfig {
gamma: 0.99,
lambda: 0.0,
};
let r = vec![1.0_f32; 3];
let v = vec![0.0_f32; 4]; let d = vec![0.0_f32; 3];
let g = compute_td_lambda(&r, &v, &d, cfg).unwrap();
for &gi in &g {
assert!((gi - 1.0).abs() < 1e-5, "λ=0 return={gi}");
}
}
#[test]
fn td_lambda_one_is_monte_carlo() {
let cfg = TdConfig {
gamma: 1.0,
lambda: 1.0,
};
let r = vec![1.0_f32; 3];
let v = vec![0.0_f32; 4];
let d = vec![0.0_f32; 3];
let g = compute_td_lambda(&r, &v, &d, cfg).unwrap();
assert!((g[0] - 3.0).abs() < 1e-5, "G_0={}", g[0]);
assert!((g[1] - 2.0).abs() < 1e-5, "G_1={}", g[1]);
assert!((g[2] - 1.0).abs() < 1e-5, "G_2={}", g[2]);
}
#[test]
fn td_lambda_done_truncates() {
let cfg = TdConfig {
gamma: 0.99,
lambda: 0.95,
};
let r = vec![1.0, 1.0, 1.0];
let v = vec![0.5, 0.5, 0.5, 0.5];
let d = vec![0.0, 1.0, 0.0]; let g = compute_td_lambda(&r, &v, &d, cfg).unwrap();
assert!(
(g[1] - r[1]).abs() < 0.1,
"G_1 at done should ≈ r={}, got {}",
r[1],
g[1]
);
}
#[test]
fn td_advantages_length() {
let r = vec![1.0_f32; 5];
let v = vec![0.5_f32; 6];
let d = vec![0.0_f32; 5];
let a = compute_td_advantages(&r, &v, &d, TdConfig::default()).unwrap();
assert_eq!(a.len(), 5);
}
#[test]
fn td_dimension_mismatch_values() {
let r = vec![1.0_f32; 3];
let v = vec![0.5_f32; 3]; let d = vec![0.0_f32; 3];
assert!(compute_td_lambda(&r, &v, &d, TdConfig::default()).is_err());
}
#[test]
fn td_returns_decrease_with_later_steps() {
let cfg = TdConfig {
gamma: 1.0,
lambda: 1.0,
};
let r = vec![1.0_f32; 5];
let v = vec![0.0_f32; 6];
let d = vec![0.0_f32; 5];
let g = compute_td_lambda(&r, &v, &d, cfg).unwrap();
assert!(
g[0] > g[1] && g[1] > g[2],
"returns should decrease for later steps"
);
}
}