use crate::error::{RlError, RlResult};
#[derive(Debug, Clone)]
pub struct NStepTransition {
pub obs: Vec<f32>,
pub action: Vec<f32>,
pub n_step_return: f32,
pub bootstrap_obs: Vec<f32>,
pub done: bool,
pub actual_n: usize,
pub gamma_n: f32,
}
#[derive(Debug, Clone)]
struct Step {
obs: Vec<f32>,
action: Vec<f32>,
reward: f32,
next_obs: Vec<f32>,
done: bool,
}
#[derive(Debug, Clone)]
pub struct NStepBuffer {
n: usize,
gamma: f32,
steps: Vec<Option<Step>>,
head: usize,
count: usize,
}
impl NStepBuffer {
#[must_use]
pub fn new(n: usize, gamma: f32) -> Self {
assert!(n > 0, "n must be > 0");
Self {
n,
gamma,
steps: vec![None; n],
head: 0,
count: 0,
}
}
#[must_use]
#[inline]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
#[inline]
pub fn gamma(&self) -> f32 {
self.gamma
}
#[must_use]
#[inline]
pub fn count(&self) -> usize {
self.count
}
pub fn push(
&mut self,
obs: impl Into<Vec<f32>>,
action: impl Into<Vec<f32>>,
reward: f32,
next_obs: impl Into<Vec<f32>>,
done: bool,
) -> Option<NStepTransition> {
let step = Step {
obs: obs.into(),
action: action.into(),
reward,
next_obs: next_obs.into(),
done,
};
self.steps[self.head] = Some(step);
self.head = (self.head + 1) % self.n;
if self.count < self.n {
self.count += 1;
}
if self.count == self.n {
Some(self.compute_return())
} else if done {
Some(self.compute_partial_return())
} else {
None
}
}
pub fn flush(&mut self) -> Vec<NStepTransition> {
let mut out = Vec::new();
while self.count > 0 {
out.push(self.compute_partial_return());
let oldest = (self.head + self.n - self.count) % self.n;
self.steps[oldest] = None;
self.count -= 1;
}
out
}
fn compute_return(&self) -> NStepTransition {
self.compute_n_step_return(self.n)
}
fn compute_partial_return(&self) -> NStepTransition {
self.compute_n_step_return(self.count)
}
fn compute_n_step_return(&self, steps: usize) -> NStepTransition {
let oldest = (self.head + self.n - self.count) % self.n;
let first = self.steps[oldest]
.as_ref()
.expect("oldest step must be Some");
let mut cumulative = 0.0_f32;
let mut gamma_k = 1.0_f32;
let mut last_next_obs = first.next_obs.clone();
let mut terminated = false;
for k in 0..steps {
let idx = (oldest + k) % self.n;
let step = self.steps[idx].as_ref().expect("step must be Some");
cumulative += gamma_k * step.reward;
gamma_k *= self.gamma;
last_next_obs = step.next_obs.clone();
if step.done {
terminated = true;
break;
}
}
NStepTransition {
obs: first.obs.clone(),
action: first.action.clone(),
n_step_return: cumulative,
bootstrap_obs: last_next_obs,
done: terminated,
actual_n: steps,
gamma_n: gamma_k,
}
}
pub fn reset(&mut self) {
for s in self.steps.iter_mut() {
*s = None;
}
self.head = 0;
self.count = 0;
}
pub fn try_get(&self) -> RlResult<NStepTransition> {
if self.count < self.n {
return Err(RlError::NStepIncomplete {
have: self.count,
need: self.n,
});
}
Ok(self.compute_return())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_buf(n: usize, gamma: f32) -> NStepBuffer {
NStepBuffer::new(n, gamma)
}
#[test]
fn none_before_n_steps() {
let mut buf = make_buf(3, 0.99);
let r1 = buf.push([0.0], [0.0], 1.0, [1.0], false);
let r2 = buf.push([1.0], [0.0], 1.0, [2.0], false);
assert!(r1.is_none(), "should be None after 1 step");
assert!(r2.is_none(), "should be None after 2 steps");
}
#[test]
fn returns_transition_at_n() {
let mut buf = make_buf(3, 0.99);
buf.push([0.0], [0.0], 1.0, [1.0], false);
buf.push([1.0], [0.0], 1.0, [2.0], false);
let t = buf.push([2.0], [0.0], 1.0, [3.0], false);
assert!(t.is_some(), "should return transition at n=3");
let t = t.unwrap();
assert!(
(t.n_step_return - (1.0 + 0.99 + 0.99_f32 * 0.99)).abs() < 1e-4,
"n_step_return={}",
t.n_step_return
);
assert_eq!(t.actual_n, 3);
}
#[test]
fn discount_applied_correctly() {
let mut buf = make_buf(2, 0.5);
buf.push([0.0], [0.0], 2.0, [1.0], false);
let t = buf.push([1.0], [0.0], 4.0, [2.0], false);
let t = t.unwrap();
assert!(
(t.n_step_return - 4.0).abs() < 1e-5,
"n_step_return={}",
t.n_step_return
);
assert!((t.gamma_n - 0.25).abs() < 1e-5, "gamma_n={}", t.gamma_n);
}
#[test]
fn terminal_truncates_return() {
let mut buf = make_buf(5, 0.99);
buf.push([0.0], [0.0], 1.0, [1.0], false);
let t = buf.push([1.0], [0.0], 2.0, [2.0], true); assert!(t.is_some(), "terminal step should emit transition early");
let t = t.unwrap();
assert!(t.done, "done flag should be set");
assert!(
(t.n_step_return - (1.0 + 0.99 * 2.0)).abs() < 1e-4,
"n_step_return={}",
t.n_step_return
);
}
#[test]
fn flush_returns_remaining() {
let mut buf = make_buf(3, 0.99);
buf.push([0.0], [0.0], 1.0, [1.0], false);
buf.push([1.0], [0.0], 2.0, [2.0], false);
let flushed = buf.flush();
assert!(
!flushed.is_empty(),
"flush should return partial transitions"
);
}
#[test]
fn flush_clears_buffer() {
let mut buf = make_buf(3, 0.99);
buf.push([0.0], [0.0], 1.0, [1.0], false);
buf.flush();
assert_eq!(buf.count(), 0);
}
#[test]
fn try_get_before_n_error() {
let mut buf = make_buf(3, 0.99);
buf.push([0.0], [0.0], 1.0, [1.0], false);
assert!(buf.try_get().is_err());
}
#[test]
fn try_get_after_n_ok() {
let mut buf = make_buf(2, 0.99);
buf.push([0.0], [0.0], 1.0, [1.0], false);
buf.push([1.0], [0.0], 2.0, [2.0], false);
assert!(buf.try_get().is_ok());
}
#[test]
fn reset_clears() {
let mut buf = make_buf(3, 0.99);
buf.push([0.0], [0.0], 1.0, [1.0], false);
buf.push([1.0], [0.0], 2.0, [2.0], false);
buf.reset();
assert_eq!(buf.count(), 0);
assert!(buf.try_get().is_err());
}
#[test]
fn obs_preserved_correctly() {
let mut buf = make_buf(1, 0.9);
let t = buf.push([7.0, 8.0], [3.0], 5.0, [9.0, 10.0], false);
let t = t.unwrap();
assert_eq!(t.obs, vec![7.0, 8.0]);
assert_eq!(t.action, vec![3.0]);
assert_eq!(t.bootstrap_obs, vec![9.0, 10.0]);
}
}