border_async_trainer/async_trainer/
config.rs1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::{
4 fs::File,
5 io::{BufReader, Write},
6 path::Path,
7};
8
9#[derive(Clone, Debug, Deserialize, Serialize)]
11pub struct AsyncTrainerConfig {
12 pub max_opts: usize,
14
15 pub eval_interval: usize,
17
18 pub flush_record_interval: usize,
20
21 pub record_compute_cost_interval: usize,
23
24 pub record_agent_info_interval: usize,
26
27 pub save_interval: usize,
29
30 pub sync_interval: usize,
32
33 pub warmup_period: usize,
35}
36
37impl AsyncTrainerConfig {
38 pub fn max_opts(mut self, v: usize) -> Result<Self> {
40 self.max_opts = v;
41 Ok(self)
42 }
43
44 pub fn eval_interval(mut self, v: usize) -> Result<Self> {
46 self.eval_interval = v;
47 Ok(self)
48 }
49
50 pub fn record_compute_cost_interval(
52 mut self,
53 record_compute_cost_interval: usize,
54 ) -> Result<Self> {
55 self.record_compute_cost_interval = record_compute_cost_interval;
56 Ok(self)
57 }
58
59 pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Result<Self> {
61 self.flush_record_interval = flush_record_interval;
62 Ok(self)
63 }
64
65 pub fn warmup_period(mut self, warmup_period: usize) -> Result<Self> {
67 self.warmup_period = warmup_period;
68 Ok(self)
69 }
70
71 pub fn save_interval(mut self, save_interval: usize) -> Result<Self> {
73 self.save_interval = save_interval;
74 Ok(self)
75 }
76
77 pub fn sync_interval(mut self, sync_interval: usize) -> Result<Self> {
79 self.sync_interval = sync_interval;
80 Ok(self)
81 }
82
83 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
85 let file = File::open(path)?;
86 let rdr = BufReader::new(file);
87 let b = serde_yaml::from_reader(rdr)?;
88 Ok(b)
89 }
90
91 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
93 let mut file = File::create(path)?;
94 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
95 Ok(())
96 }
97}
98
99impl Default for AsyncTrainerConfig {
100 fn default() -> Self {
102 Self {
103 max_opts: 10, eval_interval: 5000,
105 flush_record_interval: 5000,
106 record_compute_cost_interval: 5000,
107 record_agent_info_interval: 5000,
108 save_interval: 50000,
109 sync_interval: 100,
110 warmup_period: 10000,
111 }
112 }
113}