border_tch_agent/dqn/model/
config.rs1use crate::{opt::OptimizerConfig, util::OutDim};
2use anyhow::Result;
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4use std::{
5 fs::File,
6 io::{BufReader, Write},
7 path::Path,
8};
9
10#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
11pub struct DqnModelConfig<Q>
13where
14 Q: OutDim,
17{
18 pub q_config: Option<Q>,
19 pub opt_config: OptimizerConfig,
20}
21
22impl<Q> Default for DqnModelConfig<Q>
24where
25 Q: OutDim,
28{
29 fn default() -> Self {
30 Self {
31 q_config: None,
32 opt_config: OptimizerConfig::Adam { lr: 0.0 },
33 }
34 }
35}
36
37impl<Q> DqnModelConfig<Q>
39where
40 Q: DeserializeOwned + Serialize + OutDim,
43{
44 pub fn q_config(mut self, v: Q) -> Self {
47 self.q_config = Some(v);
48 self
49 }
50
51 pub fn out_dim(mut self, v: i64) -> Self {
53 match &mut self.q_config {
54 None => {}
55 Some(q_config) => q_config.set_out_dim(v),
56 };
57 self
58 }
59
60 pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
62 self.opt_config = v;
63 self
64 }
65
66 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
68 let file = File::open(path)?;
69 let rdr = BufReader::new(file);
70 let b = serde_yaml::from_reader(rdr)?;
71 Ok(b)
72 }
73
74 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
76 let mut file = File::create(path)?;
77 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
78 Ok(())
79 }
80}