border_candle_agent/sac/
ent_coef.rs1use std::convert::TryFrom;
3
4use crate::opt::{Optimizer, OptimizerConfig};
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor};
7use candle_nn::{init::Init, VarBuilder, VarMap};
8use log::info;
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11
12#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
14pub enum EntCoefMode {
15 Fix(f64),
17 Auto(f64, f64),
19}
20
21pub struct EntCoef {
23 varmap: VarMap,
24 log_alpha: Tensor,
25 target_entropy: Option<f64>,
26 opt: Option<Optimizer>,
27}
28
29impl EntCoef {
30 pub fn new(mode: EntCoefMode, device: Device) -> Result<Self> {
32 let varmap = VarMap::new();
33 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
34 let (log_alpha, target_entropy, opt) = match mode {
35 EntCoefMode::Fix(alpha) => {
36 let init = Init::Const(alpha.ln());
37 let log_alpha = vb.get_with_hints(1, "log_alpha", init)?;
38 (log_alpha, None, None)
39 }
40 EntCoefMode::Auto(target_entropy, learning_rate) => {
41 let init = Init::Const(0.0);
42 let log_alpha = vb.get_with_hints(1, "log_alpha", init)?;
43 let opt = OptimizerConfig::default()
44 .learning_rate(learning_rate)
45 .build(varmap.all_vars())?;
46 (log_alpha, Some(target_entropy), Some(opt))
47 }
48 };
49
50 Ok(Self {
51 varmap,
52 log_alpha,
53 opt,
54 target_entropy,
55 })
56 }
57
58 pub fn alpha(&self) -> Result<Tensor> {
60 Ok(self.log_alpha.detach().exp()?)
61 }
62
63 pub fn backward_step(&mut self, loss: &Tensor) {
65 if let Some(opt) = &mut self.opt {
66 opt.backward_step(loss).unwrap();
67 }
68 }
69
70 pub fn update(&mut self, logp: &Tensor) -> Result<()> {
72 if let Some(target_entropy) = &self.target_entropy {
73 let target_entropy =
74 Tensor::try_from(*target_entropy as f32)?.to_device(logp.device())?;
75 let loss = {
76 let tmp = (&self.log_alpha * -1f64)?
78 .broadcast_mul(&logp.broadcast_add(&target_entropy)?.detach())?;
79 tmp.mean(0)?
80 };
81 self.backward_step(&loss);
82 }
83 Ok(())
84 }
85
86 pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<()> {
88 self.varmap.save(&path)?;
89 info!("Save entropy coefficient to {:?}", path.as_ref());
90 Ok(())
91 }
92
93 pub fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()> {
95 self.varmap.load(&path)?;
96 info!("Load entropy coefficient from {:?}", path.as_ref());
97 Ok(())
98 }
99}