border_candle_agent/sac/
config.rs

1//! Configuration of SAC agent.
2use crate::{
3    model::{SubModel1, SubModel2},
4    sac::ent_coef::EntCoefMode,
5    util::{actor::GaussianActorConfig, critic::MultiCriticConfig, CriticLoss, OutDim},
6    Device,
7};
8use anyhow::Result;
9use candle_core::Tensor;
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13    fmt::Debug,
14    fs::File,
15    io::{BufReader, Write},
16    path::Path,
17};
18
19/// Configuration of [`Sac`](super::Sac).
20#[allow(clippy::upper_case_acronyms)]
21#[derive(Debug, Deserialize, Serialize, PartialEq)]
22pub struct SacConfig<Q, P>
23where
24    Q: SubModel2<Output = Tensor>,
25    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
26    P: SubModel1<Output = (Tensor, Tensor)>,
27    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
28{
29    /// Configuration of the actor model.
30    pub actor_config: GaussianActorConfig<P::Config>,
31
32    /// Configuration of the critic model.
33    pub critic_config: MultiCriticConfig<Q::Config>,
34
35    /// Discont factor.
36    pub gamma: f64,
37
38    /// How to update entropy coefficient.
39    pub ent_coef_mode: EntCoefMode,
40
41    /// Number of parameter updates per optimization step.
42    pub n_updates_per_opt: usize,
43
44    /// Batch size for training.
45    pub batch_size: usize,
46
47    /// Type of critic loss function.
48    pub critic_loss: CriticLoss,
49
50    /// Device for actor/critic models.
51    pub device: Option<Device>,
52}
53
54impl<Q, P> Clone for SacConfig<Q, P>
55where
56    Q: SubModel2<Output = Tensor>,
57    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
58    P: SubModel1<Output = (Tensor, Tensor)>,
59    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
60{
61    fn clone(&self) -> Self {
62        Self {
63            actor_config: self.actor_config.clone(),
64            critic_config: self.critic_config.clone(),
65            gamma: self.gamma.clone(),
66            ent_coef_mode: self.ent_coef_mode.clone(),
67            n_updates_per_opt: self.n_updates_per_opt.clone(),
68            batch_size: self.batch_size.clone(),
69            critic_loss: self.critic_loss.clone(),
70            device: self.device.clone(),
71        }
72    }
73}
74
75impl<Q, P> Default for SacConfig<Q, P>
76where
77    Q: SubModel2<Output = Tensor>,
78    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
79    P: SubModel1<Output = (Tensor, Tensor)>,
80    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
81{
82    fn default() -> Self {
83        Self {
84            actor_config: Default::default(),
85            critic_config: Default::default(),
86            gamma: 0.99,
87            ent_coef_mode: EntCoefMode::Fix(1.0),
88            n_updates_per_opt: 1,
89            batch_size: 1,
90            critic_loss: CriticLoss::Mse,
91            device: None,
92        }
93    }
94}
95
96impl<Q, P> SacConfig<Q, P>
97where
98    Q: SubModel2<Output = Tensor>,
99    Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
100    P: SubModel1<Output = (Tensor, Tensor)>,
101    P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
102{
103    /// Sets the numper of parameter update steps per optimization step.
104    pub fn n_updates_per_opt(mut self, v: usize) -> Self {
105        self.n_updates_per_opt = v;
106        self
107    }
108
109    /// Batch size.
110    pub fn batch_size(mut self, v: usize) -> Self {
111        self.batch_size = v;
112        self
113    }
114
115    /// Discount factor.
116    pub fn discount_factor(mut self, v: f64) -> Self {
117        self.gamma = v;
118        self
119    }
120
121    /// SAC-alpha.
122    pub fn ent_coef_mode(mut self, v: EntCoefMode) -> Self {
123        self.ent_coef_mode = v;
124        self
125    }
126
127    /// Critic loss.
128    pub fn critic_loss(mut self, v: CriticLoss) -> Self {
129        self.critic_loss = v;
130        self
131    }
132
133    /// Configuration of actor.
134    pub fn actor_config(mut self, actor_config: GaussianActorConfig<P::Config>) -> Self {
135        self.actor_config = actor_config;
136        self
137    }
138
139    /// Configuration of critic.
140    pub fn critic_config(mut self, critic_config: MultiCriticConfig<Q::Config>) -> Self {
141        self.critic_config = critic_config;
142        self
143    }
144
145    /// Device.
146    pub fn device(mut self, device: candle_core::Device) -> Self {
147        self.device = Some(device.into());
148        self
149    }
150
151    /// Constructs [`SacConfig`] from YAML file.
152    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
153        let path_ = path.as_ref().to_owned();
154        let file = File::open(path)?;
155        let rdr = BufReader::new(file);
156        let b = serde_yaml::from_reader(rdr)?;
157        info!("Load config of SAC agent from {}", path_.to_str().unwrap());
158        Ok(b)
159    }
160
161    /// Saves [`SacConfig`].
162    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
163        let path_ = path.as_ref().to_owned();
164        let mut file = File::create(path)?;
165        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
166        info!("Save config of SAC agent into {}", path_.to_str().unwrap());
167        Ok(())
168    }
169}