use crate::error::{RlError, RlResult};
#[derive(Debug, Clone, Copy)]
pub struct GaeConfig {
pub gamma: f32,
pub lambda: f32,
pub normalise: bool,
}
impl Default for GaeConfig {
fn default() -> Self {
Self {
gamma: 0.99,
lambda: 0.95,
normalise: true,
}
}
}
#[derive(Debug, Clone)]
pub struct GaeOutput {
pub advantages: Vec<f32>,
pub returns: Vec<f32>,
}
pub fn compute_gae(
rewards: &[f32],
values: &[f32],
next_values: &[f32],
dones: &[f32],
cfg: GaeConfig,
) -> RlResult<GaeOutput> {
let t = rewards.len();
if values.len() != t || next_values.len() != t || dones.len() != t {
return Err(RlError::DimensionMismatch {
expected: t,
got: values.len(),
});
}
let gamma_lambda = cfg.gamma * cfg.lambda;
let mut advantages = vec![0.0_f32; t];
let mut gae = 0.0_f32;
for i in (0..t).rev() {
let mask = 1.0 - dones[i];
let delta = rewards[i] + cfg.gamma * next_values[i] * mask - values[i];
gae = delta + gamma_lambda * mask * gae;
advantages[i] = gae;
}
if cfg.normalise && t > 1 {
let mean = advantages.iter().sum::<f32>() / t as f32;
let var = advantages
.iter()
.map(|&a| (a - mean) * (a - mean))
.sum::<f32>()
/ t as f32;
let std = (var + 1e-8).sqrt();
for a in advantages.iter_mut() {
*a = (*a - mean) / std;
}
}
let returns: Vec<f32> = advantages
.iter()
.zip(values.iter())
.map(|(&a, &v)| a + v)
.collect();
Ok(GaeOutput {
advantages,
returns,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn ones_trajectory(t: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let rewards = vec![1.0_f32; t];
let values = vec![0.5_f32; t];
let next_values = vec![0.5_f32; t];
let dones = vec![0.0_f32; t];
(rewards, values, next_values, dones)
}
#[test]
fn gae_output_length() {
let (r, v, nv, d) = ones_trajectory(10);
let out = compute_gae(&r, &v, &nv, &d, GaeConfig::default()).unwrap();
assert_eq!(out.advantages.len(), 10);
assert_eq!(out.returns.len(), 10);
}
#[test]
fn gae_dimension_mismatch() {
let r = vec![1.0; 5];
let v = vec![0.5; 4]; let nv = vec![0.5; 5];
let d = vec![0.0; 5];
assert!(compute_gae(&r, &v, &nv, &d, GaeConfig::default()).is_err());
}
#[test]
fn gae_lambda_zero_is_td() {
let cfg = GaeConfig {
gamma: 0.99,
lambda: 0.0,
normalise: false,
};
let r = vec![1.0_f32; 3];
let v = vec![0.0_f32; 3];
let nv = vec![0.0_f32; 3];
let d = vec![0.0_f32; 3];
let out = compute_gae(&r, &v, &nv, &d, cfg).unwrap();
for &a in &out.advantages {
assert!((a - 1.0).abs() < 1e-5, "A={a}");
}
}
#[test]
fn gae_done_resets_accumulation() {
let cfg = GaeConfig {
gamma: 0.99,
lambda: 0.95,
normalise: false,
};
let r = vec![1.0, 1.0, 1.0];
let v = vec![0.0, 0.0, 0.0];
let nv = vec![0.0, 0.0, 0.0];
let d = vec![0.0, 0.0, 1.0]; let out = compute_gae(&r, &v, &nv, &d, cfg).unwrap();
assert!(
out.advantages[0] > out.advantages[2],
"earlier steps with future returns should have higher advantage"
);
}
#[test]
fn gae_normalise_zero_mean() {
let (r, v, nv, d) = ones_trajectory(20);
let cfg = GaeConfig {
normalise: true,
..GaeConfig::default()
};
let out = compute_gae(&r, &v, &nv, &d, cfg).unwrap();
let mean = out.advantages.iter().sum::<f32>() / 20.0;
assert!(
mean.abs() < 1e-4,
"normalised mean should be ≈0, got {mean}"
);
}
#[test]
fn gae_normalise_unit_std() {
let (r, v, nv, d) = ones_trajectory(50);
let cfg = GaeConfig {
normalise: true,
..GaeConfig::default()
};
let out = compute_gae(&r, &v, &nv, &d, cfg).unwrap();
let mean = out.advantages.iter().sum::<f32>() / 50.0;
let var: f32 = out
.advantages
.iter()
.map(|&a| (a - mean).powi(2))
.sum::<f32>()
/ 50.0;
let std = var.sqrt();
assert!(
(std - 1.0).abs() < 0.05,
"normalised std should be ≈1, got {std}"
);
}
#[test]
fn gae_returns_equal_advantage_plus_value() {
let (r, v, nv, d) = ones_trajectory(5);
let cfg = GaeConfig {
normalise: false,
..GaeConfig::default()
};
let out = compute_gae(&r, &v, &nv, &d, cfg).unwrap();
for (i, (&ret, (&a, &vi))) in out
.returns
.iter()
.zip(out.advantages.iter().zip(v.iter()))
.enumerate()
{
assert!(
(ret - (a + vi)).abs() < 1e-5,
"G[{i}] = A+V failed: {ret} vs {}",
a + vi
);
}
}
}