1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use super::{PartialStep, Simulation};
use std::iter::FusedIterator;

/// An iterator that iterates over the first `n` episodes of `steps`.
#[derive(Debug, Default, Clone)]
pub struct TakeEpisodes<I> {
    steps: I,
    n: usize,
}

impl<I> TakeEpisodes<I> {
    #[inline]
    pub const fn new(steps: I, n: usize) -> Self {
        Self { steps, n }
    }
}

impl<I> Simulation for TakeEpisodes<I>
where
    I: Simulation,
{
    type Observation = I::Observation;
    type Action = I::Action;
    type Feedback = I::Feedback;
    type Environment = I::Environment;
    type Actor = I::Actor;
    type Logger = I::Logger;

    #[inline]
    fn env(&self) -> &Self::Environment {
        self.steps.env()
    }
    #[inline]
    fn env_mut(&mut self) -> &mut Self::Environment {
        self.steps.env_mut()
    }
    #[inline]
    fn actor(&self) -> &Self::Actor {
        self.steps.actor()
    }
    #[inline]
    fn actor_mut(&mut self) -> &mut Self::Actor {
        self.steps.actor_mut()
    }
    #[inline]
    fn logger(&self) -> &Self::Logger {
        self.steps.logger()
    }
    #[inline]
    fn logger_mut(&mut self) -> &mut Self::Logger {
        self.steps.logger_mut()
    }
}

impl<I, O, A, F> Iterator for TakeEpisodes<I>
where
    I: Iterator<Item = PartialStep<O, A, F>>,
{
    type Item = PartialStep<O, A, F>;

    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        if self.n > 0 {
            let step = self.steps.next()?;
            if step.episode_done() {
                self.n -= 1;
            }
            Some(step)
        } else {
            None
        }
    }

    #[inline]
    fn size_hint(&self) -> (usize, Option<usize>) {
        let (min, max) = self.steps.size_hint();
        // Each episode requires at least one step
        (min.min(self.n), max)
    }

    #[inline]
    fn fold<B, G>(self, init: B, g: G) -> B
    where
        G: FnMut(B, Self::Item) -> B,
    {
        let mut n = self.n;
        self.take_while(move |step| {
            if n > 0 {
                if step.episode_done() {
                    n -= 1;
                }
                true
            } else {
                false
            }
        })
        .fold(init, g)
    }
}

impl<I, O, A> FusedIterator for TakeEpisodes<I> where I: FusedIterator<Item = PartialStep<O, A>> {}

#[cfg(test)]
mod tests {
    use crate::agents::RandomAgent;
    use crate::envs::{Chain, EnvStructure, Environment, VisibleStepLimit, Wrap};
    use crate::simulation::{SimSeed, StepsIter, StepsSummary};

    #[allow(clippy::cast_possible_truncation)]
    #[test]
    fn episode_count() {
        let steps_per_episode = 10;
        let num_episodes = 30;

        let env = Chain::default().wrap(VisibleStepLimit::new(steps_per_episode));
        let agent = RandomAgent::new(env.action_space());
        let summary: StepsSummary<_> = env
            .run(agent, SimSeed::Root(53), ())
            // Additional step bound so that the test does not hang if take_episodes breaks
            .take((5 * steps_per_episode * num_episodes) as usize)
            .take_episodes(num_episodes as usize)
            .collect();
        assert_eq!(summary.num_episodes(), num_episodes);
        assert_eq!(summary.num_steps(), steps_per_episode * num_episodes);
    }
}