Struct border_core::Trainer
source · pub struct Trainer<E, P, R>where
E: Env,
P: StepProcessorBase<E>,
R: ReplayBufferBase<PushedItem = P::Output>,{ /* private fields */ }
Expand description
Manages training loop and related objects.
Training loop
Training loop looks like following:
- Given an agent implementing
Agent
and a recorder implementingRecorder
. - Initialize the objects used in the training loop, involving instances of
Env
,StepProcessorBase
,SyncSampler
.- Reset a counter of the environment steps:
env_steps = 0
- Reset a counter of the optimization steps:
opt_steps = 0
- Reset objects for computing optimization steps per sec (OSPS):
- A timer
timer = SystemTime::now()
. - A counter
opt_steps_ops = 0
- A timer
- Reset a counter of the environment steps:
- Reset
Env
. - Do an environment step and push a transition to the replaybuffer, implementing
ReplayBufferBase
. env_steps += 1
- If
env_steps % opt_interval == 0
:- Do an optimization step for the agent with transition batches
sampled from the replay buffer.
- NOTE: Here, the agent can skip an optimization step because of some reason, for example, during a warmup period for the replay buffer. In this case, the following steps are skipped as well.
opt_steps += 1, opt_steps_ops += 1
- If
opt_steps % eval_interval == 0
:- Do an evaluation of the agent and add the evaluation result to the as
"eval_reward"
. - Reset
timer
andopt_steps_ops
. - If the evaluation result is the best, agent’s model parameters are saved
in directory
(model_dir)/best
.
- Do an evaluation of the agent and add the evaluation result to the as
- If
opt_steps % record_interval == 0
, compute OSPS asopt_steps_ops / timer.elapsed()?.as_secs_f32()
and add it to the recorder as"opt_steps_per_sec"
. - If
opt_steps % save_interval == 0
, agent’s model parameters are saved in directory(model_dir)/(opt_steps)
. - If
opt_steps == max_opts
, finish training loop.
- Do an optimization step for the agent with transition batches
sampled from the replay buffer.
- Back to step 3.
Interaction of objects
In Trainer::train()
method, objects interact as shown below:
graph LR
A[Agent]-->|Env::Act|B[Env]
B -->|Env::Obs|A
B -->|"Step<E: Env>"|C[StepProcessor]
C -->|ReplayBufferBase::PushedItem|D[ReplayBufferBase]
D -->|BatchBase|A
- First,
Agent
emits anEnv::Act
a_t
based onEnv::Obs
o_t
received fromEnv
. Givena_t
,Env
changes its state and creates the observation at the next step,o_t+1
. This step of interaction betweenAgent
andEnv
is referred to as an environment step. - Next,
Step<E: Env>
will be created with the next observationo_t+1
, rewardr_t
, anda_t
. - The
Step<E: Env>
object will be processed byStepProcessorBase
and creates [ReplayBufferBase::PushedItem
], typically representing a transition(o_t, a_t, o_t+1, r_t)
, whereo_t
is kept in theStepProcessorBase
, while other items in the givenStep<E: Env>
. - Finally, the transitions pushed to the
ReplayBufferBase
will be used to create batches, each of which implementingBatchBase
. These batches will be used in optimization steps, where the agent updates its parameters using sampled experiencesp in batches.
Implementations§
source§impl<E, P, R> Trainer<E, P, R>where
E: Env,
P: StepProcessorBase<E>,
R: ReplayBufferBase<PushedItem = P::Output>,
impl<E, P, R> Trainer<E, P, R>where E: Env, P: StepProcessorBase<E>, R: ReplayBufferBase<PushedItem = P::Output>,
sourcepub fn build(
config: TrainerConfig,
env_config_train: E::Config,
step_proc_config: P::Config,
replay_buffer_config: R::Config
) -> Self
pub fn build( config: TrainerConfig, env_config_train: E::Config, step_proc_config: P::Config, replay_buffer_config: R::Config ) -> Self
Constructs a trainer.
sourcepub fn train_step<A>(
&self,
agent: &mut A,
buffer: &mut R,
sampler: &mut SyncSampler<E, P>,
env_steps: &mut usize
) -> Result<Option<Record>>where
A: Agent<E, R>,
pub fn train_step<A>( &self, agent: &mut A, buffer: &mut R, sampler: &mut SyncSampler<E, P>, env_steps: &mut usize ) -> Result<Option<Record>>where A: Agent<E, R>,
Performs a training step.
Auto Trait Implementations§
impl<E, P, R> RefUnwindSafe for Trainer<E, P, R>where <E as Env>::Config: RefUnwindSafe, <P as StepProcessorBase<E>>::Config: RefUnwindSafe, <R as ReplayBufferBase>::Config: RefUnwindSafe,
impl<E, P, R> Send for Trainer<E, P, R>where <E as Env>::Config: Send, <P as StepProcessorBase<E>>::Config: Send, <R as ReplayBufferBase>::Config: Send,
impl<E, P, R> Sync for Trainer<E, P, R>where <E as Env>::Config: Sync, <P as StepProcessorBase<E>>::Config: Sync, <R as ReplayBufferBase>::Config: Sync,
impl<E, P, R> Unpin for Trainer<E, P, R>where <E as Env>::Config: Unpin, <P as StepProcessorBase<E>>::Config: Unpin, <R as ReplayBufferBase>::Config: Unpin,
impl<E, P, R> UnwindSafe for Trainer<E, P, R>where <E as Env>::Config: UnwindSafe, <P as StepProcessorBase<E>>::Config: UnwindSafe, <R as ReplayBufferBase>::Config: UnwindSafe,
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more