border_async_trainer/async_trainer/
base.rs1use crate::{AsyncTrainStat, AsyncTrainerConfig, PushedItemMessage, SyncModel};
2use anyhow::Result;
3use border_core::{
4 record::{Record, RecordValue::Scalar, Recorder},
5 Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase,
6};
7use crossbeam_channel::{Receiver, Sender};
8use log::{debug, info};
9use std::{
10 marker::PhantomData,
11 ops::{Deref, DerefMut},
12 sync::{Arc, Mutex},
13 time::{Duration, SystemTime},
14};
15
16#[cfg_attr(doc, aquamarine::aquamarine)]
17pub struct AsyncTrainer<A, E, R>
59where
60 A: Agent<E, R> + Configurable + SyncModel,
61 E: Env,
62 R: ExperienceBufferBase + ReplayBufferBase,
64 R::Item: Send + 'static,
65{
66 env_config: E::Config,
68
69 replay_buffer_config: R::Config,
71
72 record_compute_cost_interval: usize,
74
75 record_agent_info_interval: usize,
77
78 flush_records_interval: usize,
80
81 eval_interval: usize,
83
84 save_interval: usize,
86
87 max_opts: usize,
89
90 warmup_period: usize,
92
93 sync_interval: usize,
95
96 r_bulk_pushed_item: Receiver<PushedItemMessage<R::Item>>,
98
99 stop: Arc<Mutex<bool>>,
101
102 agent_config: A::Config,
104
105 model_info_sender: Sender<(usize, A::ModelInfo)>,
107
108 samples_counter: usize,
110
111 timer_for_samples: Duration,
113
114 opt_steps_counter: usize,
116
117 timer_for_opt_steps: Duration,
119
120 max_eval_reward: f32,
122
123 opt_steps: usize,
125
126 phantom: PhantomData<(A, E, R)>,
127}
128
129impl<A, E, R> AsyncTrainer<A, E, R>
130where
131 A: Agent<E, R> + Configurable + SyncModel + 'static,
132 E: Env,
133 R: ExperienceBufferBase + ReplayBufferBase,
135 R::Item: Send + 'static,
136{
137 pub fn build(
139 config: &AsyncTrainerConfig,
140 agent_config: &A::Config,
141 env_config: &E::Config,
142 replay_buffer_config: &R::Config,
143 r_bulk_pushed_item: Receiver<PushedItemMessage<R::Item>>,
144 model_info_sender: Sender<(usize, A::ModelInfo)>,
145 stop: Arc<Mutex<bool>>,
146 ) -> Self {
147 Self {
148 eval_interval: config.eval_interval,
149 max_opts: config.max_opts,
150 record_compute_cost_interval: config.record_compute_cost_interval,
151 record_agent_info_interval: config.record_agent_info_interval,
152 flush_records_interval: config.flush_record_interval,
153 save_interval: config.save_interval,
154 sync_interval: config.sync_interval,
155 warmup_period: config.warmup_period,
156 agent_config: agent_config.clone(),
157 env_config: env_config.clone(),
158 replay_buffer_config: replay_buffer_config.clone(),
159 r_bulk_pushed_item,
160 model_info_sender,
161 stop,
162 samples_counter: 0,
163 timer_for_samples: Duration::new(0, 0),
164 opt_steps_counter: 0,
165 timer_for_opt_steps: Duration::new(0, 0),
166 max_eval_reward: f32::MIN,
167 opt_steps: 0,
168 phantom: PhantomData,
169 }
170 }
171
172 fn reset_counters(&mut self) {
174 self.samples_counter = 0;
175 self.timer_for_samples = Duration::new(0, 0);
176 self.opt_steps_counter = 0;
177 self.timer_for_opt_steps = Duration::new(0, 0);
178 }
179
180 fn average_time(&mut self) -> (f32, f32) {
182 let avr_opt_time = match self.opt_steps_counter {
183 0 => -1f32,
184 n => self.timer_for_opt_steps.as_millis() as f32 / n as f32,
185 };
186 let avr_sample_time = match self.samples_counter {
187 0 => -1f32,
188 n => self.timer_for_samples.as_millis() as f32 / n as f32,
189 };
190 (avr_opt_time, avr_sample_time)
191 }
192
193 #[inline]
194 fn downcast_ref(agent: &Box<dyn Agent<E, R>>) -> &A {
195 agent.deref().as_any_ref().downcast_ref::<A>().unwrap()
196 }
197
198 #[inline]
199 fn downcast_mut(agent: &mut Box<dyn Agent<E, R>>) -> &mut A {
200 agent.deref_mut().as_any_mut().downcast_mut::<A>().unwrap()
201 }
202
203 #[inline]
204 fn train_step(&mut self, agent: &mut Box<dyn Agent<E, R>>, buffer: &mut R) -> Record {
205 if buffer.len() < self.warmup_period {
206 return Record::empty();
207 } else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 {
208 let timer = SystemTime::now();
209 let record = agent.opt_with_record(buffer);
210 self.opt_steps += 1;
211 self.opt_steps_counter += 1;
212 self.timer_for_opt_steps += timer.elapsed().unwrap();
213 return record;
214 } else {
215 let timer = SystemTime::now();
216 agent.opt(buffer);
217 self.opt_steps += 1;
218 self.opt_steps_counter += 1;
219 self.timer_for_opt_steps += timer.elapsed().unwrap();
220 return Record::empty();
221 }
222 }
223
224 fn post_process<D>(
226 &mut self,
227 agent: &mut Box<dyn Agent<E, R>>,
228 evaluator: &mut D,
229 recorder: &mut Box<dyn Recorder<E, R>>,
230 record: &mut Record,
231 ) -> Result<()>
232 where
233 E: Env,
234 R: ReplayBufferBase,
235 D: Evaluator<E>,
236 {
237 if self.opt_steps % self.eval_interval == 0 {
239 info!("Starts evaluation of the trained model");
240 agent.eval();
241 let (score, record_eval) = evaluator.evaluate(agent)?;
242 agent.train();
243 record.merge_inplace(record_eval);
244
245 if score > self.max_eval_reward {
247 self.max_eval_reward = score;
248 recorder.save_model("best".as_ref(), agent)?;
249 }
250 };
251
252 if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) {
254 recorder.save_model(format!("{}", self.opt_steps).as_ref(), agent)?;
255 }
256
257 if self.opt_steps % self.sync_interval == 0 {
259 debug!("Sends the trained model info to ActorManager");
260 self.sync(Self::downcast_mut(agent));
261 }
262
263 Ok(())
264 }
265
266 #[inline]
268 fn sync(&mut self, agent: &A) {
269 let model_info = agent.model_info();
270 self.model_info_sender.send(model_info).unwrap();
272 }
273
274 #[inline]
275 fn update_replay_buffer(&mut self, buffer: &mut R, samples_total: &mut usize) {
276 let msgs: Vec<_> = self.r_bulk_pushed_item.try_iter().collect();
277 msgs.into_iter().for_each(|msg| {
278 self.samples_counter += msg.pushed_items.len();
279 *samples_total += msg.pushed_items.len();
280 msg.pushed_items
281 .into_iter()
282 .for_each(|pushed_item| buffer.push(pushed_item).unwrap())
283 });
284 }
285
286 pub fn train<D>(
300 &mut self,
301 recorder: &mut Box<dyn Recorder<E, R>>,
302 evaluator: &mut D,
303 guard_init_env: Arc<Mutex<bool>>,
304 ) -> AsyncTrainStat
305 where
306 D: Evaluator<E>,
307 {
308 let _env = {
310 let mut tmp = guard_init_env.lock().unwrap();
311 *tmp = true;
312 E::build(&self.env_config, 0).unwrap()
313 };
314 let mut agent: Box<dyn Agent<E, R>> = Box::new(A::build(self.agent_config.clone()));
315 let mut buffer = R::build(&self.replay_buffer_config);
316 agent.train();
317
318 self.reset_counters();
319 self.opt_steps = 0;
320 self.max_eval_reward = f32::MIN;
321 let time_total = SystemTime::now();
322 let mut samples_total = 0;
323
324 info!("Send model info first in AsyncTrainer");
325 self.sync(Self::downcast_ref(&agent));
326
327 info!("Warmup period");
328 loop {
329 self.update_replay_buffer(&mut buffer, &mut samples_total);
330 if buffer.len() >= self.warmup_period {
331 std::thread::sleep(Duration::from_millis(100));
332 break;
333 }
334 }
335
336 info!("Starts training loop");
337 loop {
338 let now = SystemTime::now();
340 self.update_replay_buffer(&mut buffer, &mut samples_total);
341 self.timer_for_samples += now.elapsed().unwrap();
342
343 let mut record = self.train_step(&mut agent, &mut buffer);
345
346 self.post_process(&mut agent, evaluator, recorder, &mut record)
348 .unwrap(); if self.opt_steps % self.record_compute_cost_interval == 0 {
352 let (avr_opt_time, avr_sample_time) = self.average_time();
353 record.insert("average_opt_time", Scalar(avr_opt_time));
354 record.insert("average_sample_time", Scalar(avr_sample_time));
355 self.reset_counters();
356 }
357
358 if !record.is_empty() {
360 recorder.store(record);
361 }
362
363 if (self.opt_steps - 1) % self.flush_records_interval == 0 {
365 recorder.flush(self.opt_steps as _);
366 }
367
368 if self.opt_steps == self.max_opts {
370 *self.stop.lock().unwrap() = true;
372 let _: Vec<_> = self.r_bulk_pushed_item.try_iter().collect();
373 self.sync(Self::downcast_mut(&mut agent));
374 break;
375 }
376 }
377 info!("Stopped training loop");
378
379 let duration = time_total.elapsed().unwrap();
380 let time_total = duration.as_secs_f32();
381 let samples_per_sec = samples_total as f32 / time_total;
382 let opt_per_sec = self.max_opts as f32 / time_total;
383 AsyncTrainStat {
384 samples_per_sec,
385 duration,
386 opt_per_sec,
387 }
388 }
389}