border_tch_agent/sac/
config.rs

1//! Configuration of SAC agent.
2use super::{ActorConfig, CriticConfig};
3use crate::{
4    model::{SubModel, SubModel2},
5    sac::ent_coef::EntCoefMode,
6    util::CriticLoss,
7    util::OutDim,
8    Device,
9};
10use anyhow::Result;
11use log::info;
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use std::{
14    fmt::Debug,
15    fs::File,
16    io::{BufReader, Write},
17    path::Path,
18};
19use tch::Tensor;
20
21/// Configuration of [`Sac`](super::Sac).
22#[derive(Debug, Deserialize, Serialize, PartialEq)]
23pub struct SacConfig<Q, P>
24where
25    Q: SubModel2<Output = Tensor>,
26    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
27    P: SubModel<Output = (Tensor, Tensor)>,
28    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
29{
30    pub actor_config: ActorConfig<P::Config>,
31    pub critic_config: CriticConfig<Q::Config>,
32    pub gamma: f64,
33    pub tau: f64,
34    pub ent_coef_mode: EntCoefMode,
35    pub epsilon: f64,
36    pub min_lstd: f64,
37    pub max_lstd: f64,
38    pub n_updates_per_opt: usize,
39    pub batch_size: usize,
40    pub train: bool,
41    pub critic_loss: CriticLoss,
42    pub reward_scale: f32,
43    pub n_critics: usize,
44    pub seed: Option<i64>,
45    pub device: Option<Device>,
46    // expr_sampling: ExperienceSampling,
47}
48
49impl<Q, P> Clone for SacConfig<Q, P>
50where
51    Q: SubModel2<Output = Tensor>,
52    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
53    P: SubModel<Output = (Tensor, Tensor)>,
54    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
55{
56    fn clone(&self) -> Self {
57        Self {
58            actor_config: self.actor_config.clone(),
59            critic_config: self.critic_config.clone(),
60            gamma: self.gamma.clone(),
61            tau: self.tau.clone(),
62            ent_coef_mode: self.ent_coef_mode.clone(),
63            epsilon: self.epsilon.clone(),
64            min_lstd: self.min_lstd.clone(),
65            max_lstd: self.max_lstd.clone(),
66            n_updates_per_opt: self.n_updates_per_opt.clone(),
67            batch_size: self.batch_size.clone(),
68            train: self.train.clone(),
69            critic_loss: self.critic_loss.clone(),
70            reward_scale: self.reward_scale.clone(),
71            n_critics: self.n_critics.clone(),
72            seed: self.seed.clone(),
73            device: self.device.clone(),
74        }
75    }
76}
77
78impl<Q, P> Default for SacConfig<Q, P>
79where
80    Q: SubModel2<Output = Tensor>,
81    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
82    P: SubModel<Output = (Tensor, Tensor)>,
83    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
84{
85    fn default() -> Self {
86        Self {
87            actor_config: Default::default(),
88            critic_config: Default::default(),
89            gamma: 0.99,
90            tau: 0.005,
91            ent_coef_mode: EntCoefMode::Fix(1.0),
92            epsilon: 1e-4,
93            min_lstd: -20.0,
94            max_lstd: 2.0,
95            n_updates_per_opt: 1,
96            batch_size: 1,
97            train: false,
98            critic_loss: CriticLoss::Mse,
99            reward_scale: 1.0,
100            n_critics: 1,
101            seed: None,
102            device: None,
103            // expr_sampling: ExperienceSampling::Uniform,
104        }
105    }
106}
107
108impl<Q, P> SacConfig<Q, P>
109where
110    Q: SubModel2<Output = Tensor>,
111    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
112    P: SubModel<Output = (Tensor, Tensor)>,
113    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
114{
115    /// Sets the numper of parameter update steps per optimization step.
116    pub fn n_updates_per_opt(mut self, v: usize) -> Self {
117        self.n_updates_per_opt = v;
118        self
119    }
120
121    /// Batch size.
122    pub fn batch_size(mut self, v: usize) -> Self {
123        self.batch_size = v;
124        self
125    }
126
127    /// Discount factor.
128    pub fn discount_factor(mut self, v: f64) -> Self {
129        self.gamma = v;
130        self
131    }
132
133    /// Sets soft update coefficient.
134    pub fn tau(mut self, v: f64) -> Self {
135        self.tau = v;
136        self
137    }
138
139    /// SAC-alpha.
140    pub fn ent_coef_mode(mut self, v: EntCoefMode) -> Self {
141        self.ent_coef_mode = v;
142        self
143    }
144
145    /// Reward scale.
146    ///
147    /// It works for obtaining target values, not the values in logs.
148    pub fn reward_scale(mut self, v: f32) -> Self {
149        self.reward_scale = v;
150        self
151    }
152
153    /// Critic loss.
154    pub fn critic_loss(mut self, v: CriticLoss) -> Self {
155        self.critic_loss = v;
156        self
157    }
158
159    /// Configuration of actor.
160    pub fn actor_config(mut self, actor_config: ActorConfig<P::Config>) -> Self {
161        self.actor_config = actor_config;
162        self
163    }
164
165    /// Configuration of critic.
166    pub fn critic_config(mut self, critic_config: CriticConfig<Q::Config>) -> Self {
167        self.critic_config = critic_config;
168        self
169    }
170
171    /// The number of critics.
172    pub fn n_critics(mut self, n_critics: usize) -> Self {
173        self.n_critics = n_critics;
174        self
175    }
176
177    /// Random seed.
178    pub fn seed(mut self, seed: i64) -> Self {
179        self.seed = Some(seed);
180        self
181    }
182
183    /// Device.
184    pub fn device(mut self, device: tch::Device) -> Self {
185        self.device = Some(device.into());
186        self
187    }
188
189    /// Constructs [SacConfig] from YAML file.
190    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
191        let path_ = path.as_ref().to_owned();
192        let file = File::open(path)?;
193        let rdr = BufReader::new(file);
194        let b = serde_yaml::from_reader(rdr)?;
195        info!("Load config of SAC agent from {}", path_.to_str().unwrap());
196        Ok(b)
197    }
198
199    /// Saves [SacConfig].
200    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
201        let path_ = path.as_ref().to_owned();
202        let mut file = File::create(path)?;
203        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
204        info!("Save config of SAC agent into {}", path_.to_str().unwrap());
205        Ok(())
206    }
207}