border_candle_agent/awac/
config.rs

1//! Configuration of AWAC agent.
2use crate::{
3    model::{SubModel1, SubModel2},
4    util::{actor::GaussianActorConfig, critic::MultiCriticConfig, CriticLoss, OutDim},
5    Device,
6};
7use anyhow::Result;
8use candle_core::Tensor;
9use log::info;
10use serde::{de::DeserializeOwned, Deserialize, Serialize};
11use std::{
12    fmt::Debug,
13    fs::File,
14    io::{BufReader, Write},
15    path::Path,
16};
17
18/// Configuration of [`Awac`](super::Awac).
19#[allow(clippy::upper_case_acronyms)]
20#[derive(Debug, Deserialize, Serialize, PartialEq)]
21pub struct AwacConfig<Q, P>
22where
23    Q: SubModel2<Output = Tensor>,
24    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
25    P: SubModel1<Output = (Tensor, Tensor)>,
26    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
27{
28    /// Configuration of the actor model.
29    pub actor_config: GaussianActorConfig<P::Config>,
30
31    /// Configuration of the critic model.
32    pub critic_config: MultiCriticConfig<Q::Config>,
33
34    /// Discont factor.
35    pub gamma: f64,
36
37    /// The inverse of lambda in the paper.
38    pub inv_lambda: f64,
39
40    /// Target smoothing coefficient.
41    ///
42    /// This is a real number between 0 and 1.
43    /// A value of 0.001 makes the target network parameters adapt very slowly to the critic network parameters.
44    ///
45    /// Formula: target_params = tau * critic_params + (1.0 - tau) * target_params
46    pub tau: f64,
47
48    /// Minimum value of the log of the standard deviation of the action distribution.
49    pub min_lstd: f64,
50
51    /// Maximum value of the log of the standard deviation of the action distribution.
52    pub max_lstd: f64,
53
54    /// Number of parameter updates per optimization step.
55    pub n_updates_per_opt: usize,
56
57    /// Batch size for training.
58    pub batch_size: usize,
59
60    // /// If `true`, the agent is
61    // pub train: bool,
62    /// Type of critic loss function.
63    pub critic_loss: CriticLoss,
64
65    /// Scaling factor for rewards.
66    pub reward_scale: f32,
67
68    /// Number of critics used.
69    pub n_critics: usize,
70
71    /// Maximum of exponent of advantage.
72    pub exp_adv_max: f64,
73
74    /// Random seed value (optional).
75    pub seed: Option<i64>,
76
77    /// Device used for the actor and critic models (e.g., CPU or GPU).
78    pub device: Option<Device>,
79
80    /// If true, advantage weights are calculated with softmax within each mini-batch.
81    pub adv_softmax: bool,
82}
83
84impl<Q, P> Clone for AwacConfig<Q, P>
85where
86    Q: SubModel2<Output = Tensor>,
87    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
88    P: SubModel1<Output = (Tensor, Tensor)>,
89    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
90{
91    fn clone(&self) -> Self {
92        Self {
93            actor_config: self.actor_config.clone(),
94            critic_config: self.critic_config.clone(),
95            gamma: self.gamma.clone(),
96            inv_lambda: self.inv_lambda.clone(),
97            tau: self.tau.clone(),
98            min_lstd: self.min_lstd.clone(),
99            max_lstd: self.max_lstd.clone(),
100            n_updates_per_opt: self.n_updates_per_opt.clone(),
101            batch_size: self.batch_size.clone(),
102            critic_loss: self.critic_loss.clone(),
103            reward_scale: self.reward_scale.clone(),
104            n_critics: self.n_critics.clone(),
105            exp_adv_max: self.exp_adv_max,
106            seed: self.seed.clone(),
107            device: self.device.clone(),
108            adv_softmax: self.adv_softmax,
109        }
110    }
111}
112
113impl<Q, P> Default for AwacConfig<Q, P>
114where
115    Q: SubModel2<Output = Tensor>,
116    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
117    P: SubModel1<Output = (Tensor, Tensor)>,
118    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
119{
120    fn default() -> Self {
121        Self {
122            actor_config: Default::default(),
123            critic_config: Default::default(),
124            gamma: 0.99,
125            inv_lambda: 10.0,
126            tau: 0.005,
127            min_lstd: -20.0,
128            max_lstd: 2.0,
129            n_updates_per_opt: 1,
130            batch_size: 1,
131            critic_loss: CriticLoss::Mse,
132            reward_scale: 1.0,
133            n_critics: 2,
134            exp_adv_max: 100.0,
135            seed: None,
136            device: None,
137            adv_softmax: false,
138        }
139    }
140}
141
142impl<Q, P> AwacConfig<Q, P>
143where
144    Q: SubModel2<Output = Tensor>,
145    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
146    P: SubModel1<Output = (Tensor, Tensor)>,
147    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
148{
149    /// Sets lambda.
150    pub fn lambda(mut self, v: f64) -> Self {
151        self.inv_lambda = 1.0 / v;
152        self
153    }
154
155    /// Sets the numper of parameter update steps per optimization step.
156    pub fn n_updates_per_opt(mut self, v: usize) -> Self {
157        self.n_updates_per_opt = v;
158        self
159    }
160
161    /// Batch size.
162    pub fn batch_size(mut self, v: usize) -> Self {
163        self.batch_size = v;
164        self
165    }
166
167    /// Discount factor.
168    pub fn discount_factor(mut self, v: f64) -> Self {
169        self.gamma = v;
170        self
171    }
172
173    /// Sets soft update coefficient.
174    pub fn tau(mut self, v: f64) -> Self {
175        self.tau = v;
176        self
177    }
178
179    /// Reward scale.
180    ///
181    /// It works for obtaining target values, not the values in logs.
182    pub fn reward_scale(mut self, v: f32) -> Self {
183        self.reward_scale = v;
184        self
185    }
186
187    /// Critic loss.
188    pub fn critic_loss(mut self, v: CriticLoss) -> Self {
189        self.critic_loss = v;
190        self
191    }
192
193    /// Configuration of actor.
194    pub fn actor_config(mut self, actor_config: GaussianActorConfig<P::Config>) -> Self {
195        self.actor_config = actor_config;
196        self
197    }
198
199    /// Configuration of critic.
200    pub fn critic_config(mut self, critic_config: MultiCriticConfig<Q::Config>) -> Self {
201        self.critic_config = critic_config;
202        self
203    }
204
205    /// The number of critics.
206    pub fn n_critics(mut self, n_critics: usize) -> Self {
207        self.n_critics = n_critics;
208        self
209    }
210
211    /// Random seed.
212    pub fn seed(mut self, seed: i64) -> Self {
213        self.seed = Some(seed);
214        self
215    }
216
217    /// Device.
218    pub fn device(mut self, device: candle_core::Device) -> Self {
219        self.device = Some(device.into());
220        self
221    }
222
223    /// If true, advantage weights are calculated with softmax within each mini-batch.
224    pub fn adv_softmax(mut self, b: bool) -> Self {
225        self.adv_softmax = b;
226        self
227    }
228
229    /// Constructs [`AwacConfig`] from YAML file.
230    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
231        let path_ = path.as_ref().to_owned();
232        let file = File::open(path)?;
233        let rdr = BufReader::new(file);
234        let b = serde_yaml::from_reader(rdr)?;
235        info!("Load config of AWAC agent from {}", path_.to_str().unwrap());
236        Ok(b)
237    }
238
239    /// Saves [`AwacConfig`] to YAML file.
240    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
241        let path_ = path.as_ref().to_owned();
242        let mut file = File::create(path)?;
243        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
244        info!("Save config of AWAC agent into {}", path_.to_str().unwrap());
245        Ok(())
246    }
247}