border_core/trainer/
config.rs

1//! Configuration of [`Trainer`](super::Trainer).
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::{
5    fs::File,
6    io::{BufReader, Write},
7    path::Path,
8};
9
10/// Configuration of [`Trainer`](super::Trainer).
11#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
12pub struct TrainerConfig {
13    /// The maximum number of optimization steps.
14    pub max_opts: usize,
15
16    /// Directory where model parameters will be saved.
17    pub model_dir: Option<String>,
18
19    /// Interval of optimization steps in environment steps.
20    pub opt_interval: usize,
21
22    /// Interval of evaluation in optimization steps.
23    pub eval_interval: usize,
24
25    /// Interval of flushing records in optimization steps.
26    pub flush_record_interval: usize,
27
28    /// Interval of recording agent information in optimization steps.
29    pub record_compute_cost_interval: usize,
30
31    /// Interval of recording agent information in optimization steps.
32    pub record_agent_info_interval: usize,
33
34    /// Warmup period, for filling replay buffer, in environment steps
35    pub warmup_period: usize,
36
37    /// Intercal of saving model parameters in optimization steps.
38    pub save_interval: usize,
39}
40
41impl Default for TrainerConfig {
42    fn default() -> Self {
43        Self {
44            max_opts: 0,
45            eval_interval: 0,
46            // eval_threshold: None,
47            model_dir: None,
48            opt_interval: 1,
49            flush_record_interval: usize::MAX,
50            record_compute_cost_interval: usize::MAX,
51            record_agent_info_interval: usize::MAX,
52            warmup_period: 0,
53            save_interval: usize::MAX,
54        }
55    }
56}
57
58impl TrainerConfig {
59    /// Sets the number of optimization steps.
60    pub fn max_opts(mut self, v: usize) -> Self {
61        self.max_opts = v;
62        self
63    }
64
65    /// Sets the interval of evaluation in optimization steps.
66    pub fn eval_interval(mut self, v: usize) -> Self {
67        self.eval_interval = v;
68        self
69    }
70
71    /// (Deprecated) Sets the evaluation threshold.
72    pub fn eval_threshold(/*mut */ self, _v: f32) -> Self {
73        unimplemented!();
74        // self.eval_threshold = Some(v);
75        // self
76    }
77
78    /// Sets the directory the trained model being saved.
79    pub fn model_dir<T: Into<String>>(mut self, model_dir: T) -> Self {
80        self.model_dir = Some(model_dir.into());
81        self
82    }
83
84    /// Sets the interval of optimization in environment steps.
85    pub fn opt_interval(mut self, opt_interval: usize) -> Self {
86        self.opt_interval = opt_interval;
87        self
88    }
89
90    /// Sets the interval of flushing recordd in optimization steps.
91    pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Self {
92        self.flush_record_interval = flush_record_interval;
93        self
94    }
95
96    /// Sets the interval of computation cost in optimization steps.
97    pub fn record_compute_cost_interval(mut self, record_compute_cost_interval: usize) -> Self {
98        self.record_compute_cost_interval = record_compute_cost_interval;
99        self
100    }
101
102    /// Sets the interval of recording agent information in optimization steps.
103    pub fn record_agent_info_interval(mut self, record_agent_info_interval: usize) -> Self {
104        self.record_agent_info_interval = record_agent_info_interval;
105        self
106    }
107
108    /// Sets warmup period in environment steps.
109    pub fn warmup_period(mut self, warmup_period: usize) -> Self {
110        self.warmup_period = warmup_period;
111        self
112    }
113
114    /// Sets the interval of saving in optimization steps.
115    pub fn save_interval(mut self, save_interval: usize) -> Self {
116        self.save_interval = save_interval;
117        self
118    }
119
120    /// Constructs [`TrainerConfig`] from YAML file.
121    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
122        let file = File::open(path)?;
123        let rdr = BufReader::new(file);
124        let b = serde_yaml::from_reader(rdr)?;
125        Ok(b)
126    }
127
128    /// Saves [`TrainerConfig`].
129    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
130        let mut file = File::create(path)?;
131        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
132        Ok(())
133    }
134}
135
136// #[cfg(test)]
137// mod tests {
138//     use super::*;
139//     use tempdir::TempDir;
140
141//     #[test]
142//     fn test_serde_trainer_builder() -> Result<()> {
143//         let builder = TrainerBuilder::default()
144//             .max_opts(100)
145//             .eval_interval(10000)
146//             .n_episodes_per_eval(5)
147//             .model_dir("some/directory");
148
149//         let dir = TempDir::new("trainer_builder")?;
150//         let path = dir.path().join("trainer_builder.yaml");
151//         println!("{:?}", path);
152
153//         builder.save(&path)?;
154//         let builder_ = TrainerBuilder::load(&path)?;
155//         assert_eq!(builder, builder_);
156//         // let yaml = serde_yaml::to_string(&trainer)?;
157//         // println!("{}", yaml);
158//         // assert_eq!(
159//         //     yaml,
160//         //     "---\n\
161//         //      max_opts: 100\n\
162//         //      eval_interval: 10000\n\
163//         //      n_episodes_per_eval: 5\n\
164//         //      eval_threshold: ~\n\
165//         //      model_dir: some/directory\n\
166//         // "
167//         // );
168//         Ok(())
169//     }
170// }