border_async_trainer/
util.rs1use crate::{
3 actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel,
4};
5use border_core::{
6 record::Recorder, Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase,
7 StepProcessor,
8};
9use crossbeam_channel::unbounded;
10use log::info;
11use std::sync::{Arc, Mutex};
12
13pub fn train_async<A, E, R, S>(
32 agent_config: &A::Config,
33 agent_configs: &Vec<A::Config>,
34 env_config_train: &E::Config,
35 env_config_eval: &E::Config,
36 step_proc_config: &S::Config,
37 replay_buffer_config: &R::Config,
38 actor_man_config: &ActorManagerConfig,
39 async_trainer_config: &AsyncTrainerConfig,
40 recorder: &mut Box<dyn Recorder<E, R>>,
41 evaluator: &mut impl Evaluator<E>,
42) where
43 A: Agent<E, R> + Configurable + SyncModel + 'static,
44 E: Env,
45 R: ExperienceBufferBase<Item = S::Output> + Send + 'static + ReplayBufferBase,
46 S: StepProcessor<E>,
47 A::Config: Send + 'static,
48 E::Config: Send + 'static,
49 S::Config: Send + 'static,
50 R::Item: Send + 'static,
51 A::ModelInfo: Send + 'static,
52{
53 let stop = Arc::new(Mutex::new(false));
55
56 let (item_s, item_r) = unbounded(); let (model_s, model_r) = unbounded(); let guard_init_env = Arc::new(Mutex::new(true));
62
63 let mut actors = ActorManager::<A, E, R, S>::build(
65 actor_man_config,
66 agent_configs,
67 env_config_train,
68 step_proc_config,
69 item_s,
70 model_r,
71 stop.clone(),
72 );
73 let mut trainer = AsyncTrainer::<A, E, R>::build(
74 async_trainer_config,
75 agent_config,
76 env_config_eval,
77 replay_buffer_config,
78 item_r,
79 model_s,
80 stop.clone(),
81 );
82
83 actors.run(guard_init_env.clone());
85 let stats = trainer.train(recorder, evaluator, guard_init_env);
86 info!("Stats of async trainer");
87 info!("{}", stats.fmt());
88
89 let stats = actors.stop_and_join();
90 info!("Stats of generated samples in actors");
91 info!("{}", actor_stats_fmt(&stats));
92}