border_candle_agent/util/
actor.rs1use crate::{
3 model::SubModel1,
4 opt::{Optimizer, OptimizerConfig},
5 util::{atanh, log_jacobian_tanh, OutDim},
6};
7use anyhow::{Context, Result};
8use candle_core::{DType, Device, Tensor, D};
9use candle_nn::{VarBuilder, VarMap};
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13 f32::consts::PI,
14 fs::File,
15 io::{BufReader, Write},
16 path::{Path, PathBuf},
17};
18
19fn normal_logp(x: &Tensor, mean: &Tensor, std: &Tensor) -> Result<Tensor> {
20 let var = std.powf(2.0)?;
21 let ps = (-0.5 * (2.0 * PI).ln() as f64
22 - (0.5 * var.log()?)?
23 - ((0.5 / var)? * (x - mean)?.powf(2.0))?)?;
24 Ok(ps.sum(D::Minus1)?)
25}
26
27#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
28pub enum ActionLimit {
30 Tanh { action_scale: f32 },
31 Clamp { action_min: f32, action_max: f32 },
32}
33
34#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
35pub struct GaussianActorConfig<P: OutDim> {
37 pub policy_config: Option<P>,
38 pub opt_config: OptimizerConfig,
39 pub min_log_std: f32,
40 pub max_log_std: f32,
41 pub action_limit: ActionLimit,
42}
43
44impl<P: OutDim> Default for GaussianActorConfig<P> {
45 fn default() -> Self {
46 Self {
47 policy_config: None,
48 opt_config: OptimizerConfig::Adam { lr: 0.0003 },
49 min_log_std: -20.0,
50 max_log_std: 2.0,
51 action_limit: ActionLimit::Clamp {
52 action_min: -1.0,
53 action_max: 1.0,
54 },
55 }
56 }
57}
58
59impl<P> GaussianActorConfig<P>
60where
61 P: DeserializeOwned + Serialize + OutDim,
62{
63 pub fn min_log_std(mut self, v: f32) -> Self {
65 self.min_log_std = v;
66 self
67 }
68
69 pub fn max_log_std(mut self, v: f32) -> Self {
71 self.max_log_std = v;
72 self
73 }
74
75 pub fn policy_config(mut self, v: P) -> Self {
77 self.policy_config = Some(v);
78 self
79 }
80
81 pub fn out_dim(mut self, v: i64) -> Self {
83 match &mut self.policy_config {
84 None => {}
85 Some(pi_config) => pi_config.set_out_dim(v),
86 };
87 self
88 }
89
90 pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
92 self.opt_config = v;
93 self
94 }
95
96 pub fn action_limit(mut self, action_limit: ActionLimit) -> Self {
98 self.action_limit = action_limit;
99 self
100 }
101
102 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
104 let file = File::open(path)?;
105 let rdr = BufReader::new(file);
106 let b = serde_yaml::from_reader(rdr)?;
107 Ok(b)
108 }
109
110 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
112 let mut file = File::create(path)?;
113 file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
114 Ok(())
115 }
116}
117
118pub struct GaussianActor<P>
120where
121 P: SubModel1<Output = (Tensor, Tensor)>,
122 P::Config: DeserializeOwned + Serialize + OutDim + Clone,
123{
124 device: Device,
125 varmap: VarMap,
126
127 out_dim: i64,
129
130 policy_config: P::Config,
132 policy: P,
133
134 opt_config: OptimizerConfig,
136 opt: Optimizer,
137
138 min_log_std: f64,
140 max_log_std: f64,
141
142 action_limit: ActionLimit,
143}
144
145impl<P> GaussianActor<P>
146where
147 P: SubModel1<Output = (Tensor, Tensor)>,
148 P::Config: DeserializeOwned + Serialize + OutDim + Clone,
149{
150 pub fn build(
152 config: GaussianActorConfig<P::Config>,
153 device: Device,
154 ) -> Result<GaussianActor<P>> {
155 let min_log_std = config.min_log_std as _;
156 let max_log_std = config.max_log_std as _;
157 let policy_config = config.policy_config.context("policy_config is not set.")?;
158 let out_dim = policy_config.get_out_dim();
159 let varmap = VarMap::new();
160 let policy = {
161 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device).set_prefix("actor");
162 P::build(vb, policy_config.clone())
163 };
164 let opt_config = config.opt_config;
165 let opt = opt_config.build(varmap.all_vars()).unwrap();
166 let action_limit = config.action_limit;
167
168 Ok(Self {
169 device,
170 out_dim,
171 opt_config,
172 varmap,
173 opt,
174 policy,
175 policy_config,
176 min_log_std,
177 max_log_std,
178 action_limit,
179 })
180 }
181
182 pub fn forward(&self, x: &P::Input) -> (Tensor, Tensor) {
187 let (mean, std) = self.policy.forward(&x);
188 debug_assert_eq!(mean.dims()[1], self.out_dim as usize);
189 debug_assert_eq!(std.dims()[1], self.out_dim as usize);
190 debug_assert_eq!(mean.dims().len(), 2);
191 debug_assert_eq!(std.dims().len(), 2);
192 (mean, std)
193 }
194
195 pub fn logp<'a>(&self, obs: &P::Input, act: &Tensor) -> Result<Tensor> {
197 let (mean, std) = {
199 let (mean, lstd) = self.forward(obs);
200 let std = lstd.clamp(self.min_log_std, self.max_log_std)?.exp()?;
201 (mean, std)
202 };
203
204 let act = act.to_device(&self.device)?;
206 match &self.action_limit {
207 ActionLimit::Clamp {
208 action_min: _,
209 action_max: _,
210 } => Ok(normal_logp(&act, &mean, &std)?),
211 ActionLimit::Tanh { action_scale } => {
212 let x = atanh(&(&act / *action_scale as f64)?)?;
214 let lj = log_jacobian_tanh(&act)?;
216 Ok((normal_logp(&x, &mean, &std)? + lj)?)
218 }
219 }
220 }
221
222 pub fn sample(&mut self, obs: &P::Input, train: bool) -> Result<Tensor> {
227 let (mean, lstd) = self.forward(&obs);
228 let std = lstd.clamp(self.min_log_std, self.max_log_std)?.exp()?;
229 let act = match train {
230 true => ((std * mean.randn_like(0., 1.)?)? + mean)?,
231 false => mean,
232 };
233 let act = match self.action_limit {
234 ActionLimit::Clamp {
235 action_min,
236 action_max,
237 } => act.clamp(action_min, action_max)?,
238 ActionLimit::Tanh { action_scale } => (action_scale as f64 * act.tanh()?)?,
239 };
240 Ok(act)
241 }
242
243 pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
244 self.opt.backward_step(loss)?;
245 Ok(())
246 }
247
248 pub fn save(&self, prefix: impl AsRef<Path>) -> Result<PathBuf> {
250 let mut path = PathBuf::from(prefix.as_ref());
251 path.set_extension("pt");
252 self.varmap.save(&path.as_path())?;
253 info!("Save actor parameters to {:?}", path);
254
255 Ok(path.to_path_buf())
256 }
257
258 pub fn load(&mut self, prefix: impl AsRef<Path>) -> Result<()> {
260 let mut path = PathBuf::from(prefix.as_ref());
261 path.set_extension("pt");
262 self.varmap.load(&path.as_path())?;
263 info!("Load actor parameters from {:?}", path);
264
265 Ok(())
266 }
267}
268
269impl<P> Clone for GaussianActor<P>
270where
271 P: SubModel1<Output = (Tensor, Tensor)>,
272 P::Config: DeserializeOwned + Serialize + OutDim + Clone,
273{
274 fn clone(&self) -> Self {
275 let min_log_std = self.min_log_std;
276 let max_log_std = self.max_log_std;
277 let device = self.device.clone();
278 let opt_config = self.opt_config.clone();
279 let mut varmap = VarMap::new();
280 let policy_config = self.policy_config.clone();
281 let policy = {
282 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
283 P::build(vb, policy_config.clone())
284 };
285 let out_dim = self.out_dim;
286 let opt = opt_config.build(varmap.all_vars()).unwrap();
287 let action_limit = self.action_limit.clone();
288
289 varmap.clone_from(&self.varmap);
291
292 Self {
293 device,
294 out_dim,
295 opt_config,
296 varmap,
297 opt,
298 policy,
299 policy_config,
300 min_log_std,
301 max_log_std,
302 action_limit,
303 }
304 }
305}