border_async_trainer/
util.rs

1//! Utility function.
2use crate::{
3    actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel,
4};
5use border_core::{
6    record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase,
7    StepProcessor,
8};
9use crossbeam_channel::unbounded;
10use log::info;
11use std::sync::{Arc, Mutex};
12
13/// Runs asynchronous training.
14///
15/// This function runs [`ActorManager`] and [`AsyncTrainer`] on threads.
16/// These communicate using [`crossbeam_channel`]. Training logs are recorded for
17/// tensorboard.
18///
19/// * `model_dir` - Directory where trained models and tensor board log will be saved.
20/// * `agent_config` - Configuration of the agent to be trained.
21/// * `agent_configs` - Configurations of agents for asynchronous sampling.
22///   It must share the same structure of the model ([`SyncModel::ModelInfo`]),
23///   while exploration parameters can be different.
24/// * `env_config_train` - Configuration of the environment with which transitions are
25///   sampled.
26/// * `env_config_eval` - Configuration of the environment on which the agent being trained
27///   is evaluated.
28/// * `replay_buffer_config` - Configuration of the replay buffer.
29/// * `actor_man_config` - Configuration of [`ActorManager`].
30/// * `async_trainer_config` - Configuration of [`AsyncTrainer`].
31pub fn train_async<A, E, R, S>(
32    agent_config: &A::Config,
33    agent_configs: &Vec<A::Config>,
34    env_config_train: &E::Config,
35    env_config_eval: &E::Config,
36    step_proc_config: &S::Config,
37    replay_buffer_config: &R::Config,
38    actor_man_config: &ActorManagerConfig,
39    async_trainer_config: &AsyncTrainerConfig,
40    recorder: &mut Box<dyn Recorder<E, R>>,
41    evaluator: &mut impl Evaluator<E>,
42) where
43    A: Agent<E, R> + Configurable + SyncModel + 'static,
44    E: Env,
45    R: ExperienceBufferBase<Item = S::Output> + Send + 'static + ReplayBufferBase,
46    S: StepProcessor<E>,
47    A::Config: Send + 'static,
48    E::Config: Send + 'static,
49    S::Config: Send + 'static,
50    R::Item: Send + 'static,
51    A::ModelInfo: Send + 'static,
52{
53    // Shared flag to stop actor threads
54    let stop = Arc::new(Mutex::new(false));
55
56    // Creates channels
57    let (item_s, item_r) = unbounded(); // items pushed to replay buffer
58    let (model_s, model_r) = unbounded(); // model_info
59
60    // guard for initialization of envs in multiple threads
61    let guard_init_env = Arc::new(Mutex::new(true));
62
63    // Actor manager and async trainer
64    let mut actors = ActorManager::<A, E, R, S>::build(
65        actor_man_config,
66        agent_configs,
67        env_config_train,
68        step_proc_config,
69        item_s,
70        model_r,
71        stop.clone(),
72    );
73    let mut trainer = AsyncTrainer::<A, E, R>::build(
74        async_trainer_config,
75        agent_config,
76        env_config_eval,
77        replay_buffer_config,
78        item_r,
79        model_s,
80        stop.clone(),
81    );
82
83    // Starts sampling and training
84    actors.run(guard_init_env.clone());
85    let stats = trainer.train(recorder, evaluator, guard_init_env);
86    info!("Stats of async trainer");
87    info!("{}", stats.fmt());
88
89    let stats = actors.stop_and_join();
90    info!("Stats of generated samples in actors");
91    info!("{}", actor_stats_fmt(&stats));
92}