use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::{
fs::File,
io::{BufReader, Write},
path::Path,
};
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub struct TrainerConfig {
pub max_opts: usize,
pub opt_interval: usize,
pub eval_interval: usize,
pub flush_record_interval: usize,
pub record_compute_cost_interval: usize,
pub record_agent_info_interval: usize,
pub warmup_period: usize,
pub save_interval: usize,
}
impl Default for TrainerConfig {
fn default() -> Self {
Self {
max_opts: 0,
eval_interval: 0,
opt_interval: 1,
flush_record_interval: usize::MAX,
record_compute_cost_interval: usize::MAX,
record_agent_info_interval: usize::MAX,
warmup_period: 0,
save_interval: usize::MAX,
}
}
}
impl TrainerConfig {
pub fn max_opts(mut self, v: usize) -> Self {
self.max_opts = v;
self
}
pub fn eval_interval(mut self, v: usize) -> Self {
self.eval_interval = v;
self
}
pub fn eval_threshold( self, _v: f32) -> Self {
unimplemented!();
}
pub fn opt_interval(mut self, opt_interval: usize) -> Self {
self.opt_interval = opt_interval;
self
}
pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Self {
self.flush_record_interval = flush_record_interval;
self
}
pub fn record_compute_cost_interval(mut self, record_compute_cost_interval: usize) -> Self {
self.record_compute_cost_interval = record_compute_cost_interval;
self
}
pub fn record_agent_info_interval(mut self, record_agent_info_interval: usize) -> Self {
self.record_agent_info_interval = record_agent_info_interval;
self
}
pub fn warmup_period(mut self, warmup_period: usize) -> Self {
self.warmup_period = warmup_period;
self
}
pub fn save_interval(mut self, save_interval: usize) -> Self {
self.save_interval = save_interval;
self
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path)?;
let rdr = BufReader::new(file);
let b = serde_yaml::from_reader(rdr)?;
Ok(b)
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let mut file = File::create(path)?;
file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
Ok(())
}
}