1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
//! IQN model.
use crate::{
opt::OptimizerConfig,
util::OutDim,
};
use anyhow::Result;
use serde::{Deserialize, de::DeserializeOwned, Serialize};
use std::{
default::Default,
fs::File,
io::{BufReader, Write},
path::Path,
};
#[cfg(not(feature = "adam_eps"))]
impl<F, M> IqnModelConfig<F, M>
where
F: DeserializeOwned + Serialize,
M: DeserializeOwned + Serialize + OutDim,
{
/// Sets the learning rate.
pub fn learning_rate(mut self, v: f64) -> Self {
match &self.opt_config {
OptimizerConfig::Adam { lr: _ } => self.opt_config = OptimizerConfig::Adam { lr: v },
};
self
}
}
// #[cfg(feature = "adam_eps")]
// impl<F: SubModel, M: SubModel> IqnModelConfig<F, M>
// where
// F::Config: DeserializeOwned + Serialize,
// M::Config: DeserializeOwned + Serialize,
// {
// /// Sets the learning rate.
// pub fn learning_rate(mut self, v: f64) -> Self {
// match &self.opt_config {
// OptimizerConfig::Adam { lr: _ } => self.opt_config = OptimizerConfig::Adam { lr: v },
// _ => unimplemented!(),
// };
// self
// }
// }
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
/// Configuration of [IqnModel](super::IqnModel).
///
/// The type parameter `F` represents a configuration struct of a feature extractor.
/// The type parameter `M` represents a configuration struct of a model for merging
/// cosine-embedded percent points and feature vectors.
pub struct IqnModelConfig<F, M>
where
M: OutDim,
{
/// Dimension of feature vector.
pub feature_dim: i64,
/// Embedding dimension.
pub embed_dim: i64,
/// Configuration of feature extractor.
pub f_config: Option<F>,
/// Configuration of a model for merging percentils and feature vectors.
pub m_config: Option<M>,
/// Configuration of optimizer.
pub opt_config: OptimizerConfig,
}
impl<F, M> Default for IqnModelConfig<F, M>
where
M: OutDim,
{
fn default() -> Self {
Self {
feature_dim: 0,
embed_dim: 0,
f_config: None,
m_config: None,
opt_config: OptimizerConfig::Adam { lr: 0.0 },
}
}
}
impl<F, M> IqnModelConfig<F, M>
where
F: DeserializeOwned + Serialize,
M: DeserializeOwned + Serialize + OutDim,
{
/// Sets the dimension of cos-embedding of percent points.
pub fn embed_dim(mut self, v: i64) -> Self {
self.embed_dim = v;
self
}
/// Sets the dimension of feature vectors.
pub fn feature_dim(mut self, v: i64) -> Self {
self.feature_dim = v;
self
}
/// Sets configurations for feature extractor.
pub fn f_config(mut self, v: F) -> Self {
self.f_config = Some(v);
self
}
/// Sets configurations for output model.
pub fn m_config(mut self, v: M) -> Self {
self.m_config = Some(v);
self
}
/// Sets output dimension of the model.
pub fn out_dim(mut self, out_dim: i64) -> Self {
if self.m_config.is_some() {
self.m_config.as_mut().unwrap().set_out_dim(out_dim);
}
self
}
/// Sets optimizer configuration.
pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
self.opt_config = v;
self
}
/// Constructs [IqnModelConfig] from YAML file.
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path)?;
let rdr = BufReader::new(file);
let b = serde_yaml::from_reader(rdr)?;
Ok(b)
}
/// Saves [IqnModelConfig].
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let mut file = File::create(path)?;
file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
Ok(())
}
}