border_async_trainer/async_trainer/
base.rs

1use crate::{AsyncTrainStat, AsyncTrainerConfig, PushedItemMessage, SyncModel};
2use anyhow::Result;
3use border_core::{
4    record::{Record, RecordValue::Scalar, Recorder},
5    Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase,
6};
7use crossbeam_channel::{Receiver, Sender};
8use log::{debug, info};
9use std::{
10    marker::PhantomData,
11    ops::{Deref, DerefMut},
12    sync::{Arc, Mutex},
13    time::{Duration, SystemTime},
14};
15
16#[cfg_attr(doc, aquamarine::aquamarine)]
17/// Manages asynchronous training loop in a single machine.
18///
19/// It interacts with [`ActorManager`] as shown below:
20///
21/// ```mermaid
22/// flowchart LR
23///   subgraph ActorManager
24///     E[Actor]-->|ReplayBufferBase::PushedItem|H[ReplayBufferProxy]
25///     F[Actor]-->H
26///     G[Actor]-->H
27///   end
28///   K-->|SyncModel::ModelInfo|E
29///   K-->|SyncModel::ModelInfo|F
30///   K-->|SyncModel::ModelInfo|G
31///
32///   subgraph I[AsyncTrainer]
33///     H-->|PushedItemMessage|J[ReplayBuffer]
34///     J-->|ReplayBufferBase::Batch|K[Agent]
35///   end
36/// ```
37///
38/// * The [`Agent`] in [`AsyncTrainer`] (left) is trained with batches
39///   of type [`ReplayBufferBase::Batch`], which are taken from the replay buffer.
40/// * The model parameters of the [`Agent`] in [`AsyncTrainer`] are wrapped in
41///   [`SyncModel::ModelInfo`] and periodically sent to the [`Agent`]s in [`Actor`]s.
42///   [`Agent`] must implement [`SyncModel`] to synchronize the model parameters.
43/// * In [`ActorManager`] (right), [`Actor`]s sample transitions, which have type
44///   [`ReplayBufferBase::Item`], and push the transitions into
45///   [`ReplayBufferProxy`].
46/// * [`ReplayBufferProxy`] has a type parameter of [`ReplayBufferBase`] and the proxy accepts
47///   [`ReplayBufferBase::Item`].
48/// * The proxy sends the transitions into the replay buffer in the [`AsyncTrainer`].
49///
50/// [`ActorManager`]: crate::ActorManager
51/// [`Actor`]: crate::Actor
52/// [`ReplayBufferBase::Item`]: border_core::ReplayBufferBase::PushedItem
53/// [`ReplayBufferBase::Batch`]: border_core::ReplayBufferBase::PushedBatch
54/// [`ReplayBufferProxy`]: crate::ReplayBufferProxy
55/// [`ReplayBufferBase`]: border_core::ReplayBufferBase
56/// [`SyncModel::ModelInfo`]: crate::SyncModel::ModelInfo
57/// [`Agent`]: border_core::Agent
58pub struct AsyncTrainer<A, E, R>
59where
60    A: Agent<E, R> + Configurable + SyncModel,
61    E: Env,
62    // R: ReplayBufferBase + Sync + Send + 'static,
63    R: ExperienceBufferBase + ReplayBufferBase,
64    R::Item: Send + 'static,
65{
66    /// Configuration of [`Env`]. Note that it is used only for evaluation, not for training.
67    env_config: E::Config,
68
69    /// Configuration of the replay buffer.
70    replay_buffer_config: R::Config,
71
72    /// Interval of recording computational cost in optimization steps.
73    record_compute_cost_interval: usize,
74
75    /// Interval of recording agent information in optimization steps.
76    record_agent_info_interval: usize,
77
78    /// Interval of flushing records in optimization steps.
79    flush_records_interval: usize,
80
81    /// Interval of evaluation in training steps.
82    eval_interval: usize,
83
84    /// Interval of saving the model in optimization steps.
85    save_interval: usize,
86
87    /// The maximal number of optimization steps.
88    max_opts: usize,
89
90    /// Warmup period, for filling replay buffer, in environment steps
91    warmup_period: usize,
92
93    /// Interval of synchronizing model parameters in training steps.
94    sync_interval: usize,
95
96    /// Receiver of pushed items.
97    r_bulk_pushed_item: Receiver<PushedItemMessage<R::Item>>,
98
99    /// If `false`, stops the actor threads.
100    stop: Arc<Mutex<bool>>,
101
102    /// Configuration of [`Agent`].
103    agent_config: A::Config,
104
105    /// Sender of model info.
106    model_info_sender: Sender<(usize, A::ModelInfo)>,
107
108    /// Counter for replay buffer samples.
109    samples_counter: usize,
110
111    /// Timer for replay buffer samples.
112    timer_for_samples: Duration,
113
114    /// Counter for optimization steps.
115    opt_steps_counter: usize,
116
117    /// Timer for optimization steps.
118    timer_for_opt_steps: Duration,
119
120    /// Max value of evaluation reward.
121    max_eval_reward: f32,
122
123    /// Optimization steps during training.
124    opt_steps: usize,
125
126    phantom: PhantomData<(A, E, R)>,
127}
128
129impl<A, E, R> AsyncTrainer<A, E, R>
130where
131    A: Agent<E, R> + Configurable + SyncModel + 'static,
132    E: Env,
133    // R: ReplayBufferBase + Sync + Send + 'static,
134    R: ExperienceBufferBase + ReplayBufferBase,
135    R::Item: Send + 'static,
136{
137    /// Creates [`AsyncTrainer`].
138    pub fn build(
139        config: &AsyncTrainerConfig,
140        agent_config: &A::Config,
141        env_config: &E::Config,
142        replay_buffer_config: &R::Config,
143        r_bulk_pushed_item: Receiver<PushedItemMessage<R::Item>>,
144        model_info_sender: Sender<(usize, A::ModelInfo)>,
145        stop: Arc<Mutex<bool>>,
146    ) -> Self {
147        Self {
148            eval_interval: config.eval_interval,
149            max_opts: config.max_opts,
150            record_compute_cost_interval: config.record_compute_cost_interval,
151            record_agent_info_interval: config.record_agent_info_interval,
152            flush_records_interval: config.flush_record_interval,
153            save_interval: config.save_interval,
154            sync_interval: config.sync_interval,
155            warmup_period: config.warmup_period,
156            agent_config: agent_config.clone(),
157            env_config: env_config.clone(),
158            replay_buffer_config: replay_buffer_config.clone(),
159            r_bulk_pushed_item,
160            model_info_sender,
161            stop,
162            samples_counter: 0,
163            timer_for_samples: Duration::new(0, 0),
164            opt_steps_counter: 0,
165            timer_for_opt_steps: Duration::new(0, 0),
166            max_eval_reward: f32::MIN,
167            opt_steps: 0,
168            phantom: PhantomData,
169        }
170    }
171
172    /// Resets the counters.
173    fn reset_counters(&mut self) {
174        self.samples_counter = 0;
175        self.timer_for_samples = Duration::new(0, 0);
176        self.opt_steps_counter = 0;
177        self.timer_for_opt_steps = Duration::new(0, 0);
178    }
179
180    /// Calculates average time for optimization steps and samples in milliseconds.
181    fn average_time(&mut self) -> (f32, f32) {
182        let avr_opt_time = match self.opt_steps_counter {
183            0 => -1f32,
184            n => self.timer_for_opt_steps.as_millis() as f32 / n as f32,
185        };
186        let avr_sample_time = match self.samples_counter {
187            0 => -1f32,
188            n => self.timer_for_samples.as_millis() as f32 / n as f32,
189        };
190        (avr_opt_time, avr_sample_time)
191    }
192
193    #[inline]
194    fn downcast_ref(agent: &Box<dyn Agent<E, R>>) -> &A {
195        agent.deref().as_any_ref().downcast_ref::<A>().unwrap()
196    }
197
198    #[inline]
199    fn downcast_mut(agent: &mut Box<dyn Agent<E, R>>) -> &mut A {
200        agent.deref_mut().as_any_mut().downcast_mut::<A>().unwrap()
201    }
202
203    #[inline]
204    fn train_step(&mut self, agent: &mut Box<dyn Agent<E, R>>, buffer: &mut R) -> Record {
205        if buffer.len() < self.warmup_period {
206            return Record::empty();
207        } else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 {
208            let timer = SystemTime::now();
209            let record = agent.opt_with_record(buffer);
210            self.opt_steps += 1;
211            self.opt_steps_counter += 1;
212            self.timer_for_opt_steps += timer.elapsed().unwrap();
213            return record;
214        } else {
215            let timer = SystemTime::now();
216            agent.opt(buffer);
217            self.opt_steps += 1;
218            self.opt_steps_counter += 1;
219            self.timer_for_opt_steps += timer.elapsed().unwrap();
220            return Record::empty();
221        }
222    }
223
224    /// Evaluates the agent, saves the best model, and syncs the model.
225    fn post_process<D>(
226        &mut self,
227        agent: &mut Box<dyn Agent<E, R>>,
228        evaluator: &mut D,
229        recorder: &mut Box<dyn Recorder<E, R>>,
230        record: &mut Record,
231    ) -> Result<()>
232    where
233        E: Env,
234        R: ReplayBufferBase,
235        D: Evaluator<E>,
236    {
237        // Evaluation
238        if self.opt_steps % self.eval_interval == 0 {
239            info!("Starts evaluation of the trained model");
240            agent.eval();
241            let (score, record_eval) = evaluator.evaluate(agent)?;
242            agent.train();
243            record.merge_inplace(record_eval);
244
245            // Save the best model up to the current iteration
246            if score > self.max_eval_reward {
247                self.max_eval_reward = score;
248                recorder.save_model("best".as_ref(), agent)?;
249            }
250        };
251
252        // Save the current model
253        if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) {
254            recorder.save_model(format!("{}", self.opt_steps).as_ref(), agent)?;
255        }
256
257        // Sync the current model
258        if self.opt_steps % self.sync_interval == 0 {
259            debug!("Sends the trained model info to ActorManager");
260            self.sync(Self::downcast_mut(agent));
261        }
262
263        Ok(())
264    }
265
266    /// Synchronize model.
267    #[inline]
268    fn sync(&mut self, agent: &A) {
269        let model_info = agent.model_info();
270        // TODO: error handling
271        self.model_info_sender.send(model_info).unwrap();
272    }
273
274    #[inline]
275    fn update_replay_buffer(&mut self, buffer: &mut R, samples_total: &mut usize) {
276        let msgs: Vec<_> = self.r_bulk_pushed_item.try_iter().collect();
277        msgs.into_iter().for_each(|msg| {
278            self.samples_counter += msg.pushed_items.len();
279            *samples_total += msg.pushed_items.len();
280            msg.pushed_items
281                .into_iter()
282                .for_each(|pushed_item| buffer.push(pushed_item).unwrap())
283        });
284    }
285
286    /// Runs training loop.
287    ///
288    /// In the training loop, the following values will be pushed into the given recorder:
289    ///
290    /// * `samples_total` - Total number of samples pushed into the replay buffer.
291    ///   Here, a "sample" is an item in [`ExperienceBufferBase::Item`].
292    /// * `opt_steps_per_sec` - The number of optimization steps per second.
293    /// * `samples_per_sec` - The number of samples per second.
294    /// * `samples_per_opt_steps` - The number of samples per optimization step.
295    ///
296    /// These values will typically be monitored with tensorboard.
297    ///
298    /// [`ExperienceBufferBase::Item`]: border_core::ExperienceBufferBase::Item
299    pub fn train<D>(
300        &mut self,
301        recorder: &mut Box<dyn Recorder<E, R>>,
302        evaluator: &mut D,
303        guard_init_env: Arc<Mutex<bool>>,
304    ) -> AsyncTrainStat
305    where
306        D: Evaluator<E>,
307    {
308        // TODO: error handling
309        let _env = {
310            let mut tmp = guard_init_env.lock().unwrap();
311            *tmp = true;
312            E::build(&self.env_config, 0).unwrap()
313        };
314        let mut agent: Box<dyn Agent<E, R>> = Box::new(A::build(self.agent_config.clone()));
315        let mut buffer = R::build(&self.replay_buffer_config);
316        agent.train();
317
318        self.reset_counters();
319        self.opt_steps = 0;
320        self.max_eval_reward = f32::MIN;
321        let time_total = SystemTime::now();
322        let mut samples_total = 0;
323
324        info!("Send model info first in AsyncTrainer");
325        self.sync(Self::downcast_ref(&agent));
326
327        info!("Warmup period");
328        loop {
329            self.update_replay_buffer(&mut buffer, &mut samples_total);
330            if buffer.len() >= self.warmup_period {
331                std::thread::sleep(Duration::from_millis(100));
332                break;
333            }
334        }
335
336        info!("Starts training loop");
337        loop {
338            // Update replay buffer
339            let now = SystemTime::now();
340            self.update_replay_buffer(&mut buffer, &mut samples_total);
341            self.timer_for_samples += now.elapsed().unwrap();
342
343            // Performe optimization step(s)
344            let mut record = self.train_step(&mut agent, &mut buffer);
345
346            // Postprocessing after each training step
347            self.post_process(&mut agent, evaluator, recorder, &mut record)
348                .unwrap(); // TODO: error handling
349
350            // Record average time for optimization steps and sampling steps in milliseconds
351            if self.opt_steps % self.record_compute_cost_interval == 0 {
352                let (avr_opt_time, avr_sample_time) = self.average_time();
353                record.insert("average_opt_time", Scalar(avr_opt_time));
354                record.insert("average_sample_time", Scalar(avr_sample_time));
355                self.reset_counters();
356            }
357
358            // Store record to the recorder
359            if !record.is_empty() {
360                recorder.store(record);
361            }
362
363            // Flush records
364            if (self.opt_steps - 1) % self.flush_records_interval == 0 {
365                recorder.flush(self.opt_steps as _);
366            }
367
368            // Finish training
369            if self.opt_steps == self.max_opts {
370                // Flush channels
371                *self.stop.lock().unwrap() = true;
372                let _: Vec<_> = self.r_bulk_pushed_item.try_iter().collect();
373                self.sync(Self::downcast_mut(&mut agent));
374                break;
375            }
376        }
377        info!("Stopped training loop");
378
379        let duration = time_total.elapsed().unwrap();
380        let time_total = duration.as_secs_f32();
381        let samples_per_sec = samples_total as f32 / time_total;
382        let opt_per_sec = self.max_opts as f32 / time_total;
383        AsyncTrainStat {
384            samples_per_sec,
385            duration,
386            opt_per_sec,
387        }
388    }
389}