border_core/
trainer.rs

1//! Train [`Agent`].
2mod config;
3mod sampler;
4use std::time::{Duration, SystemTime};
5
6use crate::{
7    record::{AggregateRecorder, Record, RecordValue::Scalar},
8    Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor,
9};
10use anyhow::Result;
11pub use config::TrainerConfig;
12use log::info;
13pub use sampler::Sampler;
14
15#[cfg_attr(doc, aquamarine::aquamarine)]
16/// Manages training loop and related objects.
17///
18/// # Training loop
19///
20/// Training loop looks like following:
21///
22/// 0. Given an agent implementing [`Agent`]  and a recorder implementing [`Recorder`].
23/// 1. Initialize the objects used in the training loop, involving instances of [`Env`],
24///    [`StepProcessor`], [`Sampler`].
25///    * Reset a counter of the environment steps: `env_steps = 0`
26///    * Reset a counter of the optimization steps: `opt_steps = 0`
27///    * Reset objects for computing optimization steps per sec (OSPS):
28///        * A timer `timer = SystemTime::now()`.
29///        * A counter `opt_steps_ops = 0`
30/// 2. Reset [`Env`].
31/// 3. Do an environment step and push a transition to the replaybuffer, implementing
32///    [`ReplayBufferBase`].
33/// 4. `env_steps += 1`
34/// 5. If `env_steps % opt_interval == 0`:
35///     1. Do an optimization step for the agent with transition batches
36///        sampled from the replay buffer.
37///         * NOTE: Here, the agent can skip an optimization step because of some reason,
38///           for example, during a warmup period for the replay buffer.
39///           In this case, the following steps are skipped as well.
40///     2. `opt_steps += 1, opt_steps_ops += 1`
41///     3. If `opt_steps % eval_interval == 0`:
42///         * Do an evaluation of the agent and add the evaluation result to the as
43///           `"eval_reward"`.
44///         * Reset `timer` and `opt_steps_ops`.
45///         * If the evaluation result is the best, agent's model parameters are saved
46///           in directory `(model_dir)/best`.
47///     4. If `opt_steps % record_interval == 0`, compute OSPS as
48///        `opt_steps_ops / timer.elapsed()?.as_secs_f32()` and add it to the
49///        recorder as `"opt_steps_per_sec"`.
50///     5. If `opt_steps % save_interval == 0`, agent's model parameters are saved
51///        in directory `(model_dir)/(opt_steps)`.
52///     6. If `opt_steps == max_opts`, finish training loop.
53/// 6. Back to step 3.
54///
55/// # Interaction of objects
56///
57/// In [`Trainer::train()`] method, objects interact as shown below:
58///
59/// ```mermaid
60/// graph LR
61///     A[Agent]-->|Env::Act|B[Env]
62///     B -->|Env::Obs|A
63///     B -->|"Step<E: Env>"|C[StepProcessor]
64///     C -->|ReplayBufferBase::PushedItem|D[ReplayBufferBase]
65///     D -->|BatchBase|A
66/// ```
67///
68/// * First, [`Agent`] emits an [`Env::Act`] `a_t` based on [`Env::Obs`] `o_t` received from
69///   [`Env`]. Given `a_t`, [`Env`] changes its state and creates the observation at the
70///   next step, `o_t+1`. This step of interaction between [`Agent`] and [`Env`] is
71///   referred to as an *environment step*.
72/// * Next, [`Step<E: Env>`] will be created with the next observation `o_t+1`,
73///   reward `r_t`, and `a_t`.
74/// * The [`Step<E: Env>`] object will be processed by [`StepProcessor`] and
75///   creates [`ReplayBufferBase::Item`], typically representing a transition
76///   `(o_t, a_t, o_t+1, r_t)`, where `o_t` is kept in the
77///   [`StepProcessor`], while other items in the given [`Step<E: Env>`].
78/// * Finally, the transitions pushed to the [`ReplayBufferBase`] will be used to create
79///   batches, each of which implementing [`BatchBase`]. These batches will be used in
80///   *optimization step*s, where the agent updates its parameters using sampled
81///   experiencesp in batches.
82///
83/// [`Trainer::train()`]: Trainer::train
84/// [`Act`]: crate::Act
85/// [`BatchBase`]: crate::BatchBase
86/// [`Step<E: Env>`]: crate::Step
87pub struct Trainer {
88    /// Where to save the trained model.
89    model_dir: Option<String>,
90
91    /// Interval of optimization in environment steps.
92    /// This is ignored for offline training.
93    opt_interval: usize,
94
95    /// Interval of recording computational cost in optimization steps.
96    record_compute_cost_interval: usize,
97
98    /// Interval of recording agent information in optimization steps.
99    record_agent_info_interval: usize,
100
101    /// Interval of flushing records in optimization steps.
102    flush_records_interval: usize,
103
104    /// Interval of evaluation in optimization steps.
105    eval_interval: usize,
106
107    /// Interval of saving the model in optimization steps.
108    save_interval: usize,
109
110    /// The maximal number of optimization steps.
111    max_opts: usize,
112
113    /// Optimization steps for computing optimization steps per second.
114    opt_steps_for_ops: usize,
115
116    /// Timer for computing for optimization steps per second.
117    timer_for_ops: Duration,
118
119    /// Warmup period, for filling replay buffer, in environment steps.
120    /// This is ignored for offline training.
121    warmup_period: usize,
122
123    /// Max value of evaluation reward.
124    max_eval_reward: f32,
125
126    /// Environment steps during online training.
127    env_steps: usize,
128
129    /// Optimization steps during training.
130    opt_steps: usize,
131}
132
133impl Trainer {
134    /// Constructs a trainer.
135    pub fn build(config: TrainerConfig) -> Self {
136        Self {
137            model_dir: config.model_dir,
138            opt_interval: config.opt_interval,
139            record_compute_cost_interval: config.record_compute_cost_interval,
140            record_agent_info_interval: config.record_agent_info_interval,
141            flush_records_interval: config.flush_record_interval,
142            eval_interval: config.eval_interval,
143            save_interval: config.save_interval,
144            max_opts: config.max_opts,
145            warmup_period: config.warmup_period,
146            opt_steps_for_ops: 0,
147            timer_for_ops: Duration::new(0, 0),
148            max_eval_reward: f32::MIN,
149            env_steps: 0,
150            opt_steps: 0,
151        }
152    }
153
154    fn save_model<E, A, R>(agent: &A, model_dir: String)
155    where
156        E: Env,
157        A: Agent<E, R>,
158        R: ReplayBufferBase,
159    {
160        match agent.save_params(&model_dir) {
161            Ok(()) => info!("Saved the model in {:?}.", &model_dir),
162            Err(_) => info!("Failed to save model in {:?}.", &model_dir),
163        }
164    }
165
166    fn save_best_model<E, A, R>(agent: &A, model_dir: String)
167    where
168        E: Env,
169        A: Agent<E, R>,
170        R: ReplayBufferBase,
171    {
172        let model_dir = model_dir + "/best";
173        Self::save_model(agent, model_dir);
174    }
175
176    fn save_model_with_steps<E, A, R>(agent: &A, model_dir: String, steps: usize)
177    where
178        E: Env,
179        A: Agent<E, R>,
180        R: ReplayBufferBase,
181    {
182        let model_dir = model_dir + format!("/{}", steps).as_str();
183        Self::save_model(agent, model_dir);
184    }
185
186    /// Returns optimization steps per second, then reset the internal counter.
187    fn opt_steps_per_sec(&mut self) -> f32 {
188        let osps = 1000. * self.opt_steps_for_ops as f32 / (self.timer_for_ops.as_millis() as f32);
189        self.opt_steps_for_ops = 0;
190        self.timer_for_ops = Duration::new(0, 0);
191        osps
192    }
193
194    /// Performs a training step.
195    ///
196    /// First, it performes an environment step once and pushes a transition
197    /// into the given buffer with [`Sampler`]. Then, if the number of environment steps
198    /// reaches the optimization interval `opt_interval`, performes an optimization
199    /// step.
200    ///
201    /// The second return value in the tuple is if an optimization step is done (`true`).
202    // pub fn train_step<E, A, P, R>(
203    pub fn train_step<E, A, R>(&mut self, agent: &mut A, buffer: &mut R) -> Result<(Record, bool)>
204    where
205        E: Env,
206        A: Agent<E, R>,
207        R: ReplayBufferBase,
208    {
209        if self.env_steps < self.warmup_period {
210            Ok((Record::empty(), false))
211        } else if self.env_steps % self.opt_interval != 0 {
212            // skip optimization step
213            Ok((Record::empty(), false))
214        } else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 {
215            // Do optimization step with record
216            let timer = SystemTime::now();
217            let record_agent = agent.opt_with_record(buffer);
218            self.opt_steps += 1;
219            self.timer_for_ops += timer.elapsed()?;
220            self.opt_steps_for_ops += 1;
221            Ok((record_agent, true))
222        } else {
223            // Do optimization step without record
224            let timer = SystemTime::now();
225            agent.opt(buffer);
226            self.opt_steps += 1;
227            self.timer_for_ops += timer.elapsed()?;
228            self.opt_steps_for_ops += 1;
229            Ok((Record::empty(), true))
230        }
231    }
232
233    fn post_process<E, A, R, D>(
234        &mut self,
235        agent: &mut A,
236        evaluator: &mut D,
237        record: &mut Record,
238        fps: f32,
239    ) -> Result<()>
240    where
241        E: Env,
242        A: Agent<E, R>,
243        R: ReplayBufferBase,
244        D: Evaluator<E, A>,
245    {
246        // Add stats wrt computation cost
247        if self.opt_steps % self.record_compute_cost_interval == 0 {
248            record.insert("fps", Scalar(fps));
249            record.insert("opt_steps_per_sec", Scalar(self.opt_steps_per_sec()));
250        }
251
252        // Evaluation
253        if self.opt_steps % self.eval_interval == 0 {
254            info!("Starts evaluation of the trained model");
255            agent.eval();
256            let eval_reward = evaluator.evaluate(agent)?;
257            agent.train();
258            record.insert("eval_reward", Scalar(eval_reward));
259
260            // Save the best model up to the current iteration
261            if eval_reward > self.max_eval_reward {
262                self.max_eval_reward = eval_reward;
263                let model_dir = self.model_dir.as_ref().unwrap().clone();
264                Self::save_best_model(agent, model_dir)
265            }
266        };
267
268        // Save the current model
269        if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) {
270            let model_dir = self.model_dir.as_ref().unwrap().clone();
271            Self::save_model_with_steps(agent, model_dir, self.opt_steps);
272        }
273
274        Ok(())
275    }
276
277    fn loop_step<E, A, R, D>(
278        &mut self,
279        agent: &mut A,
280        buffer: &mut R,
281        recorder: &mut Box<dyn AggregateRecorder>,
282        evaluator: &mut D,
283        record: Record,
284        fps: f32,
285    ) -> Result<bool>
286    where
287        E: Env,
288        A: Agent<E, R>,
289        R: ReplayBufferBase,
290        D: Evaluator<E, A>,
291    {
292        // Performe optimization step(s)
293        let (mut record, is_opt) = {
294            let (r, is_opt) = self.train_step(agent, buffer)?;
295            (record.merge(r), is_opt)
296        };
297
298        // Postprocessing after each training step
299        if is_opt {
300            self.post_process(agent, evaluator, &mut record, fps)?;
301
302            // End loop
303            if self.opt_steps == self.max_opts {
304                return Ok(true);
305            }
306        }
307
308        // Store record to the recorder
309        if !record.is_empty() {
310            recorder.store(record);
311        }
312
313        // Flush records
314        if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
315            recorder.flush(self.opt_steps as _);
316        }
317
318        Ok(false)
319    }
320
321    /// Train the agent online.
322    pub fn train<E, A, P, R, D>(
323        &mut self,
324        env: E,
325        step_proc: P,
326        agent: &mut A,
327        buffer: &mut R,
328        recorder: &mut Box<dyn AggregateRecorder>,
329        evaluator: &mut D,
330    ) -> Result<()>
331    where
332        E: Env,
333        A: Agent<E, R>,
334        P: StepProcessor<E>,
335        R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
336        D: Evaluator<E, A>,
337    {
338        let mut sampler = Sampler::new(env, step_proc);
339        sampler.reset_fps_counter();
340        agent.train();
341
342        loop {
343            let record = sampler.sample_and_push(agent, buffer)?;
344            let fps = sampler.fps();
345            self.env_steps += 1;
346
347            if self.loop_step(agent, buffer, recorder, evaluator, record, fps)? {
348                return Ok(());
349            }
350        }
351    }
352
353    /// Train the agent offline.
354    pub fn train_offline<E, A, R, D>(
355        &mut self,
356        agent: &mut A,
357        buffer: &mut R,
358        recorder: &mut Box<dyn AggregateRecorder>,
359        evaluator: &mut D,
360    ) -> Result<()>
361    where
362        E: Env,
363        A: Agent<E, R>,
364        R: ReplayBufferBase,
365        D: Evaluator<E, A>,
366    {
367        // Return empty record
368        self.warmup_period = 0;
369        self.opt_interval = 1;
370        agent.train();
371        let fps = 0f32;
372
373        loop {
374            let record = Record::empty();
375
376            if self.loop_step(agent, buffer, recorder, evaluator, record, fps)? {
377                return Ok(());
378            }
379        }
380    }
381}