border_core/trainer.rs
1//! Training loop management and agent optimization.
2//!
3//! This module provides functionality for managing the training process of reinforcement
4//! learning agents. It handles environment interactions, experience collection,
5//! optimization steps, and evaluation.
6
7mod config;
8mod sampler;
9use std::time::{Duration, SystemTime};
10
11use crate::{
12 record::{Record, RecordValue::Scalar, Recorder},
13 Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor,
14};
15use anyhow::Result;
16pub use config::TrainerConfig;
17use log::info;
18pub use sampler::Sampler;
19
20/// Manages the training loop and coordinates interactions between components.
21///
22/// The `Trainer` orchestrates the training process by managing:
23///
24/// * Environment interactions and experience collection
25/// * Agent optimization and parameter updates
26/// * Performance evaluation and model saving
27/// * Training metrics recording
28///
29/// # Training Process
30///
31/// The training loop follows these steps:
32///
33/// 1. Initialize training components:
34/// * Reset environment step counter (`env_steps = 0`)
35/// * Reset optimization step counter (`opt_steps = 0`)
36/// * Initialize performance monitoring
37///
38/// 2. Environment Interaction:
39/// * Agent observes environment state
40/// * Agent selects and executes action
41/// * Environment transitions to new state
42/// * Experience is collected and stored
43///
44/// 3. Optimization:
45/// * At specified intervals (`opt_interval`):
46/// * Sample experiences from replay buffer
47/// * Update agent parameters
48/// * Track optimization performance
49///
50/// 4. Evaluation and Recording:
51/// * Periodically evaluate agent performance
52/// * Record training metrics
53/// * Save model checkpoints
54/// * Monitor optimization speed
55///
56/// # Model Selection
57///
58/// During training, the best performing model is automatically saved based on evaluation rewards:
59///
60/// * At each evaluation interval (`eval_interval`), the agent's performance is evaluated
61/// * The evaluation reward is obtained from `Record::get_scalar_without_key()`
62/// * If the current evaluation reward exceeds the previous maximum reward:
63/// * The model is saved as the "best" model
64/// * The maximum reward is updated
65/// * This ensures that the saved "best" model represents the agent's peak performance
66///
67/// # Configuration
68///
69/// Training behavior is controlled by various intervals and parameters:
70///
71/// * `opt_interval`: Steps between optimization updates
72/// * `eval_interval`: Steps between performance evaluations
73/// * `save_interval`: Steps between model checkpoints
74/// * `warmup_period`: Initial steps before optimization begins
75/// * `max_opts`: Maximum number of optimization steps
76pub struct Trainer {
77 /// Interval between optimization steps in environment steps.
78 /// Ignored for offline training.
79 opt_interval: usize,
80
81 /// Interval for recording computational cost in optimization steps.
82 record_compute_cost_interval: usize,
83
84 /// Interval for recording agent information in optimization steps.
85 record_agent_info_interval: usize,
86
87 /// Interval for flushing records in optimization steps.
88 flush_records_interval: usize,
89
90 /// Interval for evaluation in optimization steps.
91 eval_interval: usize,
92
93 /// Interval for saving the model in optimization steps.
94 save_interval: usize,
95
96 /// Maximum number of optimization steps.
97 max_opts: usize,
98
99 /// Warmup period for filling replay buffer in environment steps.
100 /// Ignored for offline training.
101 warmup_period: usize,
102
103 /// Counter for replay buffer samples.
104 samples_counter: usize,
105
106 /// Timer for replay buffer samples.
107 timer_for_samples: Duration,
108
109 /// Counter for optimization steps.
110 opt_steps_counter: usize,
111
112 /// Timer for optimization steps.
113 timer_for_opt_steps: Duration,
114
115 /// Maximum evaluation reward achieved.
116 max_eval_reward: f32,
117
118 /// Current environment step count.
119 env_steps: usize,
120
121 /// Current optimization step count.
122 opt_steps: usize,
123}
124
125impl Trainer {
126 /// Creates a new trainer with the specified configuration.
127 ///
128 /// # Arguments
129 ///
130 /// * `config` - Configuration parameters for the trainer
131 ///
132 /// # Returns
133 ///
134 /// A new `Trainer` instance with the specified configuration
135 pub fn build(config: TrainerConfig) -> Self {
136 Self {
137 opt_interval: config.opt_interval,
138 record_compute_cost_interval: config.record_compute_cost_interval,
139 record_agent_info_interval: config.record_agent_info_interval,
140 flush_records_interval: config.flush_record_interval,
141 eval_interval: config.eval_interval,
142 save_interval: config.save_interval,
143 max_opts: config.max_opts,
144 warmup_period: config.warmup_period,
145 samples_counter: 0,
146 timer_for_samples: Duration::new(0, 0),
147 opt_steps_counter: 0,
148 timer_for_opt_steps: Duration::new(0, 0),
149 max_eval_reward: f32::MIN,
150 env_steps: 0,
151 opt_steps: 0,
152 }
153 }
154
155 /// Resets the counters.
156 fn reset_counters(&mut self) {
157 self.samples_counter = 0;
158 self.timer_for_samples = Duration::new(0, 0);
159 self.opt_steps_counter = 0;
160 self.timer_for_opt_steps = Duration::new(0, 0);
161 }
162
163 /// Calculates average time for optimization steps and samples in milliseconds.
164 fn average_time(&mut self) -> (f32, f32) {
165 let avr_opt_time = match self.opt_steps_counter {
166 0 => -1f32,
167 n => self.timer_for_opt_steps.as_millis() as f32 / n as f32,
168 };
169 let avr_sample_time = match self.samples_counter {
170 0 => -1f32,
171 n => self.timer_for_samples.as_millis() as f32 / n as f32,
172 };
173 (avr_opt_time, avr_sample_time)
174 }
175
176 /// Performs a single training step.
177 ///
178 /// This method:
179 /// 1. Performs an environment step
180 /// 2. Collects and stores the experience
181 /// 3. Optionally performs an optimization step
182 ///
183 /// # Arguments
184 ///
185 /// * `agent` - The agent being trained
186 /// * `buffer` - The replay buffer storing experiences
187 ///
188 /// # Returns
189 ///
190 /// A tuple containing:
191 /// * A record of the training step
192 /// * A boolean indicating if an optimization step was performed
193 ///
194 /// # Errors
195 ///
196 /// Returns an error if the optimization step fails
197 pub fn train_step<E, R>(
198 &mut self,
199 agent: &mut Box<dyn Agent<E, R>>,
200 buffer: &mut R,
201 ) -> Result<(Record, bool)>
202 where
203 E: Env,
204 R: ReplayBufferBase,
205 {
206 if self.env_steps < self.warmup_period {
207 Ok((Record::empty(), false))
208 } else if self.env_steps % self.opt_interval != 0 {
209 // skip optimization step
210 Ok((Record::empty(), false))
211 } else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 {
212 // Do optimization step with record
213 let timer = SystemTime::now();
214 let record_agent = agent.opt_with_record(buffer);
215 self.opt_steps += 1;
216 self.timer_for_opt_steps += timer.elapsed()?;
217 self.opt_steps_counter += 1;
218 Ok((record_agent, true))
219 } else {
220 // Do optimization step without record
221 let timer = SystemTime::now();
222 agent.opt(buffer);
223 self.opt_steps += 1;
224 self.timer_for_opt_steps += timer.elapsed()?;
225 self.opt_steps_counter += 1;
226 Ok((Record::empty(), true))
227 }
228 }
229
230 /// Evaluates the agent and saves the best model.
231 fn post_process<E, R, D>(
232 &mut self,
233 agent: &mut Box<dyn Agent<E, R>>,
234 evaluator: &mut D,
235 recorder: &mut Box<dyn Recorder<E, R>>,
236 record: &mut Record,
237 ) -> Result<()>
238 where
239 E: Env,
240 R: ReplayBufferBase,
241 D: Evaluator<E>,
242 {
243 // Evaluation
244 if self.opt_steps % self.eval_interval == 0 {
245 info!("Starts evaluation of the trained model");
246 agent.eval();
247 let (score, record_eval) = evaluator.evaluate(agent)?;
248 agent.train();
249 record.merge_inplace(record_eval);
250
251 // Save the best model up to the current iteration
252 if score > self.max_eval_reward {
253 self.max_eval_reward = score;
254 recorder.save_model("best".as_ref(), agent)?;
255 }
256 };
257
258 // Save the current model
259 if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) {
260 recorder.save_model(format!("{}", self.opt_steps).as_ref(), agent)?;
261 }
262
263 Ok(())
264 }
265
266 /// Train the agent online.
267 pub fn train<E, P, R, D>(
268 &mut self,
269 env: E,
270 step_proc: P,
271 agent: &mut Box<dyn Agent<E, R>>,
272 buffer: &mut R,
273 recorder: &mut Box<dyn Recorder<E, R>>,
274 evaluator: &mut D,
275 ) -> Result<()>
276 where
277 E: Env,
278 P: StepProcessor<E>,
279 R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
280 D: Evaluator<E>,
281 {
282 let mut sampler = Sampler::new(env, step_proc);
283 agent.train();
284
285 loop {
286 // Taking samples from the environment and pushing them to the replay buffer
287 let now = SystemTime::now();
288 let record = sampler.sample_and_push(agent, buffer)?;
289 self.timer_for_samples += now.elapsed()?;
290 self.samples_counter += 1;
291 self.env_steps += 1;
292
293 // Performe optimization step(s)
294 let (mut record, is_opt) = {
295 let (r, is_opt) = self.train_step(agent, buffer)?;
296 (record.merge(r), is_opt)
297 };
298
299 // Postprocessing after each training step
300 if is_opt {
301 self.post_process(agent, evaluator, recorder, &mut record)?;
302 }
303
304 // Record average time for optimization steps and sampling steps in milliseconds
305 if self.opt_steps % self.record_compute_cost_interval == 0 {
306 let (avr_opt_time, avr_sample_time) = self.average_time();
307 record.insert("average_opt_time", Scalar(avr_opt_time));
308 record.insert("average_sample_time", Scalar(avr_sample_time));
309 self.reset_counters();
310 }
311
312 // Store record to the recorder
313 if !record.is_empty() {
314 recorder.store(record);
315 }
316
317 // Flush records
318 if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
319 recorder.flush(self.opt_steps as _);
320 }
321
322 // Finish training
323 if self.opt_steps == self.max_opts {
324 return Ok(());
325 }
326 }
327 }
328
329 /// Train the agent offline.
330 pub fn train_offline<E, R, D>(
331 &mut self,
332 agent: &mut Box<dyn Agent<E, R>>,
333 buffer: &mut R,
334 recorder: &mut Box<dyn Recorder<E, R>>,
335 evaluator: &mut D,
336 ) -> Result<()>
337 where
338 E: Env,
339 R: ReplayBufferBase,
340 D: Evaluator<E>,
341 {
342 // Return empty record
343 self.warmup_period = 0;
344 self.opt_interval = 1;
345 agent.train();
346
347 loop {
348 let record = Record::empty();
349 self.env_steps += 1;
350
351 // Performe optimization step(s)
352 let (mut record, is_opt) = {
353 let (r, is_opt) = self.train_step(agent, buffer)?;
354 (record.merge(r), is_opt)
355 };
356
357 // Postprocessing after each training step
358 if is_opt {
359 self.post_process(agent, evaluator, recorder, &mut record)?;
360 }
361
362 // Record average time for optimization steps and sampling steps in milliseconds
363 if self.opt_steps % self.record_compute_cost_interval == 0 {
364 let (avr_opt_time, _) = self.average_time();
365 record.insert("average_opt_time", Scalar(avr_opt_time));
366 self.reset_counters();
367 }
368
369 // Store record to the recorder
370 if !record.is_empty() {
371 recorder.store(record);
372 }
373
374 // Flush records
375 if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
376 recorder.flush(self.opt_steps as _);
377 }
378
379 // Finish training
380 if self.opt_steps == self.max_opts {
381 return Ok(());
382 }
383 }
384 }
385}