relearn/agents/buffers/
vec.rs

1use super::{HistoryDataBound, WriteExperience, WriteExperienceError, WriteExperienceIncremental};
2use crate::feedback::Reward;
3use crate::simulation::PartialStep;
4use crate::utils::iter::{Differences, SplitChunksByLength};
5use std::iter::Copied;
6use std::{slice, vec};
7
8/// Simple vector history buffer. Stores steps in a vector.
9///
10/// The buffer records steps from a series of episodes one after another.
11/// The buffer is ready when either
12/// * the current episode is done and at least `soft_threshold` steps have been collected; or
13/// * at least `hard_threshold` steps have been collected.
14#[derive(Debug, Clone, PartialEq)]
15pub struct VecBuffer<O, A, F = Reward> {
16    /// Steps from all episodes with each episode stored contiguously
17    steps: Vec<PartialStep<O, A, F>>,
18    /// One past the end index of each episode within `steps`.
19    episode_ends: Vec<usize>,
20}
21
22impl<O, A, F> VecBuffer<O, A, F> {
23    /// Create a new empty [`VecBuffer`].
24    #[must_use]
25    pub const fn new() -> Self {
26        Self {
27            steps: Vec::new(),
28            episode_ends: Vec::new(),
29        }
30    }
31
32    /// Create a new buffer with capacity for the given amount of history data.
33    #[must_use]
34    pub fn with_capacity_for(bound: HistoryDataBound) -> Self {
35        Self {
36            steps: Vec::with_capacity(bound.min_steps.saturating_add(bound.slack_steps)),
37            episode_ends: Vec::new(),
38        }
39    }
40
41    /// Clear all stored data
42    pub fn clear(&mut self) {
43        self.steps.clear();
44        self.episode_ends.clear();
45    }
46
47    /// The number of steps stored in the buffer.
48    #[must_use]
49    pub fn num_steps(&self) -> usize {
50        self.steps.len()
51    }
52
53    /// The number of episodes stored in the buffer.
54    #[must_use]
55    pub fn num_episodes(&self) -> usize {
56        self.episode_ends.len()
57    }
58
59    /// Iterator over all steps stored in the buffer.
60    pub fn steps(&self) -> slice::Iter<PartialStep<O, A, F>> {
61        self.steps.iter()
62    }
63
64    /// Draining iterator over all steps stored in the buffer.
65    pub fn drain_steps(&mut self) -> vec::Drain<PartialStep<O, A, F>> {
66        self.steps.drain(..)
67    }
68
69    /// Iterator over all episode slices stored in the buffer.
70    #[must_use]
71    pub fn episodes(&self) -> EpisodesIter<O, A, F> {
72        SplitChunksByLength::new(
73            &self.steps,
74            Differences::new(self.episode_ends.iter().copied(), 0),
75        )
76    }
77}
78
79impl<O, A, F> From<Vec<PartialStep<O, A, F>>> for VecBuffer<O, A, F> {
80    fn from(steps: Vec<PartialStep<O, A, F>>) -> Self {
81        let episode_ends = steps
82            .iter()
83            .enumerate()
84            .filter_map(|(i, step)| {
85                if step.episode_done() {
86                    Some(i + 1)
87                } else {
88                    None
89                }
90            })
91            .collect();
92        let mut buffer = Self {
93            steps,
94            episode_ends,
95        };
96        buffer.end_experience();
97        buffer
98    }
99}
100
101impl<O, A, F> FromIterator<PartialStep<O, A, F>> for VecBuffer<O, A, F> {
102    fn from_iter<I>(steps: I) -> Self
103    where
104        I: IntoIterator<Item = PartialStep<O, A, F>>,
105    {
106        let mut buffer = Self::new();
107        buffer.write_experience(steps).unwrap(); // TODO: Maybe just ignore if full?
108        buffer
109    }
110}
111
112impl<O, A, F> WriteExperience<O, A, F> for VecBuffer<O, A, F> {
113    fn write_experience<I>(&mut self, steps: I) -> Result<(), WriteExperienceError>
114    where
115        I: IntoIterator<Item = PartialStep<O, A, F>>,
116    {
117        let offset = self.steps.len();
118        self.steps.extend(steps);
119        for (i, step) in self.steps[offset..].iter().enumerate() {
120            if step.episode_done() {
121                self.episode_ends.push(offset + i + 1)
122            }
123        }
124        self.end_experience();
125        Ok(())
126    }
127}
128
129impl<O, A, F> WriteExperienceIncremental<O, A, F> for VecBuffer<O, A, F> {
130    fn write_step(&mut self, step: PartialStep<O, A, F>) -> Result<(), WriteExperienceError> {
131        let episode_done = step.episode_done();
132        self.steps.push(step);
133        if episode_done {
134            self.episode_ends.push(self.steps.len());
135        }
136        Ok(())
137    }
138
139    fn end_experience(&mut self) {
140        if super::finalize_last_episode(&mut self.steps) {
141            self.episode_ends.push(self.steps.len())
142        }
143    }
144}
145
146pub type EpisodesIter<'a, O, A, F> = SplitChunksByLength<
147    &'a [PartialStep<O, A, F>],
148    Differences<Copied<slice::Iter<'a, usize>>, usize>,
149>;
150
151#[allow(clippy::needless_pass_by_value)]
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::envs::Successor::{self, Continue, Interrupt, Terminate};
156    use rstest::{fixture, rstest};
157
158    const fn step(observation: usize, next: Successor<usize, ()>) -> PartialStep<usize, bool> {
159        PartialStep {
160            observation,
161            action: false,
162            feedback: Reward(0.0),
163            next,
164        }
165    }
166
167    /// A buffer containing (in order)
168    /// * an episode of length 1,
169    /// * an episode of length 2, and
170    /// * 2 steps of an incomplete episode.
171    #[fixture]
172    fn buffer() -> VecBuffer<usize, bool> {
173        [
174            step(0, Terminate),
175            step(1, Continue(())),
176            step(2, Terminate),
177            step(3, Continue(())),
178            step(4, Continue(())),
179        ]
180        .into_iter()
181        .collect()
182    }
183
184    #[rstest]
185    fn from_vec(buffer: VecBuffer<usize, bool>) {
186        let test: VecBuffer<_, _> = vec![
187            step(0, Terminate),
188            step(1, Continue(())),
189            step(2, Terminate),
190            step(3, Continue(())),
191            step(4, Continue(())),
192        ]
193        .into();
194        assert_eq!(test, buffer);
195    }
196
197    #[rstest]
198    fn write_experience(buffer: VecBuffer<usize, bool>) {
199        let mut test = VecBuffer::new();
200        test.write_experience([
201            step(0, Terminate),
202            step(1, Continue(())),
203            step(2, Terminate),
204            step(3, Continue(())),
205            step(4, Continue(())),
206        ])
207        .unwrap();
208        assert_eq!(test, buffer);
209    }
210
211    #[rstest]
212    fn num_steps(buffer: VecBuffer<usize, bool>) {
213        // The last step is dropped in finalization
214        assert_eq!(buffer.num_steps(), 4);
215    }
216
217    #[rstest]
218    fn num_episodes(buffer: VecBuffer<usize, bool>) {
219        assert_eq!(buffer.num_episodes(), 3);
220    }
221
222    #[rstest]
223    fn steps(buffer: VecBuffer<usize, bool>) {
224        let mut steps_iter = buffer.steps();
225        assert_eq!(steps_iter.next(), Some(&step(0, Terminate)));
226        assert_eq!(steps_iter.next(), Some(&step(1, Continue(()))));
227        assert_eq!(steps_iter.next(), Some(&step(2, Terminate)));
228        assert_eq!(steps_iter.next(), Some(&step(3, Interrupt(4))));
229        assert_eq!(steps_iter.next(), None);
230    }
231
232    #[rstest]
233    fn steps_is_fused(buffer: VecBuffer<usize, bool>) {
234        let mut steps_iter = buffer.steps();
235        for _ in 0..4 {
236            assert!(steps_iter.next().is_some());
237        }
238        assert!(steps_iter.next().is_none());
239        assert!(steps_iter.next().is_none());
240    }
241
242    #[rstest]
243    fn steps_len(buffer: VecBuffer<usize, bool>) {
244        assert_eq!(buffer.steps().len(), buffer.num_steps());
245    }
246
247    #[rstest]
248    fn episodes(buffer: VecBuffer<usize, bool>) {
249        let mut episodes_iter = buffer.episodes();
250        assert_eq!(
251            episodes_iter.next().unwrap().iter().collect::<Vec<_>>(),
252            [&step(0, Terminate)]
253        );
254        assert_eq!(
255            episodes_iter.next().unwrap().iter().collect::<Vec<_>>(),
256            [&step(1, Continue(())), &step(2, Terminate)]
257        );
258        assert_eq!(
259            episodes_iter.next().unwrap().iter().collect::<Vec<_>>(),
260            [&step(3, Interrupt(4))]
261        );
262        assert!(episodes_iter.next().is_none());
263    }
264
265    #[rstest]
266    fn episodes_len(buffer: VecBuffer<usize, bool>) {
267        assert_eq!(buffer.episodes().len(), buffer.num_episodes());
268    }
269
270    #[rstest]
271    fn episode_len_sum(buffer: VecBuffer<usize, bool>) {
272        assert_eq!(
273            buffer.episodes().map(<[_]>::len).sum::<usize>(),
274            buffer.num_steps()
275        );
276    }
277}