1use super::{IqnModelConfig, IqnSample};
3use crate::{
4 iqn::{IqnExplorer, Softmax},
5 model::SubModel,
6 util::OutDim,
7 Device,
8};
9use anyhow::Result;
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{
12 default::Default,
13 fs::File,
14 io::{BufReader, Write},
15 marker::PhantomData,
16 path::Path,
17};
18
19#[derive(Debug, Deserialize, Serialize, PartialEq)]
20pub struct IqnConfig<F, M>
22where
23 F: SubModel,
24 M: SubModel,
25 F::Config: DeserializeOwned + Serialize + Clone,
26 M::Config: DeserializeOwned + Serialize + Clone + OutDim,
27{
28 pub model_config: IqnModelConfig<F::Config, M::Config>,
29 pub soft_update_interval: usize,
30 pub n_updates_per_opt: usize,
31 pub batch_size: usize,
32 pub discount_factor: f64,
33 pub tau: f64,
34 pub train: bool,
35 pub explorer: IqnExplorer,
36 pub sample_percents_pred: IqnSample,
37 pub sample_percents_tgt: IqnSample,
38 pub sample_percents_act: IqnSample,
39 pub device: Option<Device>,
40 phantom: PhantomData<(F, M)>,
41}
42
43impl<F, M> Default for IqnConfig<F, M>
44where
45 F: SubModel,
46 M: SubModel,
47 F::Config: DeserializeOwned + Serialize + Clone,
48 M::Config: DeserializeOwned + Serialize + Clone + OutDim,
49{
50 fn default() -> Self {
51 Self {
52 model_config: Default::default(),
53 soft_update_interval: 1,
54 n_updates_per_opt: 1,
55 batch_size: 1,
56 discount_factor: 0.99,
57 tau: 0.005,
58 sample_percents_pred: IqnSample::Uniform8,
59 sample_percents_tgt: IqnSample::Uniform8,
60 sample_percents_act: IqnSample::Const32,
61 train: false,
62 explorer: IqnExplorer::Softmax(Softmax::new()),
63 device: None,
65 phantom: PhantomData,
66 }
67 }
68}
69
70impl<F, M> IqnConfig<F, M>
71where
72 F: SubModel,
73 M: SubModel,
74 F::Config: DeserializeOwned + Serialize + Clone,
75 M::Config: DeserializeOwned + Serialize + Clone + OutDim,
76{
77 pub fn model_config(mut self, model_config: IqnModelConfig<F::Config, M::Config>) -> Self {
79 self.model_config = model_config;
80 self
81 }
82
83 pub fn soft_update_interval(mut self, v: usize) -> Self {
85 self.soft_update_interval = v;
86 self
87 }
88
89 pub fn n_updates_per_opt(mut self, v: usize) -> Self {
91 self.n_updates_per_opt = v;
92 self
93 }
94
95 pub fn batch_size(mut self, v: usize) -> Self {
97 self.batch_size = v;
98 self
99 }
100
101 pub fn discount_factor(mut self, v: f64) -> Self {
103 self.discount_factor = v;
104 self
105 }
106
107 pub fn tau(mut self, v: f64) -> Self {
109 self.tau = v;
110 self
111 }
112
113 pub fn explorer(mut self, v: IqnExplorer) -> Self {
115 self.explorer = v;
116 self
117 }
118
119 pub fn out_dim(mut self, out_dim: i64) -> Self {
121 let model_config = self.model_config.clone();
122 self.model_config = model_config.out_dim(out_dim);
123 self
124 }
125
126 pub fn sample_percent_pred(mut self, v: IqnSample) -> Self {
128 self.sample_percents_pred = v;
129 self
130 }
131
132 pub fn sample_percent_tgt(mut self, v: IqnSample) -> Self {
134 self.sample_percents_tgt = v;
135 self
136 }
137
138 pub fn sample_percent_act(mut self, v: IqnSample) -> Self {
140 self.sample_percents_act = v;
141 self
142 }
143
144 pub fn device(mut self, device: tch::Device) -> Self {
146 self.device = Some(device.into());
147 self
148 }
149
150 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
152 let file = File::open(path)?;
153 let rdr = BufReader::new(file);
154 let b = serde_yaml::from_reader(rdr)?;
155 Ok(b)
156 }
157
158 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
160 let mut file = File::create(path)?;
161 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
162 Ok(())
163 }
164
165 }
209
210impl<F, M> Clone for IqnConfig<F, M>
211where
212 F: SubModel,
213 M: SubModel,
214 F::Config: DeserializeOwned + Serialize + Clone,
215 M::Config: DeserializeOwned + Serialize + Clone + OutDim,
216{
217 fn clone(&self) -> Self {
218 Self {
219 model_config: self.model_config.clone(),
220 soft_update_interval: self.soft_update_interval,
221 n_updates_per_opt: self.n_updates_per_opt,
222 batch_size: self.batch_size,
223 discount_factor: self.discount_factor,
224 tau: self.tau,
225 sample_percents_pred: self.sample_percents_pred.clone(),
226 sample_percents_tgt: self.sample_percents_tgt.clone(),
227 sample_percents_act: self.sample_percents_act.clone(),
228 train: self.train,
229 explorer: self.explorer.clone(),
230 device: self.device.clone(),
231 phantom: PhantomData,
232 }
233 }
234}