border_tch_agent/iqn/model/
base.rs

1//! IQN model.
2use super::IqnModelConfig;
3use crate::{
4    model::{ModelBase, SubModel},
5    opt::{Optimizer, OptimizerConfig},
6    util::OutDim,
7};
8use anyhow::{Context, Result};
9use log::{info, trace};
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{default::Default, f64::consts::PI, marker::PhantomData, path::Path};
12use tch::{
13    nn,
14    nn::{Module, VarStore},
15    Device,
16    Kind::Float,
17    Tensor,
18};
19
20#[allow(clippy::upper_case_acronyms)]
21/// Constructs IQN output layer, which takes input features and percent points.
22/// It returns action-value quantiles.
23pub struct IqnModel<F, M>
24where
25    F: SubModel<Output = Tensor>,
26    M: SubModel<Input = Tensor, Output = Tensor>,
27    F::Config: DeserializeOwned + Serialize,
28    M::Config: DeserializeOwned + Serialize,
29{
30    device: Device,
31    var_store: nn::VarStore,
32
33    // Dimension of the input (feature) vector.
34    // The `size()[-1]` of F::Output (Tensor) is feature_dim.
35    feature_dim: i64,
36
37    // Dimension of the cosine embedding vector.
38    embed_dim: i64,
39
40    // Dimension of the output vector (equal to the number of actions).
41    pub(super) out_dim: i64,
42
43    // Feature extractor
44    psi: F,
45
46    // Cos embedding
47    phi: nn::Sequential,
48
49    // Merge network
50    f: M,
51
52    // Optimizer
53    opt_config: OptimizerConfig,
54    opt: Optimizer,
55
56    phantom: PhantomData<(F, M)>,
57}
58
59impl<F, M> IqnModel<F, M>
60where
61    F: SubModel<Output = Tensor>,
62    M: SubModel<Input = Tensor, Output = Tensor>,
63    F::Config: DeserializeOwned + Serialize,
64    M::Config: DeserializeOwned + Serialize + OutDim,
65{
66    /// Constructs [IqnModel].
67    pub fn build(
68        config: IqnModelConfig<F::Config, M::Config>,
69        device: Device,
70    ) -> Result<IqnModel<F, M>> {
71        let f_config = config.f_config.context("f_config is not set.")?;
72        let m_config = config.m_config.context("m_config is not set.")?;
73        let feature_dim = config.feature_dim;
74        let embed_dim = config.embed_dim;
75        let out_dim = m_config.get_out_dim();
76        let opt_config = config.opt_config;
77        let var_store = nn::VarStore::new(device);
78
79        // Feature extractor
80        let psi = F::build(&var_store, f_config);
81
82        // Cosine embedding
83        let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
84
85        // Merge
86        let f = M::build(&var_store, m_config);
87
88        // Optimizer
89        let opt = opt_config.build(&var_store)?;
90
91        // // let mut adam = nn::Adam::default();
92        // // adam.eps = 0.01 / 32.0;
93        // // let opt = adam.build(&var_store, learning_rate).unwrap();
94        // let opt = nn::Adam::default()
95        //     .build(&var_store, learning_rate)
96        //     .unwrap();
97
98        Ok(IqnModel {
99            device,
100            var_store,
101            feature_dim,
102            embed_dim,
103            out_dim,
104            psi,
105            phi,
106            f,
107            opt_config,
108            opt,
109            phantom: PhantomData,
110        })
111    }
112
113    /// Constructs [IqnModel] with the given configurations of sub models.
114    pub fn build_with_submodel_configs(
115        config: IqnModelConfig<F::Config, M::Config>,
116        f_config: F::Config,
117        m_config: M::Config,
118        device: Device,
119    ) -> IqnModel<F, M> {
120        let feature_dim = config.feature_dim;
121        let embed_dim = config.embed_dim;
122        let out_dim = m_config.get_out_dim();
123        let opt_config = config.opt_config.clone();
124        let var_store = nn::VarStore::new(device);
125
126        // Feature extractor
127        let psi = F::build(&var_store, f_config);
128
129        // Cosine embedding
130        let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
131
132        // Merge
133        let f = M::build(&var_store, m_config);
134
135        // Optimizer
136        // TODO: remove unwrap()
137        let opt = opt_config.build(&var_store).unwrap();
138
139        // let mut adam = nn::Adam::default();
140        // adam.eps = 0.01 / 32.0;
141        // let opt = adam.build(&var_store, learning_rate).unwrap();
142        // let opt = nn::Adam::default()
143        //     .build(&var_store, learning_rate)
144        //     .unwrap();
145
146        IqnModel {
147            device,
148            var_store,
149            feature_dim,
150            embed_dim,
151            out_dim,
152            psi,
153            phi,
154            f,
155            opt_config,
156            opt,
157            phantom: PhantomData,
158        }
159    }
160
161    // Cosine embedding.
162    fn cos_embed_nn(var_store: &VarStore, feature_dim: i64, embed_dim: i64) -> nn::Sequential {
163        let p = &var_store.root();
164        let device = p.device();
165        nn::seq()
166            .add_fn(move |tau| {
167                let batch_size = tau.size().as_slice()[0];
168                let n_percent_points = tau.size().as_slice()[1];
169                let tau = tau.unsqueeze(-1);
170                let i = Tensor::range(1, embed_dim, (Float, device))
171                    .unsqueeze(0)
172                    .unsqueeze(0);
173                debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points, 1]);
174                debug_assert_eq!(i.size().as_slice(), &[1, 1, embed_dim]);
175
176                let cos = Tensor::cos(&(tau * (PI * i)));
177                debug_assert_eq!(
178                    cos.size().as_slice(),
179                    &[batch_size, n_percent_points, embed_dim]
180                );
181
182                cos.reshape(&[-1, embed_dim])
183            })
184            .add(nn::linear(
185                p / "iqn_cos_to_feature",
186                embed_dim,
187                feature_dim,
188                Default::default(),
189            ))
190            .add_fn(|x| x.relu())
191    }
192
193    /// Returns the tensor of action-value quantiles.
194    ///
195    /// * The shape of` psi(x)` (feature vector) is [batch_size, feature_dim].
196    /// * The shape of `tau` is [batch_size, n_percent_points].
197    /// * The shape of the output is [batch_size, n_percent_points, self.out_dim].
198    pub fn forward(&self, x: &F::Input, tau: &Tensor) -> Tensor {
199        // Used to check tensor size
200        let feature_dim = self.feature_dim;
201        let n_percent_points = tau.size().as_slice()[1];
202
203        // Feature extraction
204        let psi = self.psi.forward(x);
205        let batch_size = psi.size().as_slice()[0];
206        debug_assert_eq!(psi.size().as_slice(), &[batch_size, feature_dim]);
207
208        // Cosine embedding of percent points, eq. (4) in the paper
209        debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points]);
210        let phi = self.phi.forward(tau);
211        debug_assert_eq!(
212            phi.size().as_slice(),
213            &[batch_size * n_percent_points, self.feature_dim]
214        );
215        let phi = phi.reshape(&[batch_size, n_percent_points, self.feature_dim]);
216
217        // Merge features and embedded quantiles by elem-wise multiplication
218        let psi = psi.unsqueeze(1);
219        debug_assert_eq!(psi.size().as_slice(), &[batch_size, 1, self.feature_dim]);
220        let m = psi * phi;
221        debug_assert_eq!(
222            m.size().as_slice(),
223            &[batch_size, n_percent_points, self.feature_dim]
224        );
225
226        // Action-value
227        let a = self.f.forward(&m);
228        debug_assert_eq!(
229            a.size().as_slice(),
230            &[batch_size, n_percent_points, self.out_dim]
231        );
232
233        a
234    }
235}
236
237impl<F, M> Clone for IqnModel<F, M>
238where
239    F: SubModel<Output = Tensor>,
240    M: SubModel<Input = Tensor, Output = Tensor>,
241    F::Config: DeserializeOwned + Serialize,
242    M::Config: DeserializeOwned + Serialize + OutDim,
243{
244    fn clone(&self) -> Self {
245        let device = self.device;
246        let feature_dim = self.feature_dim;
247        let embed_dim = self.embed_dim;
248        let out_dim = self.out_dim;
249        let opt_config = self.opt_config.clone();
250        let mut var_store = nn::VarStore::new(device);
251
252        // Feature extractor
253        let psi = self.psi.clone_with_var_store(&var_store);
254
255        // Cos-embedding
256        let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
257
258        // Merge
259        let f = self.f.clone_with_var_store(&var_store);
260
261        // Optimizer
262        let opt = opt_config.build(&var_store).unwrap();
263
264        // let mut adam = nn::Adam::default();
265        // adam.eps = 0.01 / 32.0;
266        // let opt = adam.build(&var_store, learning_rate).unwrap();
267        // let opt = nn::Adam::default()
268        //     .build(&var_store, learning_rate)
269        //     .unwrap();
270
271        var_store.copy(&self.var_store).unwrap();
272
273        Self {
274            device,
275            var_store,
276            feature_dim,
277            embed_dim,
278            out_dim,
279            psi,
280            phi,
281            f,
282            opt_config,
283            opt,
284            phantom: PhantomData,
285        }
286    }
287}
288
289impl<F, M> ModelBase for IqnModel<F, M>
290where
291    F: SubModel<Output = Tensor>,
292    M: SubModel<Input = Tensor, Output = Tensor>,
293    F::Config: DeserializeOwned + Serialize,
294    M::Config: DeserializeOwned + Serialize,
295{
296    fn backward_step(&mut self, loss: &Tensor) {
297        self.opt.backward_step(loss);
298    }
299
300    fn get_var_store(&self) -> &nn::VarStore {
301        &self.var_store
302    }
303
304    fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
305        &mut self.var_store
306    }
307
308    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
309        self.var_store.save(&path)?;
310        info!("Save IQN model to {:?}", path.as_ref());
311        let vs = self.var_store.variables();
312        for (name, _) in vs.iter() {
313            trace!("Save variable {}", name);
314        }
315        Ok(())
316    }
317
318    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
319        self.var_store.load(&path)?;
320        info!("Load IQN model from {:?}", path.as_ref());
321        Ok(())
322    }
323}
324
325#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
326/// The way of taking percent points.
327pub enum IqnSample {
328    /// Samples over percent points `0.05:0.1:0.95`.
329    ///
330    /// The precent points are constants.
331    Const10,
332
333    /// The precent points are constants.
334    Const32,
335
336    /// 10 samples from uniform distribution.
337    Uniform10,
338
339    /// 8 samples from uniform distribution.
340    Uniform8,
341
342    /// 32 samples from uniform distribution.
343    Uniform32,
344
345    /// 64 samples from uniform distribution.
346    Uniform64,
347
348    /// Single sample, median.
349    Median,
350}
351
352impl IqnSample {
353    /// Returns samples of percent points.
354    pub fn sample(&self, batch_size: i64) -> Tensor {
355        match self {
356            Self::Const10 => Tensor::from_slice(&[
357                0.05_f32, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95,
358            ])
359            .unsqueeze(0)
360            .repeat(&[batch_size, 1]),
361            Self::Const32 => {
362                let t: Tensor = (1.0 / 32.0) * Tensor::range(0, 32, tch::kind::FLOAT_CPU);
363                t.unsqueeze(0).repeat(&[batch_size, 1])
364            }
365            Self::Uniform10 => Tensor::rand(&[batch_size, 10], tch::kind::FLOAT_CPU),
366            Self::Uniform8 => Tensor::rand(&[batch_size, 8], tch::kind::FLOAT_CPU),
367            Self::Uniform32 => Tensor::rand(&[batch_size, 32], tch::kind::FLOAT_CPU),
368            Self::Uniform64 => Tensor::rand(&[batch_size, 64], tch::kind::FLOAT_CPU),
369            Self::Median => Tensor::from_slice(&[0.5_f32])
370                .unsqueeze(0)
371                .repeat(&[batch_size, 1]),
372        }
373    }
374
375    /// Returns the number of percent points generated by this way.
376    pub fn n_percent_points(&self) -> i64 {
377        match self {
378            Self::Const10 => 10,
379            Self::Const32 => 32,
380            Self::Uniform10 => 10,
381            Self::Uniform8 => 8,
382            Self::Uniform32 => 32,
383            Self::Uniform64 => 64,
384            Self::Median => 1,
385        }
386    }
387}
388
389/// Takes an average over percent points specified by `mode`.
390///
391/// * `obs` - Observations.
392/// * `iqn` - IQN model.
393/// * `mode` - The way of taking percent points.
394pub fn average<F, M>(
395    batch_size: i64,
396    obs: &F::Input,
397    iqn: &IqnModel<F, M>,
398    mode: &IqnSample,
399    device: Device,
400) -> Tensor
401where
402    F: SubModel<Output = Tensor>,
403    M: SubModel<Input = Tensor, Output = Tensor>,
404    F::Config: DeserializeOwned + Serialize,
405    M::Config: DeserializeOwned + Serialize + OutDim,
406{
407    let tau = mode.sample(batch_size).to(device);
408    let averaged_action_value = iqn
409        .forward(obs, &tau)
410        .mean_dim(Some([1].as_slice()), false, Float);
411    let batch_size = averaged_action_value.size()[0];
412    let n_action = iqn.out_dim;
413    debug_assert_eq!(
414        averaged_action_value.size().as_slice(),
415        &[batch_size, n_action]
416    );
417    averaged_action_value
418}
419
420#[cfg(test)]
421mod test {
422    use super::super::IqnModelConfig;
423    use super::*;
424    use crate::util::OutDim;
425    use std::default::Default;
426    use tch::{nn, Device, Tensor};
427
428    #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
429    struct FeatureExtractorConfig {}
430
431    struct FeatureExtractor {}
432
433    impl SubModel for FeatureExtractor {
434        type Config = FeatureExtractorConfig;
435        type Input = Tensor;
436        type Output = Tensor;
437
438        fn clone_with_var_store(&self, _var_store: &nn::VarStore) -> Self {
439            Self {}
440        }
441
442        fn build(_var_store: &VarStore, _config: Self::Config) -> Self {
443            Self {}
444        }
445
446        fn forward(&self, input: &Self::Input) -> Self::Output {
447            input.copy()
448        }
449    }
450
451    #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
452    struct MergeConfig {
453        out_dim: i64,
454    }
455
456    impl OutDim for MergeConfig {
457        fn get_out_dim(&self) -> i64 {
458            self.out_dim
459        }
460
461        fn set_out_dim(&mut self, v: i64) {
462            self.out_dim = v;
463        }
464    }
465
466    struct Merge {}
467
468    impl SubModel for Merge {
469        type Config = MergeConfig;
470        type Input = Tensor;
471        type Output = Tensor;
472
473        fn clone_with_var_store(&self, _var_store: &nn::VarStore) -> Self {
474            Self {}
475        }
476
477        fn build(_var_store: &VarStore, _config: Self::Config) -> Self {
478            Self {}
479        }
480
481        fn forward(&self, input: &Self::Input) -> Self::Output {
482            input.copy()
483        }
484    }
485
486    fn iqn_model(
487        feature_dim: i64,
488        embed_dim: i64,
489        out_dim: i64,
490    ) -> IqnModel<FeatureExtractor, Merge> {
491        let fe_config = FeatureExtractorConfig {};
492        let m_config = MergeConfig { out_dim };
493        let device = Device::Cpu;
494        let learning_rate = 1e-4;
495
496        let config = IqnModelConfig::default()
497            .feature_dim(feature_dim)
498            .embed_dim(embed_dim)
499            .learning_rate(learning_rate);
500
501        IqnModel::build_with_submodel_configs(config, fe_config, m_config, device)
502    }
503
504    #[test]
505    /// Check shape of tensors in IQNModel.
506    fn test_iqn_model() {
507        let in_dim = 100;
508        let feature_dim = 100;
509        let embed_dim = 64;
510        let out_dim = 100;
511        let n_percent_points = 8;
512        let batch_size = 32;
513
514        let model = iqn_model(feature_dim, embed_dim, out_dim);
515        let psi = Tensor::rand(&[batch_size, in_dim], tch::kind::FLOAT_CPU);
516        let tau = Tensor::rand(&[batch_size, n_percent_points], tch::kind::FLOAT_CPU);
517        let _q = model.forward(&psi, &tau);
518    }
519}