border_tch_agent/sac/
ent_coef.rs

1//! Entropy coefficient of SAC.
2use anyhow::Result;
3use log::{info, trace};
4use serde::{Deserialize, Serialize};
5use std::{/*borrow::Borrow,*/ path::Path};
6use tch::{nn, nn::OptimizerConfig, Tensor};
7
8/// Mode of the entropy coefficient of SAC.
9#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
10pub enum EntCoefMode {
11    /// Use a constant as alpha.
12    Fix(f64),
13    /// Automatic tuning given `(target_entropy, learning_rate)`.
14    Auto(f64, f64),
15}
16
17/// The entropy coefficient of SAC.
18pub struct EntCoef {
19    var_store: nn::VarStore,
20    log_alpha: Tensor,
21    target_entropy: Option<f64>,
22    opt: Option<nn::Optimizer>,
23}
24
25impl EntCoef {
26    /// Constructs an instance of `EntCoef`.
27    pub fn new(mode: EntCoefMode, device: tch::Device) -> Self {
28        let var_store = nn::VarStore::new(device);
29        let path = &var_store.root();
30        let (log_alpha, target_entropy, opt) = match mode {
31            EntCoefMode::Fix(alpha) => {
32                let init = nn::Init::Const(alpha.ln());
33                // let log_alpha = path.borrow().var("log_alpha", &[1], init);
34                let log_alpha = path.var("log_alpha", &[1], init);
35                (log_alpha, None, None)
36            }
37            EntCoefMode::Auto(target_entropy, learning_rate) => {
38                let init = nn::Init::Const(0.0);
39                // let log_alpha = path.borrow().var("log_alpha", &[1], init);
40                let log_alpha = path.var("log_alpha", &[1], init);
41                let opt = nn::Adam::default()
42                    .build(&var_store, learning_rate)
43                    .unwrap();
44                (log_alpha, Some(target_entropy), Some(opt))
45            }
46        };
47
48        Self {
49            var_store,
50            log_alpha,
51            opt,
52            target_entropy,
53        }
54    }
55
56    /// Returns the entropy coefficient.
57    pub fn alpha(&self) -> Tensor {
58        self.log_alpha.detach().exp()
59    }
60
61    /// Does an optimization step given a loss.
62    pub fn backward_step(&mut self, loss: &Tensor) {
63        if let Some(opt) = &mut self.opt {
64            opt.backward_step(loss);
65        }
66    }
67
68    /// Update the parameter given an action probability vector.
69    pub fn update(&mut self, logp: &Tensor) {
70        if let Some(target_entropy) = &self.target_entropy {
71            let target_entropy = Tensor::from(*target_entropy);
72            let loss = -(&self.log_alpha * (logp + target_entropy).detach()).mean(tch::Kind::Float);
73            self.backward_step(&loss);
74        }
75    }
76
77    /// Save the parameter into a file.
78    pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
79        self.var_store.save(&path)?;
80        info!("Save entropy coefficient to {:?}", path.as_ref());
81        let vs = self.var_store.variables();
82        for (name, _) in vs.iter() {
83            trace!("Save variable {}", name);
84        }
85        Ok(())
86    }
87
88    /// Save the parameter from a file.
89    pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
90        self.var_store.load(&path)?;
91        info!("Load entropy coefficient from {:?}", path.as_ref());
92        Ok(())
93    }
94}