border_core/trainer/
config.rs

1//! Configuration for the training process.
2//!
3//! This module provides configuration options for controlling the training process
4//! of reinforcement learning agents. It allows fine-tuning of various aspects
5//! of the training loop, including optimization intervals, evaluation frequency,
6//! and model saving.
7//!
8//! # Configuration Options
9//!
10//! The configuration allows control over:
11//!
12//! * Training duration and optimization steps
13//! * Evaluation frequency and model selection
14//! * Performance monitoring and metrics recording
15//! * Model checkpointing and warmup periods
16use anyhow::Result;
17use serde::{Deserialize, Serialize};
18use std::{
19    fs::File,
20    io::{BufReader, Write},
21    path::Path,
22};
23
24/// Configuration parameters for the training process.
25///
26/// This struct defines various intervals and thresholds that control the
27/// behavior of the training loop. Each parameter can be set using the
28/// builder pattern methods.
29#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
30pub struct TrainerConfig {
31    /// Maximum number of optimization steps to perform.
32    /// Training stops when this number is reached.
33    pub max_opts: usize,
34
35    /// Number of environment steps between optimization updates.
36    /// For example, if set to 1, optimization occurs after every environment step.
37    pub opt_interval: usize,
38
39    /// Number of optimization steps between performance evaluations.
40    /// During evaluation, the agent's performance is measured and the best model is saved.
41    pub eval_interval: usize,
42
43    /// Number of optimization steps between flushing recorded metrics to storage.
44    /// This controls how frequently training metrics are persisted.
45    pub flush_record_interval: usize,
46
47    /// Number of optimization steps between recording computational performance metrics.
48    /// This includes metrics like optimization steps per second.
49    pub record_compute_cost_interval: usize,
50
51    /// Number of optimization steps between recording agent-specific information.
52    /// This can include internal agent metrics or state information.
53    pub record_agent_info_interval: usize,
54
55    /// Initial number of environment steps before optimization begins.
56    /// During this period, the replay buffer is filled with initial experiences.
57    pub warmup_period: usize,
58
59    /// Number of optimization steps between saving model checkpoints.
60    /// These checkpoints can be used for resuming training or analysis.
61    pub save_interval: usize,
62}
63
64impl Default for TrainerConfig {
65    /// Creates a default configuration with conservative values.
66    ///
67    /// Default values are set to:
68    /// * `max_opts`: 0
69    /// * `opt_interval`: 1 (optimize every step)
70    /// * `eval_interval`: 0 (no evaluation)
71    /// * `flush_record_interval`: usize::MAX (never flush)
72    /// * `record_compute_cost_interval`: usize::MAX (never record)
73    /// * `record_agent_info_interval`: usize::MAX (never record)
74    /// * `warmup_period`: 0 (no warmup)
75    /// * `save_interval`: usize::MAX (never save)
76    fn default() -> Self {
77        Self {
78            max_opts: 0,
79            eval_interval: 0,
80            opt_interval: 1,
81            flush_record_interval: usize::MAX,
82            record_compute_cost_interval: usize::MAX,
83            record_agent_info_interval: usize::MAX,
84            warmup_period: 0,
85            save_interval: usize::MAX,
86        }
87    }
88}
89
90impl TrainerConfig {
91    /// Sets the maximum number of optimization steps.
92    ///
93    /// # Arguments
94    ///
95    /// * `v` - Maximum number of optimization steps
96    ///
97    /// # Returns
98    ///
99    /// Self with the updated configuration
100    pub fn max_opts(mut self, v: usize) -> Self {
101        self.max_opts = v;
102        self
103    }
104
105    /// Sets the interval between performance evaluations.
106    ///
107    /// # Arguments
108    ///
109    /// * `v` - Number of optimization steps between evaluations
110    ///
111    /// # Returns
112    ///
113    /// Self with the updated configuration
114    pub fn eval_interval(mut self, v: usize) -> Self {
115        self.eval_interval = v;
116        self
117    }
118
119    /// (Deprecated) Sets the evaluation threshold.
120    ///
121    /// This method is currently unimplemented and may be removed in future versions.
122    pub fn eval_threshold(/*mut */ self, _v: f32) -> Self {
123        unimplemented!();
124        // self.eval_threshold = Some(v);
125        // self
126    }
127
128    /// Sets the interval between optimization updates.
129    ///
130    /// # Arguments
131    ///
132    /// * `opt_interval` - Number of environment steps between optimizations
133    ///
134    /// # Returns
135    ///
136    /// Self with the updated configuration
137    pub fn opt_interval(mut self, opt_interval: usize) -> Self {
138        self.opt_interval = opt_interval;
139        self
140    }
141
142    /// Sets the interval for flushing recorded metrics.
143    ///
144    /// # Arguments
145    ///
146    /// * `flush_record_interval` - Number of optimization steps between flushes
147    ///
148    /// # Returns
149    ///
150    /// Self with the updated configuration
151    pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Self {
152        self.flush_record_interval = flush_record_interval;
153        self
154    }
155
156    /// Sets the interval for recording computational performance metrics.
157    ///
158    /// # Arguments
159    ///
160    /// * `record_compute_cost_interval` - Number of optimization steps between recordings
161    ///
162    /// # Returns
163    ///
164    /// Self with the updated configuration
165    pub fn record_compute_cost_interval(mut self, record_compute_cost_interval: usize) -> Self {
166        self.record_compute_cost_interval = record_compute_cost_interval;
167        self
168    }
169
170    /// Sets the interval for recording agent-specific information.
171    ///
172    /// # Arguments
173    ///
174    /// * `record_agent_info_interval` - Number of optimization steps between recordings
175    ///
176    /// # Returns
177    ///
178    /// Self with the updated configuration
179    pub fn record_agent_info_interval(mut self, record_agent_info_interval: usize) -> Self {
180        self.record_agent_info_interval = record_agent_info_interval;
181        self
182    }
183
184    /// Sets the initial warmup period before optimization begins.
185    ///
186    /// # Arguments
187    ///
188    /// * `warmup_period` - Number of environment steps in the warmup period
189    ///
190    /// # Returns
191    ///
192    /// Self with the updated configuration
193    pub fn warmup_period(mut self, warmup_period: usize) -> Self {
194        self.warmup_period = warmup_period;
195        self
196    }
197
198    /// Sets the interval for saving model checkpoints.
199    ///
200    /// # Arguments
201    ///
202    /// * `save_interval` - Number of optimization steps between checkpoints
203    ///
204    /// # Returns
205    ///
206    /// Self with the updated configuration
207    pub fn save_interval(mut self, save_interval: usize) -> Self {
208        self.save_interval = save_interval;
209        self
210    }
211
212    /// Loads configuration from a YAML file.
213    ///
214    /// # Arguments
215    ///
216    /// * `path` - Path to the configuration file
217    ///
218    /// # Returns
219    ///
220    /// Result containing the loaded configuration
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if the file cannot be read or parsed
225    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
226        let file = File::open(path)?;
227        let rdr = BufReader::new(file);
228        let b = serde_yaml::from_reader(rdr)?;
229        Ok(b)
230    }
231
232    /// Saves the configuration to a file.
233    ///
234    /// # Arguments
235    ///
236    /// * `path` - Path where the configuration will be saved
237    ///
238    /// # Returns
239    ///
240    /// Result indicating success or failure
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if the file cannot be written
245    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
246        let mut file = File::create(path)?;
247        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
248        Ok(())
249    }
250}
251
252// #[cfg(test)]
253// mod tests {
254//     use super::*;
255//     use tempdir::TempDir;
256
257//     #[test]
258//     fn test_serde_trainer_builder() -> Result<()> {
259//         let builder = TrainerBuilder::default()
260//             .max_opts(100)
261//             .eval_interval(10000)
262//             .n_episodes_per_eval(5)
263//             .model_dir("some/directory");
264
265//         let dir = TempDir::new("trainer_builder")?;
266//         let path = dir.path().join("trainer_builder.yaml");
267//         println!("{:?}", path);
268
269//         builder.save(&path)?;
270//         let builder_ = TrainerBuilder::load(&path)?;
271//         assert_eq!(builder, builder_);
272//         // let yaml = serde_yaml::to_string(&trainer)?;
273//         // println!("{}", yaml);
274//         // assert_eq!(
275//         //     yaml,
276//         //     "---\n\
277//         //      max_opts: 100\n\
278//         //      eval_interval: 10000\n\
279//         //      n_episodes_per_eval: 5\n\
280//         //      eval_threshold: ~\n\
281//         //      model_dir: some/directory\n\
282//         // "
283//         // );
284//         Ok(())
285//     }
286// }