border_async_trainer/async_trainer/
config.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::{
4    fs::File,
5    io::{BufReader, Write},
6    path::Path,
7};
8
9/// Configuration of [`AsyncTrainer`](crate::AsyncTrainer).
10#[derive(Clone, Debug, Deserialize, Serialize)]
11pub struct AsyncTrainerConfig {
12    /// The maximum number of optimization steps.
13    pub max_opts: usize,
14
15    /// Interval of evaluation in training steps.
16    pub eval_interval: usize,
17
18    /// Interval of flushing records in optimization steps.
19    pub flush_record_interval: usize,
20
21    /// Interval of recording agent information in optimization steps.
22    pub record_compute_cost_interval: usize,
23
24    /// Interval of recording agent information in optimization steps.
25    pub record_agent_info_interval: usize,
26
27    /// Interval of saving the model in optimization steps.
28    pub save_interval: usize,
29
30    /// Interval of synchronizing model parameters in training steps.
31    pub sync_interval: usize,
32
33    /// Warmup period, for filling replay buffer, in environment steps
34    pub warmup_period: usize,
35}
36
37impl AsyncTrainerConfig {
38    /// Sets the number of optimization steps.
39    pub fn max_opts(mut self, v: usize) -> Result<Self> {
40        self.max_opts = v;
41        Ok(self)
42    }
43
44    /// Sets the interval of evaluation in optimization steps.
45    pub fn eval_interval(mut self, v: usize) -> Result<Self> {
46        self.eval_interval = v;
47        Ok(self)
48    }
49
50    /// Sets the interval of computation cost in optimization steps.
51    pub fn record_compute_cost_interval(
52        mut self,
53        record_compute_cost_interval: usize,
54    ) -> Result<Self> {
55        self.record_compute_cost_interval = record_compute_cost_interval;
56        Ok(self)
57    }
58
59    /// Sets the interval of flushing recordd in optimization steps.
60    pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Result<Self> {
61        self.flush_record_interval = flush_record_interval;
62        Ok(self)
63    }
64
65    /// Sets warmup period in environment steps.
66    pub fn warmup_period(mut self, warmup_period: usize) -> Result<Self> {
67        self.warmup_period = warmup_period;
68        Ok(self)
69    }
70
71    /// Sets the interval of saving in optimization steps.
72    pub fn save_interval(mut self, save_interval: usize) -> Result<Self> {
73        self.save_interval = save_interval;
74        Ok(self)
75    }
76
77    /// Sets the interval of synchronizing model parameters in training steps.
78    pub fn sync_interval(mut self, sync_interval: usize) -> Result<Self> {
79        self.sync_interval = sync_interval;
80        Ok(self)
81    }
82
83    /// Constructs [AsyncTrainerConfig] from YAML file.
84    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
85        let file = File::open(path)?;
86        let rdr = BufReader::new(file);
87        let b = serde_yaml::from_reader(rdr)?;
88        Ok(b)
89    }
90
91    /// Saves [AsyncTrainerConfig].
92    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
93        let mut file = File::create(path)?;
94        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
95        Ok(())
96    }
97}
98
99impl Default for AsyncTrainerConfig {
100    /// There is no special intention behind these initial values.
101    fn default() -> Self {
102        Self {
103            max_opts: 10, //000,
104            eval_interval: 5000,
105            flush_record_interval: 5000,
106            record_compute_cost_interval: 5000,
107            record_agent_info_interval: 5000,
108            save_interval: 50000,
109            sync_interval: 100,
110            warmup_period: 10000,
111        }
112    }
113}