use super::{PartialStep, Simulation};
use std::iter::FusedIterator;
#[derive(Debug, Default, Clone)]
pub struct TakeAlignedSteps<I> {
steps: I,
n: usize,
slack_steps: usize,
}
impl<I> TakeAlignedSteps<I> {
pub const fn new(steps: I, min_steps: usize, slack_steps: usize) -> Self {
let n = if min_steps == 0 {
0
} else {
min_steps + slack_steps
};
Self {
steps,
n,
slack_steps,
}
}
}
impl<I> Simulation for TakeAlignedSteps<I>
where
I: Simulation,
{
type Observation = I::Observation;
type Action = I::Action;
type Feedback = I::Feedback;
type Environment = I::Environment;
type Actor = I::Actor;
type Logger = I::Logger;
#[inline]
fn env(&self) -> &Self::Environment {
self.steps.env()
}
#[inline]
fn env_mut(&mut self) -> &mut Self::Environment {
self.steps.env_mut()
}
#[inline]
fn actor(&self) -> &Self::Actor {
self.steps.actor()
}
#[inline]
fn actor_mut(&mut self) -> &mut Self::Actor {
self.steps.actor_mut()
}
#[inline]
fn logger(&self) -> &Self::Logger {
self.steps.logger()
}
#[inline]
fn logger_mut(&mut self) -> &mut Self::Logger {
self.steps.logger_mut()
}
}
impl<I, O, A, F> Iterator for TakeAlignedSteps<I>
where
I: Iterator<Item = PartialStep<O, A, F>>,
{
type Item = PartialStep<O, A, F>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.n == 0 {
return None;
}
let step = self.steps.next()?;
self.n -= 1;
if step.episode_done() && self.n <= self.slack_steps {
self.n = 0;
}
Some(step)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let (inner_min, inner_max) = self.steps.size_hint();
let min = inner_min.min(self.n.saturating_sub(self.slack_steps));
let max = inner_max.map_or(self.n, |m| m.min(self.n));
(min, Some(max))
}
#[inline]
fn fold<B, G>(self, init: B, g: G) -> B
where
G: FnMut(B, Self::Item) -> B,
{
let mut n = self.n;
let slack_steps = self.slack_steps;
self.take_while(move |step| {
if n == 0 {
return false;
}
n -= 1;
if step.episode_done() && n <= slack_steps {
n = 0;
}
true
})
.fold(init, g)
}
}
impl<I, O, A> FusedIterator for TakeAlignedSteps<I> where I: FusedIterator<Item = PartialStep<O, A>> {}
#[cfg(test)]
#[allow(clippy::needless_pass_by_value)]
mod tests {
use super::super::StepsIter;
use super::*;
use crate::envs::Successor::{self, Continue, Interrupt, Terminate};
use crate::feedback::Reward;
use rstest::{fixture, rstest};
const fn step<O>(observation: O, next: Successor<O, ()>) -> PartialStep<O, ()> {
PartialStep {
observation,
action: (),
feedback: Reward(0.0),
next,
}
}
type Steps = Vec<PartialStep<u8, ()>>;
#[fixture]
fn steps() -> Steps {
vec![
step(0, Continue(())),
step(1, Terminate),
step(10, Continue(())),
step(11, Continue(())),
step(12, Terminate),
step(20, Continue(())),
step(21, Continue(())),
step(23, Interrupt(23)),
]
}
#[rstest]
fn take_no_steps(steps: Steps) {
assert_eq!(
steps
.into_iter()
.take_aligned_steps(0, 2)
.collect::<Vec<_>>(),
[]
);
}
#[rstest]
fn take_all_steps(steps: Steps) {
assert_eq!(
steps
.iter()
.copied()
.take_aligned_steps(100, 2)
.collect::<Vec<_>>(),
steps
);
}
#[rstest]
fn take_aligned_no_slack(steps: Steps) {
assert_eq!(
steps
.iter()
.copied()
.take_aligned_steps(5, 0)
.collect::<Vec<_>>(),
steps[..5]
);
}
#[rstest]
fn take_aligned_slack(steps: Steps) {
assert_eq!(
steps
.iter()
.copied()
.take_aligned_steps(5, 2)
.collect::<Vec<_>>(),
steps[..5]
);
}
#[rstest]
fn take_unaligned_no_slack(steps: Steps) {
assert_eq!(
steps
.iter()
.copied()
.take_aligned_steps(3, 0)
.collect::<Vec<_>>(),
steps[..3]
);
}
#[rstest]
fn take_unaligned_slack(steps: Steps) {
assert_eq!(
steps
.iter()
.copied()
.take_aligned_steps(3, 2)
.collect::<Vec<_>>(),
steps[..5]
);
}
}