1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use anyhow::Result;
use log::{info, trace};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, path::Path};
use tch::{nn, nn::OptimizerConfig, Tensor};
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub enum EntCoefMode {
Fix(f64),
Auto(f64, f64),
}
pub struct EntCoef {
var_store: nn::VarStore,
log_alpha: Tensor,
target_entropy: Option<f64>,
opt: Option<nn::Optimizer<nn::Adam>>,
}
impl EntCoef {
pub fn new(mode: EntCoefMode, device: tch::Device) -> Self {
let var_store = nn::VarStore::new(device);
let path = &var_store.root();
let (log_alpha, target_entropy, opt) = match mode {
EntCoefMode::Fix(alpha) => {
let init = nn::Init::Const(alpha.ln());
let log_alpha = path.borrow().var("log_alpha", &[1], init);
(log_alpha, None, None)
}
EntCoefMode::Auto(target_entropy, learning_rate) => {
let init = nn::Init::Const(0.0);
let log_alpha = path.borrow().var("log_alpha", &[1], init);
let opt = nn::Adam::default()
.build(&var_store, learning_rate)
.unwrap();
(log_alpha, Some(target_entropy), Some(opt))
}
};
Self {
var_store,
log_alpha,
opt,
target_entropy,
}
}
pub fn alpha(&self) -> Tensor {
self.log_alpha.detach().exp()
}
pub fn backward_step(&mut self, loss: &Tensor) {
if let Some(opt) = &mut self.opt {
opt.backward_step(loss);
}
}
pub fn update(&mut self, logp: &Tensor) {
if let Some(target_entropy) = &self.target_entropy {
let target_entropy = Tensor::from(*target_entropy);
let loss = -(&self.log_alpha * (logp + target_entropy).detach()).mean(tch::Kind::Float);
self.backward_step(&loss);
}
}
pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
self.var_store.save(&path)?;
info!("Save entropy coefficient to {:?}", path.as_ref());
let vs = self.var_store.variables();
for (name, _) in vs.iter() {
trace!("Save variable {}", name);
}
Ok(())
}
pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
self.var_store.load(&path)?;
info!("Load entropy coefficient from {:?}", path.as_ref());
Ok(())
}
}