border_candle_agent/util/
actor.rs

1//! Actor with Gaussian policy.
2use crate::{
3    model::SubModel1,
4    opt::{Optimizer, OptimizerConfig},
5    util::{atanh, log_jacobian_tanh, OutDim},
6};
7use anyhow::{Context, Result};
8use candle_core::{DType, Device, Tensor, D};
9use candle_nn::{VarBuilder, VarMap};
10use log::info;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use std::{
13    f32::consts::PI,
14    fs::File,
15    io::{BufReader, Write},
16    path::{Path, PathBuf},
17};
18
19fn normal_logp(x: &Tensor, mean: &Tensor, std: &Tensor) -> Result<Tensor> {
20    let var = std.powf(2.0)?;
21    let ps = (-0.5 * (2.0 * PI).ln() as f64
22        - (0.5 * var.log()?)?
23        - ((0.5 / var)? * (x - mean)?.powf(2.0))?)?;
24    Ok(ps.sum(D::Minus1)?)
25}
26
27#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
28/// Action limit type for [`GaussianActor`].
29pub enum ActionLimit {
30    Tanh { action_scale: f32 },
31    Clamp { action_min: f32, action_max: f32 },
32}
33
34#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
35/// Configuration of [`GaussianActor`].
36pub struct GaussianActorConfig<P: OutDim> {
37    pub policy_config: Option<P>,
38    pub opt_config: OptimizerConfig,
39    pub min_log_std: f32,
40    pub max_log_std: f32,
41    pub action_limit: ActionLimit,
42}
43
44impl<P: OutDim> Default for GaussianActorConfig<P> {
45    fn default() -> Self {
46        Self {
47            policy_config: None,
48            opt_config: OptimizerConfig::Adam { lr: 0.0003 },
49            min_log_std: -20.0,
50            max_log_std: 2.0,
51            action_limit: ActionLimit::Clamp {
52                action_min: -1.0,
53                action_max: 1.0,
54            },
55        }
56    }
57}
58
59impl<P> GaussianActorConfig<P>
60where
61    P: DeserializeOwned + Serialize + OutDim,
62{
63    /// Sets the minimum value of log std.
64    pub fn min_log_std(mut self, v: f32) -> Self {
65        self.min_log_std = v;
66        self
67    }
68
69    /// Sets the maximum value of log std.
70    pub fn max_log_std(mut self, v: f32) -> Self {
71        self.max_log_std = v;
72        self
73    }
74
75    /// Sets configurations for policy function.
76    pub fn policy_config(mut self, v: P) -> Self {
77        self.policy_config = Some(v);
78        self
79    }
80
81    /// Sets output dimension of the model.
82    pub fn out_dim(mut self, v: i64) -> Self {
83        match &mut self.policy_config {
84            None => {}
85            Some(pi_config) => pi_config.set_out_dim(v),
86        };
87        self
88    }
89
90    /// Sets optimizer configuration.
91    pub fn opt_config(mut self, v: OptimizerConfig) -> Self {
92        self.opt_config = v;
93        self
94    }
95
96    /// Sets action limit.
97    pub fn action_limit(mut self, action_limit: ActionLimit) -> Self {
98        self.action_limit = action_limit;
99        self
100    }
101
102    /// Loads [`GaussianActorConfig`] from YAML file.
103    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
104        let file = File::open(path)?;
105        let rdr = BufReader::new(file);
106        let b = serde_yaml::from_reader(rdr)?;
107        Ok(b)
108    }
109
110    /// Saves [`GaussianActorConfig`] as YAML file.
111    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
112        let mut file = File::create(path)?;
113        file.write_all(serde_yaml::to_string(&self)?.as_bytes())?;
114        Ok(())
115    }
116}
117
118/// Actor with Gaussian policy.
119pub struct GaussianActor<P>
120where
121    P: SubModel1<Output = (Tensor, Tensor)>,
122    P::Config: DeserializeOwned + Serialize + OutDim + Clone,
123{
124    device: Device,
125    varmap: VarMap,
126
127    // Dimension of the action vector.
128    out_dim: i64,
129
130    // Action-value function
131    policy_config: P::Config,
132    policy: P,
133
134    // Optimizer
135    opt_config: OptimizerConfig,
136    opt: Optimizer,
137
138    // Min/max log std
139    min_log_std: f64,
140    max_log_std: f64,
141
142    action_limit: ActionLimit,
143}
144
145impl<P> GaussianActor<P>
146where
147    P: SubModel1<Output = (Tensor, Tensor)>,
148    P::Config: DeserializeOwned + Serialize + OutDim + Clone,
149{
150    /// Constructs [`GaussianActor`].
151    pub fn build(
152        config: GaussianActorConfig<P::Config>,
153        device: Device,
154    ) -> Result<GaussianActor<P>> {
155        let min_log_std = config.min_log_std as _;
156        let max_log_std = config.max_log_std as _;
157        let policy_config = config.policy_config.context("policy_config is not set.")?;
158        let out_dim = policy_config.get_out_dim();
159        let varmap = VarMap::new();
160        let policy = {
161            let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device).set_prefix("actor");
162            P::build(vb, policy_config.clone())
163        };
164        let opt_config = config.opt_config;
165        let opt = opt_config.build(varmap.all_vars()).unwrap();
166        let action_limit = config.action_limit;
167
168        Ok(Self {
169            device,
170            out_dim,
171            opt_config,
172            varmap,
173            opt,
174            policy,
175            policy_config,
176            min_log_std,
177            max_log_std,
178            action_limit,
179        })
180    }
181
182    /// Returns the parameters of Gaussian distribution given an observation.
183    ///
184    /// The type of return values is `(Tensor, Tensor)`.
185    /// The shape of the both tensors is `(batch_size, action_dimension)`.
186    pub fn forward(&self, x: &P::Input) -> (Tensor, Tensor) {
187        let (mean, std) = self.policy.forward(&x);
188        debug_assert_eq!(mean.dims()[1], self.out_dim as usize);
189        debug_assert_eq!(std.dims()[1], self.out_dim as usize);
190        debug_assert_eq!(mean.dims().len(), 2);
191        debug_assert_eq!(std.dims().len(), 2);
192        (mean, std)
193    }
194
195    /// Rerurns the log probabilities (densities) of the given actions
196    pub fn logp<'a>(&self, obs: &P::Input, act: &Tensor) -> Result<Tensor> {
197        // Distribution parameters on the given observation
198        let (mean, std) = {
199            let (mean, lstd) = self.forward(obs);
200            let std = lstd.clamp(self.min_log_std, self.max_log_std)?.exp()?;
201            (mean, std)
202        };
203
204        // Log probability
205        let act = act.to_device(&self.device)?;
206        match &self.action_limit {
207            ActionLimit::Clamp {
208                action_min: _,
209                action_max: _,
210            } => Ok(normal_logp(&act, &mean, &std)?),
211            ActionLimit::Tanh { action_scale } => {
212                // Back to normal distributed RV
213                let x = atanh(&(&act / *action_scale as f64)?)?;
214                // Log Jacobian
215                let lj = log_jacobian_tanh(&act)?;
216                // Log probability
217                Ok((normal_logp(&x, &mean, &std)? + lj)?)
218            }
219        }
220    }
221
222    /// Samples actions.
223    ///
224    /// If `train` is `true`, actions are sampled from a Gaussian distribution.
225    /// Otherwise, the mean of the Gaussian distribution is returned.
226    pub fn sample(&mut self, obs: &P::Input, train: bool) -> Result<Tensor> {
227        let (mean, lstd) = self.forward(&obs);
228        let std = lstd.clamp(self.min_log_std, self.max_log_std)?.exp()?;
229        let act = match train {
230            true => ((std * mean.randn_like(0., 1.)?)? + mean)?,
231            false => mean,
232        };
233        let act = match self.action_limit {
234            ActionLimit::Clamp {
235                action_min,
236                action_max,
237            } => act.clamp(action_min, action_max)?,
238            ActionLimit::Tanh { action_scale } => (action_scale as f64 * act.tanh()?)?,
239        };
240        Ok(act)
241    }
242
243    pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
244        self.opt.backward_step(loss)?;
245        Ok(())
246    }
247
248    /// Save variables to prefix + ".pt".
249    pub fn save(&self, prefix: impl AsRef<Path>) -> Result<PathBuf> {
250        let mut path = PathBuf::from(prefix.as_ref());
251        path.set_extension("pt");
252        self.varmap.save(&path.as_path())?;
253        info!("Save actor parameters to {:?}", path);
254
255        Ok(path.to_path_buf())
256    }
257
258    /// Load variables from prefix + ".pt".
259    pub fn load(&mut self, prefix: impl AsRef<Path>) -> Result<()> {
260        let mut path = PathBuf::from(prefix.as_ref());
261        path.set_extension("pt");
262        self.varmap.load(&path.as_path())?;
263        info!("Load actor parameters from {:?}", path);
264
265        Ok(())
266    }
267}
268
269impl<P> Clone for GaussianActor<P>
270where
271    P: SubModel1<Output = (Tensor, Tensor)>,
272    P::Config: DeserializeOwned + Serialize + OutDim + Clone,
273{
274    fn clone(&self) -> Self {
275        let min_log_std = self.min_log_std;
276        let max_log_std = self.max_log_std;
277        let device = self.device.clone();
278        let opt_config = self.opt_config.clone();
279        let mut varmap = VarMap::new();
280        let policy_config = self.policy_config.clone();
281        let policy = {
282            let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
283            P::build(vb, policy_config.clone())
284        };
285        let out_dim = self.out_dim;
286        let opt = opt_config.build(varmap.all_vars()).unwrap();
287        let action_limit = self.action_limit.clone();
288
289        // Copy varmap
290        varmap.clone_from(&self.varmap);
291
292        Self {
293            device,
294            out_dim,
295            opt_config,
296            varmap,
297            opt,
298            policy,
299            policy_config,
300            min_log_std,
301            max_log_std,
302            action_limit,
303        }
304    }
305}