border_async_trainer/
replay_buffer_proxy.rs

1use crate::PushedItemMessage;
2use anyhow::Result;
3use border_core::{ExperienceBufferBase, ReplayBufferBase};
4use crossbeam_channel::Sender;
5use std::marker::PhantomData;
6
7/// Configuration of [`ReplayBufferProxy`].
8#[derive(Clone, Debug)]
9pub struct ReplayBufferProxyConfig {
10    /// Number of samples buffered until sent to the trainer.
11    ///
12    /// A sample is a `R::Item` for [`ReplayBufferProxy`]`<R>`.
13    pub n_buffer: usize,
14}
15
16/// A wrapper of replay buffer for asynchronous trainer.
17pub struct ReplayBufferProxy<R: ExperienceBufferBase> {
18    id: usize,
19
20    /// Sender of [PushedItemMessage].
21    sender: Sender<PushedItemMessage<R::Item>>,
22
23    /// Number of samples buffered until sent to the trainer.
24    n_buffer: usize,
25
26    /// Buffer of `R::Item`s.
27    buffer: Vec<R::Item>,
28
29    phantom: PhantomData<R>,
30}
31
32impl<R: ExperienceBufferBase> ReplayBufferProxy<R> {
33    pub fn build_with_sender(
34        id: usize,
35        config: &ReplayBufferProxyConfig,
36        sender: Sender<PushedItemMessage<R::Item>>,
37    ) -> Self {
38        let n_buffer = config.n_buffer;
39        Self {
40            id,
41            sender,
42            n_buffer,
43            buffer: Vec::with_capacity(n_buffer),
44            phantom: PhantomData,
45        }
46    }
47}
48
49impl<R: ExperienceBufferBase> ExperienceBufferBase for ReplayBufferProxy<R> {
50    type Item = R::Item;
51
52    fn push(&mut self, tr: Self::Item) -> Result<()> {
53        self.buffer.push(tr);
54        if self.buffer.len() == self.n_buffer {
55            let mut buffer = Vec::with_capacity(self.n_buffer);
56            std::mem::swap(&mut self.buffer, &mut buffer);
57
58            let msg = PushedItemMessage {
59                id: self.id,
60                pushed_items: buffer,
61            };
62
63            match self.sender.try_send(msg) {
64                Ok(()) => {}
65                Err(_e) => {
66                    return Err(crate::BorderAsyncTrainerError::SendMsgForPush)?;
67                }
68            }
69        }
70
71        Ok(())
72    }
73
74    fn len(&self) -> usize {
75        unimplemented!();
76    }
77}
78
79impl<R: ExperienceBufferBase + ReplayBufferBase> ReplayBufferBase for ReplayBufferProxy<R> {
80    type Config = ReplayBufferProxyConfig;
81    type Batch = R::Batch;
82
83    fn build(_config: &Self::Config) -> Self {
84        unimplemented!();
85    }
86
87    fn batch(&mut self, _size: usize) -> anyhow::Result<Self::Batch> {
88        unimplemented!();
89    }
90
91    fn update_priority(&mut self, _ixs: &Option<Vec<usize>>, _td_err: &Option<Vec<f32>>) {
92        unimplemented!();
93    }
94}