1use super::IqnModelConfig;
3use crate::{
4 model::{ModelBase, SubModel},
5 opt::{Optimizer, OptimizerConfig},
6 util::OutDim,
7};
8use anyhow::{Context, Result};
9use log::{info, trace};
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{default::Default, f64::consts::PI, marker::PhantomData, path::Path};
12use tch::{
13 nn,
14 nn::{Module, VarStore},
15 Device,
16 Kind::Float,
17 Tensor,
18};
19
20#[allow(clippy::upper_case_acronyms)]
21pub struct IqnModel<F, M>
24where
25 F: SubModel<Output = Tensor>,
26 M: SubModel<Input = Tensor, Output = Tensor>,
27 F::Config: DeserializeOwned + Serialize,
28 M::Config: DeserializeOwned + Serialize,
29{
30 device: Device,
31 var_store: nn::VarStore,
32
33 feature_dim: i64,
36
37 embed_dim: i64,
39
40 pub(super) out_dim: i64,
42
43 psi: F,
45
46 phi: nn::Sequential,
48
49 f: M,
51
52 opt_config: OptimizerConfig,
54 opt: Optimizer,
55
56 phantom: PhantomData<(F, M)>,
57}
58
59impl<F, M> IqnModel<F, M>
60where
61 F: SubModel<Output = Tensor>,
62 M: SubModel<Input = Tensor, Output = Tensor>,
63 F::Config: DeserializeOwned + Serialize,
64 M::Config: DeserializeOwned + Serialize + OutDim,
65{
66 pub fn build(
68 config: IqnModelConfig<F::Config, M::Config>,
69 device: Device,
70 ) -> Result<IqnModel<F, M>> {
71 let f_config = config.f_config.context("f_config is not set.")?;
72 let m_config = config.m_config.context("m_config is not set.")?;
73 let feature_dim = config.feature_dim;
74 let embed_dim = config.embed_dim;
75 let out_dim = m_config.get_out_dim();
76 let opt_config = config.opt_config;
77 let var_store = nn::VarStore::new(device);
78
79 let psi = F::build(&var_store, f_config);
81
82 let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
84
85 let f = M::build(&var_store, m_config);
87
88 let opt = opt_config.build(&var_store)?;
90
91 Ok(IqnModel {
99 device,
100 var_store,
101 feature_dim,
102 embed_dim,
103 out_dim,
104 psi,
105 phi,
106 f,
107 opt_config,
108 opt,
109 phantom: PhantomData,
110 })
111 }
112
113 pub fn build_with_submodel_configs(
115 config: IqnModelConfig<F::Config, M::Config>,
116 f_config: F::Config,
117 m_config: M::Config,
118 device: Device,
119 ) -> IqnModel<F, M> {
120 let feature_dim = config.feature_dim;
121 let embed_dim = config.embed_dim;
122 let out_dim = m_config.get_out_dim();
123 let opt_config = config.opt_config.clone();
124 let var_store = nn::VarStore::new(device);
125
126 let psi = F::build(&var_store, f_config);
128
129 let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
131
132 let f = M::build(&var_store, m_config);
134
135 let opt = opt_config.build(&var_store).unwrap();
138
139 IqnModel {
147 device,
148 var_store,
149 feature_dim,
150 embed_dim,
151 out_dim,
152 psi,
153 phi,
154 f,
155 opt_config,
156 opt,
157 phantom: PhantomData,
158 }
159 }
160
161 fn cos_embed_nn(var_store: &VarStore, feature_dim: i64, embed_dim: i64) -> nn::Sequential {
163 let p = &var_store.root();
164 let device = p.device();
165 nn::seq()
166 .add_fn(move |tau| {
167 let batch_size = tau.size().as_slice()[0];
168 let n_percent_points = tau.size().as_slice()[1];
169 let tau = tau.unsqueeze(-1);
170 let i = Tensor::range(1, embed_dim, (Float, device))
171 .unsqueeze(0)
172 .unsqueeze(0);
173 debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points, 1]);
174 debug_assert_eq!(i.size().as_slice(), &[1, 1, embed_dim]);
175
176 let cos = Tensor::cos(&(tau * (PI * i)));
177 debug_assert_eq!(
178 cos.size().as_slice(),
179 &[batch_size, n_percent_points, embed_dim]
180 );
181
182 cos.reshape(&[-1, embed_dim])
183 })
184 .add(nn::linear(
185 p / "iqn_cos_to_feature",
186 embed_dim,
187 feature_dim,
188 Default::default(),
189 ))
190 .add_fn(|x| x.relu())
191 }
192
193 pub fn forward(&self, x: &F::Input, tau: &Tensor) -> Tensor {
199 let feature_dim = self.feature_dim;
201 let n_percent_points = tau.size().as_slice()[1];
202
203 let psi = self.psi.forward(x);
205 let batch_size = psi.size().as_slice()[0];
206 debug_assert_eq!(psi.size().as_slice(), &[batch_size, feature_dim]);
207
208 debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points]);
210 let phi = self.phi.forward(tau);
211 debug_assert_eq!(
212 phi.size().as_slice(),
213 &[batch_size * n_percent_points, self.feature_dim]
214 );
215 let phi = phi.reshape(&[batch_size, n_percent_points, self.feature_dim]);
216
217 let psi = psi.unsqueeze(1);
219 debug_assert_eq!(psi.size().as_slice(), &[batch_size, 1, self.feature_dim]);
220 let m = psi * phi;
221 debug_assert_eq!(
222 m.size().as_slice(),
223 &[batch_size, n_percent_points, self.feature_dim]
224 );
225
226 let a = self.f.forward(&m);
228 debug_assert_eq!(
229 a.size().as_slice(),
230 &[batch_size, n_percent_points, self.out_dim]
231 );
232
233 a
234 }
235}
236
237impl<F, M> Clone for IqnModel<F, M>
238where
239 F: SubModel<Output = Tensor>,
240 M: SubModel<Input = Tensor, Output = Tensor>,
241 F::Config: DeserializeOwned + Serialize,
242 M::Config: DeserializeOwned + Serialize + OutDim,
243{
244 fn clone(&self) -> Self {
245 let device = self.device;
246 let feature_dim = self.feature_dim;
247 let embed_dim = self.embed_dim;
248 let out_dim = self.out_dim;
249 let opt_config = self.opt_config.clone();
250 let mut var_store = nn::VarStore::new(device);
251
252 let psi = self.psi.clone_with_var_store(&var_store);
254
255 let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
257
258 let f = self.f.clone_with_var_store(&var_store);
260
261 let opt = opt_config.build(&var_store).unwrap();
263
264 var_store.copy(&self.var_store).unwrap();
272
273 Self {
274 device,
275 var_store,
276 feature_dim,
277 embed_dim,
278 out_dim,
279 psi,
280 phi,
281 f,
282 opt_config,
283 opt,
284 phantom: PhantomData,
285 }
286 }
287}
288
289impl<F, M> ModelBase for IqnModel<F, M>
290where
291 F: SubModel<Output = Tensor>,
292 M: SubModel<Input = Tensor, Output = Tensor>,
293 F::Config: DeserializeOwned + Serialize,
294 M::Config: DeserializeOwned + Serialize,
295{
296 fn backward_step(&mut self, loss: &Tensor) {
297 self.opt.backward_step(loss);
298 }
299
300 fn get_var_store(&self) -> &nn::VarStore {
301 &self.var_store
302 }
303
304 fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
305 &mut self.var_store
306 }
307
308 fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
309 self.var_store.save(&path)?;
310 info!("Save IQN model to {:?}", path.as_ref());
311 let vs = self.var_store.variables();
312 for (name, _) in vs.iter() {
313 trace!("Save variable {}", name);
314 }
315 Ok(())
316 }
317
318 fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
319 self.var_store.load(&path)?;
320 info!("Load IQN model from {:?}", path.as_ref());
321 Ok(())
322 }
323}
324
325#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
326pub enum IqnSample {
328 Const10,
332
333 Const32,
335
336 Uniform10,
338
339 Uniform8,
341
342 Uniform32,
344
345 Uniform64,
347
348 Median,
350}
351
352impl IqnSample {
353 pub fn sample(&self, batch_size: i64) -> Tensor {
355 match self {
356 Self::Const10 => Tensor::from_slice(&[
357 0.05_f32, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95,
358 ])
359 .unsqueeze(0)
360 .repeat(&[batch_size, 1]),
361 Self::Const32 => {
362 let t: Tensor = (1.0 / 32.0) * Tensor::range(0, 32, tch::kind::FLOAT_CPU);
363 t.unsqueeze(0).repeat(&[batch_size, 1])
364 }
365 Self::Uniform10 => Tensor::rand(&[batch_size, 10], tch::kind::FLOAT_CPU),
366 Self::Uniform8 => Tensor::rand(&[batch_size, 8], tch::kind::FLOAT_CPU),
367 Self::Uniform32 => Tensor::rand(&[batch_size, 32], tch::kind::FLOAT_CPU),
368 Self::Uniform64 => Tensor::rand(&[batch_size, 64], tch::kind::FLOAT_CPU),
369 Self::Median => Tensor::from_slice(&[0.5_f32])
370 .unsqueeze(0)
371 .repeat(&[batch_size, 1]),
372 }
373 }
374
375 pub fn n_percent_points(&self) -> i64 {
377 match self {
378 Self::Const10 => 10,
379 Self::Const32 => 32,
380 Self::Uniform10 => 10,
381 Self::Uniform8 => 8,
382 Self::Uniform32 => 32,
383 Self::Uniform64 => 64,
384 Self::Median => 1,
385 }
386 }
387}
388
389pub fn average<F, M>(
395 batch_size: i64,
396 obs: &F::Input,
397 iqn: &IqnModel<F, M>,
398 mode: &IqnSample,
399 device: Device,
400) -> Tensor
401where
402 F: SubModel<Output = Tensor>,
403 M: SubModel<Input = Tensor, Output = Tensor>,
404 F::Config: DeserializeOwned + Serialize,
405 M::Config: DeserializeOwned + Serialize + OutDim,
406{
407 let tau = mode.sample(batch_size).to(device);
408 let averaged_action_value = iqn
409 .forward(obs, &tau)
410 .mean_dim(Some([1].as_slice()), false, Float);
411 let batch_size = averaged_action_value.size()[0];
412 let n_action = iqn.out_dim;
413 debug_assert_eq!(
414 averaged_action_value.size().as_slice(),
415 &[batch_size, n_action]
416 );
417 averaged_action_value
418}
419
420#[cfg(test)]
421mod test {
422 use super::super::IqnModelConfig;
423 use super::*;
424 use crate::util::OutDim;
425 use std::default::Default;
426 use tch::{nn, Device, Tensor};
427
428 #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
429 struct FeatureExtractorConfig {}
430
431 struct FeatureExtractor {}
432
433 impl SubModel for FeatureExtractor {
434 type Config = FeatureExtractorConfig;
435 type Input = Tensor;
436 type Output = Tensor;
437
438 fn clone_with_var_store(&self, _var_store: &nn::VarStore) -> Self {
439 Self {}
440 }
441
442 fn build(_var_store: &VarStore, _config: Self::Config) -> Self {
443 Self {}
444 }
445
446 fn forward(&self, input: &Self::Input) -> Self::Output {
447 input.copy()
448 }
449 }
450
451 #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
452 struct MergeConfig {
453 out_dim: i64,
454 }
455
456 impl OutDim for MergeConfig {
457 fn get_out_dim(&self) -> i64 {
458 self.out_dim
459 }
460
461 fn set_out_dim(&mut self, v: i64) {
462 self.out_dim = v;
463 }
464 }
465
466 struct Merge {}
467
468 impl SubModel for Merge {
469 type Config = MergeConfig;
470 type Input = Tensor;
471 type Output = Tensor;
472
473 fn clone_with_var_store(&self, _var_store: &nn::VarStore) -> Self {
474 Self {}
475 }
476
477 fn build(_var_store: &VarStore, _config: Self::Config) -> Self {
478 Self {}
479 }
480
481 fn forward(&self, input: &Self::Input) -> Self::Output {
482 input.copy()
483 }
484 }
485
486 fn iqn_model(
487 feature_dim: i64,
488 embed_dim: i64,
489 out_dim: i64,
490 ) -> IqnModel<FeatureExtractor, Merge> {
491 let fe_config = FeatureExtractorConfig {};
492 let m_config = MergeConfig { out_dim };
493 let device = Device::Cpu;
494 let learning_rate = 1e-4;
495
496 let config = IqnModelConfig::default()
497 .feature_dim(feature_dim)
498 .embed_dim(embed_dim)
499 .learning_rate(learning_rate);
500
501 IqnModel::build_with_submodel_configs(config, fe_config, m_config, device)
502 }
503
504 #[test]
505 fn test_iqn_model() {
507 let in_dim = 100;
508 let feature_dim = 100;
509 let embed_dim = 64;
510 let out_dim = 100;
511 let n_percent_points = 8;
512 let batch_size = 32;
513
514 let model = iqn_model(feature_dim, embed_dim, out_dim);
515 let psi = Tensor::rand(&[batch_size, in_dim], tch::kind::FLOAT_CPU);
516 let tau = Tensor::rand(&[batch_size, n_percent_points], tch::kind::FLOAT_CPU);
517 let _q = model.forward(&psi, &tau);
518 }
519}