use super::{
explorer::{DQNExplorer, Softmax},
DQNModelConfig,
};
use crate::{model::SubModel, util::OutDim, Device};
use anyhow::Result;
use log::info;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
default::Default,
fs::File,
io::{BufReader, Write},
marker::PhantomData,
path::Path,
};
use tch::Tensor;
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Deserialize, Serialize, PartialEq)]
pub struct DQNConfig<Q>
where
Q: SubModel<Output = Tensor>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
{
pub(super) model_config: DQNModelConfig<Q::Config>,
pub(super) soft_update_interval: usize,
pub(super) n_updates_per_opt: usize,
pub(super) min_transitions_warmup: usize,
pub(super) batch_size: usize,
pub(super) discount_factor: f64,
pub(super) tau: f64,
pub(super) train: bool,
pub(super) explorer: DQNExplorer,
#[serde(default)]
pub(super) clip_reward: Option<f64>,
#[serde(default)]
pub(super) double_dqn: bool,
pub(super) clip_td_err: Option<(f64, f64)>,
pub device: Option<Device>,
phantom: PhantomData<Q>,
}
impl<Q> Clone for DQNConfig<Q>
where
Q: SubModel<Output = Tensor>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
{
fn clone(&self) -> Self {
Self {
model_config: self.model_config.clone(),
soft_update_interval: self.soft_update_interval,
n_updates_per_opt: self.n_updates_per_opt,
min_transitions_warmup: self.min_transitions_warmup,
batch_size: self.batch_size,
discount_factor: self.discount_factor,
tau: self.tau,
train: self.train,
explorer: self.explorer.clone(),
clip_reward: self.clip_reward,
double_dqn: self.double_dqn,
clip_td_err: self.clip_td_err,
device: self.device.clone(),
phantom: PhantomData,
}
}
}
impl<Q> Default for DQNConfig<Q>
where
Q: SubModel<Output = Tensor>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
{
fn default() -> Self {
Self {
model_config: Default::default(),
soft_update_interval: 1,
n_updates_per_opt: 1,
min_transitions_warmup: 1,
batch_size: 1,
discount_factor: 0.99,
tau: 0.005,
train: false,
explorer: DQNExplorer::Softmax(Softmax::new()),
clip_reward: None,
double_dqn: false,
clip_td_err: None,
device: None,
phantom: PhantomData,
}
}
}
impl<Q> DQNConfig<Q>
where
Q: SubModel<Output = Tensor>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
{
pub fn soft_update_interval(mut self, v: usize) -> Self {
self.soft_update_interval = v;
self
}
pub fn n_updates_per_opt(mut self, v: usize) -> Self {
self.n_updates_per_opt = v;
self
}
pub fn min_transitions_warmup(mut self, v: usize) -> Self {
self.min_transitions_warmup = v;
self
}
pub fn batch_size(mut self, v: usize) -> Self {
self.batch_size = v;
self
}
pub fn discount_factor(mut self, v: f64) -> Self {
self.discount_factor = v;
self
}
pub fn tau(mut self, v: f64) -> Self {
self.tau = v;
self
}
pub fn explorer(mut self, v: DQNExplorer) -> Self {
self.explorer = v;
self
}
pub fn model_config(mut self, model_config: DQNModelConfig<Q::Config>) -> Self {
self.model_config = model_config;
self
}
pub fn out_dim(mut self, out_dim: i64) -> Self {
let model_config = self.model_config.clone();
self.model_config = model_config.out_dim(out_dim);
self
}
pub fn clip_reward(mut self, clip_reward: Option<f64>) -> Self {
self.clip_reward = clip_reward;
self
}
pub fn double_dqn(mut self, double_dqn: bool) -> Self {
self.double_dqn = double_dqn;
self
}
pub fn clip_td_err(mut self, clip_td_err: Option<(f64, f64)>) -> Self {
self.clip_td_err = clip_td_err;
self
}
pub fn device(mut self, device: tch::Device) -> Self {
self.device = Some(device.into());
self
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path_ = path.as_ref().to_owned();
let file = File::open(path)?;
let rdr = BufReader::new(file);
let b = serde_yaml::from_reader(rdr)?;
info!("Load config of DQN agent from {}", path_.to_str().unwrap());
Ok(b)
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let path_ = path.as_ref().to_owned();
let mut file = File::create(path)?;
file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
info!("Save config of DQN agent into {}", path_.to_str().unwrap());
Ok(())
}
}