pub struct Trainer { /* private fields */ }
Expand description
Manages training loop and related objects.
§Training loop
Training loop looks like following:
- Given an agent implementing
Agent
and a recorder implementing [Recorder
]. - Initialize the objects used in the training loop, involving instances of
Env
,StepProcessor
,Sampler
.- Reset a counter of the environment steps:
env_steps = 0
- Reset a counter of the optimization steps:
opt_steps = 0
- Reset objects for computing optimization steps per sec (OSPS):
- A timer
timer = SystemTime::now()
. - A counter
opt_steps_ops = 0
- A timer
- Reset a counter of the environment steps:
- Reset
Env
. - Do an environment step and push a transition to the replaybuffer, implementing
ReplayBufferBase
. env_steps += 1
- If
env_steps % opt_interval == 0
:- Do an optimization step for the agent with transition batches
sampled from the replay buffer.
- NOTE: Here, the agent can skip an optimization step because of some reason, for example, during a warmup period for the replay buffer. In this case, the following steps are skipped as well.
opt_steps += 1, opt_steps_ops += 1
- If
opt_steps % eval_interval == 0
:- Do an evaluation of the agent and add the evaluation result to the as
"eval_reward"
. - Reset
timer
andopt_steps_ops
. - If the evaluation result is the best, agent’s model parameters are saved
in directory
(model_dir)/best
.
- Do an evaluation of the agent and add the evaluation result to the as
- If
opt_steps % record_interval == 0
, compute OSPS asopt_steps_ops / timer.elapsed()?.as_secs_f32()
and add it to the recorder as"opt_steps_per_sec"
. - If
opt_steps % save_interval == 0
, agent’s model parameters are saved in directory(model_dir)/(opt_steps)
. - If
opt_steps == max_opts
, finish training loop.
- Do an optimization step for the agent with transition batches
sampled from the replay buffer.
- Back to step 3.
§Interaction of objects
In Trainer::train()
method, objects interact as shown below:
graph LR
A[Agent]-->|Env::Act|B[Env]
B -->|Env::Obs|A
B -->|"Step<E: Env>"|C[StepProcessor]
C -->|ReplayBufferBase::PushedItem|D[ReplayBufferBase]
D -->|BatchBase|A
- First,
Agent
emits anEnv::Act
a_t
based onEnv::Obs
o_t
received fromEnv
. Givena_t
,Env
changes its state and creates the observation at the next step,o_t+1
. This step of interaction betweenAgent
andEnv
is referred to as an environment step. - Next,
Step<E: Env>
will be created with the next observationo_t+1
, rewardr_t
, anda_t
. - The
Step<E: Env>
object will be processed byStepProcessor
and creates [ReplayBufferBase::Item
], typically representing a transition(o_t, a_t, o_t+1, r_t)
, whereo_t
is kept in theStepProcessor
, while other items in the givenStep<E: Env>
. - Finally, the transitions pushed to the
ReplayBufferBase
will be used to create batches, each of which implementingBatchBase
. These batches will be used in optimization steps, where the agent updates its parameters using sampled experiencesp in batches.
Implementations§
Source§impl Trainer
impl Trainer
Sourcepub fn build(config: TrainerConfig) -> Self
pub fn build(config: TrainerConfig) -> Self
Constructs a trainer.
Sourcepub fn train_step<E, A, R>(
&mut self,
agent: &mut A,
buffer: &mut R,
) -> Result<(Record, bool)>
pub fn train_step<E, A, R>( &mut self, agent: &mut A, buffer: &mut R, ) -> Result<(Record, bool)>
Performs a training step.
First, it performes an environment step once and pushes a transition
into the given buffer with Sampler
. Then, if the number of environment steps
reaches the optimization interval opt_interval
, performes an optimization
step.
The second return value in the tuple is if an optimization step is done (true
).
Sourcepub fn train<E, A, P, R, D>(
&mut self,
env: E,
step_proc: P,
agent: &mut A,
buffer: &mut R,
recorder: &mut Box<dyn AggregateRecorder>,
evaluator: &mut D,
) -> Result<()>where
E: Env,
A: Agent<E, R>,
P: StepProcessor<E>,
R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
D: Evaluator<E, A>,
pub fn train<E, A, P, R, D>(
&mut self,
env: E,
step_proc: P,
agent: &mut A,
buffer: &mut R,
recorder: &mut Box<dyn AggregateRecorder>,
evaluator: &mut D,
) -> Result<()>where
E: Env,
A: Agent<E, R>,
P: StepProcessor<E>,
R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
D: Evaluator<E, A>,
Train the agent online.
Sourcepub fn train_offline<E, A, R, D>(
&mut self,
agent: &mut A,
buffer: &mut R,
recorder: &mut Box<dyn AggregateRecorder>,
evaluator: &mut D,
) -> Result<()>
pub fn train_offline<E, A, R, D>( &mut self, agent: &mut A, buffer: &mut R, recorder: &mut Box<dyn AggregateRecorder>, evaluator: &mut D, ) -> Result<()>
Train the agent offline.
Auto Trait Implementations§
impl Freeze for Trainer
impl RefUnwindSafe for Trainer
impl Send for Trainer
impl Sync for Trainer
impl Unpin for Trainer
impl UnwindSafe for Trainer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more