border_candle_agent/awac/
config.rs1use crate::{
3 model::{SubModel1, SubModel2},
4 util::{actor::GaussianActorConfig, critic::MultiCriticConfig, CriticLoss, OutDim},
5 Device,
6};
7use anyhow::Result;
8use candle_core::Tensor;
9use log::info;
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{
12 fmt::Debug,
13 fs::File,
14 io::{BufReader, Write},
15 path::Path,
16};
17
18#[allow(clippy::upper_case_acronyms)]
20#[derive(Debug, Deserialize, Serialize, PartialEq)]
21pub struct AwacConfig<Q, P>
22where
23 Q: SubModel2<Output = Tensor>,
24 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
25 P: SubModel1<Output = (Tensor, Tensor)>,
26 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
27{
28 pub actor_config: GaussianActorConfig<P::Config>,
30
31 pub critic_config: MultiCriticConfig<Q::Config>,
33
34 pub gamma: f64,
36
37 pub inv_lambda: f64,
39
40 pub tau: f64,
47
48 pub min_lstd: f64,
50
51 pub max_lstd: f64,
53
54 pub n_updates_per_opt: usize,
56
57 pub batch_size: usize,
59
60 pub critic_loss: CriticLoss,
64
65 pub reward_scale: f32,
67
68 pub n_critics: usize,
70
71 pub exp_adv_max: f64,
73
74 pub seed: Option<i64>,
76
77 pub device: Option<Device>,
79
80 pub adv_softmax: bool,
82}
83
84impl<Q, P> Clone for AwacConfig<Q, P>
85where
86 Q: SubModel2<Output = Tensor>,
87 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
88 P: SubModel1<Output = (Tensor, Tensor)>,
89 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
90{
91 fn clone(&self) -> Self {
92 Self {
93 actor_config: self.actor_config.clone(),
94 critic_config: self.critic_config.clone(),
95 gamma: self.gamma.clone(),
96 inv_lambda: self.inv_lambda.clone(),
97 tau: self.tau.clone(),
98 min_lstd: self.min_lstd.clone(),
99 max_lstd: self.max_lstd.clone(),
100 n_updates_per_opt: self.n_updates_per_opt.clone(),
101 batch_size: self.batch_size.clone(),
102 critic_loss: self.critic_loss.clone(),
103 reward_scale: self.reward_scale.clone(),
104 n_critics: self.n_critics.clone(),
105 exp_adv_max: self.exp_adv_max,
106 seed: self.seed.clone(),
107 device: self.device.clone(),
108 adv_softmax: self.adv_softmax,
109 }
110 }
111}
112
113impl<Q, P> Default for AwacConfig<Q, P>
114where
115 Q: SubModel2<Output = Tensor>,
116 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
117 P: SubModel1<Output = (Tensor, Tensor)>,
118 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
119{
120 fn default() -> Self {
121 Self {
122 actor_config: Default::default(),
123 critic_config: Default::default(),
124 gamma: 0.99,
125 inv_lambda: 10.0,
126 tau: 0.005,
127 min_lstd: -20.0,
128 max_lstd: 2.0,
129 n_updates_per_opt: 1,
130 batch_size: 1,
131 critic_loss: CriticLoss::Mse,
132 reward_scale: 1.0,
133 n_critics: 2,
134 exp_adv_max: 100.0,
135 seed: None,
136 device: None,
137 adv_softmax: false,
138 }
139 }
140}
141
142impl<Q, P> AwacConfig<Q, P>
143where
144 Q: SubModel2<Output = Tensor>,
145 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
146 P: SubModel1<Output = (Tensor, Tensor)>,
147 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
148{
149 pub fn lambda(mut self, v: f64) -> Self {
151 self.inv_lambda = 1.0 / v;
152 self
153 }
154
155 pub fn n_updates_per_opt(mut self, v: usize) -> Self {
157 self.n_updates_per_opt = v;
158 self
159 }
160
161 pub fn batch_size(mut self, v: usize) -> Self {
163 self.batch_size = v;
164 self
165 }
166
167 pub fn discount_factor(mut self, v: f64) -> Self {
169 self.gamma = v;
170 self
171 }
172
173 pub fn tau(mut self, v: f64) -> Self {
175 self.tau = v;
176 self
177 }
178
179 pub fn reward_scale(mut self, v: f32) -> Self {
183 self.reward_scale = v;
184 self
185 }
186
187 pub fn critic_loss(mut self, v: CriticLoss) -> Self {
189 self.critic_loss = v;
190 self
191 }
192
193 pub fn actor_config(mut self, actor_config: GaussianActorConfig<P::Config>) -> Self {
195 self.actor_config = actor_config;
196 self
197 }
198
199 pub fn critic_config(mut self, critic_config: MultiCriticConfig<Q::Config>) -> Self {
201 self.critic_config = critic_config;
202 self
203 }
204
205 pub fn n_critics(mut self, n_critics: usize) -> Self {
207 self.n_critics = n_critics;
208 self
209 }
210
211 pub fn seed(mut self, seed: i64) -> Self {
213 self.seed = Some(seed);
214 self
215 }
216
217 pub fn device(mut self, device: candle_core::Device) -> Self {
219 self.device = Some(device.into());
220 self
221 }
222
223 pub fn adv_softmax(mut self, b: bool) -> Self {
225 self.adv_softmax = b;
226 self
227 }
228
229 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
231 let path_ = path.as_ref().to_owned();
232 let file = File::open(path)?;
233 let rdr = BufReader::new(file);
234 let b = serde_yaml::from_reader(rdr)?;
235 info!("Load config of AWAC agent from {}", path_.to_str().unwrap());
236 Ok(b)
237 }
238
239 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
241 let path_ = path.as_ref().to_owned();
242 let mut file = File::create(path)?;
243 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
244 info!("Save config of AWAC agent into {}", path_.to_str().unwrap());
245 Ok(())
246 }
247}