border_candle_agent/
opt.rs

1//! Optimizers.
2use anyhow::Result;
3use candle_core::{Tensor, Var};
4use candle_nn::{AdamW, Optimizer as _, ParamsAdamW};
5use candle_optimisers::adam::{Adam, ParamsAdam};
6use serde::{Deserialize, Serialize};
7
8/// Configuration of optimizer for training neural networks in an RL agent.
9#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
10pub enum OptimizerConfig {
11    /// AdamW optimizer.
12    AdamW {
13        lr: f64,
14        #[serde(default = "default_beta1")]
15        beta1: f64,
16        #[serde(default = "default_beta2")]
17        beta2: f64,
18        #[serde(default = "default_eps")]
19        eps: f64,
20        #[serde(default = "default_weight_decay")]
21        weight_decay: f64,
22    },
23
24    /// Adam optimizer.
25    Adam {
26        /// Learning rate.
27        lr: f64,
28    },
29}
30
31fn default_beta1() -> f64 {
32    ParamsAdamW::default().beta1
33}
34
35fn default_beta2() -> f64 {
36    ParamsAdamW::default().beta2
37}
38
39fn default_eps() -> f64 {
40    ParamsAdamW::default().eps
41}
42
43fn default_weight_decay() -> f64 {
44    ParamsAdamW::default().weight_decay
45}
46
47impl OptimizerConfig {
48    /// Constructs [`AdamW`] optimizer.
49    pub fn build(&self, vars: Vec<Var>) -> Result<Optimizer> {
50        match &self {
51            OptimizerConfig::AdamW {
52                lr,
53                beta1,
54                beta2,
55                eps,
56                weight_decay,
57            } => {
58                let params = ParamsAdamW {
59                    lr: *lr,
60                    beta1: *beta1,
61                    beta2: *beta2,
62                    eps: *eps,
63                    weight_decay: *weight_decay,
64                };
65                let opt = AdamW::new(vars, params)?;
66                Ok(Optimizer::AdamW(opt))
67            }
68            OptimizerConfig::Adam { lr } => {
69                let params = ParamsAdam {
70                    lr: *lr,
71                    ..ParamsAdam::default()
72                };
73                let opt = Adam::new(vars, params)?;
74                Ok(Optimizer::Adam(opt))
75            }
76        }
77    }
78
79    /// Override learning rate.
80    pub fn learning_rate(self, lr: f64) -> Self {
81        match self {
82            Self::AdamW {
83                lr: _,
84                beta1,
85                beta2,
86                eps,
87                weight_decay,
88            } => Self::AdamW {
89                lr,
90                beta1,
91                beta2,
92                eps,
93                weight_decay,
94            },
95            Self::Adam { lr: _ } => Self::Adam { lr },
96        }
97    }
98}
99
100impl Default for OptimizerConfig {
101    fn default() -> Self {
102        let params = ParamsAdamW::default();
103        Self::AdamW {
104            lr: params.lr,
105            beta1: params.beta1,
106            beta2: params.beta2,
107            eps: params.eps,
108            weight_decay: params.weight_decay,
109        }
110    }
111}
112
113/// Optimizers.
114///
115/// This is a thin wrapper of [`candle_nn::optim::Optimizer`].
116///
117/// [`candle_nn::optim::Optimizer`]: https://docs.rs/candle-nn/0.4.1/candle_nn/optim/trait.Optimizer.html
118pub enum Optimizer {
119    /// Adam optimizer.
120    AdamW(AdamW),
121
122    Adam(Adam),
123}
124
125impl Optimizer {
126    /// Applies a backward step pass.
127    pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
128        match self {
129            Self::AdamW(opt) => Ok(opt.backward_step(loss)?),
130            Self::Adam(opt) => Ok(opt.backward_step(loss)?),
131        }
132    }
133
134    pub fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
135        match self {
136            Self::AdamW(opt) => Ok(opt.step(grads)?),
137            Self::Adam(opt) => Ok(opt.step(grads)?),
138        }
139    }
140}