border_tch_agent/sac/
ent_coef.rs1use anyhow::Result;
3use log::{info, trace};
4use serde::{Deserialize, Serialize};
5use std::{path::Path};
6use tch::{nn, nn::OptimizerConfig, Tensor};
7
8#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
10pub enum EntCoefMode {
11 Fix(f64),
13 Auto(f64, f64),
15}
16
17pub struct EntCoef {
19 var_store: nn::VarStore,
20 log_alpha: Tensor,
21 target_entropy: Option<f64>,
22 opt: Option<nn::Optimizer>,
23}
24
25impl EntCoef {
26 pub fn new(mode: EntCoefMode, device: tch::Device) -> Self {
28 let var_store = nn::VarStore::new(device);
29 let path = &var_store.root();
30 let (log_alpha, target_entropy, opt) = match mode {
31 EntCoefMode::Fix(alpha) => {
32 let init = nn::Init::Const(alpha.ln());
33 let log_alpha = path.var("log_alpha", &[1], init);
35 (log_alpha, None, None)
36 }
37 EntCoefMode::Auto(target_entropy, learning_rate) => {
38 let init = nn::Init::Const(0.0);
39 let log_alpha = path.var("log_alpha", &[1], init);
41 let opt = nn::Adam::default()
42 .build(&var_store, learning_rate)
43 .unwrap();
44 (log_alpha, Some(target_entropy), Some(opt))
45 }
46 };
47
48 Self {
49 var_store,
50 log_alpha,
51 opt,
52 target_entropy,
53 }
54 }
55
56 pub fn alpha(&self) -> Tensor {
58 self.log_alpha.detach().exp()
59 }
60
61 pub fn backward_step(&mut self, loss: &Tensor) {
63 if let Some(opt) = &mut self.opt {
64 opt.backward_step(loss);
65 }
66 }
67
68 pub fn update(&mut self, logp: &Tensor) {
70 if let Some(target_entropy) = &self.target_entropy {
71 let target_entropy = Tensor::from(*target_entropy);
72 let loss = -(&self.log_alpha * (logp + target_entropy).detach()).mean(tch::Kind::Float);
73 self.backward_step(&loss);
74 }
75 }
76
77 pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
79 self.var_store.save(&path)?;
80 info!("Save entropy coefficient to {:?}", path.as_ref());
81 let vs = self.var_store.variables();
82 for (name, _) in vs.iter() {
83 trace!("Save variable {}", name);
84 }
85 Ok(())
86 }
87
88 pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
90 self.var_store.load(&path)?;
91 info!("Load entropy coefficient from {:?}", path.as_ref());
92 Ok(())
93 }
94}