border_tch_agent/iqn/model/
config.rs

1//! IQN model.
2use crate::{opt::OptimizerConfig, util::OutDim};
3use anyhow::Result;
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use std::{
6    default::Default,
7    fs::File,
8    io::{BufReader, Write},
9    path::Path,
10};
11
12impl<F, M> IqnModelConfig<F, M>
13where
14    F: DeserializeOwned + Serialize,
15    M: DeserializeOwned + Serialize + OutDim,
16{
17    /// Sets the learning rate.
18    pub fn learning_rate(mut self, v: f64) -> Self {
19        match &self.opt_config {
20            OptimizerConfig::Adam { lr: _ } => self.opt_config = OptimizerConfig::Adam { lr: v },
21            _ => unimplemented!(),
22        };
23        self
24    }
25}
26
27#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
28/// Configuration of [`IqnModel`](super::IqnModel).
29///
30/// The type parameter `F` represents a configuration struct of a feature extractor.
31/// The type parameter `M` represents a configuration struct of a model for merging
32/// cosine-embedded percent points and feature vectors.
33pub struct IqnModelConfig<F, M>
34where
35    M: OutDim,
36{
37    /// Dimension of feature vector.
38    pub feature_dim: i64,
39
40    /// Embedding dimension.
41    pub embed_dim: i64,
42
43    /// Configuration of feature extractor.
44    pub f_config: Option<F>,
45
46    /// Configuration of a model for merging percentils and feature vectors.
47    pub m_config: Option<M>,
48
49    /// Configuration of optimizer.
50    pub opt_config: OptimizerConfig,
51}
52
53impl<F, M> Default for IqnModelConfig<F, M>
54where
55    M: OutDim,
56{
57    fn default() -> Self {
58        Self {
59            feature_dim: 0,
60            embed_dim: 0,
61            f_config: None,
62            m_config: None,
63            opt_config: OptimizerConfig::Adam { lr: 0.0 },
64        }
65    }
66}
67
68impl<F, M> IqnModelConfig<F, M>
69where
70    F: DeserializeOwned + Serialize,
71    M: DeserializeOwned + Serialize + OutDim,
72{
73    /// Sets the dimension of cos-embedding of percent points.
74    pub fn embed_dim(mut self, v: i64) -> Self {
75        self.embed_dim = v;
76        self
77    }
78
79    /// Sets the dimension of feature vectors.
80    pub fn feature_dim(mut self, v: i64) -> Self {
81        self.feature_dim = v;
82        self
83    }
84
85    /// Sets configurations for feature extractor.
86    pub fn f_config(mut self, v: F) -> Self {
87        self.f_config = Some(v);
88        self
89    }
90
91    /// Sets configurations for output model.
92    pub fn m_config(mut self, v: M) -> Self {
93        self.m_config = Some(v);
94        self
95    }
96
97    /// Sets output dimension of the model.
98    pub fn out_dim(mut self, out_dim: i64) -> Self {
99        if self.m_config.is_some() {
100            self.m_config.as_mut().unwrap().set_out_dim(out_dim);
101        }
102        self
103    }
104
105    /// Sets optimizer configuration.
106    pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
107        self.opt_config = v;
108        self
109    }
110
111    /// Constructs [IqnModelConfig] from YAML file.
112    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
113        let file = File::open(path)?;
114        let rdr = BufReader::new(file);
115        let b = serde_yaml::from_reader(rdr)?;
116        Ok(b)
117    }
118
119    /// Saves [IqnModelConfig].
120    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
121        let mut file = File::create(path)?;
122        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
123        Ok(())
124    }
125}