Struct border_core::Trainer

source ·
pub struct Trainer<E, P, R>where
    E: Env,
    P: StepProcessorBase<E>,
    R: ReplayBufferBase<PushedItem = P::Output>,{ /* private fields */ }
Expand description

Manages training loop and related objects.

Training loop

Training loop looks like following:

  1. Given an agent implementing Agent and a recorder implementing Recorder.
  2. Initialize the objects used in the training loop, involving instances of Env, StepProcessorBase, SyncSampler.
    • Reset a counter of the environment steps: env_steps = 0
    • Reset a counter of the optimization steps: opt_steps = 0
    • Reset objects for computing optimization steps per sec (OSPS):
      • A timer timer = SystemTime::now().
      • A counter opt_steps_ops = 0
  3. Reset Env.
  4. Do an environment step and push a transition to the replaybuffer, implementing ReplayBufferBase.
  5. env_steps += 1
  6. If env_steps % opt_interval == 0:
    1. Do an optimization step for the agent with transition batches sampled from the replay buffer.
      • NOTE: Here, the agent can skip an optimization step because of some reason, for example, during a warmup period for the replay buffer. In this case, the following steps are skipped as well.
    2. opt_steps += 1, opt_steps_ops += 1
    3. If opt_steps % eval_interval == 0:
      • Do an evaluation of the agent and add the evaluation result to the as "eval_reward".
      • Reset timer and opt_steps_ops.
      • If the evaluation result is the best, agent’s model parameters are saved in directory (model_dir)/best.
    4. If opt_steps % record_interval == 0, compute OSPS as opt_steps_ops / timer.elapsed()?.as_secs_f32() and add it to the recorder as "opt_steps_per_sec".
    5. If opt_steps % save_interval == 0, agent’s model parameters are saved in directory (model_dir)/(opt_steps).
    6. If opt_steps == max_opts, finish training loop.
  7. Back to step 3.

Interaction of objects

In Trainer::train() method, objects interact as shown below:

graph LR A[Agent]-->|Env::Act|B[Env] B -->|Env::Obs|A B -->|"Step<E: Env>"|C[StepProcessor] C -->|ReplayBufferBase::PushedItem|D[ReplayBufferBase] D -->|BatchBase|A
  • First, Agent emits an Env::Act a_t based on Env::Obs o_t received from Env. Given a_t, Env changes its state and creates the observation at the next step, o_t+1. This step of interaction between Agent and Env is referred to as an environment step.
  • Next, Step<E: Env> will be created with the next observation o_t+1, reward r_t, and a_t.
  • The Step<E: Env> object will be processed by StepProcessorBase and creates [ReplayBufferBase::PushedItem], typically representing a transition (o_t, a_t, o_t+1, r_t), where o_t is kept in the StepProcessorBase, while other items in the given Step<E: Env>.
  • Finally, the transitions pushed to the ReplayBufferBase will be used to create batches, each of which implementing BatchBase. These batches will be used in optimization steps, where the agent updates its parameters using sampled experiencesp in batches.

Implementations§

source§

impl<E, P, R> Trainer<E, P, R>where E: Env, P: StepProcessorBase<E>, R: ReplayBufferBase<PushedItem = P::Output>,

source

pub fn build( config: TrainerConfig, env_config_train: E::Config, step_proc_config: P::Config, replay_buffer_config: R::Config ) -> Self

Constructs a trainer.

source

pub fn train_step<A>( &self, agent: &mut A, buffer: &mut R, sampler: &mut SyncSampler<E, P>, env_steps: &mut usize ) -> Result<Option<Record>>where A: Agent<E, R>,

Performs a training step.

source

pub fn train<A, S, D>( &mut self, agent: &mut A, recorder: &mut S, evaluator: &mut D ) -> Result<()>where A: Agent<E, R>, S: Recorder, D: Evaluator<E, A>,

Train the agent.

Auto Trait Implementations§

§

impl<E, P, R> RefUnwindSafe for Trainer<E, P, R>where <E as Env>::Config: RefUnwindSafe, <P as StepProcessorBase<E>>::Config: RefUnwindSafe, <R as ReplayBufferBase>::Config: RefUnwindSafe,

§

impl<E, P, R> Send for Trainer<E, P, R>where <E as Env>::Config: Send, <P as StepProcessorBase<E>>::Config: Send, <R as ReplayBufferBase>::Config: Send,

§

impl<E, P, R> Sync for Trainer<E, P, R>where <E as Env>::Config: Sync, <P as StepProcessorBase<E>>::Config: Sync, <R as ReplayBufferBase>::Config: Sync,

§

impl<E, P, R> Unpin for Trainer<E, P, R>where <E as Env>::Config: Unpin, <P as StepProcessorBase<E>>::Config: Unpin, <R as ReplayBufferBase>::Config: Unpin,

§

impl<E, P, R> UnwindSafe for Trainer<E, P, R>where <E as Env>::Config: UnwindSafe, <P as StepProcessorBase<E>>::Config: UnwindSafe, <R as ReplayBufferBase>::Config: UnwindSafe,

Blanket Implementations§

source§

impl<T> Any for Twhere T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for Twhere T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for Twhere T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for Twhere U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T, U> TryFrom<U> for Twhere U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for Twhere U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for Twhere V: MultiLane<T>,

§

fn vzip(self) -> V