use super::{WriteExperience, WriteExperienceError, WriteExperienceIncremental};
use crate::feedback::Reward;
use crate::simulation::PartialStep;
use crate::utils::iter::{Differences, SplitChunksByLength};
use crate::utils::sequence::Sequence;
use crate::utils::slice::SplitSlice;
use std::collections::{vec_deque, VecDeque};
use std::iter::{Copied, Map};
#[derive(Debug, Clone, PartialEq)]
pub struct ReplayBuffer<O, A, F = Reward> {
steps: VecDeque<PartialStep<O, A, F>>,
episode_ends: VecDeque<u64>,
index_offset: u64,
total_step_count: u64,
}
impl<O, A, F> ReplayBuffer<O, A, F> {
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
steps: VecDeque::with_capacity(capacity),
episode_ends: VecDeque::new(),
index_offset: 0,
total_step_count: 0,
}
}
#[must_use]
pub fn capacity(&self) -> usize {
self.steps.capacity()
}
#[must_use]
pub fn num_steps(&self) -> usize {
self.steps.len()
}
#[must_use]
pub fn num_episodes(&self) -> usize {
self.episode_ends.len()
}
#[must_use]
pub fn steps(&self) -> vec_deque::Iter<PartialStep<O, A, F>> {
self.steps.iter()
}
#[must_use]
pub fn episodes(&self) -> Episodes<O, A, F> {
Episodes {
steps: self.steps.as_slices().into(),
episode_ends: &self.episode_ends,
index_offset: self.index_offset,
}
}
#[must_use]
pub const fn total_step_count(&self) -> u64 {
self.total_step_count
}
}
impl<O, A, F> WriteExperienceIncremental<O, A, F> for ReplayBuffer<O, A, F> {
fn write_step(&mut self, step: PartialStep<O, A, F>) -> Result<(), WriteExperienceError> {
if self.steps.len() == self.steps.capacity() {
let ep_end = self
.episode_ends
.pop_front()
.ok_or(WriteExperienceError::Full { written_steps: 0 })?;
assert!(
ep_end > self.index_offset,
"episodes always have at least 1 step"
);
#[allow(clippy::cast_possible_truncation)]
let ep_len = (ep_end - self.index_offset) as usize;
self.steps.drain(0..ep_len);
self.index_offset = ep_end;
}
let episode_done = step.episode_done();
self.steps.push_back(step);
self.total_step_count += 1;
if episode_done {
self.episode_ends.push_back(self.total_step_count);
}
Ok(())
}
fn end_experience(&mut self) {
if super::finalize_last_episode(&mut self.steps) {
self.total_step_count -= 1; self.episode_ends.push_back(self.total_step_count);
assert_eq!(
self.total_step_count,
self.index_offset + self.steps.len() as u64
);
}
}
}
impl<O, A, F> WriteExperience<O, A, F> for ReplayBuffer<O, A, F> {}
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Episodes<'a, O, A, F> {
steps: SplitSlice<'a, PartialStep<O, A, F>>,
episode_ends: &'a VecDeque<u64>,
index_offset: u64,
}
impl<'a, O, A, F> Sequence for Episodes<'a, O, A, F> {
type Item = SplitSlice<'a, PartialStep<O, A, F>>;
#[inline]
fn len(&self) -> usize {
self.episode_ends.len()
}
#[inline]
fn is_empty(&self) -> bool {
self.episode_ends.is_empty()
}
#[allow(clippy::cast_possible_truncation)]
#[inline]
fn get(&self, idx: usize) -> Option<Self::Item> {
let end = (self.episode_ends.get(idx)? - self.index_offset) as usize;
let start = if idx == 0 {
0
} else {
(self.episode_ends.get(idx - 1).unwrap() - self.index_offset) as usize
};
assert!(end >= start);
Some(self.steps.split_at(start).1.split_at(end - start).0)
}
}
impl<'a, O, A, F> IntoIterator for Episodes<'a, O, A, F> {
type IntoIter = EpisodesIter<'a, O, A, F>;
type Item = SplitSlice<'a, PartialStep<O, A, F>>;
fn into_iter(self) -> Self::IntoIter {
SplitChunksByLength::new(
self.steps,
Differences::new(self.episode_ends.iter().copied(), self.index_offset)
.map(u64_as_usize as _),
)
}
}
pub type EpisodesIter<'a, O, A, F> = SplitChunksByLength<
SplitSlice<'a, PartialStep<O, A, F>>,
Map<Differences<Copied<vec_deque::Iter<'a, u64>>, u64>, fn(u64) -> usize>,
>;
#[allow(clippy::cast_possible_truncation)]
#[inline]
const fn u64_as_usize(x: u64) -> usize {
x as _
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::Successor::{self, Continue, Interrupt, Terminate};
use std::iter;
const fn step(observation: usize, next: Successor<usize, ()>) -> PartialStep<usize, bool> {
PartialStep {
observation,
action: false,
feedback: Reward(0.0),
next,
}
}
#[test]
fn comprehensive() {
let mut buffer = ReplayBuffer::with_capacity(7);
assert_eq!(
buffer.capacity(),
7,
"Implementation detail; rework test if this fails"
);
let ep1 = [
step(0, Continue(())),
step(1, Continue(())),
step(2, Terminate),
];
buffer.write_experience(ep1).unwrap();
assert_eq!(buffer.num_steps(), 3);
assert_eq!(buffer.num_episodes(), 1);
assert!(buffer.steps().eq(&ep1));
assert!(buffer.episodes().into_iter().eq([&ep1 as &[_]].into_iter()));
let ep2_raw = [
step(3, Continue(())),
step(4, Continue(())),
step(5, Continue(())),
];
let ep2_finalized = [step(3, Continue(())), step(4, Interrupt(5))];
buffer.write_experience(ep2_raw).unwrap();
assert_eq!(buffer.num_steps(), 5);
assert_eq!(buffer.num_episodes(), 2);
assert!(buffer.steps().eq(ep1.iter().chain(&ep2_finalized)));
assert!(buffer
.episodes()
.into_iter()
.eq([&ep1 as &[_], &ep2_finalized as &[_]].into_iter()));
let ep3 = [
step(6, Continue(())),
step(7, Continue(())),
step(8, Terminate),
];
buffer.write_experience(ep3).unwrap();
assert_eq!(buffer.num_steps(), 5);
assert_eq!(buffer.num_episodes(), 2);
assert!(buffer.steps().eq(ep2_finalized.iter().chain(&ep3)));
assert!(buffer
.episodes()
.into_iter()
.eq([&ep2_finalized as &[_], &ep3 as &[_]].into_iter()));
let ep45 = [step(9, Terminate), step(10, Terminate)];
buffer.write_experience(ep45).unwrap();
assert_eq!(buffer.num_steps(), 7);
assert_eq!(buffer.num_episodes(), 4);
assert!(buffer
.steps()
.eq(ep2_finalized.iter().chain(&ep3).chain(&ep45)));
assert!(buffer.episodes().into_iter().eq([
&ep2_finalized as &[_],
&ep3 as &[_],
&ep45[..1] as &[_],
&ep45[1..] as &[_]
]
.into_iter()));
}
#[test]
fn get_episode() {
let mut buffer = ReplayBuffer::with_capacity(7);
assert_eq!(
buffer.capacity(),
7,
"Implementation detail; rework test if this fails"
);
let data = [
step(0, Continue(())),
step(1, Continue(())),
step(2, Terminate),
step(3, Continue(())),
step(4, Interrupt(5)),
step(6, Continue(())),
step(7, Continue(())),
step(8, Terminate),
step(9, Terminate),
step(10, Terminate),
];
buffer.write_experience(data).unwrap();
let episodes = buffer.episodes();
assert_eq!(episodes.get(1).unwrap(), &data[5..8]);
assert_eq!(episodes.get(3).unwrap(), &data[9..10]);
assert!(episodes.get(4).is_none());
}
#[test]
fn episode_too_large() {
let mut buffer = ReplayBuffer::with_capacity(7);
let result = buffer.write_experience(iter::repeat(step(0, Continue(()))).take(100));
assert!(matches!(
result,
Err(WriteExperienceError::Full { written_steps: _ })
));
}
}