use super::IqnModelConfig;
use crate::{
model::{ModelBase, SubModel},
opt::{Optimizer, OptimizerConfig},
util::OutDim,
};
use anyhow::{Context, Result};
use log::{info, trace};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{default::Default, f64::consts::PI, marker::PhantomData, path::Path};
use tch::{
nn,
nn::{Module, VarStore},
Device,
Kind::Float,
Tensor,
};
#[allow(clippy::upper_case_acronyms)]
pub struct IqnModel<F, M>
where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize,
{
device: Device,
var_store: nn::VarStore,
feature_dim: i64,
embed_dim: i64,
pub(super) out_dim: i64,
psi: F,
phi: nn::Sequential,
f: M,
opt_config: OptimizerConfig,
opt: Optimizer,
phantom: PhantomData<(F, M)>,
}
impl<F, M> IqnModel<F, M>
where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize + OutDim,
{
pub fn build(
config: IqnModelConfig<F::Config, M::Config>,
device: Device,
) -> Result<IqnModel<F, M>> {
let f_config = config.f_config.context("f_config is not set.")?;
let m_config = config.m_config.context("m_config is not set.")?;
let feature_dim = config.feature_dim;
let embed_dim = config.embed_dim;
let out_dim = m_config.get_out_dim();
let opt_config = config.opt_config;
let var_store = nn::VarStore::new(device);
let psi = F::build(&var_store, f_config);
let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
let f = M::build(&var_store, m_config);
let opt = opt_config.build(&var_store)?;
Ok(IqnModel {
device,
var_store,
feature_dim,
embed_dim,
out_dim,
psi,
phi,
f,
opt_config,
opt,
phantom: PhantomData,
})
}
pub fn build_with_submodel_configs(
config: IqnModelConfig<F::Config, M::Config>,
f_config: F::Config,
m_config: M::Config,
device: Device,
) -> IqnModel<F, M> {
let feature_dim = config.feature_dim;
let embed_dim = config.embed_dim;
let out_dim = m_config.get_out_dim();
let opt_config = config.opt_config.clone();
let var_store = nn::VarStore::new(device);
let psi = F::build(&var_store, f_config);
let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
let f = M::build(&var_store, m_config);
let opt = opt_config.build(&var_store).unwrap();
IqnModel {
device,
var_store,
feature_dim,
embed_dim,
out_dim,
psi,
phi,
f,
opt_config,
opt,
phantom: PhantomData,
}
}
fn cos_embed_nn(var_store: &VarStore, feature_dim: i64, embed_dim: i64) -> nn::Sequential {
let p = &var_store.root();
let device = p.device();
nn::seq()
.add_fn(move |tau| {
let batch_size = tau.size().as_slice()[0];
let n_percent_points = tau.size().as_slice()[1];
let tau = tau.unsqueeze(-1);
let i = Tensor::range(1, embed_dim, (Float, device))
.unsqueeze(0)
.unsqueeze(0);
debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points, 1]);
debug_assert_eq!(i.size().as_slice(), &[1, 1, embed_dim]);
let cos = Tensor::cos(&(tau * (PI * i)));
debug_assert_eq!(
cos.size().as_slice(),
&[batch_size, n_percent_points, embed_dim]
);
cos.reshape(&[-1, embed_dim])
})
.add(nn::linear(
p / "iqn_cos_to_feature",
embed_dim,
feature_dim,
Default::default(),
))
.add_fn(|x| x.relu())
}
pub fn forward(&self, x: &F::Input, tau: &Tensor) -> Tensor {
let feature_dim = self.feature_dim;
let n_percent_points = tau.size().as_slice()[1];
let psi = self.psi.forward(x);
let batch_size = psi.size().as_slice()[0];
debug_assert_eq!(psi.size().as_slice(), &[batch_size, feature_dim]);
debug_assert_eq!(tau.size().as_slice(), &[batch_size, n_percent_points]);
let phi = self.phi.forward(tau);
debug_assert_eq!(
phi.size().as_slice(),
&[batch_size * n_percent_points, self.feature_dim]
);
let phi = phi.reshape(&[batch_size, n_percent_points, self.feature_dim]);
let psi = psi.unsqueeze(1);
debug_assert_eq!(psi.size().as_slice(), &[batch_size, 1, self.feature_dim]);
let m = psi * phi;
debug_assert_eq!(
m.size().as_slice(),
&[batch_size, n_percent_points, self.feature_dim]
);
let a = self.f.forward(&m);
debug_assert_eq!(
a.size().as_slice(),
&[batch_size, n_percent_points, self.out_dim]
);
a
}
}
impl<F, M> Clone for IqnModel<F, M>
where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize + OutDim,
{
fn clone(&self) -> Self {
let device = self.device;
let feature_dim = self.feature_dim;
let embed_dim = self.embed_dim;
let out_dim = self.out_dim;
let opt_config = self.opt_config.clone();
let mut var_store = nn::VarStore::new(device);
let psi = self.psi.clone_with_var_store(&var_store);
let phi = IqnModel::<F, M>::cos_embed_nn(&var_store, feature_dim, embed_dim);
let f = self.f.clone_with_var_store(&var_store);
let opt = opt_config.build(&var_store).unwrap();
var_store.copy(&self.var_store).unwrap();
Self {
device,
var_store,
feature_dim,
embed_dim,
out_dim,
psi,
phi,
f,
opt_config,
opt,
phantom: PhantomData,
}
}
}
impl<F, M> ModelBase for IqnModel<F, M>
where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize,
{
fn backward_step(&mut self, loss: &Tensor) {
self.opt.backward_step(loss);
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
self.var_store.save(&path)?;
info!("Save IQN model to {:?}", path.as_ref());
let vs = self.var_store.variables();
for (name, _) in vs.iter() {
trace!("Save variable {}", name);
}
Ok(())
}
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
self.var_store.load(&path)?;
info!("Load IQN model from {:?}", path.as_ref());
Ok(())
}
}
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub enum IqnSample {
Const10,
Uniform10,
Uniform8,
Uniform32,
Uniform64,
Median,
}
impl IqnSample {
pub fn sample(&self, batch_size: i64) -> Tensor {
match self {
Self::Const10 => Tensor::of_slice(&[
0.05_f32, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95,
])
.unsqueeze(0)
.repeat(&[batch_size, 1]),
Self::Uniform10 => Tensor::rand(&[batch_size, 10], tch::kind::FLOAT_CPU),
Self::Uniform8 => Tensor::rand(&[batch_size, 8], tch::kind::FLOAT_CPU),
Self::Uniform32 => Tensor::rand(&[batch_size, 32], tch::kind::FLOAT_CPU),
Self::Uniform64 => Tensor::rand(&[batch_size, 64], tch::kind::FLOAT_CPU),
Self::Median => Tensor::of_slice(&[0.5_f32])
.unsqueeze(0)
.repeat(&[batch_size, 1]),
}
}
pub fn n_percent_points(&self) -> i64 {
match self {
Self::Const10 => 10,
Self::Uniform10 => 10,
Self::Uniform8 => 8,
Self::Uniform32 => 32,
Self::Uniform64 => 64,
Self::Median => 1,
}
}
}
pub fn average<F, M>(
batch_size: i64,
obs: &F::Input,
iqn: &IqnModel<F, M>,
mode: &IqnSample,
device: Device,
) -> Tensor
where
F: SubModel<Output = Tensor>,
M: SubModel<Input = Tensor, Output = Tensor>,
F::Config: DeserializeOwned + Serialize,
M::Config: DeserializeOwned + Serialize + OutDim,
{
let tau = mode.sample(batch_size).to(device);
let averaged_action_value = iqn.forward(obs, &tau).mean_dim(&[1], false, Float);
let batch_size = averaged_action_value.size()[0];
let n_action = iqn.out_dim;
debug_assert_eq!(
averaged_action_value.size().as_slice(),
&[batch_size, n_action]
);
averaged_action_value
}
#[cfg(test)]
mod test {
use super::super::IqnModelConfig;
use super::*;
use crate::util::OutDim;
use std::default::Default;
use tch::{nn, Device, Tensor};
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
struct FeatureExtractorConfig {}
struct FeatureExtractor {}
impl SubModel for FeatureExtractor {
type Config = FeatureExtractorConfig;
type Input = Tensor;
type Output = Tensor;
fn clone_with_var_store(&self, _var_store: &nn::VarStore) -> Self {
Self {}
}
fn build(_var_store: &VarStore, _config: Self::Config) -> Self {
Self {}
}
fn forward(&self, input: &Self::Input) -> Self::Output {
input.copy()
}
}
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
struct MergeConfig {
out_dim: i64,
}
impl OutDim for MergeConfig {
fn get_out_dim(&self) -> i64 {
self.out_dim
}
fn set_out_dim(&mut self, v: i64) {
self.out_dim = v;
}
}
struct Merge {}
impl SubModel for Merge {
type Config = MergeConfig;
type Input = Tensor;
type Output = Tensor;
fn clone_with_var_store(&self, _var_store: &nn::VarStore) -> Self {
Self {}
}
fn build(_var_store: &VarStore, _config: Self::Config) -> Self {
Self {}
}
fn forward(&self, input: &Self::Input) -> Self::Output {
input.copy()
}
}
fn iqn_model(
feature_dim: i64,
embed_dim: i64,
out_dim: i64,
) -> IqnModel<FeatureExtractor, Merge> {
let fe_config = FeatureExtractorConfig {};
let m_config = MergeConfig { out_dim };
let device = Device::Cpu;
let learning_rate = 1e-4;
let config = IqnModelConfig::default()
.feature_dim(feature_dim)
.embed_dim(embed_dim)
.learning_rate(learning_rate);
IqnModel::build_with_submodel_configs(config, fe_config, m_config, device)
}
#[test]
fn test_iqn_model() {
let in_dim = 100;
let feature_dim = 100;
let embed_dim = 64;
let out_dim = 100;
let n_percent_points = 8;
let batch_size = 32;
let model = iqn_model(feature_dim, embed_dim, out_dim);
let psi = Tensor::rand(&[batch_size, in_dim], tch::kind::FLOAT_CPU);
let tau = Tensor::rand(&[batch_size, n_percent_points], tch::kind::FLOAT_CPU);
let _q = model.forward(&psi, &tau);
}
}