Crate border_async_trainer
source ·Expand description
Asynchronous trainer with parallel sampling processes.
The code might look like below.
ⓘ
fn train() {
let agent_configs: Vec<_> = vec![agent_config()];
let env_config_train = env_config(name);
let env_config_eval = env_config(name).eval();
let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?;
let step_proc_config = SimpleStepProcessorConfig::default();
let actor_man_config = ActorManagerConfig::default();
let async_trainer_config = load_async_trainer_config(model_dir.as_str())?;
let mut recorder = TensorboardRecorder::new(model_dir);
let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?;
// Shared flag to stop actor threads
let stop = Arc::new(Mutex::new(false));
// Creates channels
let (item_s, item_r) = unbounded(); // items pushed to replay buffer
let (model_s, model_r) = unbounded(); // model_info
// guard for initialization of envs in multiple threads
let guard_init_env = Arc::new(Mutex::new(true));
// Actor manager and async trainer
let mut actors = ActorManager::build(
&actor_man_config,
&agent_configs,
&env_config_train,
&step_proc_config,
item_s,
model_r,
stop.clone(),
);
let mut trainer = AsyncTrainer::build(
&async_trainer_config,
&agent_config,
&env_config_eval,
&replay_buffer_config,
item_r,
model_s,
stop.clone(),
);
// Set the number of threads
tch::set_num_threads(1);
// Starts sampling and training
actors.run(guard_init_env.clone());
let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env);
println!("Stats of async trainer");
println!("{}", stats.fmt());
let stats = actors.stop_and_join();
println!("Stats of generated samples in actors");
println!("{}", actor_stats_fmt(&stats));
}
Training process consists of the following two components:
ActorManager
managesActor
s, each of which runs a thread for interactingAgent
andEnv
and taking samples. Those samples will be sent to the replay buffer inAsyncTrainer
.AsyncTrainer
is responsible for training of an agent. It also runs a thread for pushing samples fromActorManager
into a replay buffer.
The Agent
must implement SyncModel
trait in order to synchronize the model of
the agent in Actor
with the trained agent in AsyncTrainer
. The trait has
the ability to import and export the information of the model as
SyncModel
::ModelInfo
.
The Agent
in AsyncTrainer
is responsible for training, typically with a GPU,
while the Agent
s in Actor
s in ActorManager
is responsible for sampling
using CPU.
Both AsyncTrainer
and ActorManager
are running in the same machine and
communicate by channels.
Modules
- Utility function.
Structs
- Manages
Actor
s. - Configuration of
ActorManager
. - Stats of sampling process in each
Actor
. - Stats of
AsyncTrainer
::train()
. - Manages asynchronous training loop in a single machine.
- Configuration of AsyncTrainer
- Message containing a
ReplayBufferBase
::PushedItem
. - A wrapper of replay buffer for asynchronous trainer.
- Configuration of
ReplayBufferProxy
.
Enums
Traits
- Synchronizes the model of the agent in asynchronous training.
Functions
- Returns a formatted string of the set of
ActorStat
for reporting.