1use super::{HistoryDataBound, WriteExperience, WriteExperienceError, WriteExperienceIncremental};
2use crate::feedback::Reward;
3use crate::simulation::PartialStep;
4use crate::utils::iter::{Differences, SplitChunksByLength};
5use std::iter::Copied;
6use std::{slice, vec};
7
8#[derive(Debug, Clone, PartialEq)]
15pub struct VecBuffer<O, A, F = Reward> {
16 steps: Vec<PartialStep<O, A, F>>,
18 episode_ends: Vec<usize>,
20}
21
22impl<O, A, F> VecBuffer<O, A, F> {
23 #[must_use]
25 pub const fn new() -> Self {
26 Self {
27 steps: Vec::new(),
28 episode_ends: Vec::new(),
29 }
30 }
31
32 #[must_use]
34 pub fn with_capacity_for(bound: HistoryDataBound) -> Self {
35 Self {
36 steps: Vec::with_capacity(bound.min_steps.saturating_add(bound.slack_steps)),
37 episode_ends: Vec::new(),
38 }
39 }
40
41 pub fn clear(&mut self) {
43 self.steps.clear();
44 self.episode_ends.clear();
45 }
46
47 #[must_use]
49 pub fn num_steps(&self) -> usize {
50 self.steps.len()
51 }
52
53 #[must_use]
55 pub fn num_episodes(&self) -> usize {
56 self.episode_ends.len()
57 }
58
59 pub fn steps(&self) -> slice::Iter<PartialStep<O, A, F>> {
61 self.steps.iter()
62 }
63
64 pub fn drain_steps(&mut self) -> vec::Drain<PartialStep<O, A, F>> {
66 self.steps.drain(..)
67 }
68
69 #[must_use]
71 pub fn episodes(&self) -> EpisodesIter<O, A, F> {
72 SplitChunksByLength::new(
73 &self.steps,
74 Differences::new(self.episode_ends.iter().copied(), 0),
75 )
76 }
77}
78
79impl<O, A, F> From<Vec<PartialStep<O, A, F>>> for VecBuffer<O, A, F> {
80 fn from(steps: Vec<PartialStep<O, A, F>>) -> Self {
81 let episode_ends = steps
82 .iter()
83 .enumerate()
84 .filter_map(|(i, step)| {
85 if step.episode_done() {
86 Some(i + 1)
87 } else {
88 None
89 }
90 })
91 .collect();
92 let mut buffer = Self {
93 steps,
94 episode_ends,
95 };
96 buffer.end_experience();
97 buffer
98 }
99}
100
101impl<O, A, F> FromIterator<PartialStep<O, A, F>> for VecBuffer<O, A, F> {
102 fn from_iter<I>(steps: I) -> Self
103 where
104 I: IntoIterator<Item = PartialStep<O, A, F>>,
105 {
106 let mut buffer = Self::new();
107 buffer.write_experience(steps).unwrap(); buffer
109 }
110}
111
112impl<O, A, F> WriteExperience<O, A, F> for VecBuffer<O, A, F> {
113 fn write_experience<I>(&mut self, steps: I) -> Result<(), WriteExperienceError>
114 where
115 I: IntoIterator<Item = PartialStep<O, A, F>>,
116 {
117 let offset = self.steps.len();
118 self.steps.extend(steps);
119 for (i, step) in self.steps[offset..].iter().enumerate() {
120 if step.episode_done() {
121 self.episode_ends.push(offset + i + 1)
122 }
123 }
124 self.end_experience();
125 Ok(())
126 }
127}
128
129impl<O, A, F> WriteExperienceIncremental<O, A, F> for VecBuffer<O, A, F> {
130 fn write_step(&mut self, step: PartialStep<O, A, F>) -> Result<(), WriteExperienceError> {
131 let episode_done = step.episode_done();
132 self.steps.push(step);
133 if episode_done {
134 self.episode_ends.push(self.steps.len());
135 }
136 Ok(())
137 }
138
139 fn end_experience(&mut self) {
140 if super::finalize_last_episode(&mut self.steps) {
141 self.episode_ends.push(self.steps.len())
142 }
143 }
144}
145
146pub type EpisodesIter<'a, O, A, F> = SplitChunksByLength<
147 &'a [PartialStep<O, A, F>],
148 Differences<Copied<slice::Iter<'a, usize>>, usize>,
149>;
150
151#[allow(clippy::needless_pass_by_value)]
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::envs::Successor::{self, Continue, Interrupt, Terminate};
156 use rstest::{fixture, rstest};
157
158 const fn step(observation: usize, next: Successor<usize, ()>) -> PartialStep<usize, bool> {
159 PartialStep {
160 observation,
161 action: false,
162 feedback: Reward(0.0),
163 next,
164 }
165 }
166
167 #[fixture]
172 fn buffer() -> VecBuffer<usize, bool> {
173 [
174 step(0, Terminate),
175 step(1, Continue(())),
176 step(2, Terminate),
177 step(3, Continue(())),
178 step(4, Continue(())),
179 ]
180 .into_iter()
181 .collect()
182 }
183
184 #[rstest]
185 fn from_vec(buffer: VecBuffer<usize, bool>) {
186 let test: VecBuffer<_, _> = vec![
187 step(0, Terminate),
188 step(1, Continue(())),
189 step(2, Terminate),
190 step(3, Continue(())),
191 step(4, Continue(())),
192 ]
193 .into();
194 assert_eq!(test, buffer);
195 }
196
197 #[rstest]
198 fn write_experience(buffer: VecBuffer<usize, bool>) {
199 let mut test = VecBuffer::new();
200 test.write_experience([
201 step(0, Terminate),
202 step(1, Continue(())),
203 step(2, Terminate),
204 step(3, Continue(())),
205 step(4, Continue(())),
206 ])
207 .unwrap();
208 assert_eq!(test, buffer);
209 }
210
211 #[rstest]
212 fn num_steps(buffer: VecBuffer<usize, bool>) {
213 assert_eq!(buffer.num_steps(), 4);
215 }
216
217 #[rstest]
218 fn num_episodes(buffer: VecBuffer<usize, bool>) {
219 assert_eq!(buffer.num_episodes(), 3);
220 }
221
222 #[rstest]
223 fn steps(buffer: VecBuffer<usize, bool>) {
224 let mut steps_iter = buffer.steps();
225 assert_eq!(steps_iter.next(), Some(&step(0, Terminate)));
226 assert_eq!(steps_iter.next(), Some(&step(1, Continue(()))));
227 assert_eq!(steps_iter.next(), Some(&step(2, Terminate)));
228 assert_eq!(steps_iter.next(), Some(&step(3, Interrupt(4))));
229 assert_eq!(steps_iter.next(), None);
230 }
231
232 #[rstest]
233 fn steps_is_fused(buffer: VecBuffer<usize, bool>) {
234 let mut steps_iter = buffer.steps();
235 for _ in 0..4 {
236 assert!(steps_iter.next().is_some());
237 }
238 assert!(steps_iter.next().is_none());
239 assert!(steps_iter.next().is_none());
240 }
241
242 #[rstest]
243 fn steps_len(buffer: VecBuffer<usize, bool>) {
244 assert_eq!(buffer.steps().len(), buffer.num_steps());
245 }
246
247 #[rstest]
248 fn episodes(buffer: VecBuffer<usize, bool>) {
249 let mut episodes_iter = buffer.episodes();
250 assert_eq!(
251 episodes_iter.next().unwrap().iter().collect::<Vec<_>>(),
252 [&step(0, Terminate)]
253 );
254 assert_eq!(
255 episodes_iter.next().unwrap().iter().collect::<Vec<_>>(),
256 [&step(1, Continue(())), &step(2, Terminate)]
257 );
258 assert_eq!(
259 episodes_iter.next().unwrap().iter().collect::<Vec<_>>(),
260 [&step(3, Interrupt(4))]
261 );
262 assert!(episodes_iter.next().is_none());
263 }
264
265 #[rstest]
266 fn episodes_len(buffer: VecBuffer<usize, bool>) {
267 assert_eq!(buffer.episodes().len(), buffer.num_episodes());
268 }
269
270 #[rstest]
271 fn episode_len_sum(buffer: VecBuffer<usize, bool>) {
272 assert_eq!(
273 buffer.episodes().map(<[_]>::len).sum::<usize>(),
274 buffer.num_steps()
275 );
276 }
277}