Struct border_async_trainer::AsyncTrainer
source · pub struct AsyncTrainer<A, E, R>where
A: Agent<E, R> + SyncModel,
E: Env,
R: ReplayBufferBase,
R::PushedItem: Send + 'static,{ /* private fields */ }
Expand description
Manages asynchronous training loop in a single machine.
It interacts with ActorManager
as shown below:
flowchart LR
subgraph ActorManager
E[Actor]-->|ReplayBufferBase::PushedItem|H[ReplayBufferProxy]
F[Actor]-->H
G[Actor]-->H
end
K-->|SyncModel::ModelInfo|E
K-->|SyncModel::ModelInfo|F
K-->|SyncModel::ModelInfo|G
subgraph I[AsyncTrainer]
H-->|PushedItemMessage|J[ReplayBuffer]
J-->|ReplayBufferBase::Batch|K[Agent]
end
- In
ActorManager
(right),Actor
s sample transitions, which have typeReplayBufferBase::PushedItem
, in parallel and push the transitions intoReplayBufferProxy
. It should be noted thatReplayBufferProxy
has a type parameter ofReplayBufferBase
and the proxy acceptsReplayBufferBase::PushedItem
. - The proxy sends the transitions into the replay buffer, implementing
ReplayBufferBase
, in theAsyncTrainer
. - The
Agent
inAsyncTrainer
trains its model parameters by using batches of typeReplayBufferBase::Batch
, which are taken from the replay buffer. - The model parameters of the
Agent
inAsyncTrainer
are wrapped inSyncModel::ModelInfo
and periodically sent to theAgent
s inActor
s.Agent
must implementSyncModel
to synchronize its model.
Implementations§
source§impl<A, E, R> AsyncTrainer<A, E, R>where
A: Agent<E, R> + SyncModel,
E: Env,
R: ReplayBufferBase,
R::PushedItem: Send + 'static,
impl<A, E, R> AsyncTrainer<A, E, R>where A: Agent<E, R> + SyncModel, E: Env, R: ReplayBufferBase, R::PushedItem: Send + 'static,
sourcepub fn build(
config: &AsyncTrainerConfig,
agent_config: &A::Config,
env_config: &E::Config,
replay_buffer_config: &R::Config,
r_bulk_pushed_item: Receiver<PushedItemMessage<R::PushedItem>>,
model_info_sender: Sender<(usize, A::ModelInfo)>,
stop: Arc<Mutex<bool>>
) -> Self
pub fn build( config: &AsyncTrainerConfig, agent_config: &A::Config, env_config: &E::Config, replay_buffer_config: &R::Config, r_bulk_pushed_item: Receiver<PushedItemMessage<R::PushedItem>>, model_info_sender: Sender<(usize, A::ModelInfo)>, stop: Arc<Mutex<bool>> ) -> Self
Creates AsyncTrainer
.
sourcepub fn train<D>(
&mut self,
recorder: &mut impl Recorder,
evaluator: &mut D,
guard_init_env: Arc<Mutex<bool>>
) -> AsyncTrainStatwhere
D: Evaluator<E, A>,
pub fn train<D>( &mut self, recorder: &mut impl Recorder, evaluator: &mut D, guard_init_env: Arc<Mutex<bool>> ) -> AsyncTrainStatwhere D: Evaluator<E, A>,
Runs training loop.
In the training loop, the following values will be pushed into the given recorder:
samples_total
- Total number of samples pushed into the replay buffer. Here, a “sample” is an item inExperienceBufferBase::PushedItem
.opt_steps_per_sec
- The number of optimization steps per second.samples_per_sec
- The number of samples per second.samples_per_opt_steps
- The number of samples per optimization step.
These values will typically be monitored with tensorboard.
Auto Trait Implementations§
impl<A, E, R> RefUnwindSafe for AsyncTrainer<A, E, R>where A: RefUnwindSafe, E: RefUnwindSafe, R: RefUnwindSafe, <A as Policy<E>>::Config: RefUnwindSafe, <E as Env>::Config: RefUnwindSafe, <R as ReplayBufferBase>::Config: RefUnwindSafe,
impl<A, E, R> Send for AsyncTrainer<A, E, R>where A: Send, E: Send, R: Send, <A as Policy<E>>::Config: Send, <E as Env>::Config: Send, <R as ReplayBufferBase>::Config: Send, <A as SyncModel>::ModelInfo: Send,
impl<A, E, R> Sync for AsyncTrainer<A, E, R>where A: Sync, E: Sync, R: Sync, <A as Policy<E>>::Config: Sync, <E as Env>::Config: Sync, <R as ReplayBufferBase>::Config: Sync, <A as SyncModel>::ModelInfo: Send,
impl<A, E, R> Unpin for AsyncTrainer<A, E, R>where A: Unpin, E: Unpin, R: Unpin, <A as Policy<E>>::Config: Unpin, <E as Env>::Config: Unpin, <R as ReplayBufferBase>::Config: Unpin, <R as ExperienceBufferBase>::PushedItem: Unpin,
impl<A, E, R> UnwindSafe for AsyncTrainer<A, E, R>where A: UnwindSafe, E: UnwindSafe, R: UnwindSafe, <A as Policy<E>>::Config: UnwindSafe, <E as Env>::Config: UnwindSafe, <R as ReplayBufferBase>::Config: UnwindSafe,
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more