candle_optimisers/
adamax.rs

1/*!
2Adamax optimiser
3
4An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
5
6Pseudocode (including decoupling of weight decay):
7
8$$
9\\begin{aligned}
10    &\\rule{110mm}{0.4pt}                                                                 \\\\
11    &\\textbf{input}      : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2
12        \\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)},
13        \\: \\lambda \\text{ (weight decay)},                                                \\\\
14    &\\hspace{13mm}    \\epsilon \\text{ (epsilon)}                                          \\\\
15    &\\textbf{initialize} :  m_0 \\leftarrow 0 \\text{ ( first moment)},
16        u_0 \\leftarrow 0 \\text{ ( infinity norm)}                                 \\\\[-1.ex]
17    &\\rule{110mm}{0.4pt}                                                                 \\\\
18    &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do}                         \\\\
19    &\\hspace{5mm}g_t           \\leftarrow   \\nabla_{\\theta} f_t (\\theta_{t-1})           \\\\
20    &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some}                        \\\\
21    &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled}                       \\\\
22    &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1}                    \\\\
23    &\\hspace{10mm}\\textbf{else}                                                              \\\\
24    &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda  \\theta_{t-1}                            \\\\
25    &\\hspace{5mm}m_t      \\leftarrow   \\beta_1 m_{t-1} + (1 - \\beta_1) g_t               \\\\
26    &\\hspace{5mm}u_t      \\leftarrow   \\mathrm{max}(\\beta_2 u_{t-1}, |g_{t}|+\\epsilon)   \\\\
27    &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\frac{\\gamma m_t}{(1-\\beta^t_1) u_t} \\\\
28    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
29    &\\bf{return} \\:  \\theta_t                                                     \\\\[-1.ex]
30    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
31\\end{aligned}
32$$
33*/
34
35use candle_core::{Result, Var};
36use candle_nn::optim::Optimizer;
37
38use crate::{Decay, OptimParams};
39
40/// Adamax optimiser
41///
42/// An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
43
44#[derive(Debug)]
45pub struct Adamax {
46    vars: Vec<VarAdaMax>,
47    params: ParamsAdaMax,
48    t: f64,
49}
50
51#[derive(Debug)]
52struct VarAdaMax {
53    theta: Var,
54    m: Var,
55    u: Var,
56}
57
58/// Parameters for the Adamax optimiser
59#[derive(Clone, Debug, PartialEq, PartialOrd)]
60pub struct ParamsAdaMax {
61    /// Learning rate
62    pub lr: f64,
63    /// Coefficient for moving average of first moment
64    pub beta_1: f64,
65    /// Coefficient for moving average of second moment
66    pub beta_2: f64,
67    /// Weight decay
68    pub weight_decay: Option<Decay>,
69    /// Term added to denominator to improve numerical stability
70    pub eps: f64,
71}
72
73impl Default for ParamsAdaMax {
74    fn default() -> Self {
75        Self {
76            lr: 1.0,
77            beta_1: 0.9,
78            beta_2: 0.999,
79            weight_decay: None,
80            eps: 1e-8,
81        }
82    }
83}
84
85impl Optimizer for Adamax {
86    type Config = ParamsAdaMax;
87
88    fn new(vars: Vec<Var>, params: ParamsAdaMax) -> Result<Self> {
89        let vars = vars
90            .into_iter()
91            .filter(|var| var.dtype().is_float())
92            .map(|var| {
93                let dtype = var.dtype();
94                let shape = var.shape();
95                let device = var.device();
96                let m = Var::zeros(shape, dtype, device)?;
97                let u = Var::zeros(shape, dtype, device)?;
98                Ok(VarAdaMax { theta: var, m, u })
99            })
100            .collect::<Result<Vec<VarAdaMax>>>()?;
101        // // Err(SGDError::NoMomentum)?;
102        // let mut params = params;
103        // params.t = 0;
104        Ok(Self {
105            vars,
106            params,
107            t: 1.,
108        })
109    }
110
111    fn learning_rate(&self) -> f64 {
112        self.params.lr
113    }
114
115    fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
116        if let Some(decay) = self.params.weight_decay {
117            match decay {
118                Decay::WeightDecay(decay) => {
119                    for var in &self.vars {
120                        let theta = &var.theta;
121                        let m = &var.m;
122                        let u = &var.u;
123                        if let Some(grad) = grads.get(theta) {
124                            let grad = &(grad + (decay * theta.as_tensor())?)?;
125                            let m_next = ((self.params.beta_1 * m.as_tensor())?
126                                + (1. - self.params.beta_1) * grad)?;
127                            let u_next = (self.params.beta_2 * u.as_tensor())?
128                                .maximum(&(grad.abs()? + self.params.eps)?)?;
129                            let delta = (&m_next * self.params.lr)?
130                                .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
131                            theta.set(&theta.sub(&(delta))?)?;
132                            m.set(&m_next)?;
133                            u.set(&u_next)?;
134                        }
135                    }
136                }
137                Decay::DecoupledWeightDecay(decay) => {
138                    for var in &self.vars {
139                        let theta = &var.theta;
140                        let m = &var.m;
141                        let u = &var.u;
142                        if let Some(grad) = grads.get(theta) {
143                            // decoupled weight decay step
144                            theta
145                                .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
146                            let m_next = ((self.params.beta_1 * m.as_tensor())?
147                                + (1. - self.params.beta_1) * grad)?;
148                            let u_next = (self.params.beta_2 * u.as_tensor())?
149                                .maximum(&(grad.abs()? + self.params.eps)?)?;
150                            let delta = (&m_next * self.params.lr)?
151                                .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
152                            theta.set(&theta.sub(&(delta))?)?;
153                            m.set(&m_next)?;
154                            u.set(&u_next)?;
155                        }
156                    }
157                }
158            }
159        } else {
160            for var in &self.vars {
161                let theta = &var.theta;
162                let m = &var.m;
163                let u = &var.u;
164                if let Some(grad) = grads.get(theta) {
165                    let m_next =
166                        ((self.params.beta_1 * m.as_tensor())? + (1. - self.params.beta_1) * grad)?;
167                    let u_next = (self.params.beta_2 * u.as_tensor())?
168                        .maximum(&(grad.abs()? + self.params.eps)?)?;
169                    let delta = (&m_next * self.params.lr)?
170                        .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
171                    theta.set(&theta.sub(&(delta))?)?;
172                    m.set(&m_next)?;
173                    u.set(&u_next)?;
174                }
175            }
176        }
177        self.t += 1.;
178        Ok(())
179    }
180
181    fn set_learning_rate(&mut self, lr: f64) {
182        self.params.lr = lr;
183    }
184}
185
186impl OptimParams for Adamax {
187    fn params(&self) -> &Self::Config {
188        &self.params
189    }
190
191    fn set_params(&mut self, config: Self::Config) {
192        self.params = config;
193    }
194}
195
196impl Adamax {
197    /// Return the vars being optimised
198    #[must_use]
199    pub fn into_inner(self) -> Vec<Var> {
200        self.vars.into_iter().map(|v| v.theta).collect()
201    }
202
203    // pub fn push(&mut self, var: &Var) {
204    //     self.vars.push(var.clone());
205    // }
206}
207
208#[cfg(test)]
209mod tests {
210    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
211
212    use anyhow::Result;
213    use assert_approx_eq::assert_approx_eq;
214    use candle_core::{Device, Var};
215    use candle_nn::Optimizer;
216
217    use super::*;
218    #[test]
219    fn lr_test() -> Result<()> {
220        let params = ParamsAdaMax {
221            lr: 0.004,
222            ..Default::default()
223        };
224        // Now use backprop to run a linear regression between samples and get the coefficients back.
225        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
226        let b = Var::new(0f32, &Device::Cpu)?;
227        let mut optim = Adamax::new(vec![w.clone(), b.clone()], params)?;
228        assert_approx_eq!(0.004, optim.learning_rate());
229        optim.set_learning_rate(0.002);
230        assert_approx_eq!(0.002, optim.learning_rate());
231        Ok(())
232    }
233
234    #[test]
235    fn into_inner_test() -> Result<()> {
236        let params = ParamsAdaMax::default();
237        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
238        let b = Var::new(-2f32, &Device::Cpu)?;
239        let optim = Adamax::new(vec![w.clone(), b.clone()], params)?;
240        let inner = optim.into_inner();
241        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
242        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
243        Ok(())
244    }
245
246    #[test]
247    fn params_test() -> Result<()> {
248        let params = ParamsAdaMax {
249            lr: 0.004,
250            ..Default::default()
251        };
252        // Now use backprop to run a linear regression between samples and get the coefficients back.
253        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
254        let b = Var::new(0f32, &Device::Cpu)?;
255        let mut optim = Adamax::new(vec![w.clone(), b.clone()], params.clone())?;
256        assert_eq!(params, optim.params().clone());
257        let new_params = ParamsAdaMax {
258            lr: 0.002,
259            ..Default::default()
260        };
261        optim.set_params(new_params.clone());
262        assert_eq!(new_params, optim.params().clone());
263        Ok(())
264    }
265}