1use super::ValueConfig;
3use crate::{
4 model::{SubModel1, SubModel2},
5 util::{actor::GaussianActorConfig, critic::MultiCriticConfig, CriticLoss, OutDim},
6 Device,
7};
8use anyhow::Result;
9use candle_core::Tensor;
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13 fmt::Debug,
14 fs::File,
15 io::{BufReader, Write},
16 path::Path,
17};
18
19#[allow(clippy::upper_case_acronyms)]
21#[derive(Debug, Deserialize, Serialize, PartialEq)]
22pub struct IqlConfig<Q, P, V>
23where
24 Q: SubModel2<Output = Tensor>,
25 P: SubModel1<Output = (Tensor, Tensor)>,
26 V: SubModel1<Output = Tensor>,
27 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
28 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
29 V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
30{
31 pub value_config: ValueConfig<V::Config>,
33
34 pub critic_config: MultiCriticConfig<Q::Config>,
36
37 pub actor_config: GaussianActorConfig<P::Config>,
39
40 pub gamma: f32,
42
43 pub tau_iql: f64,
45
46 pub inv_lambda: f64,
48
49 pub n_updates_per_opt: usize,
51
52 pub batch_size: usize,
54
55 pub adv_softmax: bool,
59
60 pub critic_loss: CriticLoss,
64
65 pub device: Option<Device>,
67
68 pub exp_adv_max: f64,
70}
71
72impl<Q, P, V> Clone for IqlConfig<Q, P, V>
73where
74 Q: SubModel2<Output = Tensor>,
75 P: SubModel1<Output = (Tensor, Tensor)>,
76 V: SubModel1<Output = Tensor>,
77 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
78 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
79 V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
80{
81 fn clone(&self) -> Self {
82 Self {
83 value_config: self.value_config.clone(),
84 critic_config: self.critic_config.clone(),
85 actor_config: self.actor_config.clone(),
86 gamma: self.gamma,
87 tau_iql: self.tau_iql,
88 inv_lambda: self.inv_lambda,
89 n_updates_per_opt: self.n_updates_per_opt,
90 batch_size: self.batch_size,
91 adv_softmax: self.adv_softmax,
93 critic_loss: self.critic_loss.clone(),
94 device: self.device.clone(),
95 exp_adv_max: self.exp_adv_max,
96 }
97 }
98}
99
100impl<Q, P, V> Default for IqlConfig<Q, P, V>
101where
102 Q: SubModel2<Output = Tensor>,
103 P: SubModel1<Output = (Tensor, Tensor)>,
104 V: SubModel1<Output = Tensor>,
105 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
106 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
107 V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
108{
109 fn default() -> Self {
110 Self {
111 value_config: Default::default(),
112 critic_config: Default::default(),
113 actor_config: Default::default(),
114 gamma: 0.99,
115 tau_iql: 0.7,
116 inv_lambda: 10.0,
117 n_updates_per_opt: 1,
118 batch_size: 1,
119 adv_softmax: false,
121 critic_loss: CriticLoss::Mse,
122 device: None,
123 exp_adv_max: 100.0,
124 }
125 }
126}
127
128impl<Q, P, V> IqlConfig<Q, P, V>
129where
130 Q: SubModel2<Output = Tensor>,
131 P: SubModel1<Output = (Tensor, Tensor)>,
132 V: SubModel1<Output = Tensor>,
133 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
134 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
135 V::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
136{
137 pub fn lambda(mut self, v: f64) -> Self {
139 self.inv_lambda = 1.0 / v;
140 self
141 }
142
143 pub fn n_updates_per_opt(mut self, v: usize) -> Self {
145 self.n_updates_per_opt = v;
146 self
147 }
148
149 pub fn batch_size(mut self, v: usize) -> Self {
151 self.batch_size = v;
152 self
153 }
154
155 pub fn discount_factor(mut self, v: f32) -> Self {
157 self.gamma = v;
158 self
159 }
160
161 pub fn critic_loss(mut self, v: CriticLoss) -> Self {
171 self.critic_loss = v;
172 self
173 }
174
175 pub fn value_config(mut self, value_config: ValueConfig<V::Config>) -> Self {
177 self.value_config = value_config;
178 self
179 }
180
181 pub fn actor_config(mut self, actor_config: GaussianActorConfig<P::Config>) -> Self {
183 self.actor_config = actor_config;
184 self
185 }
186
187 pub fn critic_config(mut self, critic_config: MultiCriticConfig<Q::Config>) -> Self {
189 self.critic_config = critic_config;
190 self
191 }
192
193 pub fn device(mut self, device: candle_core::Device) -> Self {
195 self.device = Some(device.into());
196 self
197 }
198
199 pub fn adv_softmax(mut self, b: bool) -> Self {
201 self.adv_softmax = b;
202 self
203 }
204
205 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
207 let path_ = path.as_ref().to_owned();
208 let mut file = File::create(path)?;
209 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
210 info!("Save config of IQL agent into {}", path_.to_str().unwrap());
211 Ok(())
212 }
213
214 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
216 let path_ = path.as_ref().to_owned();
217 let file = File::open(path)?;
218 let rdr = BufReader::new(file);
219 let b = serde_yaml::from_reader(rdr)?;
220 info!("Load config of IQL agent from {}", path_.to_str().unwrap());
221 Ok(b)
222 }
223}