1use super::{
3 explorer::{DqnExplorer, Softmax},
4 DqnModelConfig,
5};
6use crate::{
7 model::SubModel,
8 opt::OptimizerConfig,
9 util::{CriticLoss, OutDim},
10 Device,
11};
12use anyhow::Result;
13use log::info;
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15use std::{
16 default::Default,
17 fs::File,
18 io::{BufReader, Write},
19 marker::PhantomData,
20 path::Path,
21};
22use tch::Tensor;
23
24#[derive(Debug, Deserialize, Serialize, PartialEq)]
26pub struct DqnConfig<Q>
27where
28 Q: SubModel<Output = Tensor>,
29 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
30{
31 pub model_config: DqnModelConfig<Q::Config>,
32 pub soft_update_interval: usize,
33 pub n_updates_per_opt: usize,
34 pub batch_size: usize,
35 pub discount_factor: f64,
36 pub tau: f64,
37 pub train: bool,
38 pub explorer: DqnExplorer,
39 #[serde(default)]
40 pub clip_reward: Option<f64>,
41 #[serde(default)]
42 pub double_dqn: bool,
43 pub clip_td_err: Option<(f64, f64)>,
44 pub device: Option<Device>,
45 pub critic_loss: CriticLoss,
46 pub record_verbose_level: usize,
47 pub phantom: PhantomData<Q>,
48}
49
50impl<Q> Clone for DqnConfig<Q>
51where
52 Q: SubModel<Output = Tensor>,
53 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
54{
55 fn clone(&self) -> Self {
56 Self {
57 model_config: self.model_config.clone(),
58 soft_update_interval: self.soft_update_interval,
59 n_updates_per_opt: self.n_updates_per_opt,
60 batch_size: self.batch_size,
61 discount_factor: self.discount_factor,
62 tau: self.tau,
63 train: self.train,
64 explorer: self.explorer.clone(),
65 clip_reward: self.clip_reward,
66 double_dqn: self.double_dqn,
67 clip_td_err: self.clip_td_err,
68 device: self.device.clone(),
69 critic_loss: self.critic_loss.clone(),
70 record_verbose_level: self.record_verbose_level,
71 phantom: PhantomData,
72 }
73 }
74}
75
76impl<Q> Default for DqnConfig<Q>
77where
78 Q: SubModel<Output = Tensor>,
79 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
80{
81 fn default() -> Self {
83 Self {
84 model_config: Default::default(),
85 soft_update_interval: 1,
86 n_updates_per_opt: 1,
87 batch_size: 1,
88 discount_factor: 0.99,
89 tau: 0.005,
90 train: false,
91 explorer: DqnExplorer::Softmax(Softmax::new()),
93 clip_reward: None,
95 double_dqn: false,
96 clip_td_err: None,
97 device: None,
98 critic_loss: CriticLoss::Mse,
99 record_verbose_level: 0,
100 phantom: PhantomData,
101 }
102 }
103}
104
105impl<Q> DqnConfig<Q>
106where
107 Q: SubModel<Output = Tensor>,
108 Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
109{
110 pub fn soft_update_interval(mut self, v: usize) -> Self {
112 self.soft_update_interval = v;
113 self
114 }
115
116 pub fn n_updates_per_opt(mut self, v: usize) -> Self {
118 self.n_updates_per_opt = v;
119 self
120 }
121
122 pub fn batch_size(mut self, v: usize) -> Self {
124 self.batch_size = v;
125 self
126 }
127
128 pub fn discount_factor(mut self, v: f64) -> Self {
130 self.discount_factor = v;
131 self
132 }
133
134 pub fn tau(mut self, v: f64) -> Self {
136 self.tau = v;
137 self
138 }
139
140 pub fn explorer(mut self, v: DqnExplorer) -> Self {
142 self.explorer = v;
143 self
144 }
145
146 pub fn model_config(mut self, model_config: DqnModelConfig<Q::Config>) -> Self {
148 self.model_config = model_config;
149 self
150 }
151
152 pub fn opt_config(mut self, opt_config: OptimizerConfig) -> Self {
154 self.model_config = self.model_config.opt_config(opt_config);
155 self
156 }
157
158 pub fn out_dim(mut self, out_dim: i64) -> Self {
160 let model_config = self.model_config.clone();
161 self.model_config = model_config.out_dim(out_dim);
162 self
163 }
164
165 pub fn clip_reward(mut self, clip_reward: Option<f64>) -> Self {
167 self.clip_reward = clip_reward;
168 self
169 }
170
171 pub fn double_dqn(mut self, double_dqn: bool) -> Self {
173 self.double_dqn = double_dqn;
174 self
175 }
176
177 pub fn clip_td_err(mut self, clip_td_err: Option<(f64, f64)>) -> Self {
179 self.clip_td_err = clip_td_err;
180 self
181 }
182
183 pub fn device(mut self, device: tch::Device) -> Self {
185 self.device = Some(device.into());
186 self
187 }
188
189 pub fn critic_loss(mut self, v: CriticLoss) -> Self {
191 self.critic_loss = v;
192 self
193 }
194
195 pub fn record_verbose_level(mut self, v: usize) -> Self {
197 self.record_verbose_level = v;
198 self
199 }
200
201 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
203 let path_ = path.as_ref().to_owned();
204 let file = File::open(path)?;
205 let rdr = BufReader::new(file);
206 let b = serde_yaml::from_reader(rdr)?;
207 info!("Load config of DQN agent from {}", path_.to_str().unwrap());
208 Ok(b)
209 }
210
211 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
213 let path_ = path.as_ref().to_owned();
214 let mut file = File::create(path)?;
215 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
216 info!("Save config of DQN agent into {}", path_.to_str().unwrap());
217 Ok(())
218 }
219}