Struct Trainer

Source
pub struct Trainer { /* private fields */ }
Expand description

Manages training loop and related objects.

§Training loop

Training loop looks like following:

  1. Given an agent implementing Agent and a recorder implementing [Recorder].
  2. Initialize the objects used in the training loop, involving instances of Env, StepProcessor, Sampler.
    • Reset a counter of the environment steps: env_steps = 0
    • Reset a counter of the optimization steps: opt_steps = 0
    • Reset objects for computing optimization steps per sec (OSPS):
      • A timer timer = SystemTime::now().
      • A counter opt_steps_ops = 0
  3. Reset Env.
  4. Do an environment step and push a transition to the replaybuffer, implementing ReplayBufferBase.
  5. env_steps += 1
  6. If env_steps % opt_interval == 0:
    1. Do an optimization step for the agent with transition batches sampled from the replay buffer.
      • NOTE: Here, the agent can skip an optimization step because of some reason, for example, during a warmup period for the replay buffer. In this case, the following steps are skipped as well.
    2. opt_steps += 1, opt_steps_ops += 1
    3. If opt_steps % eval_interval == 0:
      • Do an evaluation of the agent and add the evaluation result to the as "eval_reward".
      • Reset timer and opt_steps_ops.
      • If the evaluation result is the best, agent’s model parameters are saved in directory (model_dir)/best.
    4. If opt_steps % record_interval == 0, compute OSPS as opt_steps_ops / timer.elapsed()?.as_secs_f32() and add it to the recorder as "opt_steps_per_sec".
    5. If opt_steps % save_interval == 0, agent’s model parameters are saved in directory (model_dir)/(opt_steps).
    6. If opt_steps == max_opts, finish training loop.
  7. Back to step 3.

§Interaction of objects

In Trainer::train() method, objects interact as shown below:

graph LR A[Agent]-->|Env::Act|B[Env] B -->|Env::Obs|A B -->|"Step<E: Env>"|C[StepProcessor] C -->|ReplayBufferBase::PushedItem|D[ReplayBufferBase] D -->|BatchBase|A
  • First, Agent emits an Env::Act a_t based on Env::Obs o_t received from Env. Given a_t, Env changes its state and creates the observation at the next step, o_t+1. This step of interaction between Agent and Env is referred to as an environment step.
  • Next, Step<E: Env> will be created with the next observation o_t+1, reward r_t, and a_t.
  • The Step<E: Env> object will be processed by StepProcessor and creates [ReplayBufferBase::Item], typically representing a transition (o_t, a_t, o_t+1, r_t), where o_t is kept in the StepProcessor, while other items in the given Step<E: Env>.
  • Finally, the transitions pushed to the ReplayBufferBase will be used to create batches, each of which implementing BatchBase. These batches will be used in optimization steps, where the agent updates its parameters using sampled experiencesp in batches.

Implementations§

Source§

impl Trainer

Source

pub fn build(config: TrainerConfig) -> Self

Constructs a trainer.

Source

pub fn train_step<E, A, R>( &mut self, agent: &mut A, buffer: &mut R, ) -> Result<(Record, bool)>
where E: Env, A: Agent<E, R>, R: ReplayBufferBase,

Performs a training step.

First, it performes an environment step once and pushes a transition into the given buffer with Sampler. Then, if the number of environment steps reaches the optimization interval opt_interval, performes an optimization step.

The second return value in the tuple is if an optimization step is done (true).

Source

pub fn train<E, A, P, R, D>( &mut self, env: E, step_proc: P, agent: &mut A, buffer: &mut R, recorder: &mut Box<dyn AggregateRecorder>, evaluator: &mut D, ) -> Result<()>
where E: Env, A: Agent<E, R>, P: StepProcessor<E>, R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase, D: Evaluator<E, A>,

Train the agent online.

Source

pub fn train_offline<E, A, R, D>( &mut self, agent: &mut A, buffer: &mut R, recorder: &mut Box<dyn AggregateRecorder>, evaluator: &mut D, ) -> Result<()>
where E: Env, A: Agent<E, R>, R: ReplayBufferBase, D: Evaluator<E, A>,

Train the agent offline.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V