1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
//! Entropy coefficient of SAC.
use anyhow::Result;
use log::{info, trace};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, path::Path};
use tch::{nn, nn::OptimizerConfig, Tensor};

/// Mode of the entropy coefficient of SAC.
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub enum EntCoefMode {
    /// Use a constant as alpha.
    Fix(f64),
    /// Automatic tuning given `(target_entropy, learning_rate)`.
    Auto(f64, f64),
}

/// The entropy coefficient of SAC.
pub struct EntCoef {
    var_store: nn::VarStore,
    log_alpha: Tensor,
    target_entropy: Option<f64>,
    opt: Option<nn::Optimizer<nn::Adam>>,
}

impl EntCoef {
    /// Constructs an instance of `EntCoef`.
    pub fn new(mode: EntCoefMode, device: tch::Device) -> Self {
        let var_store = nn::VarStore::new(device);
        let path = &var_store.root();
        let (log_alpha, target_entropy, opt) = match mode {
            EntCoefMode::Fix(alpha) => {
                let init = nn::Init::Const(alpha.ln());
                let log_alpha = path.borrow().var("log_alpha", &[1], init);
                (log_alpha, None, None)
            }
            EntCoefMode::Auto(target_entropy, learning_rate) => {
                let init = nn::Init::Const(0.0);
                let log_alpha = path.borrow().var("log_alpha", &[1], init);
                let opt = nn::Adam::default()
                    .build(&var_store, learning_rate)
                    .unwrap();
                (log_alpha, Some(target_entropy), Some(opt))
            }
        };

        Self {
            var_store,
            log_alpha,
            opt,
            target_entropy,
        }
    }

    /// Returns the entropy coefficient.
    pub fn alpha(&self) -> Tensor {
        self.log_alpha.detach().exp()
    }

    /// Does an optimization step given a loss.
    pub fn backward_step(&mut self, loss: &Tensor) {
        if let Some(opt) = &mut self.opt {
            opt.backward_step(loss);
        }
    }

    /// Update the parameter given an action probability vector.
    pub fn update(&mut self, logp: &Tensor) {
        if let Some(target_entropy) = &self.target_entropy {
            let target_entropy = Tensor::from(*target_entropy);
            let loss = -(&self.log_alpha * (logp + target_entropy).detach()).mean(tch::Kind::Float);
            self.backward_step(&loss);
        }
    }

    /// Save the parameter into a file.
    pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
        self.var_store.save(&path)?;
        info!("Save entropy coefficient to {:?}", path.as_ref());
        let vs = self.var_store.variables();
        for (name, _) in vs.iter() {
            trace!("Save variable {}", name);
        }
        Ok(())
    }

    /// Save the parameter from a file.
    pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
        self.var_store.load(&path)?;
        info!("Load entropy coefficient from {:?}", path.as_ref());
        Ok(())
    }
}