1use crate::error::{RlError, RlResult};
31use crate::handle::RlHandle;
32
33#[derive(Debug, Clone)]
37pub struct Transition {
38 pub obs: Vec<f32>,
40 pub action: Vec<f32>,
42 pub reward: f32,
44 pub next_obs: Vec<f32>,
46 pub done: bool,
48}
49
50impl Transition {
51 #[must_use]
53 pub fn new(
54 obs: impl Into<Vec<f32>>,
55 action: impl Into<Vec<f32>>,
56 reward: f32,
57 next_obs: impl Into<Vec<f32>>,
58 done: bool,
59 ) -> Self {
60 Self {
61 obs: obs.into(),
62 action: action.into(),
63 reward,
64 next_obs: next_obs.into(),
65 done,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
74pub struct UniformReplayBuffer {
75 capacity: usize,
76 obs_dim: usize,
77 act_dim: usize,
78 obs: Vec<f32>, actions: Vec<f32>, rewards: Vec<f32>, next_obs: Vec<f32>, dones: Vec<f32>, head: usize,
86 size: usize,
88}
89
90impl UniformReplayBuffer {
91 pub fn new(capacity: usize, obs_dim: usize, act_dim: usize) -> Self {
98 assert!(capacity > 0, "capacity must be > 0");
99 Self {
100 capacity,
101 obs_dim,
102 act_dim,
103 obs: vec![0.0; capacity * obs_dim],
104 actions: vec![0.0; capacity * act_dim],
105 rewards: vec![0.0; capacity],
106 next_obs: vec![0.0; capacity * obs_dim],
107 dones: vec![0.0; capacity],
108 head: 0,
109 size: 0,
110 }
111 }
112
113 #[must_use]
115 #[inline]
116 pub fn len(&self) -> usize {
117 self.size
118 }
119
120 #[must_use]
122 #[inline]
123 pub fn capacity(&self) -> usize {
124 self.capacity
125 }
126
127 #[must_use]
129 #[inline]
130 pub fn is_empty(&self) -> bool {
131 self.size == 0
132 }
133
134 #[must_use]
136 #[inline]
137 pub fn is_full(&self) -> bool {
138 self.size == self.capacity
139 }
140
141 #[must_use]
143 #[inline]
144 pub fn obs_dim(&self) -> usize {
145 self.obs_dim
146 }
147
148 #[must_use]
150 #[inline]
151 pub fn act_dim(&self) -> usize {
152 self.act_dim
153 }
154
155 pub fn push(
162 &mut self,
163 obs: impl AsRef<[f32]>,
164 action: impl AsRef<[f32]>,
165 reward: f32,
166 next_obs: impl AsRef<[f32]>,
167 done: bool,
168 ) {
169 let obs = obs.as_ref();
170 let action = action.as_ref();
171 let next_obs = next_obs.as_ref();
172 debug_assert_eq!(obs.len(), self.obs_dim);
173 debug_assert_eq!(action.len(), self.act_dim);
174 debug_assert_eq!(next_obs.len(), self.obs_dim);
175
176 let i = self.head;
177 self.obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(obs);
178 self.actions[i * self.act_dim..(i + 1) * self.act_dim].copy_from_slice(action);
179 self.rewards[i] = reward;
180 self.next_obs[i * self.obs_dim..(i + 1) * self.obs_dim].copy_from_slice(next_obs);
181 self.dones[i] = if done { 1.0 } else { 0.0 };
182
183 self.head = (self.head + 1) % self.capacity;
184 if self.size < self.capacity {
185 self.size += 1;
186 }
187 }
188
189 pub fn push_transition(&mut self, t: Transition) {
191 self.push(t.obs, t.action, t.reward, t.next_obs, t.done);
192 }
193
194 pub fn sample(&self, batch_size: usize, handle: &mut RlHandle) -> RlResult<Vec<Transition>> {
202 if self.size < batch_size {
203 return Err(RlError::InsufficientTransitions {
204 have: self.size,
205 need: batch_size,
206 });
207 }
208 let rng = handle.rng_mut();
209 let mut out = Vec::with_capacity(batch_size);
210 let mut indices: Vec<usize> = Vec::with_capacity(batch_size);
213 while indices.len() < batch_size {
214 let idx = rng.next_usize(self.size);
215 if !indices.contains(&idx) {
216 indices.push(idx);
217 }
218 }
219 for idx in indices {
220 let obs = self.obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
221 let action = self.actions[idx * self.act_dim..(idx + 1) * self.act_dim].to_vec();
222 let reward = self.rewards[idx];
223 let next_obs = self.next_obs[idx * self.obs_dim..(idx + 1) * self.obs_dim].to_vec();
224 let done = self.dones[idx] > 0.5;
225 out.push(Transition {
226 obs,
227 action,
228 reward,
229 next_obs,
230 done,
231 });
232 }
233 Ok(out)
234 }
235
236 #[allow(clippy::too_many_arguments)]
250 pub fn sample_into(
251 &self,
252 batch_size: usize,
253 obs_out: &mut [f32],
254 action_out: &mut [f32],
255 reward_out: &mut [f32],
256 next_obs_out: &mut [f32],
257 done_out: &mut [f32],
258 handle: &mut RlHandle,
259 ) -> RlResult<()> {
260 if self.size < batch_size {
261 return Err(RlError::InsufficientTransitions {
262 have: self.size,
263 need: batch_size,
264 });
265 }
266 let rng = handle.rng_mut();
267 let mut indices: Vec<usize> = Vec::with_capacity(batch_size);
268 while indices.len() < batch_size {
269 let idx = rng.next_usize(self.size);
270 if !indices.contains(&idx) {
271 indices.push(idx);
272 }
273 }
274 for (b, &idx) in indices.iter().enumerate() {
275 obs_out[b * self.obs_dim..(b + 1) * self.obs_dim]
276 .copy_from_slice(&self.obs[idx * self.obs_dim..(idx + 1) * self.obs_dim]);
277 action_out[b * self.act_dim..(b + 1) * self.act_dim]
278 .copy_from_slice(&self.actions[idx * self.act_dim..(idx + 1) * self.act_dim]);
279 reward_out[b] = self.rewards[idx];
280 next_obs_out[b * self.obs_dim..(b + 1) * self.obs_dim]
281 .copy_from_slice(&self.next_obs[idx * self.obs_dim..(idx + 1) * self.obs_dim]);
282 done_out[b] = self.dones[idx];
283 }
284 Ok(())
285 }
286
287 pub fn clear(&mut self) {
289 self.head = 0;
290 self.size = 0;
291 }
292}
293
294#[cfg(test)]
297mod tests {
298 use super::*;
299
300 fn push_n(buf: &mut UniformReplayBuffer, n: usize) {
301 let od = buf.obs_dim();
302 let ad = buf.act_dim();
303 for i in 0..n {
304 buf.push(
305 vec![i as f32; od],
306 vec![0.0_f32; ad],
307 i as f32 * 0.1,
308 vec![i as f32 + 1.0; od],
309 i % 10 == 9,
310 );
311 }
312 }
313
314 #[test]
315 fn buffer_empty_initially() {
316 let buf = UniformReplayBuffer::new(100, 4, 2);
317 assert!(buf.is_empty());
318 assert_eq!(buf.len(), 0);
319 }
320
321 #[test]
322 fn buffer_grows_to_capacity() {
323 let mut buf = UniformReplayBuffer::new(10, 2, 1);
324 push_n(&mut buf, 10);
325 assert_eq!(buf.len(), 10);
326 assert!(buf.is_full());
327 }
328
329 #[test]
330 fn buffer_overwrites_oldest() {
331 let mut buf = UniformReplayBuffer::new(5, 1, 1);
332 for i in 0..7_usize {
333 buf.push([i as f32], [0.0], i as f32, [i as f32 + 1.0], false);
334 }
335 assert_eq!(buf.len(), 5);
337 }
338
339 #[test]
340 fn sample_correct_size() {
341 let mut buf = UniformReplayBuffer::new(100, 4, 2);
342 push_n(&mut buf, 100);
343 let mut handle = RlHandle::default_handle();
344 let batch = buf.sample(32, &mut handle).unwrap();
345 assert_eq!(batch.len(), 32);
346 }
347
348 #[test]
349 fn sample_no_duplicates() {
350 let mut buf = UniformReplayBuffer::new(100, 1, 1);
351 push_n(&mut buf, 100);
352 let mut handle = RlHandle::default_handle();
353 let batch = buf.sample(50, &mut handle).unwrap();
354 let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
355 for t in &batch {
356 let idx = t.obs[0] as usize;
357 assert!(seen.insert(idx), "duplicate index {idx}");
358 }
359 }
360
361 #[test]
362 fn sample_insufficient_error() {
363 let buf = UniformReplayBuffer::new(100, 4, 2);
364 let mut handle = RlHandle::default_handle();
365 assert!(buf.sample(10, &mut handle).is_err());
366 }
367
368 #[test]
369 fn push_transition_struct() {
370 let mut buf = UniformReplayBuffer::new(10, 3, 2);
371 buf.push_transition(Transition::new(
372 [1.0, 2.0, 3.0],
373 [4.0, 5.0],
374 1.0,
375 [2.0, 3.0, 4.0],
376 false,
377 ));
378 assert_eq!(buf.len(), 1);
379 }
380
381 #[test]
382 fn sample_into_fills_slices() {
383 let mut buf = UniformReplayBuffer::new(64, 4, 2);
384 push_n(&mut buf, 64);
385 let mut handle = RlHandle::default_handle();
386 let bs = 16;
387 let mut obs = vec![0.0_f32; bs * 4];
388 let mut act = vec![0.0_f32; bs * 2];
389 let mut rew = vec![0.0_f32; bs];
390 let mut nobs = vec![0.0_f32; bs * 4];
391 let mut done = vec![0.0_f32; bs];
392 buf.sample_into(
393 bs,
394 &mut obs,
395 &mut act,
396 &mut rew,
397 &mut nobs,
398 &mut done,
399 &mut handle,
400 )
401 .unwrap();
402 assert!(obs.iter().all(|&v| v >= 0.0));
404 }
405
406 #[test]
407 fn clear_resets_buffer() {
408 let mut buf = UniformReplayBuffer::new(10, 2, 1);
409 push_n(&mut buf, 10);
410 buf.clear();
411 assert!(buf.is_empty());
412 }
413
414 #[test]
415 fn transition_done_flag() {
416 let mut buf = UniformReplayBuffer::new(5, 1, 1);
417 buf.push([0.0], [0.0], 0.0, [1.0], true);
418 let mut handle = RlHandle::default_handle();
419 let batch = buf.sample(1, &mut handle).unwrap();
420 assert!(batch[0].done);
421 }
422}