#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DistillationType {
#[default]
Mf,
Lsd,
Psd,
}
impl DistillationType {
pub fn from_str(s: &str) -> Self {
match s {
"lsd" => Self::Lsd,
"psd" => Self::Psd,
_ => Self::Mf,
}
}
}
#[derive(Debug, Clone)]
pub struct RlSpec {
pub state_dim: usize,
pub action_dim: usize,
pub batch: usize,
pub hidden: Vec<usize>,
pub gamma: f32,
pub eta: f32,
pub eta_beta: f32,
pub eta_kappa: f32,
pub actor_lr: f32,
pub critic_lr: f32,
pub tau: f32,
pub fmq_alpha: f32,
pub fmq_sigma_sq: f32,
pub fmq_eta_override: Option<f32>,
pub fmq_adaptive_eta: bool,
pub fmq_beta: f32,
pub fmq_grad_at_online: bool,
pub fmq_normalize_grad: bool,
pub actor_num_samples: usize,
pub action_clip: f32,
pub flow_map_warmup_steps: usize,
pub flow_map_anneal_end_step: usize,
pub esd_warmup_steps: usize,
pub esd_anneal_end_step: usize,
pub distillation_type: DistillationType,
pub esd_weight: f32,
pub diag_weight: f32,
pub qgbs_eta: f32,
}
impl RlSpec {
pub fn toy(batch: usize) -> Self {
Self {
state_dim: 4,
action_dim: 2,
batch,
hidden: vec![64, 64],
gamma: 0.99,
eta: 0.5,
eta_beta: 0.3,
eta_kappa: 1e-4,
actor_lr: 3e-4,
critic_lr: 3e-4,
tau: 0.005,
fmq_alpha: 1.0,
fmq_sigma_sq: 1.0,
fmq_eta_override: None,
fmq_adaptive_eta: false,
fmq_beta: 0.3,
fmq_grad_at_online: false,
fmq_normalize_grad: true,
actor_num_samples: 32,
action_clip: 1.0,
flow_map_warmup_steps: 5,
flow_map_anneal_end_step: 50,
esd_warmup_steps: 0,
esd_anneal_end_step: 50,
distillation_type: DistillationType::Mf,
esd_weight: 0.0,
diag_weight: 0.0,
qgbs_eta: 0.3,
}
}
pub fn fmq_eta(&self) -> f32 {
if let Some(e) = self.fmq_eta_override {
return e;
}
self.fmq_sigma_sq / (2.0 * self.fmq_alpha)
}
pub fn with_batch(&self, batch: usize) -> Self {
let mut s = self.clone();
s.batch = batch;
s
}
pub fn actor_in_dim(&self) -> usize {
self.state_dim + self.action_dim + 2
}
pub fn critic_in_dim(&self) -> usize {
self.state_dim + self.action_dim
}
}