border_tch_agent/dqn/
config.rs

1//! Configuration of DQN agent.
2use super::{
3    explorer::{DqnExplorer, Softmax},
4    DqnModelConfig,
5};
6use crate::{
7    model::SubModel,
8    opt::OptimizerConfig,
9    util::{CriticLoss, OutDim},
10    Device,
11};
12use anyhow::Result;
13use log::info;
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15use std::{
16    default::Default,
17    fs::File,
18    io::{BufReader, Write},
19    marker::PhantomData,
20    path::Path,
21};
22use tch::Tensor;
23
24/// Configuration of [`Dqn`](super::Dqn) agent.
25#[derive(Debug, Deserialize, Serialize, PartialEq)]
26pub struct DqnConfig<Q>
27where
28    Q: SubModel<Output = Tensor>,
29    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
30{
31    pub model_config: DqnModelConfig<Q::Config>,
32    pub soft_update_interval: usize,
33    pub n_updates_per_opt: usize,
34    pub batch_size: usize,
35    pub discount_factor: f64,
36    pub tau: f64,
37    pub train: bool,
38    pub explorer: DqnExplorer,
39    #[serde(default)]
40    pub clip_reward: Option<f64>,
41    #[serde(default)]
42    pub double_dqn: bool,
43    pub clip_td_err: Option<(f64, f64)>,
44    pub device: Option<Device>,
45    pub critic_loss: CriticLoss,
46    pub record_verbose_level: usize,
47    pub phantom: PhantomData<Q>,
48}
49
50impl<Q> Clone for DqnConfig<Q>
51where
52    Q: SubModel<Output = Tensor>,
53    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
54{
55    fn clone(&self) -> Self {
56        Self {
57            model_config: self.model_config.clone(),
58            soft_update_interval: self.soft_update_interval,
59            n_updates_per_opt: self.n_updates_per_opt,
60            batch_size: self.batch_size,
61            discount_factor: self.discount_factor,
62            tau: self.tau,
63            train: self.train,
64            explorer: self.explorer.clone(),
65            clip_reward: self.clip_reward,
66            double_dqn: self.double_dqn,
67            clip_td_err: self.clip_td_err,
68            device: self.device.clone(),
69            critic_loss: self.critic_loss.clone(),
70            record_verbose_level: self.record_verbose_level,
71            phantom: PhantomData,
72        }
73    }
74}
75
76impl<Q> Default for DqnConfig<Q>
77where
78    Q: SubModel<Output = Tensor>,
79    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
80{
81    /// Constructs DQN builder with default parameters.
82    fn default() -> Self {
83        Self {
84            model_config: Default::default(),
85            soft_update_interval: 1,
86            n_updates_per_opt: 1,
87            batch_size: 1,
88            discount_factor: 0.99,
89            tau: 0.005,
90            train: false,
91            // replay_burffer_capacity: 100,
92            explorer: DqnExplorer::Softmax(Softmax::new()),
93            // expr_sampling: ExperienceSampling::Uniform,
94            clip_reward: None,
95            double_dqn: false,
96            clip_td_err: None,
97            device: None,
98            critic_loss: CriticLoss::Mse,
99            record_verbose_level: 0,
100            phantom: PhantomData,
101        }
102    }
103}
104
105impl<Q> DqnConfig<Q>
106where
107    Q: SubModel<Output = Tensor>,
108    Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
109{
110    /// Sets soft update interval.
111    pub fn soft_update_interval(mut self, v: usize) -> Self {
112        self.soft_update_interval = v;
113        self
114    }
115
116    /// Sets the numper of parameter update steps per optimization step.
117    pub fn n_updates_per_opt(mut self, v: usize) -> Self {
118        self.n_updates_per_opt = v;
119        self
120    }
121
122    /// Batch size.
123    pub fn batch_size(mut self, v: usize) -> Self {
124        self.batch_size = v;
125        self
126    }
127
128    /// Discount factor.
129    pub fn discount_factor(mut self, v: f64) -> Self {
130        self.discount_factor = v;
131        self
132    }
133
134    /// Soft update coefficient.
135    pub fn tau(mut self, v: f64) -> Self {
136        self.tau = v;
137        self
138    }
139
140    /// Explorer.
141    pub fn explorer(mut self, v: DqnExplorer) -> Self {
142        self.explorer = v;
143        self
144    }
145
146    /// Sets the configuration of the model.
147    pub fn model_config(mut self, model_config: DqnModelConfig<Q::Config>) -> Self {
148        self.model_config = model_config;
149        self
150    }
151
152    /// Sets the configration of the optimizer.
153    pub fn opt_config(mut self, opt_config: OptimizerConfig) -> Self {
154        self.model_config = self.model_config.opt_config(opt_config);
155        self
156    }
157
158    /// Sets the output dimention of the dqn model of the DQN agent.
159    pub fn out_dim(mut self, out_dim: i64) -> Self {
160        let model_config = self.model_config.clone();
161        self.model_config = model_config.out_dim(out_dim);
162        self
163    }
164
165    /// Reward clipping.
166    pub fn clip_reward(mut self, clip_reward: Option<f64>) -> Self {
167        self.clip_reward = clip_reward;
168        self
169    }
170
171    /// Double DQN
172    pub fn double_dqn(mut self, double_dqn: bool) -> Self {
173        self.double_dqn = double_dqn;
174        self
175    }
176
177    /// TD-error clipping.
178    pub fn clip_td_err(mut self, clip_td_err: Option<(f64, f64)>) -> Self {
179        self.clip_td_err = clip_td_err;
180        self
181    }
182
183    /// Device.
184    pub fn device(mut self, device: tch::Device) -> Self {
185        self.device = Some(device.into());
186        self
187    }
188
189    /// Sets critic loss.
190    pub fn critic_loss(mut self, v: CriticLoss) -> Self {
191        self.critic_loss = v;
192        self
193    }
194
195    /// Sets verbose level.
196    pub fn record_verbose_level(mut self, v: usize) -> Self {
197        self.record_verbose_level = v;
198        self
199    }
200
201    /// Loads [`DqnConfig`] from YAML file.
202    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
203        let path_ = path.as_ref().to_owned();
204        let file = File::open(path)?;
205        let rdr = BufReader::new(file);
206        let b = serde_yaml::from_reader(rdr)?;
207        info!("Load config of DQN agent from {}", path_.to_str().unwrap());
208        Ok(b)
209    }
210
211    /// Saves [`DqnConfig`].
212    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
213        let path_ = path.as_ref().to_owned();
214        let mut file = File::create(path)?;
215        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
216        info!("Save config of DQN agent into {}", path_.to_str().unwrap());
217        Ok(())
218    }
219}