border_core/trainer.rs
1//! Train [`Agent`].
2mod config;
3mod sampler;
4use std::time::{Duration, SystemTime};
5
6use crate::{
7 record::{AggregateRecorder, Record, RecordValue::Scalar},
8 Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor,
9};
10use anyhow::Result;
11pub use config::TrainerConfig;
12use log::info;
13pub use sampler::Sampler;
14
15#[cfg_attr(doc, aquamarine::aquamarine)]
16/// Manages training loop and related objects.
17///
18/// # Training loop
19///
20/// Training loop looks like following:
21///
22/// 0. Given an agent implementing [`Agent`] and a recorder implementing [`Recorder`].
23/// 1. Initialize the objects used in the training loop, involving instances of [`Env`],
24/// [`StepProcessor`], [`Sampler`].
25/// * Reset a counter of the environment steps: `env_steps = 0`
26/// * Reset a counter of the optimization steps: `opt_steps = 0`
27/// * Reset objects for computing optimization steps per sec (OSPS):
28/// * A timer `timer = SystemTime::now()`.
29/// * A counter `opt_steps_ops = 0`
30/// 2. Reset [`Env`].
31/// 3. Do an environment step and push a transition to the replaybuffer, implementing
32/// [`ReplayBufferBase`].
33/// 4. `env_steps += 1`
34/// 5. If `env_steps % opt_interval == 0`:
35/// 1. Do an optimization step for the agent with transition batches
36/// sampled from the replay buffer.
37/// * NOTE: Here, the agent can skip an optimization step because of some reason,
38/// for example, during a warmup period for the replay buffer.
39/// In this case, the following steps are skipped as well.
40/// 2. `opt_steps += 1, opt_steps_ops += 1`
41/// 3. If `opt_steps % eval_interval == 0`:
42/// * Do an evaluation of the agent and add the evaluation result to the as
43/// `"eval_reward"`.
44/// * Reset `timer` and `opt_steps_ops`.
45/// * If the evaluation result is the best, agent's model parameters are saved
46/// in directory `(model_dir)/best`.
47/// 4. If `opt_steps % record_interval == 0`, compute OSPS as
48/// `opt_steps_ops / timer.elapsed()?.as_secs_f32()` and add it to the
49/// recorder as `"opt_steps_per_sec"`.
50/// 5. If `opt_steps % save_interval == 0`, agent's model parameters are saved
51/// in directory `(model_dir)/(opt_steps)`.
52/// 6. If `opt_steps == max_opts`, finish training loop.
53/// 6. Back to step 3.
54///
55/// # Interaction of objects
56///
57/// In [`Trainer::train()`] method, objects interact as shown below:
58///
59/// ```mermaid
60/// graph LR
61/// A[Agent]-->|Env::Act|B[Env]
62/// B -->|Env::Obs|A
63/// B -->|"Step<E: Env>"|C[StepProcessor]
64/// C -->|ReplayBufferBase::PushedItem|D[ReplayBufferBase]
65/// D -->|BatchBase|A
66/// ```
67///
68/// * First, [`Agent`] emits an [`Env::Act`] `a_t` based on [`Env::Obs`] `o_t` received from
69/// [`Env`]. Given `a_t`, [`Env`] changes its state and creates the observation at the
70/// next step, `o_t+1`. This step of interaction between [`Agent`] and [`Env`] is
71/// referred to as an *environment step*.
72/// * Next, [`Step<E: Env>`] will be created with the next observation `o_t+1`,
73/// reward `r_t`, and `a_t`.
74/// * The [`Step<E: Env>`] object will be processed by [`StepProcessor`] and
75/// creates [`ReplayBufferBase::Item`], typically representing a transition
76/// `(o_t, a_t, o_t+1, r_t)`, where `o_t` is kept in the
77/// [`StepProcessor`], while other items in the given [`Step<E: Env>`].
78/// * Finally, the transitions pushed to the [`ReplayBufferBase`] will be used to create
79/// batches, each of which implementing [`BatchBase`]. These batches will be used in
80/// *optimization step*s, where the agent updates its parameters using sampled
81/// experiencesp in batches.
82///
83/// [`Trainer::train()`]: Trainer::train
84/// [`Act`]: crate::Act
85/// [`BatchBase`]: crate::BatchBase
86/// [`Step<E: Env>`]: crate::Step
87pub struct Trainer {
88 /// Where to save the trained model.
89 model_dir: Option<String>,
90
91 /// Interval of optimization in environment steps.
92 /// This is ignored for offline training.
93 opt_interval: usize,
94
95 /// Interval of recording computational cost in optimization steps.
96 record_compute_cost_interval: usize,
97
98 /// Interval of recording agent information in optimization steps.
99 record_agent_info_interval: usize,
100
101 /// Interval of flushing records in optimization steps.
102 flush_records_interval: usize,
103
104 /// Interval of evaluation in optimization steps.
105 eval_interval: usize,
106
107 /// Interval of saving the model in optimization steps.
108 save_interval: usize,
109
110 /// The maximal number of optimization steps.
111 max_opts: usize,
112
113 /// Optimization steps for computing optimization steps per second.
114 opt_steps_for_ops: usize,
115
116 /// Timer for computing for optimization steps per second.
117 timer_for_ops: Duration,
118
119 /// Warmup period, for filling replay buffer, in environment steps.
120 /// This is ignored for offline training.
121 warmup_period: usize,
122
123 /// Max value of evaluation reward.
124 max_eval_reward: f32,
125
126 /// Environment steps during online training.
127 env_steps: usize,
128
129 /// Optimization steps during training.
130 opt_steps: usize,
131}
132
133impl Trainer {
134 /// Constructs a trainer.
135 pub fn build(config: TrainerConfig) -> Self {
136 Self {
137 model_dir: config.model_dir,
138 opt_interval: config.opt_interval,
139 record_compute_cost_interval: config.record_compute_cost_interval,
140 record_agent_info_interval: config.record_agent_info_interval,
141 flush_records_interval: config.flush_record_interval,
142 eval_interval: config.eval_interval,
143 save_interval: config.save_interval,
144 max_opts: config.max_opts,
145 warmup_period: config.warmup_period,
146 opt_steps_for_ops: 0,
147 timer_for_ops: Duration::new(0, 0),
148 max_eval_reward: f32::MIN,
149 env_steps: 0,
150 opt_steps: 0,
151 }
152 }
153
154 fn save_model<E, A, R>(agent: &A, model_dir: String)
155 where
156 E: Env,
157 A: Agent<E, R>,
158 R: ReplayBufferBase,
159 {
160 match agent.save_params(&model_dir) {
161 Ok(()) => info!("Saved the model in {:?}.", &model_dir),
162 Err(_) => info!("Failed to save model in {:?}.", &model_dir),
163 }
164 }
165
166 fn save_best_model<E, A, R>(agent: &A, model_dir: String)
167 where
168 E: Env,
169 A: Agent<E, R>,
170 R: ReplayBufferBase,
171 {
172 let model_dir = model_dir + "/best";
173 Self::save_model(agent, model_dir);
174 }
175
176 fn save_model_with_steps<E, A, R>(agent: &A, model_dir: String, steps: usize)
177 where
178 E: Env,
179 A: Agent<E, R>,
180 R: ReplayBufferBase,
181 {
182 let model_dir = model_dir + format!("/{}", steps).as_str();
183 Self::save_model(agent, model_dir);
184 }
185
186 /// Returns optimization steps per second, then reset the internal counter.
187 fn opt_steps_per_sec(&mut self) -> f32 {
188 let osps = 1000. * self.opt_steps_for_ops as f32 / (self.timer_for_ops.as_millis() as f32);
189 self.opt_steps_for_ops = 0;
190 self.timer_for_ops = Duration::new(0, 0);
191 osps
192 }
193
194 /// Performs a training step.
195 ///
196 /// First, it performes an environment step once and pushes a transition
197 /// into the given buffer with [`Sampler`]. Then, if the number of environment steps
198 /// reaches the optimization interval `opt_interval`, performes an optimization
199 /// step.
200 ///
201 /// The second return value in the tuple is if an optimization step is done (`true`).
202 // pub fn train_step<E, A, P, R>(
203 pub fn train_step<E, A, R>(&mut self, agent: &mut A, buffer: &mut R) -> Result<(Record, bool)>
204 where
205 E: Env,
206 A: Agent<E, R>,
207 R: ReplayBufferBase,
208 {
209 if self.env_steps < self.warmup_period {
210 Ok((Record::empty(), false))
211 } else if self.env_steps % self.opt_interval != 0 {
212 // skip optimization step
213 Ok((Record::empty(), false))
214 } else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 {
215 // Do optimization step with record
216 let timer = SystemTime::now();
217 let record_agent = agent.opt_with_record(buffer);
218 self.opt_steps += 1;
219 self.timer_for_ops += timer.elapsed()?;
220 self.opt_steps_for_ops += 1;
221 Ok((record_agent, true))
222 } else {
223 // Do optimization step without record
224 let timer = SystemTime::now();
225 agent.opt(buffer);
226 self.opt_steps += 1;
227 self.timer_for_ops += timer.elapsed()?;
228 self.opt_steps_for_ops += 1;
229 Ok((Record::empty(), true))
230 }
231 }
232
233 fn post_process<E, A, R, D>(
234 &mut self,
235 agent: &mut A,
236 evaluator: &mut D,
237 record: &mut Record,
238 fps: f32,
239 ) -> Result<()>
240 where
241 E: Env,
242 A: Agent<E, R>,
243 R: ReplayBufferBase,
244 D: Evaluator<E, A>,
245 {
246 // Add stats wrt computation cost
247 if self.opt_steps % self.record_compute_cost_interval == 0 {
248 record.insert("fps", Scalar(fps));
249 record.insert("opt_steps_per_sec", Scalar(self.opt_steps_per_sec()));
250 }
251
252 // Evaluation
253 if self.opt_steps % self.eval_interval == 0 {
254 info!("Starts evaluation of the trained model");
255 agent.eval();
256 let eval_reward = evaluator.evaluate(agent)?;
257 agent.train();
258 record.insert("eval_reward", Scalar(eval_reward));
259
260 // Save the best model up to the current iteration
261 if eval_reward > self.max_eval_reward {
262 self.max_eval_reward = eval_reward;
263 let model_dir = self.model_dir.as_ref().unwrap().clone();
264 Self::save_best_model(agent, model_dir)
265 }
266 };
267
268 // Save the current model
269 if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) {
270 let model_dir = self.model_dir.as_ref().unwrap().clone();
271 Self::save_model_with_steps(agent, model_dir, self.opt_steps);
272 }
273
274 Ok(())
275 }
276
277 fn loop_step<E, A, R, D>(
278 &mut self,
279 agent: &mut A,
280 buffer: &mut R,
281 recorder: &mut Box<dyn AggregateRecorder>,
282 evaluator: &mut D,
283 record: Record,
284 fps: f32,
285 ) -> Result<bool>
286 where
287 E: Env,
288 A: Agent<E, R>,
289 R: ReplayBufferBase,
290 D: Evaluator<E, A>,
291 {
292 // Performe optimization step(s)
293 let (mut record, is_opt) = {
294 let (r, is_opt) = self.train_step(agent, buffer)?;
295 (record.merge(r), is_opt)
296 };
297
298 // Postprocessing after each training step
299 if is_opt {
300 self.post_process(agent, evaluator, &mut record, fps)?;
301
302 // End loop
303 if self.opt_steps == self.max_opts {
304 return Ok(true);
305 }
306 }
307
308 // Store record to the recorder
309 if !record.is_empty() {
310 recorder.store(record);
311 }
312
313 // Flush records
314 if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
315 recorder.flush(self.opt_steps as _);
316 }
317
318 Ok(false)
319 }
320
321 /// Train the agent online.
322 pub fn train<E, A, P, R, D>(
323 &mut self,
324 env: E,
325 step_proc: P,
326 agent: &mut A,
327 buffer: &mut R,
328 recorder: &mut Box<dyn AggregateRecorder>,
329 evaluator: &mut D,
330 ) -> Result<()>
331 where
332 E: Env,
333 A: Agent<E, R>,
334 P: StepProcessor<E>,
335 R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
336 D: Evaluator<E, A>,
337 {
338 let mut sampler = Sampler::new(env, step_proc);
339 sampler.reset_fps_counter();
340 agent.train();
341
342 loop {
343 let record = sampler.sample_and_push(agent, buffer)?;
344 let fps = sampler.fps();
345 self.env_steps += 1;
346
347 if self.loop_step(agent, buffer, recorder, evaluator, record, fps)? {
348 return Ok(());
349 }
350 }
351 }
352
353 /// Train the agent offline.
354 pub fn train_offline<E, A, R, D>(
355 &mut self,
356 agent: &mut A,
357 buffer: &mut R,
358 recorder: &mut Box<dyn AggregateRecorder>,
359 evaluator: &mut D,
360 ) -> Result<()>
361 where
362 E: Env,
363 A: Agent<E, R>,
364 R: ReplayBufferBase,
365 D: Evaluator<E, A>,
366 {
367 // Return empty record
368 self.warmup_period = 0;
369 self.opt_interval = 1;
370 agent.train();
371 let fps = 0f32;
372
373 loop {
374 let record = Record::empty();
375
376 if self.loop_step(agent, buffer, recorder, evaluator, record, fps)? {
377 return Ok(());
378 }
379 }
380 }
381}