border_core/trainer/config.rs
1//! Configuration for the training process.
2//!
3//! This module provides configuration options for controlling the training process
4//! of reinforcement learning agents. It allows fine-tuning of various aspects
5//! of the training loop, including optimization intervals, evaluation frequency,
6//! and model saving.
7//!
8//! # Configuration Options
9//!
10//! The configuration allows control over:
11//!
12//! * Training duration and optimization steps
13//! * Evaluation frequency and model selection
14//! * Performance monitoring and metrics recording
15//! * Model checkpointing and warmup periods
16use anyhow::Result;
17use serde::{Deserialize, Serialize};
18use std::{
19 fs::File,
20 io::{BufReader, Write},
21 path::Path,
22};
23
24/// Configuration parameters for the training process.
25///
26/// This struct defines various intervals and thresholds that control the
27/// behavior of the training loop. Each parameter can be set using the
28/// builder pattern methods.
29#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
30pub struct TrainerConfig {
31 /// Maximum number of optimization steps to perform.
32 /// Training stops when this number is reached.
33 pub max_opts: usize,
34
35 /// Number of environment steps between optimization updates.
36 /// For example, if set to 1, optimization occurs after every environment step.
37 pub opt_interval: usize,
38
39 /// Number of optimization steps between performance evaluations.
40 /// During evaluation, the agent's performance is measured and the best model is saved.
41 pub eval_interval: usize,
42
43 /// Number of optimization steps between flushing recorded metrics to storage.
44 /// This controls how frequently training metrics are persisted.
45 pub flush_record_interval: usize,
46
47 /// Number of optimization steps between recording computational performance metrics.
48 /// This includes metrics like optimization steps per second.
49 pub record_compute_cost_interval: usize,
50
51 /// Number of optimization steps between recording agent-specific information.
52 /// This can include internal agent metrics or state information.
53 pub record_agent_info_interval: usize,
54
55 /// Initial number of environment steps before optimization begins.
56 /// During this period, the replay buffer is filled with initial experiences.
57 pub warmup_period: usize,
58
59 /// Number of optimization steps between saving model checkpoints.
60 /// These checkpoints can be used for resuming training or analysis.
61 pub save_interval: usize,
62}
63
64impl Default for TrainerConfig {
65 /// Creates a default configuration with conservative values.
66 ///
67 /// Default values are set to:
68 /// * `max_opts`: 0
69 /// * `opt_interval`: 1 (optimize every step)
70 /// * `eval_interval`: 0 (no evaluation)
71 /// * `flush_record_interval`: usize::MAX (never flush)
72 /// * `record_compute_cost_interval`: usize::MAX (never record)
73 /// * `record_agent_info_interval`: usize::MAX (never record)
74 /// * `warmup_period`: 0 (no warmup)
75 /// * `save_interval`: usize::MAX (never save)
76 fn default() -> Self {
77 Self {
78 max_opts: 0,
79 eval_interval: 0,
80 opt_interval: 1,
81 flush_record_interval: usize::MAX,
82 record_compute_cost_interval: usize::MAX,
83 record_agent_info_interval: usize::MAX,
84 warmup_period: 0,
85 save_interval: usize::MAX,
86 }
87 }
88}
89
90impl TrainerConfig {
91 /// Sets the maximum number of optimization steps.
92 ///
93 /// # Arguments
94 ///
95 /// * `v` - Maximum number of optimization steps
96 ///
97 /// # Returns
98 ///
99 /// Self with the updated configuration
100 pub fn max_opts(mut self, v: usize) -> Self {
101 self.max_opts = v;
102 self
103 }
104
105 /// Sets the interval between performance evaluations.
106 ///
107 /// # Arguments
108 ///
109 /// * `v` - Number of optimization steps between evaluations
110 ///
111 /// # Returns
112 ///
113 /// Self with the updated configuration
114 pub fn eval_interval(mut self, v: usize) -> Self {
115 self.eval_interval = v;
116 self
117 }
118
119 /// (Deprecated) Sets the evaluation threshold.
120 ///
121 /// This method is currently unimplemented and may be removed in future versions.
122 pub fn eval_threshold(/*mut */ self, _v: f32) -> Self {
123 unimplemented!();
124 // self.eval_threshold = Some(v);
125 // self
126 }
127
128 /// Sets the interval between optimization updates.
129 ///
130 /// # Arguments
131 ///
132 /// * `opt_interval` - Number of environment steps between optimizations
133 ///
134 /// # Returns
135 ///
136 /// Self with the updated configuration
137 pub fn opt_interval(mut self, opt_interval: usize) -> Self {
138 self.opt_interval = opt_interval;
139 self
140 }
141
142 /// Sets the interval for flushing recorded metrics.
143 ///
144 /// # Arguments
145 ///
146 /// * `flush_record_interval` - Number of optimization steps between flushes
147 ///
148 /// # Returns
149 ///
150 /// Self with the updated configuration
151 pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Self {
152 self.flush_record_interval = flush_record_interval;
153 self
154 }
155
156 /// Sets the interval for recording computational performance metrics.
157 ///
158 /// # Arguments
159 ///
160 /// * `record_compute_cost_interval` - Number of optimization steps between recordings
161 ///
162 /// # Returns
163 ///
164 /// Self with the updated configuration
165 pub fn record_compute_cost_interval(mut self, record_compute_cost_interval: usize) -> Self {
166 self.record_compute_cost_interval = record_compute_cost_interval;
167 self
168 }
169
170 /// Sets the interval for recording agent-specific information.
171 ///
172 /// # Arguments
173 ///
174 /// * `record_agent_info_interval` - Number of optimization steps between recordings
175 ///
176 /// # Returns
177 ///
178 /// Self with the updated configuration
179 pub fn record_agent_info_interval(mut self, record_agent_info_interval: usize) -> Self {
180 self.record_agent_info_interval = record_agent_info_interval;
181 self
182 }
183
184 /// Sets the initial warmup period before optimization begins.
185 ///
186 /// # Arguments
187 ///
188 /// * `warmup_period` - Number of environment steps in the warmup period
189 ///
190 /// # Returns
191 ///
192 /// Self with the updated configuration
193 pub fn warmup_period(mut self, warmup_period: usize) -> Self {
194 self.warmup_period = warmup_period;
195 self
196 }
197
198 /// Sets the interval for saving model checkpoints.
199 ///
200 /// # Arguments
201 ///
202 /// * `save_interval` - Number of optimization steps between checkpoints
203 ///
204 /// # Returns
205 ///
206 /// Self with the updated configuration
207 pub fn save_interval(mut self, save_interval: usize) -> Self {
208 self.save_interval = save_interval;
209 self
210 }
211
212 /// Loads configuration from a YAML file.
213 ///
214 /// # Arguments
215 ///
216 /// * `path` - Path to the configuration file
217 ///
218 /// # Returns
219 ///
220 /// Result containing the loaded configuration
221 ///
222 /// # Errors
223 ///
224 /// Returns an error if the file cannot be read or parsed
225 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
226 let file = File::open(path)?;
227 let rdr = BufReader::new(file);
228 let b = serde_yaml::from_reader(rdr)?;
229 Ok(b)
230 }
231
232 /// Saves the configuration to a file.
233 ///
234 /// # Arguments
235 ///
236 /// * `path` - Path where the configuration will be saved
237 ///
238 /// # Returns
239 ///
240 /// Result indicating success or failure
241 ///
242 /// # Errors
243 ///
244 /// Returns an error if the file cannot be written
245 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
246 let mut file = File::create(path)?;
247 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
248 Ok(())
249 }
250}
251
252// #[cfg(test)]
253// mod tests {
254// use super::*;
255// use tempdir::TempDir;
256
257// #[test]
258// fn test_serde_trainer_builder() -> Result<()> {
259// let builder = TrainerBuilder::default()
260// .max_opts(100)
261// .eval_interval(10000)
262// .n_episodes_per_eval(5)
263// .model_dir("some/directory");
264
265// let dir = TempDir::new("trainer_builder")?;
266// let path = dir.path().join("trainer_builder.yaml");
267// println!("{:?}", path);
268
269// builder.save(&path)?;
270// let builder_ = TrainerBuilder::load(&path)?;
271// assert_eq!(builder, builder_);
272// // let yaml = serde_yaml::to_string(&trainer)?;
273// // println!("{}", yaml);
274// // assert_eq!(
275// // yaml,
276// // "---\n\
277// // max_opts: 100\n\
278// // eval_interval: 10000\n\
279// // n_episodes_per_eval: 5\n\
280// // eval_threshold: ~\n\
281// // model_dir: some/directory\n\
282// // "
283// // );
284// Ok(())
285// }
286// }