use crate::{
actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel,
};
use border_core::{
Agent, DefaultEvaluator, Env, ReplayBufferBase,
StepProcessorBase,
};
use border_tensorboard::TensorboardRecorder;
use crossbeam_channel::unbounded;
use log::info;
use std::{
path::Path,
sync::{Arc, Mutex},
};
pub fn train_async<A, E, R, S, P>(
model_dir: &P,
agent_config: &A::Config,
agent_configs: &Vec<A::Config>,
env_config_train: &E::Config,
env_config_eval: &E::Config,
step_proc_config: &S::Config,
replay_buffer_config: &R::Config,
actor_man_config: &ActorManagerConfig,
async_trainer_config: &AsyncTrainerConfig,
) where
A: Agent<E, R> + SyncModel,
E: Env,
R: ReplayBufferBase<PushedItem = S::Output> + Send + 'static,
S: StepProcessorBase<E>,
A::Config: Send + 'static,
E::Config: Send + 'static,
S::Config: Send + 'static,
R::PushedItem: Send + 'static,
A::ModelInfo: Send + 'static,
P: AsRef<Path>,
{
let mut recorder = TensorboardRecorder::new(model_dir);
let mut evaluator = DefaultEvaluator::new(env_config_eval, 0, 1).unwrap();
let stop = Arc::new(Mutex::new(false));
let (item_s, item_r) = unbounded(); let (model_s, model_r) = unbounded(); let guard_init_env = Arc::new(Mutex::new(true));
let mut actors = ActorManager::<A, E, R, S>::build(
actor_man_config,
agent_configs,
env_config_train,
step_proc_config,
item_s,
model_r,
stop.clone(),
);
let mut trainer = AsyncTrainer::<A, E, R>::build(
async_trainer_config,
agent_config,
env_config_eval,
replay_buffer_config,
item_r,
model_s,
stop.clone(),
);
actors.run(guard_init_env.clone());
let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env);
info!("Stats of async trainer");
info!("{}", stats.fmt());
let stats = actors.stop_and_join();
info!("Stats of generated samples in actors");
info!("{}", actor_stats_fmt(&stats));
}