border_async_trainer/actor_manager/
base.rs

1use crate::{
2    Actor, ActorManagerConfig, ActorStat, PushedItemMessage, ReplayBufferProxyConfig, SyncModel,
3};
4use border_core::{
5    Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor,
6};
7use crossbeam_channel::{bounded, /*unbounded,*/ Receiver, Sender};
8use log::info;
9use std::{
10    marker::PhantomData,
11    sync::{Arc, Mutex},
12    thread::JoinHandle,
13};
14
15/// Manages [`Actor`]s.
16///
17/// This struct handles the following requests:
18/// * From the [`AsyncTrainer`] for updating the latest model info, stored in this struct.
19/// * From the [`Actor`]s for getting the latest model info.
20/// * From the [`Actor`]s for pushing sample batch to the `LearnerManager`.
21///
22/// [`AsyncTrainer`]: crate::AsyncTrainer
23pub struct ActorManager<A, E, R, P>
24where
25    A: Agent<E, R> + Configurable + SyncModel,
26    E: Env,
27    P: StepProcessor<E>,
28    R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
29{
30    /// Configurations of [`Agent`]s.
31    agent_configs: Vec<A::Config>,
32
33    /// Configuration of [`Env`].
34    env_config: E::Config,
35
36    /// Configuration of a `StepProcessor`.
37    step_proc_config: P::Config,
38
39    /// Thread handles.
40    threads: Vec<JoinHandle<()>>,
41
42    /// Number of samples to be buffered in each actor before being pushed to the replay buffer.
43    ///
44    /// This parameter is used as `n_buffer` in [`ReplayBufferProxyConfig`].
45    n_buffer: usize,
46
47    /// Flag to stop training
48    stop: Arc<Mutex<bool>>,
49
50    /// Receiver of [PushedItemMessage]s from [Actor].
51    batch_message_receiver: Option<Receiver<PushedItemMessage<R::Item>>>,
52
53    /// Sender of [PushedItemMessage]s to [AsyncTrainer](crate::AsyncTrainer).
54    pushed_item_message_sender: Sender<PushedItemMessage<R::Item>>,
55
56    /// Information of the model.
57    ///
58    /// Also has The number of optimization steps of the current model information.
59    model_info: Option<Arc<Mutex<(usize, A::ModelInfo)>>>,
60
61    /// Receives incoming model info from [AsyncTrainer](crate::AsyncTrainer).
62    model_info_receiver: Receiver<(usize, A::ModelInfo)>,
63
64    /// Stats of [Actor]s, shared with actor threads.
65    actor_stats: Vec<Arc<Mutex<Option<ActorStat>>>>,
66
67    phantom: PhantomData<R>,
68}
69
70impl<A, E, R, P> ActorManager<A, E, R, P>
71where
72    A: Agent<E, R> + Configurable + SyncModel + 'static,
73    E: Env,
74    P: StepProcessor<E>,
75    R: ExperienceBufferBase<Item = P::Output> + Send + 'static + ReplayBufferBase,
76    A::Config: Send + 'static,
77    E::Config: Send + 'static,
78    P::Config: Send + 'static,
79    R::Item: Send + 'static,
80    A::ModelInfo: Send + 'static,
81{
82    /// Builds a [`ActorManager`].
83    pub fn build(
84        config: &ActorManagerConfig,
85        agent_configs: &Vec<A::Config>,
86        env_config: &E::Config,
87        step_proc_config: &P::Config,
88        pushed_item_message_sender: Sender<PushedItemMessage<R::Item>>,
89        model_info_receiver: Receiver<(usize, A::ModelInfo)>,
90        stop: Arc<Mutex<bool>>,
91    ) -> Self {
92        Self {
93            agent_configs: agent_configs.clone(),
94            env_config: env_config.clone(),
95            step_proc_config: step_proc_config.clone(),
96            n_buffer: config.n_buffer,
97            stop,
98            threads: vec![],
99            batch_message_receiver: None,
100            pushed_item_message_sender,
101            model_info: None,
102            model_info_receiver,
103            actor_stats: vec![],
104            phantom: PhantomData,
105        }
106    }
107
108    /// Runs threads for [`Actor`]s and a thread for sending samples into the replay buffer.
109    ///
110    /// Each thread is blocked until receiving the initial [`SyncModel::ModelInfo`]
111    /// from [`AsyncTrainer`](crate::AsyncTrainer).
112    pub fn run(&mut self, guard_init_env: Arc<Mutex<bool>>) {
113        // Guard for sync of the initial model
114        let guard_init_model = Arc::new(Mutex::new(true));
115
116        // Dummy model info
117        self.model_info = {
118            let agent = A::build(self.agent_configs[0].clone());
119            Some(Arc::new(Mutex::new(agent.model_info())))
120        };
121
122        // Thread for waiting [SyncModel::ModelInfo]
123        {
124            let stop = self.stop.clone();
125            let model_info_receiver = self.model_info_receiver.clone();
126            let model_info = self.model_info.as_ref().unwrap().clone();
127            let guard_init_model = guard_init_model.clone();
128            let handle = std::thread::spawn(move || {
129                Self::run_model_info_loop(model_info_receiver, model_info, stop, guard_init_model);
130            });
131            self.threads.push(handle);
132            info!("Starts thread for updating model info");
133        }
134
135        // Create channel for [BatchMessage]
136        // let (s, r) = unbounded();
137        let (s, r) = bounded(1000);
138        self.batch_message_receiver = Some(r.clone());
139
140        // Runs sampling processes
141        self.agent_configs
142            .clone()
143            .into_iter()
144            .enumerate()
145            .for_each(|(id, agent_config)| {
146                let sender = s.clone();
147                let replay_buffer_proxy_config = ReplayBufferProxyConfig {
148                    n_buffer: self.n_buffer,
149                };
150                let env_config = self.env_config.clone();
151                let step_proc_config = self.step_proc_config.clone();
152                let stop = self.stop.clone();
153                let seed = id;
154                let guard = guard_init_env.clone();
155                let guard_init_model = guard_init_model.clone();
156                let model_info = self.model_info.as_ref().unwrap().clone();
157                let stats = Arc::new(Mutex::new(None));
158                self.actor_stats.push(stats.clone());
159
160                // Spawn actor thread
161                let handle = std::thread::spawn(move || {
162                    Actor::<A, E, P, R>::build(
163                        id,
164                        agent_config,
165                        env_config,
166                        step_proc_config,
167                        replay_buffer_proxy_config,
168                        stop,
169                        seed as i64,
170                        stats,
171                    )
172                    .run(sender, model_info, guard, guard_init_model);
173                });
174                self.threads.push(handle);
175            });
176
177        // Thread for handling incoming samples
178        {
179            let stop = self.stop.clone();
180            let s = self.pushed_item_message_sender.clone();
181            let handle = std::thread::spawn(move || {
182                Self::handle_message(r, stop, s);
183            });
184            self.threads.push(handle);
185        }
186    }
187
188    /// Waits until all actors finish.
189    pub fn join(self) -> Vec<ActorStat> {
190        for h in self.threads {
191            h.join().unwrap();
192        }
193
194        self.actor_stats
195            .iter()
196            .map(|e| e.lock().unwrap().clone().unwrap())
197            .collect::<Vec<_>>()
198    }
199
200    /// Stops actor threads.
201    pub fn stop(&self) {
202        let mut stop = self.stop.lock().unwrap();
203        *stop = true;
204    }
205
206    /// Stops and joins actors.
207    pub fn stop_and_join(self) -> Vec<ActorStat> {
208        self.stop();
209        self.join()
210    }
211
212    /// Loop waiting [PushedItemMessage] from [Actor]s.
213    fn handle_message(
214        receiver: Receiver<PushedItemMessage<R::Item>>,
215        stop: Arc<Mutex<bool>>,
216        sender: Sender<PushedItemMessage<R::Item>>,
217    ) {
218        let mut _n_samples = 0;
219
220        loop {
221            // Handle incoming message
222            // TODO: error handling, timeout
223            // TODO: caching
224            // TODO: stats
225            let msg = receiver.recv();
226            if msg.is_ok() {
227                _n_samples += 1;
228                sender.try_send(msg.unwrap()).unwrap();
229            }
230
231            // Stop the loop
232            if *stop.lock().unwrap() {
233                break;
234            }
235        }
236        info!("Stopped thread for message handling");
237    }
238
239    fn run_model_info_loop(
240        model_info_receiver: Receiver<(usize, A::ModelInfo)>,
241        model_info: Arc<Mutex<(usize, A::ModelInfo)>>,
242        stop: Arc<Mutex<bool>>,
243        guard_init_model: Arc<Mutex<bool>>,
244    ) {
245        // Blocks threads sharing model_info until arriving the first message from AsyncTrainer.
246        {
247            let mut guard_init_model = guard_init_model.lock().unwrap();
248            let mut model_info = model_info.lock().unwrap();
249            // TODO: error handling
250            let msg = model_info_receiver.recv().unwrap();
251            assert_eq!(msg.0, 0);
252            *model_info = msg;
253            *guard_init_model = true;
254        }
255
256        loop {
257            // TODO: error handling
258            let msg = model_info_receiver.recv().unwrap();
259            let mut model_info = model_info.lock().unwrap();
260            *model_info = msg;
261            if *stop.lock().unwrap() {
262                break;
263            }
264        }
265        info!("Stopped model info thread");
266    }
267}