border_core/trainer/config.rs
1//! Configuration of [`Trainer`](super::Trainer).
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::{
5 fs::File,
6 io::{BufReader, Write},
7 path::Path,
8};
9
10/// Configuration of [`Trainer`](super::Trainer).
11#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
12pub struct TrainerConfig {
13 /// The maximum number of optimization steps.
14 pub max_opts: usize,
15
16 /// Directory where model parameters will be saved.
17 pub model_dir: Option<String>,
18
19 /// Interval of optimization steps in environment steps.
20 pub opt_interval: usize,
21
22 /// Interval of evaluation in optimization steps.
23 pub eval_interval: usize,
24
25 /// Interval of flushing records in optimization steps.
26 pub flush_record_interval: usize,
27
28 /// Interval of recording agent information in optimization steps.
29 pub record_compute_cost_interval: usize,
30
31 /// Interval of recording agent information in optimization steps.
32 pub record_agent_info_interval: usize,
33
34 /// Warmup period, for filling replay buffer, in environment steps
35 pub warmup_period: usize,
36
37 /// Intercal of saving model parameters in optimization steps.
38 pub save_interval: usize,
39}
40
41impl Default for TrainerConfig {
42 fn default() -> Self {
43 Self {
44 max_opts: 0,
45 eval_interval: 0,
46 // eval_threshold: None,
47 model_dir: None,
48 opt_interval: 1,
49 flush_record_interval: usize::MAX,
50 record_compute_cost_interval: usize::MAX,
51 record_agent_info_interval: usize::MAX,
52 warmup_period: 0,
53 save_interval: usize::MAX,
54 }
55 }
56}
57
58impl TrainerConfig {
59 /// Sets the number of optimization steps.
60 pub fn max_opts(mut self, v: usize) -> Self {
61 self.max_opts = v;
62 self
63 }
64
65 /// Sets the interval of evaluation in optimization steps.
66 pub fn eval_interval(mut self, v: usize) -> Self {
67 self.eval_interval = v;
68 self
69 }
70
71 /// (Deprecated) Sets the evaluation threshold.
72 pub fn eval_threshold(/*mut */ self, _v: f32) -> Self {
73 unimplemented!();
74 // self.eval_threshold = Some(v);
75 // self
76 }
77
78 /// Sets the directory the trained model being saved.
79 pub fn model_dir<T: Into<String>>(mut self, model_dir: T) -> Self {
80 self.model_dir = Some(model_dir.into());
81 self
82 }
83
84 /// Sets the interval of optimization in environment steps.
85 pub fn opt_interval(mut self, opt_interval: usize) -> Self {
86 self.opt_interval = opt_interval;
87 self
88 }
89
90 /// Sets the interval of flushing recordd in optimization steps.
91 pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Self {
92 self.flush_record_interval = flush_record_interval;
93 self
94 }
95
96 /// Sets the interval of computation cost in optimization steps.
97 pub fn record_compute_cost_interval(mut self, record_compute_cost_interval: usize) -> Self {
98 self.record_compute_cost_interval = record_compute_cost_interval;
99 self
100 }
101
102 /// Sets the interval of recording agent information in optimization steps.
103 pub fn record_agent_info_interval(mut self, record_agent_info_interval: usize) -> Self {
104 self.record_agent_info_interval = record_agent_info_interval;
105 self
106 }
107
108 /// Sets warmup period in environment steps.
109 pub fn warmup_period(mut self, warmup_period: usize) -> Self {
110 self.warmup_period = warmup_period;
111 self
112 }
113
114 /// Sets the interval of saving in optimization steps.
115 pub fn save_interval(mut self, save_interval: usize) -> Self {
116 self.save_interval = save_interval;
117 self
118 }
119
120 /// Constructs [`TrainerConfig`] from YAML file.
121 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
122 let file = File::open(path)?;
123 let rdr = BufReader::new(file);
124 let b = serde_yaml::from_reader(rdr)?;
125 Ok(b)
126 }
127
128 /// Saves [`TrainerConfig`].
129 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
130 let mut file = File::create(path)?;
131 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
132 Ok(())
133 }
134}
135
136// #[cfg(test)]
137// mod tests {
138// use super::*;
139// use tempdir::TempDir;
140
141// #[test]
142// fn test_serde_trainer_builder() -> Result<()> {
143// let builder = TrainerBuilder::default()
144// .max_opts(100)
145// .eval_interval(10000)
146// .n_episodes_per_eval(5)
147// .model_dir("some/directory");
148
149// let dir = TempDir::new("trainer_builder")?;
150// let path = dir.path().join("trainer_builder.yaml");
151// println!("{:?}", path);
152
153// builder.save(&path)?;
154// let builder_ = TrainerBuilder::load(&path)?;
155// assert_eq!(builder, builder_);
156// // let yaml = serde_yaml::to_string(&trainer)?;
157// // println!("{}", yaml);
158// // assert_eq!(
159// // yaml,
160// // "---\n\
161// // max_opts: 100\n\
162// // eval_interval: 10000\n\
163// // n_episodes_per_eval: 5\n\
164// // eval_threshold: ~\n\
165// // model_dir: some/directory\n\
166// // "
167// // );
168// Ok(())
169// }
170// }