border_core/
trainer.rs

1//! Training loop management and agent optimization.
2//!
3//! This module provides functionality for managing the training process of reinforcement
4//! learning agents. It handles environment interactions, experience collection,
5//! optimization steps, and evaluation.
6
7mod config;
8mod sampler;
9use std::time::{Duration, SystemTime};
10
11use crate::{
12    record::{Record, RecordValue::Scalar, Recorder},
13    Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor,
14};
15use anyhow::Result;
16pub use config::TrainerConfig;
17use log::info;
18pub use sampler::Sampler;
19
20/// Manages the training loop and coordinates interactions between components.
21///
22/// The `Trainer` orchestrates the training process by managing:
23///
24/// * Environment interactions and experience collection
25/// * Agent optimization and parameter updates
26/// * Performance evaluation and model saving
27/// * Training metrics recording
28///
29/// # Training Process
30///
31/// The training loop follows these steps:
32///
33/// 1. Initialize training components:
34///    * Reset environment step counter (`env_steps = 0`)
35///    * Reset optimization step counter (`opt_steps = 0`)
36///    * Initialize performance monitoring
37///
38/// 2. Environment Interaction:
39///    * Agent observes environment state
40///    * Agent selects and executes action
41///    * Environment transitions to new state
42///    * Experience is collected and stored
43///
44/// 3. Optimization:
45///    * At specified intervals (`opt_interval`):
46///      * Sample experiences from replay buffer
47///      * Update agent parameters
48///      * Track optimization performance
49///
50/// 4. Evaluation and Recording:
51///    * Periodically evaluate agent performance
52///    * Record training metrics
53///    * Save model checkpoints
54///    * Monitor optimization speed
55///
56/// # Model Selection
57///
58/// During training, the best performing model is automatically saved based on evaluation rewards:
59///
60/// * At each evaluation interval (`eval_interval`), the agent's performance is evaluated
61/// * The evaluation reward is obtained from `Record::get_scalar_without_key()`
62/// * If the current evaluation reward exceeds the previous maximum reward:
63///   * The model is saved as the "best" model
64///   * The maximum reward is updated
65/// * This ensures that the saved "best" model represents the agent's peak performance
66///
67/// # Configuration
68///
69/// Training behavior is controlled by various intervals and parameters:
70///
71/// * `opt_interval`: Steps between optimization updates
72/// * `eval_interval`: Steps between performance evaluations
73/// * `save_interval`: Steps between model checkpoints
74/// * `warmup_period`: Initial steps before optimization begins
75/// * `max_opts`: Maximum number of optimization steps
76pub struct Trainer {
77    /// Interval between optimization steps in environment steps.
78    /// Ignored for offline training.
79    opt_interval: usize,
80
81    /// Interval for recording computational cost in optimization steps.
82    record_compute_cost_interval: usize,
83
84    /// Interval for recording agent information in optimization steps.
85    record_agent_info_interval: usize,
86
87    /// Interval for flushing records in optimization steps.
88    flush_records_interval: usize,
89
90    /// Interval for evaluation in optimization steps.
91    eval_interval: usize,
92
93    /// Interval for saving the model in optimization steps.
94    save_interval: usize,
95
96    /// Maximum number of optimization steps.
97    max_opts: usize,
98
99    /// Warmup period for filling replay buffer in environment steps.
100    /// Ignored for offline training.
101    warmup_period: usize,
102
103    /// Counter for replay buffer samples.
104    samples_counter: usize,
105
106    /// Timer for replay buffer samples.
107    timer_for_samples: Duration,
108
109    /// Counter for optimization steps.
110    opt_steps_counter: usize,
111
112    /// Timer for optimization steps.
113    timer_for_opt_steps: Duration,
114
115    /// Maximum evaluation reward achieved.
116    max_eval_reward: f32,
117
118    /// Current environment step count.
119    env_steps: usize,
120
121    /// Current optimization step count.
122    opt_steps: usize,
123}
124
125impl Trainer {
126    /// Creates a new trainer with the specified configuration.
127    ///
128    /// # Arguments
129    ///
130    /// * `config` - Configuration parameters for the trainer
131    ///
132    /// # Returns
133    ///
134    /// A new `Trainer` instance with the specified configuration
135    pub fn build(config: TrainerConfig) -> Self {
136        Self {
137            opt_interval: config.opt_interval,
138            record_compute_cost_interval: config.record_compute_cost_interval,
139            record_agent_info_interval: config.record_agent_info_interval,
140            flush_records_interval: config.flush_record_interval,
141            eval_interval: config.eval_interval,
142            save_interval: config.save_interval,
143            max_opts: config.max_opts,
144            warmup_period: config.warmup_period,
145            samples_counter: 0,
146            timer_for_samples: Duration::new(0, 0),
147            opt_steps_counter: 0,
148            timer_for_opt_steps: Duration::new(0, 0),
149            max_eval_reward: f32::MIN,
150            env_steps: 0,
151            opt_steps: 0,
152        }
153    }
154
155    /// Resets the counters.
156    fn reset_counters(&mut self) {
157        self.samples_counter = 0;
158        self.timer_for_samples = Duration::new(0, 0);
159        self.opt_steps_counter = 0;
160        self.timer_for_opt_steps = Duration::new(0, 0);
161    }
162
163    /// Calculates average time for optimization steps and samples in milliseconds.
164    fn average_time(&mut self) -> (f32, f32) {
165        let avr_opt_time = match self.opt_steps_counter {
166            0 => -1f32,
167            n => self.timer_for_opt_steps.as_millis() as f32 / n as f32,
168        };
169        let avr_sample_time = match self.samples_counter {
170            0 => -1f32,
171            n => self.timer_for_samples.as_millis() as f32 / n as f32,
172        };
173        (avr_opt_time, avr_sample_time)
174    }
175
176    /// Performs a single training step.
177    ///
178    /// This method:
179    /// 1. Performs an environment step
180    /// 2. Collects and stores the experience
181    /// 3. Optionally performs an optimization step
182    ///
183    /// # Arguments
184    ///
185    /// * `agent` - The agent being trained
186    /// * `buffer` - The replay buffer storing experiences
187    ///
188    /// # Returns
189    ///
190    /// A tuple containing:
191    /// * A record of the training step
192    /// * A boolean indicating if an optimization step was performed
193    ///
194    /// # Errors
195    ///
196    /// Returns an error if the optimization step fails
197    pub fn train_step<E, R>(
198        &mut self,
199        agent: &mut Box<dyn Agent<E, R>>,
200        buffer: &mut R,
201    ) -> Result<(Record, bool)>
202    where
203        E: Env,
204        R: ReplayBufferBase,
205    {
206        if self.env_steps < self.warmup_period {
207            Ok((Record::empty(), false))
208        } else if self.env_steps % self.opt_interval != 0 {
209            // skip optimization step
210            Ok((Record::empty(), false))
211        } else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 {
212            // Do optimization step with record
213            let timer = SystemTime::now();
214            let record_agent = agent.opt_with_record(buffer);
215            self.opt_steps += 1;
216            self.timer_for_opt_steps += timer.elapsed()?;
217            self.opt_steps_counter += 1;
218            Ok((record_agent, true))
219        } else {
220            // Do optimization step without record
221            let timer = SystemTime::now();
222            agent.opt(buffer);
223            self.opt_steps += 1;
224            self.timer_for_opt_steps += timer.elapsed()?;
225            self.opt_steps_counter += 1;
226            Ok((Record::empty(), true))
227        }
228    }
229
230    /// Evaluates the agent and saves the best model.
231    fn post_process<E, R, D>(
232        &mut self,
233        agent: &mut Box<dyn Agent<E, R>>,
234        evaluator: &mut D,
235        recorder: &mut Box<dyn Recorder<E, R>>,
236        record: &mut Record,
237    ) -> Result<()>
238    where
239        E: Env,
240        R: ReplayBufferBase,
241        D: Evaluator<E>,
242    {
243        // Evaluation
244        if self.opt_steps % self.eval_interval == 0 {
245            info!("Starts evaluation of the trained model");
246            agent.eval();
247            let (score, record_eval) = evaluator.evaluate(agent)?;
248            agent.train();
249            record.merge_inplace(record_eval);
250
251            // Save the best model up to the current iteration
252            if score > self.max_eval_reward {
253                self.max_eval_reward = score;
254                recorder.save_model("best".as_ref(), agent)?;
255            }
256        };
257
258        // Save the current model
259        if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) {
260            recorder.save_model(format!("{}", self.opt_steps).as_ref(), agent)?;
261        }
262
263        Ok(())
264    }
265
266    /// Train the agent online.
267    pub fn train<E, P, R, D>(
268        &mut self,
269        env: E,
270        step_proc: P,
271        agent: &mut Box<dyn Agent<E, R>>,
272        buffer: &mut R,
273        recorder: &mut Box<dyn Recorder<E, R>>,
274        evaluator: &mut D,
275    ) -> Result<()>
276    where
277        E: Env,
278        P: StepProcessor<E>,
279        R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
280        D: Evaluator<E>,
281    {
282        let mut sampler = Sampler::new(env, step_proc);
283        agent.train();
284
285        loop {
286            // Taking samples from the environment and pushing them to the replay buffer
287            let now = SystemTime::now();
288            let record = sampler.sample_and_push(agent, buffer)?;
289            self.timer_for_samples += now.elapsed()?;
290            self.samples_counter += 1;
291            self.env_steps += 1;
292
293            // Performe optimization step(s)
294            let (mut record, is_opt) = {
295                let (r, is_opt) = self.train_step(agent, buffer)?;
296                (record.merge(r), is_opt)
297            };
298
299            // Postprocessing after each training step
300            if is_opt {
301                self.post_process(agent, evaluator, recorder, &mut record)?;
302            }
303
304            // Record average time for optimization steps and sampling steps in milliseconds
305            if self.opt_steps % self.record_compute_cost_interval == 0 {
306                let (avr_opt_time, avr_sample_time) = self.average_time();
307                record.insert("average_opt_time", Scalar(avr_opt_time));
308                record.insert("average_sample_time", Scalar(avr_sample_time));
309                self.reset_counters();
310            }
311
312            // Store record to the recorder
313            if !record.is_empty() {
314                recorder.store(record);
315            }
316
317            // Flush records
318            if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
319                recorder.flush(self.opt_steps as _);
320            }
321
322            // Finish training
323            if self.opt_steps == self.max_opts {
324                return Ok(());
325            }
326        }
327    }
328
329    /// Train the agent offline.
330    pub fn train_offline<E, R, D>(
331        &mut self,
332        agent: &mut Box<dyn Agent<E, R>>,
333        buffer: &mut R,
334        recorder: &mut Box<dyn Recorder<E, R>>,
335        evaluator: &mut D,
336    ) -> Result<()>
337    where
338        E: Env,
339        R: ReplayBufferBase,
340        D: Evaluator<E>,
341    {
342        // Return empty record
343        self.warmup_period = 0;
344        self.opt_interval = 1;
345        agent.train();
346
347        loop {
348            let record = Record::empty();
349            self.env_steps += 1;
350
351            // Performe optimization step(s)
352            let (mut record, is_opt) = {
353                let (r, is_opt) = self.train_step(agent, buffer)?;
354                (record.merge(r), is_opt)
355            };
356
357            // Postprocessing after each training step
358            if is_opt {
359                self.post_process(agent, evaluator, recorder, &mut record)?;
360            }
361
362            // Record average time for optimization steps and sampling steps in milliseconds
363            if self.opt_steps % self.record_compute_cost_interval == 0 {
364                let (avr_opt_time, _) = self.average_time();
365                record.insert("average_opt_time", Scalar(avr_opt_time));
366                self.reset_counters();
367            }
368
369            // Store record to the recorder
370            if !record.is_empty() {
371                recorder.store(record);
372            }
373
374            // Flush records
375            if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
376                recorder.flush(self.opt_steps as _);
377            }
378
379            // Finish training
380            if self.opt_steps == self.max_opts {
381                return Ok(());
382            }
383        }
384    }
385}