border_tch_agent/iqn/
config.rs

1//! Configuration of IQN agent.
2use super::{IqnModelConfig, IqnSample};
3use crate::{
4    iqn::{IqnExplorer, Softmax},
5    model::SubModel,
6    util::OutDim,
7    Device,
8};
9use anyhow::Result;
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{
12    default::Default,
13    fs::File,
14    io::{BufReader, Write},
15    marker::PhantomData,
16    path::Path,
17};
18
19#[derive(Debug, Deserialize, Serialize, PartialEq)]
20/// Configuration of [`Iqn`](super::Iqn) agent.
21pub struct IqnConfig<F, M>
22where
23    F: SubModel,
24    M: SubModel,
25    F::Config: DeserializeOwned + Serialize + Clone,
26    M::Config: DeserializeOwned + Serialize + Clone + OutDim,
27{
28    pub model_config: IqnModelConfig<F::Config, M::Config>,
29    pub soft_update_interval: usize,
30    pub n_updates_per_opt: usize,
31    pub batch_size: usize,
32    pub discount_factor: f64,
33    pub tau: f64,
34    pub train: bool,
35    pub explorer: IqnExplorer,
36    pub sample_percents_pred: IqnSample,
37    pub sample_percents_tgt: IqnSample,
38    pub sample_percents_act: IqnSample,
39    pub device: Option<Device>,
40    phantom: PhantomData<(F, M)>,
41}
42
43impl<F, M> Default for IqnConfig<F, M>
44where
45    F: SubModel,
46    M: SubModel,
47    F::Config: DeserializeOwned + Serialize + Clone,
48    M::Config: DeserializeOwned + Serialize + Clone + OutDim,
49{
50    fn default() -> Self {
51        Self {
52            model_config: Default::default(),
53            soft_update_interval: 1,
54            n_updates_per_opt: 1,
55            batch_size: 1,
56            discount_factor: 0.99,
57            tau: 0.005,
58            sample_percents_pred: IqnSample::Uniform8,
59            sample_percents_tgt: IqnSample::Uniform8,
60            sample_percents_act: IqnSample::Const32,
61            train: false,
62            explorer: IqnExplorer::Softmax(Softmax::new()),
63            // explorer: IqnExplorer::EpsilonGreedy(EpsilonGreedy::default()),
64            device: None,
65            phantom: PhantomData,
66        }
67    }
68}
69
70impl<F, M> IqnConfig<F, M>
71where
72    F: SubModel,
73    M: SubModel,
74    F::Config: DeserializeOwned + Serialize + Clone,
75    M::Config: DeserializeOwned + Serialize + Clone + OutDim,
76{
77    /// Sets the configuration of the model.
78    pub fn model_config(mut self, model_config: IqnModelConfig<F::Config, M::Config>) -> Self {
79        self.model_config = model_config;
80        self
81    }
82
83    /// Set soft update interval.
84    pub fn soft_update_interval(mut self, v: usize) -> Self {
85        self.soft_update_interval = v;
86        self
87    }
88
89    /// Set numper of parameter update steps per optimization step.
90    pub fn n_updates_per_opt(mut self, v: usize) -> Self {
91        self.n_updates_per_opt = v;
92        self
93    }
94
95    /// Batch size.
96    pub fn batch_size(mut self, v: usize) -> Self {
97        self.batch_size = v;
98        self
99    }
100
101    /// Discount factor.
102    pub fn discount_factor(mut self, v: f64) -> Self {
103        self.discount_factor = v;
104        self
105    }
106
107    /// Soft update coefficient.
108    pub fn tau(mut self, v: f64) -> Self {
109        self.tau = v;
110        self
111    }
112
113    /// Set explorer.
114    pub fn explorer(mut self, v: IqnExplorer) -> Self {
115        self.explorer = v;
116        self
117    }
118
119    /// Sets the output dimention of the iqn model.
120    pub fn out_dim(mut self, out_dim: i64) -> Self {
121        let model_config = self.model_config.clone();
122        self.model_config = model_config.out_dim(out_dim);
123        self
124    }
125
126    /// Sampling percent points.
127    pub fn sample_percent_pred(mut self, v: IqnSample) -> Self {
128        self.sample_percents_pred = v;
129        self
130    }
131
132    /// Sampling percent points.
133    pub fn sample_percent_tgt(mut self, v: IqnSample) -> Self {
134        self.sample_percents_tgt = v;
135        self
136    }
137
138    /// Sampling percent points.
139    pub fn sample_percent_act(mut self, v: IqnSample) -> Self {
140        self.sample_percents_act = v;
141        self
142    }
143
144    /// Device.
145    pub fn device(mut self, device: tch::Device) -> Self {
146        self.device = Some(device.into());
147        self
148    }
149
150    /// Constructs [`IqnConfig`] from YAML file.
151    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
152        let file = File::open(path)?;
153        let rdr = BufReader::new(file);
154        let b = serde_yaml::from_reader(rdr)?;
155        Ok(b)
156    }
157
158    /// Saves [`IqnConfig`].
159    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
160        let mut file = File::create(path)?;
161        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
162        Ok(())
163    }
164
165    // /// Constructs [IQN] agent with the given replay buffer.
166    // pub fn build_with_replay_bufferbuild<E, F, M, O, A>(
167    //     self,
168    //     iqn_model: IQNModel<F, M>,
169    //     replay_buffer: ReplayBuffer<E, O, A>,
170    //     expr_sampling: ExperienceSampling,
171    //     device: Device,
172    // ) -> IQN<E, F, M, O, A>
173    // where
174    //     E: Env,
175    //     F: SubModel<Output = Tensor>,
176    //     M: SubModel<Input = Tensor, Output = Tensor>,
177    //     E::Obs: Into<F::Input>,
178    //     E::Act: From<Tensor>,
179    //     O: TchBuffer<Item = E::Obs, SubBatch = F::Input>,
180    //     A: TchBuffer<Item = E::Act, SubBatch = Tensor>,
181    // {
182    //     let iqn = iqn_model;
183    //     let iqn_tgt = iqn.clone();
184
185    //     IQN {
186    //         iqn,
187    //         iqn_tgt,
188    //         replay_buffer,
189    //         prev_obs: RefCell::new(None),
190    //         opt_interval_counter: self.opt_interval_counter,
191    //         soft_update_interval: self.soft_update_interval,
192    //         soft_update_counter: 0,
193    //         n_updates_per_opt: self.n_updates_per_opt,
194    //         min_transitions_warmup: self.min_transitions_warmup,
195    //         batch_size: self.batch_size,
196    //         discount_factor: self.discount_factor,
197    //         tau: self.tau,
198    //         sample_percents_pred: self.sample_percents_pred,
199    //         sample_percents_tgt: self.sample_percents_tgt,
200    //         sample_percents_act: self.sample_percents_act,
201    //         train: self.train,
202    //         explorer: self.explorer,
203    //         // expr_sampling,
204    //         device,
205    //         phantom: PhantomData,
206    //     }
207    // }
208}
209
210impl<F, M> Clone for IqnConfig<F, M>
211where
212    F: SubModel,
213    M: SubModel,
214    F::Config: DeserializeOwned + Serialize + Clone,
215    M::Config: DeserializeOwned + Serialize + Clone + OutDim,
216{
217    fn clone(&self) -> Self {
218        Self {
219            model_config: self.model_config.clone(),
220            soft_update_interval: self.soft_update_interval,
221            n_updates_per_opt: self.n_updates_per_opt,
222            batch_size: self.batch_size,
223            discount_factor: self.discount_factor,
224            tau: self.tau,
225            sample_percents_pred: self.sample_percents_pred.clone(),
226            sample_percents_tgt: self.sample_percents_tgt.clone(),
227            sample_percents_act: self.sample_percents_act.clone(),
228            train: self.train,
229            explorer: self.explorer.clone(),
230            device: self.device.clone(),
231            phantom: PhantomData,
232        }
233    }
234}