border_candle_agent/sac/
config.rs1use crate::{
3 model::{SubModel1, SubModel2},
4 sac::ent_coef::EntCoefMode,
5 util::{actor::GaussianActorConfig, critic::MultiCriticConfig, CriticLoss, OutDim},
6 Device,
7};
8use anyhow::Result;
9use candle_core::Tensor;
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13 fmt::Debug,
14 fs::File,
15 io::{BufReader, Write},
16 path::Path,
17};
18
19#[allow(clippy::upper_case_acronyms)]
21#[derive(Debug, Deserialize, Serialize, PartialEq)]
22pub struct SacConfig<Q, P>
23where
24 Q: SubModel2<Output = Tensor>,
25 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
26 P: SubModel1<Output = (Tensor, Tensor)>,
27 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
28{
29 pub actor_config: GaussianActorConfig<P::Config>,
31
32 pub critic_config: MultiCriticConfig<Q::Config>,
34
35 pub gamma: f64,
37
38 pub ent_coef_mode: EntCoefMode,
40
41 pub n_updates_per_opt: usize,
43
44 pub batch_size: usize,
46
47 pub critic_loss: CriticLoss,
49
50 pub device: Option<Device>,
52}
53
54impl<Q, P> Clone for SacConfig<Q, P>
55where
56 Q: SubModel2<Output = Tensor>,
57 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
58 P: SubModel1<Output = (Tensor, Tensor)>,
59 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
60{
61 fn clone(&self) -> Self {
62 Self {
63 actor_config: self.actor_config.clone(),
64 critic_config: self.critic_config.clone(),
65 gamma: self.gamma.clone(),
66 ent_coef_mode: self.ent_coef_mode.clone(),
67 n_updates_per_opt: self.n_updates_per_opt.clone(),
68 batch_size: self.batch_size.clone(),
69 critic_loss: self.critic_loss.clone(),
70 device: self.device.clone(),
71 }
72 }
73}
74
75impl<Q, P> Default for SacConfig<Q, P>
76where
77 Q: SubModel2<Output = Tensor>,
78 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
79 P: SubModel1<Output = (Tensor, Tensor)>,
80 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
81{
82 fn default() -> Self {
83 Self {
84 actor_config: Default::default(),
85 critic_config: Default::default(),
86 gamma: 0.99,
87 ent_coef_mode: EntCoefMode::Fix(1.0),
88 n_updates_per_opt: 1,
89 batch_size: 1,
90 critic_loss: CriticLoss::Mse,
91 device: None,
92 }
93 }
94}
95
96impl<Q, P> SacConfig<Q, P>
97where
98 Q: SubModel2<Output = Tensor>,
99 Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone,
100 P: SubModel1<Output = (Tensor, Tensor)>,
101 P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone,
102{
103 pub fn n_updates_per_opt(mut self, v: usize) -> Self {
105 self.n_updates_per_opt = v;
106 self
107 }
108
109 pub fn batch_size(mut self, v: usize) -> Self {
111 self.batch_size = v;
112 self
113 }
114
115 pub fn discount_factor(mut self, v: f64) -> Self {
117 self.gamma = v;
118 self
119 }
120
121 pub fn ent_coef_mode(mut self, v: EntCoefMode) -> Self {
123 self.ent_coef_mode = v;
124 self
125 }
126
127 pub fn critic_loss(mut self, v: CriticLoss) -> Self {
129 self.critic_loss = v;
130 self
131 }
132
133 pub fn actor_config(mut self, actor_config: GaussianActorConfig<P::Config>) -> Self {
135 self.actor_config = actor_config;
136 self
137 }
138
139 pub fn critic_config(mut self, critic_config: MultiCriticConfig<Q::Config>) -> Self {
141 self.critic_config = critic_config;
142 self
143 }
144
145 pub fn device(mut self, device: candle_core::Device) -> Self {
147 self.device = Some(device.into());
148 self
149 }
150
151 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
153 let path_ = path.as_ref().to_owned();
154 let file = File::open(path)?;
155 let rdr = BufReader::new(file);
156 let b = serde_yaml::from_reader(rdr)?;
157 info!("Load config of SAC agent from {}", path_.to_str().unwrap());
158 Ok(b)
159 }
160
161 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
163 let path_ = path.as_ref().to_owned();
164 let mut file = File::create(path)?;
165 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
166 info!("Save config of SAC agent into {}", path_.to_str().unwrap());
167 Ok(())
168 }
169}