border_async_trainer/
replay_buffer_proxy.rs1use crate::PushedItemMessage;
2use anyhow::Result;
3use border_core::{ExperienceBufferBase, ReplayBufferBase};
4use crossbeam_channel::Sender;
5use std::marker::PhantomData;
6
7#[derive(Clone, Debug)]
9pub struct ReplayBufferProxyConfig {
10 pub n_buffer: usize,
14}
15
16pub struct ReplayBufferProxy<R: ExperienceBufferBase> {
18 id: usize,
19
20 sender: Sender<PushedItemMessage<R::Item>>,
22
23 n_buffer: usize,
25
26 buffer: Vec<R::Item>,
28
29 phantom: PhantomData<R>,
30}
31
32impl<R: ExperienceBufferBase> ReplayBufferProxy<R> {
33 pub fn build_with_sender(
34 id: usize,
35 config: &ReplayBufferProxyConfig,
36 sender: Sender<PushedItemMessage<R::Item>>,
37 ) -> Self {
38 let n_buffer = config.n_buffer;
39 Self {
40 id,
41 sender,
42 n_buffer,
43 buffer: Vec::with_capacity(n_buffer),
44 phantom: PhantomData,
45 }
46 }
47}
48
49impl<R: ExperienceBufferBase> ExperienceBufferBase for ReplayBufferProxy<R> {
50 type Item = R::Item;
51
52 fn push(&mut self, tr: Self::Item) -> Result<()> {
53 self.buffer.push(tr);
54 if self.buffer.len() == self.n_buffer {
55 let mut buffer = Vec::with_capacity(self.n_buffer);
56 std::mem::swap(&mut self.buffer, &mut buffer);
57
58 let msg = PushedItemMessage {
59 id: self.id,
60 pushed_items: buffer,
61 };
62
63 match self.sender.try_send(msg) {
64 Ok(()) => {}
65 Err(_e) => {
66 return Err(crate::BorderAsyncTrainerError::SendMsgForPush)?;
67 }
68 }
69 }
70
71 Ok(())
72 }
73
74 fn len(&self) -> usize {
75 unimplemented!();
76 }
77}
78
79impl<R: ExperienceBufferBase + ReplayBufferBase> ReplayBufferBase for ReplayBufferProxy<R> {
80 type Config = ReplayBufferProxyConfig;
81 type Batch = R::Batch;
82
83 fn build(_config: &Self::Config) -> Self {
84 unimplemented!();
85 }
86
87 fn batch(&mut self, _size: usize) -> anyhow::Result<Self::Batch> {
88 unimplemented!();
89 }
90
91 fn update_priority(&mut self, _ixs: &Option<Vec<usize>>, _td_err: &Option<Vec<f32>>) {
92 unimplemented!();
93 }
94}