use super::{IqnModelConfig, IqnSample};
use crate::{
iqn::{EpsilonGreedy, IqnExplorer},
model::SubModel,
util::OutDim,
Device,
};
use anyhow::Result;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
default::Default,
fs::File,
io::{BufReader, Write},
marker::PhantomData,
path::Path,
};
#[derive(Debug, Deserialize, Serialize, PartialEq)]
pub struct IqnConfig<F, M>
where
F: SubModel,
M: SubModel,
F::Config: DeserializeOwned + Serialize + Clone,
M::Config: DeserializeOwned + Serialize + Clone + OutDim,
{
pub(super) model_config: IqnModelConfig<F::Config, M::Config>,
pub(super) soft_update_interval: usize,
pub(super) n_updates_per_opt: usize,
pub(super) min_transitions_warmup: usize,
pub(super) batch_size: usize,
pub(super) discount_factor: f64,
pub(super) tau: f64,
pub(super) train: bool,
pub(super) explorer: IqnExplorer,
pub(super) sample_percents_pred: IqnSample,
pub(super) sample_percents_tgt: IqnSample,
pub(super) sample_percents_act: IqnSample,
pub device: Option<Device>,
phantom: PhantomData<(F, M)>,
}
impl<F, M> Default for IqnConfig<F, M>
where
F: SubModel,
M: SubModel,
F::Config: DeserializeOwned + Serialize + Clone,
M::Config: DeserializeOwned + Serialize + Clone + OutDim,
{
fn default() -> Self {
Self {
model_config: Default::default(),
soft_update_interval: 1,
n_updates_per_opt: 1,
min_transitions_warmup: 1,
batch_size: 1,
discount_factor: 0.99,
tau: 0.005,
sample_percents_pred: IqnSample::Uniform64,
sample_percents_tgt: IqnSample::Uniform64,
sample_percents_act: IqnSample::Uniform32, train: false,
explorer: IqnExplorer::EpsilonGreedy(EpsilonGreedy::default()),
device: None,
phantom: PhantomData,
}
}
}
impl<F, M> IqnConfig<F, M>
where
F: SubModel,
M: SubModel,
F::Config: DeserializeOwned + Serialize + Clone,
M::Config: DeserializeOwned + Serialize + Clone + OutDim,
{
pub fn model_config(mut self, model_config: IqnModelConfig<F::Config, M::Config>) -> Self {
self.model_config = model_config;
self
}
pub fn soft_update_interval(mut self, v: usize) -> Self {
self.soft_update_interval = v;
self
}
pub fn n_updates_per_opt(mut self, v: usize) -> Self {
self.n_updates_per_opt = v;
self
}
pub fn min_transitions_warmup(mut self, v: usize) -> Self {
self.min_transitions_warmup = v;
self
}
pub fn batch_size(mut self, v: usize) -> Self {
self.batch_size = v;
self
}
pub fn discount_factor(mut self, v: f64) -> Self {
self.discount_factor = v;
self
}
pub fn tau(mut self, v: f64) -> Self {
self.tau = v;
self
}
pub fn explorer(mut self, v: IqnExplorer) -> Self {
self.explorer = v;
self
}
pub fn out_dim(mut self, out_dim: i64) -> Self {
let model_config = self.model_config.clone();
self.model_config = model_config.out_dim(out_dim);
self
}
pub fn sample_percent_pred(mut self, v: IqnSample) -> Self {
self.sample_percents_pred = v;
self
}
pub fn sample_percent_tgt(mut self, v: IqnSample) -> Self {
self.sample_percents_tgt = v;
self
}
pub fn sample_percent_act(mut self, v: IqnSample) -> Self {
self.sample_percents_act = v;
self
}
pub fn device(mut self, device: tch::Device) -> Self {
self.device = Some(device.into());
self
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path)?;
let rdr = BufReader::new(file);
let b = serde_yaml::from_reader(rdr)?;
Ok(b)
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let mut file = File::create(path)?;
file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
Ok(())
}
}
impl<F, M> Clone for IqnConfig<F, M>
where
F: SubModel,
M: SubModel,
F::Config: DeserializeOwned + Serialize + Clone,
M::Config: DeserializeOwned + Serialize + Clone + OutDim,
{
fn clone(&self) -> Self {
Self {
model_config: self.model_config.clone(),
soft_update_interval: self.soft_update_interval,
n_updates_per_opt: self.n_updates_per_opt,
min_transitions_warmup: self.min_transitions_warmup,
batch_size: self.batch_size,
discount_factor: self.discount_factor,
tau: self.tau,
sample_percents_pred: self.sample_percents_pred.clone(),
sample_percents_tgt: self.sample_percents_tgt.clone(),
sample_percents_act: self.sample_percents_act.clone(),
train: self.train,
explorer: self.explorer.clone(),
device: self.device.clone(),
phantom: PhantomData,
}
}
}