border_tch_agent/iqn/model/
config.rs1use crate::{opt::OptimizerConfig, util::OutDim};
3use anyhow::Result;
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use std::{
6 default::Default,
7 fs::File,
8 io::{BufReader, Write},
9 path::Path,
10};
11
12impl<F, M> IqnModelConfig<F, M>
13where
14 F: DeserializeOwned + Serialize,
15 M: DeserializeOwned + Serialize + OutDim,
16{
17 pub fn learning_rate(mut self, v: f64) -> Self {
19 match &self.opt_config {
20 OptimizerConfig::Adam { lr: _ } => self.opt_config = OptimizerConfig::Adam { lr: v },
21 _ => unimplemented!(),
22 };
23 self
24 }
25}
26
27#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
28pub struct IqnModelConfig<F, M>
34where
35 M: OutDim,
36{
37 pub feature_dim: i64,
39
40 pub embed_dim: i64,
42
43 pub f_config: Option<F>,
45
46 pub m_config: Option<M>,
48
49 pub opt_config: OptimizerConfig,
51}
52
53impl<F, M> Default for IqnModelConfig<F, M>
54where
55 M: OutDim,
56{
57 fn default() -> Self {
58 Self {
59 feature_dim: 0,
60 embed_dim: 0,
61 f_config: None,
62 m_config: None,
63 opt_config: OptimizerConfig::Adam { lr: 0.0 },
64 }
65 }
66}
67
68impl<F, M> IqnModelConfig<F, M>
69where
70 F: DeserializeOwned + Serialize,
71 M: DeserializeOwned + Serialize + OutDim,
72{
73 pub fn embed_dim(mut self, v: i64) -> Self {
75 self.embed_dim = v;
76 self
77 }
78
79 pub fn feature_dim(mut self, v: i64) -> Self {
81 self.feature_dim = v;
82 self
83 }
84
85 pub fn f_config(mut self, v: F) -> Self {
87 self.f_config = Some(v);
88 self
89 }
90
91 pub fn m_config(mut self, v: M) -> Self {
93 self.m_config = Some(v);
94 self
95 }
96
97 pub fn out_dim(mut self, out_dim: i64) -> Self {
99 if self.m_config.is_some() {
100 self.m_config.as_mut().unwrap().set_out_dim(out_dim);
101 }
102 self
103 }
104
105 pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
107 self.opt_config = v;
108 self
109 }
110
111 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
113 let file = File::open(path)?;
114 let rdr = BufReader::new(file);
115 let b = serde_yaml::from_reader(rdr)?;
116 Ok(b)
117 }
118
119 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
121 let mut file = File::create(path)?;
122 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
123 Ok(())
124 }
125}