border_candle_agent/sac/
ent_coef.rs

1//! Entropy coefficient of SAC.
2use std::convert::TryFrom;
3
4use crate::opt::{Optimizer, OptimizerConfig};
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor};
7use candle_nn::{init::Init, VarBuilder, VarMap};
8use log::info;
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11
12/// Mode of the entropy coefficient of SAC.
13#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
14pub enum EntCoefMode {
15    /// Use a constant as alpha.
16    Fix(f64),
17    /// Automatic tuning given `(target_entropy, learning_rate)`.
18    Auto(f64, f64),
19}
20
21/// The entropy coefficient of SAC.
22pub struct EntCoef {
23    varmap: VarMap,
24    log_alpha: Tensor,
25    target_entropy: Option<f64>,
26    opt: Option<Optimizer>,
27}
28
29impl EntCoef {
30    /// Constructs an instance of `EntCoef`.
31    pub fn new(mode: EntCoefMode, device: Device) -> Result<Self> {
32        let varmap = VarMap::new();
33        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
34        let (log_alpha, target_entropy, opt) = match mode {
35            EntCoefMode::Fix(alpha) => {
36                let init = Init::Const(alpha.ln());
37                let log_alpha = vb.get_with_hints(1, "log_alpha", init)?;
38                (log_alpha, None, None)
39            }
40            EntCoefMode::Auto(target_entropy, learning_rate) => {
41                let init = Init::Const(0.0);
42                let log_alpha = vb.get_with_hints(1, "log_alpha", init)?;
43                let opt = OptimizerConfig::default()
44                    .learning_rate(learning_rate)
45                    .build(varmap.all_vars())?;
46                (log_alpha, Some(target_entropy), Some(opt))
47            }
48        };
49
50        Ok(Self {
51            varmap,
52            log_alpha,
53            opt,
54            target_entropy,
55        })
56    }
57
58    /// Returns the entropy coefficient.
59    pub fn alpha(&self) -> Result<Tensor> {
60        Ok(self.log_alpha.detach().exp()?)
61    }
62
63    /// Does an optimization step given a loss.
64    pub fn backward_step(&mut self, loss: &Tensor) {
65        if let Some(opt) = &mut self.opt {
66            opt.backward_step(loss).unwrap();
67        }
68    }
69
70    /// Update the parameter given an action probability vector.
71    pub fn update(&mut self, logp: &Tensor) -> Result<()> {
72        if let Some(target_entropy) = &self.target_entropy {
73            let target_entropy =
74                Tensor::try_from(*target_entropy as f32)?.to_device(logp.device())?;
75            let loss = {
76                // let tmp = ((&self.log_alpha * (logp + target_entropy)?.detach())? * -1f64)?;
77                let tmp = (&self.log_alpha * -1f64)?
78                    .broadcast_mul(&logp.broadcast_add(&target_entropy)?.detach())?;
79                tmp.mean(0)?
80            };
81            self.backward_step(&loss);
82        }
83        Ok(())
84    }
85
86    /// Save the parameter into a file.
87    pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
88        self.varmap.save(&path)?;
89        info!("Save entropy coefficient to {:?}", path.as_ref());
90        Ok(())
91    }
92
93    /// Save the parameter from a file.
94    pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
95        self.varmap.load(&path)?;
96        info!("Load entropy coefficient from {:?}", path.as_ref());
97        Ok(())
98    }
99}