border_async_trainer/actor_manager/
base.rs1use crate::{
2 Actor, ActorManagerConfig, ActorStat, PushedItemMessage, ReplayBufferProxyConfig, SyncModel,
3};
4use border_core::{
5 Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor,
6};
7use crossbeam_channel::{bounded, Receiver, Sender};
8use log::info;
9use std::{
10 marker::PhantomData,
11 sync::{Arc, Mutex},
12 thread::JoinHandle,
13};
14
15pub struct ActorManager<A, E, R, P>
24where
25 A: Agent<E, R> + Configurable + SyncModel,
26 E: Env,
27 P: StepProcessor<E>,
28 R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
29{
30 agent_configs: Vec<A::Config>,
32
33 env_config: E::Config,
35
36 step_proc_config: P::Config,
38
39 threads: Vec<JoinHandle<()>>,
41
42 n_buffer: usize,
46
47 stop: Arc<Mutex<bool>>,
49
50 batch_message_receiver: Option<Receiver<PushedItemMessage<R::Item>>>,
52
53 pushed_item_message_sender: Sender<PushedItemMessage<R::Item>>,
55
56 model_info: Option<Arc<Mutex<(usize, A::ModelInfo)>>>,
60
61 model_info_receiver: Receiver<(usize, A::ModelInfo)>,
63
64 actor_stats: Vec<Arc<Mutex<Option<ActorStat>>>>,
66
67 phantom: PhantomData<R>,
68}
69
70impl<A, E, R, P> ActorManager<A, E, R, P>
71where
72 A: Agent<E, R> + Configurable + SyncModel + 'static,
73 E: Env,
74 P: StepProcessor<E>,
75 R: ExperienceBufferBase<Item = P::Output> + Send + 'static + ReplayBufferBase,
76 A::Config: Send + 'static,
77 E::Config: Send + 'static,
78 P::Config: Send + 'static,
79 R::Item: Send + 'static,
80 A::ModelInfo: Send + 'static,
81{
82 pub fn build(
84 config: &ActorManagerConfig,
85 agent_configs: &Vec<A::Config>,
86 env_config: &E::Config,
87 step_proc_config: &P::Config,
88 pushed_item_message_sender: Sender<PushedItemMessage<R::Item>>,
89 model_info_receiver: Receiver<(usize, A::ModelInfo)>,
90 stop: Arc<Mutex<bool>>,
91 ) -> Self {
92 Self {
93 agent_configs: agent_configs.clone(),
94 env_config: env_config.clone(),
95 step_proc_config: step_proc_config.clone(),
96 n_buffer: config.n_buffer,
97 stop,
98 threads: vec![],
99 batch_message_receiver: None,
100 pushed_item_message_sender,
101 model_info: None,
102 model_info_receiver,
103 actor_stats: vec![],
104 phantom: PhantomData,
105 }
106 }
107
108 pub fn run(&mut self, guard_init_env: Arc<Mutex<bool>>) {
113 let guard_init_model = Arc::new(Mutex::new(true));
115
116 self.model_info = {
118 let agent = A::build(self.agent_configs[0].clone());
119 Some(Arc::new(Mutex::new(agent.model_info())))
120 };
121
122 {
124 let stop = self.stop.clone();
125 let model_info_receiver = self.model_info_receiver.clone();
126 let model_info = self.model_info.as_ref().unwrap().clone();
127 let guard_init_model = guard_init_model.clone();
128 let handle = std::thread::spawn(move || {
129 Self::run_model_info_loop(model_info_receiver, model_info, stop, guard_init_model);
130 });
131 self.threads.push(handle);
132 info!("Starts thread for updating model info");
133 }
134
135 let (s, r) = bounded(1000);
138 self.batch_message_receiver = Some(r.clone());
139
140 self.agent_configs
142 .clone()
143 .into_iter()
144 .enumerate()
145 .for_each(|(id, agent_config)| {
146 let sender = s.clone();
147 let replay_buffer_proxy_config = ReplayBufferProxyConfig {
148 n_buffer: self.n_buffer,
149 };
150 let env_config = self.env_config.clone();
151 let step_proc_config = self.step_proc_config.clone();
152 let stop = self.stop.clone();
153 let seed = id;
154 let guard = guard_init_env.clone();
155 let guard_init_model = guard_init_model.clone();
156 let model_info = self.model_info.as_ref().unwrap().clone();
157 let stats = Arc::new(Mutex::new(None));
158 self.actor_stats.push(stats.clone());
159
160 let handle = std::thread::spawn(move || {
162 Actor::<A, E, P, R>::build(
163 id,
164 agent_config,
165 env_config,
166 step_proc_config,
167 replay_buffer_proxy_config,
168 stop,
169 seed as i64,
170 stats,
171 )
172 .run(sender, model_info, guard, guard_init_model);
173 });
174 self.threads.push(handle);
175 });
176
177 {
179 let stop = self.stop.clone();
180 let s = self.pushed_item_message_sender.clone();
181 let handle = std::thread::spawn(move || {
182 Self::handle_message(r, stop, s);
183 });
184 self.threads.push(handle);
185 }
186 }
187
188 pub fn join(self) -> Vec<ActorStat> {
190 for h in self.threads {
191 h.join().unwrap();
192 }
193
194 self.actor_stats
195 .iter()
196 .map(|e| e.lock().unwrap().clone().unwrap())
197 .collect::<Vec<_>>()
198 }
199
200 pub fn stop(&self) {
202 let mut stop = self.stop.lock().unwrap();
203 *stop = true;
204 }
205
206 pub fn stop_and_join(self) -> Vec<ActorStat> {
208 self.stop();
209 self.join()
210 }
211
212 fn handle_message(
214 receiver: Receiver<PushedItemMessage<R::Item>>,
215 stop: Arc<Mutex<bool>>,
216 sender: Sender<PushedItemMessage<R::Item>>,
217 ) {
218 let mut _n_samples = 0;
219
220 loop {
221 let msg = receiver.recv();
226 if msg.is_ok() {
227 _n_samples += 1;
228 sender.try_send(msg.unwrap()).unwrap();
229 }
230
231 if *stop.lock().unwrap() {
233 break;
234 }
235 }
236 info!("Stopped thread for message handling");
237 }
238
239 fn run_model_info_loop(
240 model_info_receiver: Receiver<(usize, A::ModelInfo)>,
241 model_info: Arc<Mutex<(usize, A::ModelInfo)>>,
242 stop: Arc<Mutex<bool>>,
243 guard_init_model: Arc<Mutex<bool>>,
244 ) {
245 {
247 let mut guard_init_model = guard_init_model.lock().unwrap();
248 let mut model_info = model_info.lock().unwrap();
249 let msg = model_info_receiver.recv().unwrap();
251 assert_eq!(msg.0, 0);
252 *model_info = msg;
253 *guard_init_model = true;
254 }
255
256 loop {
257 let msg = model_info_receiver.recv().unwrap();
259 let mut model_info = model_info.lock().unwrap();
260 *model_info = msg;
261 if *stop.lock().unwrap() {
262 break;
263 }
264 }
265 info!("Stopped model info thread");
266 }
267}