mod null;
mod replay;
mod vec;
pub use null::NullBuffer;
pub use replay::ReplayBuffer;
pub use vec::VecBuffer;
use crate::envs::Successor;
use crate::simulation::{PartialStep, Step, StepsIter, TakeAlignedSteps};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use thiserror::Error;
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct HistoryDataBound {
pub min_steps: usize,
pub slack_steps: usize,
}
impl HistoryDataBound {
#[inline]
#[must_use]
pub const fn empty() -> Self {
Self {
min_steps: 0,
slack_steps: 0,
}
}
#[inline]
#[must_use]
pub const fn new(min_steps: usize, slack_steps: usize) -> Self {
Self {
min_steps,
slack_steps,
}
}
#[inline]
#[must_use]
pub fn with_default_slack(min_steps: usize) -> Self {
let slack_steps = (min_steps / 100).clamp(5, 1000);
Self {
min_steps,
slack_steps,
}
}
#[inline]
#[must_use]
pub fn max(self, other: Self) -> Self {
Self {
min_steps: self.min_steps.max(other.min_steps),
slack_steps: self.slack_steps.max(other.slack_steps),
}
}
#[inline]
#[must_use]
pub const fn divide(self, n: usize) -> Self {
Self {
min_steps: div_ceil(self.min_steps, n),
slack_steps: self.slack_steps,
}
}
#[inline]
#[must_use]
pub fn is_satisfied<O, A, F, U>(
&self,
num_steps: usize,
last: Option<&Step<O, A, F, U>>,
) -> bool {
num_steps >= self.min_steps && last.map_or(true, Step::episode_done)
|| num_steps >= self.min_steps + self.slack_steps
}
#[inline]
pub fn take<I, O, A, F>(self, steps: I) -> TakeAlignedSteps<I::IntoIter>
where
I: IntoIterator<Item = PartialStep<O, A, F>>,
{
steps
.into_iter()
.take_aligned_steps(self.min_steps, self.slack_steps)
}
}
const fn div_ceil(numerator: usize, denominator: usize) -> usize {
let mut quotient = numerator / denominator;
let remainder = numerator % denominator;
if remainder > 0 {
quotient += 1;
}
quotient
}
pub trait WriteExperience<O, A, F>: WriteExperienceIncremental<O, A, F> {
fn write_experience<I>(&mut self, steps: I) -> Result<(), WriteExperienceError>
where
I: IntoIterator<Item = PartialStep<O, A, F>>,
Self: Sized,
{
for (i, step) in steps.into_iter().enumerate() {
self.write_step(step).map_err(|e| match e {
WriteExperienceError::Full { written_steps } => {
assert_eq!(
written_steps, 0,
"write_step `Full` has non-zero written steps"
);
WriteExperienceError::Full { written_steps: i }
}
})?;
}
self.end_experience();
Ok(())
}
}
macro_rules! impl_wrapped_write_experience {
($wrapper:ty) => {
impl<T, O, A, F> WriteExperience<O, A, F> for $wrapper where
T: WriteExperience<O, A, F> + ?Sized
{
}
};
}
impl_wrapped_write_experience!(&'_ mut T);
impl_wrapped_write_experience!(Box<T>);
pub trait WriteExperienceIncremental<O, A, F> {
fn write_step(&mut self, step: PartialStep<O, A, F>) -> Result<(), WriteExperienceError>;
fn end_experience(&mut self);
}
macro_rules! impl_wrapped_write_experience_incremental {
($wrapper:ty) => {
impl<T, O, A, F> WriteExperienceIncremental<O, A, F> for $wrapper
where
T: WriteExperienceIncremental<O, A, F> + ?Sized,
{
fn write_step(
&mut self,
step: PartialStep<O, A, F>,
) -> Result<(), WriteExperienceError> {
T::write_step(self, step)
}
fn end_experience(&mut self) {
T::end_experience(self)
}
}
};
}
impl_wrapped_write_experience_incremental!(&'_ mut T);
impl_wrapped_write_experience_incremental!(Box<T>);
#[derive(Error, Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum WriteExperienceError {
#[error("buffer full after writing {written_steps} steps")]
Full { written_steps: usize },
}
fn finalize_last_episode<S, O, A, F>(steps: &mut S) -> bool
where
S: Stack<PartialStep<O, A, F>>,
{
if steps.top().map_or(true, PartialStep::episode_done) {
return false;
}
let final_observation = steps.pop().unwrap().observation;
if let Some(step) = steps.top_mut() {
if !step.episode_done() {
step.next = Successor::Interrupt(final_observation);
return true;
}
}
false
}
trait Stack<T> {
fn push(&mut self, value: T);
fn pop(&mut self) -> Option<T>;
fn top(&self) -> Option<&T>;
fn top_mut(&mut self) -> Option<&mut T>;
}
impl<T> Stack<T> for Vec<T> {
fn push(&mut self, value: T) {
Self::push(self, value)
}
fn pop(&mut self) -> Option<T> {
Self::pop(self)
}
fn top(&self) -> Option<&T> {
self.last()
}
fn top_mut(&mut self) -> Option<&mut T> {
self.last_mut()
}
}
impl<T> Stack<T> for VecDeque<T> {
fn push(&mut self, value: T) {
self.push_back(value)
}
fn pop(&mut self) -> Option<T> {
self.pop_back()
}
fn top(&self) -> Option<&T> {
self.back()
}
fn top_mut(&mut self) -> Option<&mut T> {
self.back_mut()
}
}