1use super::{ActorConfig, CriticConfig};
3use crate::{
4 model::{SubModel, SubModel2},
5 sac::ent_coef::EntCoefMode,
6 util::CriticLoss,
7 util::OutDim,
8 Device,
9};
10use anyhow::Result;
11use log::info;
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13use std::{
14 fmt::Debug,
15 fs::File,
16 io::{BufReader, Write},
17 path::Path,
18};
19use tch::Tensor;
20
21#[derive(Debug, Deserialize, Serialize, PartialEq)]
23pub struct SacConfig<Q, P>
24where
25 Q: SubModel2<Output = Tensor>,
26 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
27 P: SubModel<Output = (Tensor, Tensor)>,
28 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
29{
30 pub actor_config: ActorConfig<P::Config>,
31 pub critic_config: CriticConfig<Q::Config>,
32 pub gamma: f64,
33 pub tau: f64,
34 pub ent_coef_mode: EntCoefMode,
35 pub epsilon: f64,
36 pub min_lstd: f64,
37 pub max_lstd: f64,
38 pub n_updates_per_opt: usize,
39 pub batch_size: usize,
40 pub train: bool,
41 pub critic_loss: CriticLoss,
42 pub reward_scale: f32,
43 pub n_critics: usize,
44 pub seed: Option<i64>,
45 pub device: Option<Device>,
46 }
48
49impl<Q, P> Clone for SacConfig<Q, P>
50where
51 Q: SubModel2<Output = Tensor>,
52 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
53 P: SubModel<Output = (Tensor, Tensor)>,
54 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
55{
56 fn clone(&self) -> Self {
57 Self {
58 actor_config: self.actor_config.clone(),
59 critic_config: self.critic_config.clone(),
60 gamma: self.gamma.clone(),
61 tau: self.tau.clone(),
62 ent_coef_mode: self.ent_coef_mode.clone(),
63 epsilon: self.epsilon.clone(),
64 min_lstd: self.min_lstd.clone(),
65 max_lstd: self.max_lstd.clone(),
66 n_updates_per_opt: self.n_updates_per_opt.clone(),
67 batch_size: self.batch_size.clone(),
68 train: self.train.clone(),
69 critic_loss: self.critic_loss.clone(),
70 reward_scale: self.reward_scale.clone(),
71 n_critics: self.n_critics.clone(),
72 seed: self.seed.clone(),
73 device: self.device.clone(),
74 }
75 }
76}
77
78impl<Q, P> Default for SacConfig<Q, P>
79where
80 Q: SubModel2<Output = Tensor>,
81 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
82 P: SubModel<Output = (Tensor, Tensor)>,
83 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
84{
85 fn default() -> Self {
86 Self {
87 actor_config: Default::default(),
88 critic_config: Default::default(),
89 gamma: 0.99,
90 tau: 0.005,
91 ent_coef_mode: EntCoefMode::Fix(1.0),
92 epsilon: 1e-4,
93 min_lstd: -20.0,
94 max_lstd: 2.0,
95 n_updates_per_opt: 1,
96 batch_size: 1,
97 train: false,
98 critic_loss: CriticLoss::Mse,
99 reward_scale: 1.0,
100 n_critics: 1,
101 seed: None,
102 device: None,
103 }
105 }
106}
107
108impl<Q, P> SacConfig<Q, P>
109where
110 Q: SubModel2<Output = Tensor>,
111 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
112 P: SubModel<Output = (Tensor, Tensor)>,
113 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
114{
115 pub fn n_updates_per_opt(mut self, v: usize) -> Self {
117 self.n_updates_per_opt = v;
118 self
119 }
120
121 pub fn batch_size(mut self, v: usize) -> Self {
123 self.batch_size = v;
124 self
125 }
126
127 pub fn discount_factor(mut self, v: f64) -> Self {
129 self.gamma = v;
130 self
131 }
132
133 pub fn tau(mut self, v: f64) -> Self {
135 self.tau = v;
136 self
137 }
138
139 pub fn ent_coef_mode(mut self, v: EntCoefMode) -> Self {
141 self.ent_coef_mode = v;
142 self
143 }
144
145 pub fn reward_scale(mut self, v: f32) -> Self {
149 self.reward_scale = v;
150 self
151 }
152
153 pub fn critic_loss(mut self, v: CriticLoss) -> Self {
155 self.critic_loss = v;
156 self
157 }
158
159 pub fn actor_config(mut self, actor_config: ActorConfig<P::Config>) -> Self {
161 self.actor_config = actor_config;
162 self
163 }
164
165 pub fn critic_config(mut self, critic_config: CriticConfig<Q::Config>) -> Self {
167 self.critic_config = critic_config;
168 self
169 }
170
171 pub fn n_critics(mut self, n_critics: usize) -> Self {
173 self.n_critics = n_critics;
174 self
175 }
176
177 pub fn seed(mut self, seed: i64) -> Self {
179 self.seed = Some(seed);
180 self
181 }
182
183 pub fn device(mut self, device: tch::Device) -> Self {
185 self.device = Some(device.into());
186 self
187 }
188
189 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
191 let path_ = path.as_ref().to_owned();
192 let file = File::open(path)?;
193 let rdr = BufReader::new(file);
194 let b = serde_yaml::from_reader(rdr)?;
195 info!("Load config of SAC agent from {}", path_.to_str().unwrap());
196 Ok(b)
197 }
198
199 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
201 let path_ = path.as_ref().to_owned();
202 let mut file = File::create(path)?;
203 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
204 info!("Save config of SAC agent into {}", path_.to_str().unwrap());
205 Ok(())
206 }
207}