1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use super::{PartialStep, Simulation};
use std::iter::FusedIterator;
#[derive(Debug, Default, Clone)]
pub struct TakeEpisodes<I> {
steps: I,
n: usize,
}
impl<I> TakeEpisodes<I> {
#[inline]
pub const fn new(steps: I, n: usize) -> Self {
Self { steps, n }
}
}
impl<I> Simulation for TakeEpisodes<I>
where
I: Simulation,
{
type Observation = I::Observation;
type Action = I::Action;
type Feedback = I::Feedback;
type Environment = I::Environment;
type Actor = I::Actor;
type Logger = I::Logger;
#[inline]
fn env(&self) -> &Self::Environment {
self.steps.env()
}
#[inline]
fn env_mut(&mut self) -> &mut Self::Environment {
self.steps.env_mut()
}
#[inline]
fn actor(&self) -> &Self::Actor {
self.steps.actor()
}
#[inline]
fn actor_mut(&mut self) -> &mut Self::Actor {
self.steps.actor_mut()
}
#[inline]
fn logger(&self) -> &Self::Logger {
self.steps.logger()
}
#[inline]
fn logger_mut(&mut self) -> &mut Self::Logger {
self.steps.logger_mut()
}
}
impl<I, O, A, F> Iterator for TakeEpisodes<I>
where
I: Iterator<Item = PartialStep<O, A, F>>,
{
type Item = PartialStep<O, A, F>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.n > 0 {
let step = self.steps.next()?;
if step.episode_done() {
self.n -= 1;
}
Some(step)
} else {
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let (min, max) = self.steps.size_hint();
(min.min(self.n), max)
}
#[inline]
fn fold<B, G>(self, init: B, g: G) -> B
where
G: FnMut(B, Self::Item) -> B,
{
let mut n = self.n;
self.take_while(move |step| {
if n > 0 {
if step.episode_done() {
n -= 1;
}
true
} else {
false
}
})
.fold(init, g)
}
}
impl<I, O, A> FusedIterator for TakeEpisodes<I> where I: FusedIterator<Item = PartialStep<O, A>> {}
#[cfg(test)]
mod tests {
use crate::agents::RandomAgent;
use crate::envs::{Chain, EnvStructure, Environment, VisibleStepLimit, Wrap};
use crate::simulation::{SimSeed, StepsIter, StepsSummary};
#[allow(clippy::cast_possible_truncation)]
#[test]
fn episode_count() {
let steps_per_episode = 10;
let num_episodes = 30;
let env = Chain::default().wrap(VisibleStepLimit::new(steps_per_episode));
let agent = RandomAgent::new(env.action_space());
let summary: StepsSummary<_> = env
.run(agent, SimSeed::Root(53), ())
.take((5 * steps_per_episode * num_episodes) as usize)
.take_episodes(num_episodes as usize)
.collect();
assert_eq!(summary.num_episodes(), num_episodes);
assert_eq!(summary.num_steps(), steps_per_episode * num_episodes);
}
}