pub struct Trainer { /* private fields */ }Expand description
Manages training loop and related objects.
§Training loop
Training loop looks like following:
- Given an agent implementing
Agentand a recorder implementing [Recorder]. - Initialize the objects used in the training loop, involving instances of
Env,StepProcessor,Sampler.- Reset a counter of the environment steps:
env_steps = 0 - Reset a counter of the optimization steps:
opt_steps = 0 - Reset objects for computing optimization steps per sec (OSPS):
- A timer
timer = SystemTime::now(). - A counter
opt_steps_ops = 0
- A timer
- Reset a counter of the environment steps:
- Reset
Env. - Do an environment step and push a transition to the replaybuffer, implementing
ReplayBufferBase. env_steps += 1- If
env_steps % opt_interval == 0:- Do an optimization step for the agent with transition batches
sampled from the replay buffer.
- NOTE: Here, the agent can skip an optimization step because of some reason, for example, during a warmup period for the replay buffer. In this case, the following steps are skipped as well.
opt_steps += 1, opt_steps_ops += 1- If
opt_steps % eval_interval == 0:- Do an evaluation of the agent and add the evaluation result to the as
"eval_reward". - Reset
timerandopt_steps_ops. - If the evaluation result is the best, agent’s model parameters are saved
in directory
(model_dir)/best.
- Do an evaluation of the agent and add the evaluation result to the as
- If
opt_steps % record_interval == 0, compute OSPS asopt_steps_ops / timer.elapsed()?.as_secs_f32()and add it to the recorder as"opt_steps_per_sec". - If
opt_steps % save_interval == 0, agent’s model parameters are saved in directory(model_dir)/(opt_steps). - If
opt_steps == max_opts, finish training loop.
- Do an optimization step for the agent with transition batches
sampled from the replay buffer.
- Back to step 3.
§Interaction of objects
In Trainer::train() method, objects interact as shown below:
graph LR
A[Agent]-->|Env::Act|B[Env]
B -->|Env::Obs|A
B -->|"Step<E: Env>"|C[StepProcessor]
C -->|ReplayBufferBase::PushedItem|D[ReplayBufferBase]
D -->|BatchBase|A
- First,
Agentemits anEnv::Acta_tbased onEnv::Obso_treceived fromEnv. Givena_t,Envchanges its state and creates the observation at the next step,o_t+1. This step of interaction betweenAgentandEnvis referred to as an environment step. - Next,
Step<E: Env>will be created with the next observationo_t+1, rewardr_t, anda_t. - The
Step<E: Env>object will be processed byStepProcessorand creates [ReplayBufferBase::Item], typically representing a transition(o_t, a_t, o_t+1, r_t), whereo_tis kept in theStepProcessor, while other items in the givenStep<E: Env>. - Finally, the transitions pushed to the
ReplayBufferBasewill be used to create batches, each of which implementingBatchBase. These batches will be used in optimization steps, where the agent updates its parameters using sampled experiencesp in batches.
Implementations§
Source§impl Trainer
impl Trainer
Sourcepub fn build(config: TrainerConfig) -> Self
pub fn build(config: TrainerConfig) -> Self
Constructs a trainer.
Sourcepub fn train_step<E, A, R>(
&mut self,
agent: &mut A,
buffer: &mut R,
) -> Result<(Record, bool)>
pub fn train_step<E, A, R>( &mut self, agent: &mut A, buffer: &mut R, ) -> Result<(Record, bool)>
Performs a training step.
First, it performes an environment step once and pushes a transition
into the given buffer with Sampler. Then, if the number of environment steps
reaches the optimization interval opt_interval, performes an optimization
step.
The second return value in the tuple is if an optimization step is done (true).
Sourcepub fn train<E, A, P, R, D>(
&mut self,
env: E,
step_proc: P,
agent: &mut A,
buffer: &mut R,
recorder: &mut Box<dyn AggregateRecorder>,
evaluator: &mut D,
) -> Result<()>where
E: Env,
A: Agent<E, R>,
P: StepProcessor<E>,
R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
D: Evaluator<E, A>,
pub fn train<E, A, P, R, D>(
&mut self,
env: E,
step_proc: P,
agent: &mut A,
buffer: &mut R,
recorder: &mut Box<dyn AggregateRecorder>,
evaluator: &mut D,
) -> Result<()>where
E: Env,
A: Agent<E, R>,
P: StepProcessor<E>,
R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
D: Evaluator<E, A>,
Train the agent online.
Sourcepub fn train_offline<E, A, R, D>(
&mut self,
agent: &mut A,
buffer: &mut R,
recorder: &mut Box<dyn AggregateRecorder>,
evaluator: &mut D,
) -> Result<()>
pub fn train_offline<E, A, R, D>( &mut self, agent: &mut A, buffer: &mut R, recorder: &mut Box<dyn AggregateRecorder>, evaluator: &mut D, ) -> Result<()>
Train the agent offline.
Auto Trait Implementations§
impl Freeze for Trainer
impl RefUnwindSafe for Trainer
impl Send for Trainer
impl Sync for Trainer
impl Unpin for Trainer
impl UnwindSafe for Trainer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more