candle_optimisers/
adam.rs

1/*!
2Adam optimiser (inlcuding AdamW)
3
4This includes AdamW via use of decoupled weight decay
5
6Described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
7and [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
8
9The AMSGrad variant is also implemented, described in [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
10
11Pseudocode (including decoupling of weight decay AdamW):
12
13Note the AMSGrad branch is different to the PyTorch pseudocode: this is however equivalent to the torch implementation as far as I can tell.
14
15$$
16\\begin{aligned}
17    &\\rule{110mm}{0.4pt}                                                                 \\\\
18    &\\textbf{input}      : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2
19    \\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)}          \\\\
20    &\\hspace{13mm}      \\lambda \\text{ (weight decay)},  \\: \\textit{amsgrad}    \\\\
21    &\\textbf{initialize} :  m_0 \\leftarrow 0 \\text{ ( first moment)},
22                v_0\\leftarrow 0 \\text{ (second moment)},\\: v_0^{max}\\leftarrow 0                          \\\\[-1.ex]
23    &\\rule{110mm}{0.4pt}                                                                 \\\\
24    &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do}                         \\\\
25    &\\hspace{5mm}g_t           \\leftarrow   \\nabla_{\\theta} f_t (\\theta_{t-1})           \\\\
26    &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some}                        \\\\
27    &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled}                       \\\\
28    &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1}                    \\\\
29    &\\hspace{10mm}\\textbf{else}                                                              \\\\
30    &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda  \\theta_{t-1}                            \\\\
31    &\\hspace{5mm}m_t           \\leftarrow   \\beta_1 m_{t-1} + (1 - \\beta_1) g_t          \\\\
32    &\\hspace{5mm}v_t           \\leftarrow   \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t          \\\\
33    &\\hspace{5mm}\\widehat{m_t} \\leftarrow   m_t/\\big(1-\\beta_1^t \\big)                   \\\\
34    &\\hspace{5mm}\\textbf{if} \\: amsgrad                                                  \\\\
35    &\\hspace{10mm}v_t^{max} \\leftarrow \\mathrm{max}(v_{t-1}^{max}, v_t)    \\\\
36    &\\hspace{10mm}\\widehat{v_t}^{max} \\leftarrow v_t^{max}   /\\big(1-\\beta_2^t \\big)  \\\\
37    &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t}/
38        \\big(\\sqrt{\\widehat{v_t}^{max}} + \\epsilon \\big)                                 \\\\
39    &\\hspace{5mm}\\textbf{else}                                                           \\\\
40    &\\hspace{10mm}\\widehat{v_t} \\leftarrow   v_t/\\big(1-\\beta_2^t \\big)                   \\\\
41    &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t}/
42    \\big(\\sqrt{\\widehat{v_t}} + \\epsilon \\big)                                       \\\\
43        &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
44        &\\bf{return} \\:  \\theta_t                                                     \\\\[-1.ex]
45        &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
46\\end{aligned}
47$$
48*/
49
50use candle_core::{Result, Var};
51use candle_nn::optim::Optimizer;
52use log::warn;
53
54use crate::{Decay, OptimParams};
55
56trait AdamInner {
57    fn new(vars: Vec<Var>) -> Result<Self>
58    where
59        Self: Sized;
60    fn into_inner(self) -> Vec<Var>;
61    fn inner_step(
62        &self,
63        params: &ParamsAdam,
64        grads: &candle_core::backprop::GradStore,
65        t: f64,
66    ) -> Result<()>;
67}
68
69/// Adam optimiser
70///
71/// This includes AdamW via use of decoupled weight decay
72///
73/// Described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
74/// and [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
75///
76/// The AMSGrad variant is also implemented, described in [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
77#[derive(Debug)]
78pub struct Adam {
79    vars: VarAdam,
80    params: ParamsAdam,
81    t: f64,
82}
83
84#[derive(Debug)]
85struct VarAdamBase {
86    theta: Var,
87    m: Var,
88    v: Var,
89}
90
91#[derive(Debug)]
92struct VecAdamBase(Vec<VarAdamBase>);
93
94impl AdamInner for VecAdamBase {
95    fn new(vars: Vec<Var>) -> Result<Self>
96    where
97        Self: Sized,
98    {
99        Ok(VecAdamBase(
100            vars.into_iter()
101                .filter(|var| var.dtype().is_float())
102                .map(|var| {
103                    let dtype = var.dtype();
104                    let shape = var.shape();
105                    let device = var.device();
106                    let m = Var::zeros(shape, dtype, device)?;
107                    let v = Var::zeros(shape, dtype, device)?;
108                    Ok(VarAdamBase { theta: var, m, v })
109                })
110                .collect::<Result<Vec<VarAdamBase>>>()?,
111        ))
112    }
113
114    fn into_inner(self) -> Vec<Var> {
115        self.0.into_iter().map(|var| var.theta).collect()
116    }
117
118    fn inner_step(
119        &self,
120        params: &ParamsAdam,
121        grads: &candle_core::backprop::GradStore,
122        t: f64,
123    ) -> Result<()> {
124        if let Some(decay) = params.weight_decay {
125            match decay {
126                Decay::WeightDecay(decay) => {
127                    for var in &self.0 {
128                        let theta = &var.theta;
129                        let m = &var.m;
130                        let v = &var.v;
131                        if let Some(grad) = grads.get(theta) {
132                            let grad = &(grad + (decay * theta.as_tensor())?)?;
133                            let m_next = ((params.beta_1 * m.as_tensor())?
134                                + ((1. - params.beta_1) * grad)?)?;
135                            let v_next = ((params.beta_2 * v.as_tensor())?
136                                + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
137                            let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
138                            let v_hat = (&v_next / (1. - params.beta_2.powf(t)))?;
139                            let delta =
140                                (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
141                            theta.set(&theta.sub(&(delta))?)?;
142                            m.set(&m_next)?;
143                            v.set(&v_next)?;
144                        }
145                    }
146                }
147                Decay::DecoupledWeightDecay(decay) => {
148                    for var in &self.0 {
149                        let theta = &var.theta;
150                        let m = &var.m;
151                        let v = &var.v;
152                        if let Some(grad) = grads.get(theta) {
153                            theta.set(&(theta.as_tensor() * params.lr.mul_add(-decay, 1.))?)?;
154                            let m_next = ((params.beta_1 * m.as_tensor())?
155                                + ((1. - params.beta_1) * grad)?)?;
156                            let v_next = ((params.beta_2 * v.as_tensor())?
157                                + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
158                            let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
159                            let v_hat = (&v_next / (1. - params.beta_2.powf(t)))?;
160                            let delta =
161                                (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
162                            theta.set(&theta.sub(&(delta))?)?;
163                            m.set(&m_next)?;
164                            v.set(&v_next)?;
165                        }
166                    }
167                }
168            }
169        } else {
170            for var in &self.0 {
171                let theta = &var.theta;
172                let m = &var.m;
173                let v = &var.v;
174                if let Some(grad) = grads.get(theta) {
175                    let m_next =
176                        ((params.beta_1 * m.as_tensor())? + ((1. - params.beta_1) * grad)?)?;
177                    let v_next = ((params.beta_2 * v.as_tensor())?
178                        + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
179                    let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
180                    let v_hat = (&v_next / (1. - params.beta_2.powf(t)))?;
181                    let delta = (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
182                    theta.set(&theta.sub(&(delta))?)?;
183                    m.set(&m_next)?;
184                    v.set(&v_next)?;
185                }
186            }
187        }
188        Ok(())
189    }
190}
191
192#[derive(Debug)]
193struct VarAdamAmsgrad {
194    theta: Var,
195    m: Var,
196    v: Var,
197    vmax: Var,
198}
199
200#[derive(Debug)]
201struct VecAdamAmsgrad(Vec<VarAdamAmsgrad>);
202
203impl AdamInner for VecAdamAmsgrad {
204    fn new(vars: Vec<Var>) -> Result<Self>
205    where
206        Self: Sized,
207    {
208        Ok(VecAdamAmsgrad(
209            vars.into_iter()
210                .filter(|var| var.dtype().is_float())
211                .map(|var| {
212                    let dtype = var.dtype();
213                    let shape = var.shape();
214                    let device = var.device();
215                    let m = Var::zeros(shape, dtype, device)?;
216                    let v = Var::zeros(shape, dtype, device)?;
217                    let vmax = Var::zeros(shape, dtype, device)?;
218                    Ok(VarAdamAmsgrad {
219                        theta: var,
220                        m,
221                        v,
222                        vmax,
223                    })
224                })
225                .collect::<Result<Vec<VarAdamAmsgrad>>>()?,
226        ))
227    }
228
229    fn into_inner(self) -> Vec<Var> {
230        self.0.into_iter().map(|var| var.theta).collect()
231    }
232
233    fn inner_step(
234        &self,
235        params: &ParamsAdam,
236        grads: &candle_core::backprop::GradStore,
237        t: f64,
238    ) -> Result<()> {
239        if let Some(decay) = params.weight_decay {
240            match decay {
241                Decay::WeightDecay(decay) => {
242                    for var in &self.0 {
243                        let theta = &var.theta;
244                        let m = &var.m;
245                        let v = &var.v;
246                        let vmax = &var.vmax;
247                        if let Some(grad) = grads.get(theta) {
248                            let grad = &(grad + (decay * theta.as_tensor())?)?;
249                            let m_next = ((params.beta_1 * m.as_tensor())?
250                                + ((1. - params.beta_1) * grad)?)?;
251                            let v_next = ((params.beta_2 * v.as_tensor())?
252                                + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
253                            let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
254                            let vmax_next = vmax.maximum(&v_next)?;
255                            let v_hat = (&vmax_next / (1. - params.beta_2.powf(t)))?;
256                            let delta =
257                                (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
258                            theta.set(&theta.sub(&(delta))?)?;
259                            m.set(&m_next)?;
260                            v.set(&v_next)?;
261                            vmax.set(&vmax_next)?;
262                        }
263                    }
264                }
265                Decay::DecoupledWeightDecay(decay) => {
266                    for var in &self.0 {
267                        let theta = &var.theta;
268                        let m = &var.m;
269                        let v = &var.v;
270                        let vmax = &var.vmax;
271                        if let Some(grad) = grads.get(theta) {
272                            theta.set(&(theta.as_tensor() * params.lr.mul_add(-decay, 1.))?)?;
273                            let m_next = ((params.beta_1 * m.as_tensor())?
274                                + ((1. - params.beta_1) * grad)?)?;
275                            let v_next = ((params.beta_2 * v.as_tensor())?
276                                + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
277                            let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
278                            let vmax_next = vmax.maximum(&v_next)?;
279                            let v_hat = (&vmax_next / (1. - params.beta_2.powf(t)))?;
280                            let delta =
281                                (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
282                            theta.set(&theta.sub(&(delta))?)?;
283                            m.set(&m_next)?;
284                            v.set(&v_next)?;
285                            vmax.set(&vmax_next)?;
286                        }
287                    }
288                }
289            }
290        } else {
291            for var in &self.0 {
292                let theta = &var.theta;
293                let m = &var.m;
294                let v = &var.v;
295                let vmax = &var.vmax;
296                if let Some(grad) = grads.get(theta) {
297                    let m_next =
298                        ((params.beta_1 * m.as_tensor())? + ((1. - params.beta_1) * grad)?)?;
299                    let v_next = ((params.beta_2 * v.as_tensor())?
300                        + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
301                    let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
302                    let vmax_next = vmax.maximum(&v_next)?;
303                    let v_hat = (&vmax_next / (1. - params.beta_2.powf(t)))?;
304                    let delta = (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
305                    theta.set(&theta.sub(&(delta))?)?;
306                    m.set(&m_next)?;
307                    v.set(&v_next)?;
308                    vmax.set(&vmax_next)?;
309                }
310            }
311        }
312        Ok(())
313    }
314}
315
316#[derive(Debug)]
317enum VarAdam {
318    VecAdamBase(VecAdamBase),
319    VecAdamAmsgrad(VecAdamAmsgrad),
320}
321
322/// Parameters for the Adam optimiser
323#[allow(clippy::module_name_repetitions)]
324#[derive(Clone, Debug, PartialEq, PartialOrd)]
325pub struct ParamsAdam {
326    /// Learning rate
327    pub lr: f64,
328    /// Coefficient for moving average of first moment
329    pub beta_1: f64,
330    /// Coefficient for moving average of second moment
331    pub beta_2: f64,
332    /// Term added to denominator to improve numerical stability
333    pub eps: f64,
334    /// Weight decay
335    pub weight_decay: Option<Decay>,
336    /// Whether to use AMSGrad variant
337    pub amsgrad: bool,
338}
339
340impl Default for ParamsAdam {
341    fn default() -> Self {
342        Self {
343            lr: 0.001,
344            beta_1: 0.9,
345            beta_2: 0.999,
346            eps: 1e-8,
347            weight_decay: None,
348            amsgrad: false,
349            // decoupled_weight_decay: false,
350        }
351    }
352}
353
354impl Optimizer for Adam {
355    type Config = ParamsAdam;
356
357    fn new(vars: Vec<Var>, params: ParamsAdam) -> Result<Self> {
358        if params.amsgrad {
359            let vars = VarAdam::VecAdamAmsgrad(VecAdamAmsgrad::new(vars)?);
360            Ok(Self {
361                vars,
362                params,
363                t: 1.,
364            })
365        } else {
366            let vars = VarAdam::VecAdamBase(VecAdamBase::new(vars)?);
367            Ok(Self {
368                vars,
369                params,
370                t: 1.,
371            })
372        }
373    }
374
375    fn learning_rate(&self) -> f64 {
376        self.params.lr
377    }
378
379    fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
380        match &self.vars {
381            VarAdam::VecAdamBase(vars) => vars.inner_step(&self.params, grads, self.t)?,
382            VarAdam::VecAdamAmsgrad(vars) => vars.inner_step(&self.params, grads, self.t)?,
383        }
384        self.t += 1.;
385        Ok(())
386    }
387
388    fn set_learning_rate(&mut self, lr: f64) {
389        self.params.lr = lr;
390    }
391}
392
393impl OptimParams for Adam {
394    fn params(&self) -> &Self::Config {
395        &self.params
396    }
397
398    /// Set the parameters for the optimiser
399    ///
400    /// # Warning
401    ///
402    /// As the AMSGrad variant requires having tracked an additional tensor
403    /// this variable cannot be changed once set initally on creation of the optimiser.
404    fn set_params(&mut self, config: Self::Config) {
405        let ams_grad = self.params.amsgrad;
406        if ams_grad == config.amsgrad {
407            self.params = config;
408        } else {
409            warn!("AMSGrad cannot be changed once set");
410            let mut config = config;
411            config.amsgrad = ams_grad;
412            self.params = config;
413        }
414    }
415}
416
417impl Adam {
418    /// Return the vars being optimised
419    #[must_use]
420    pub fn into_inner(self) -> Vec<Var> {
421        match self.vars {
422            VarAdam::VecAdamBase(vars) => vars.into_inner(),
423            VarAdam::VecAdamAmsgrad(vars) => vars.into_inner(),
424        }
425    }
426
427    /// set the betas
428    ///
429    /// this can be combined with set_lr for LR and momentum decay scheduling
430    pub fn set_betas(&mut self, beta_1: f64, beta_2: f64) {
431        self.params.beta_1 = beta_1;
432        self.params.beta_2 = beta_2;
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
439
440    use anyhow::Result;
441    use assert_approx_eq::assert_approx_eq;
442    use candle_core::{Device, Var};
443    use candle_nn::Optimizer;
444
445    use super::*;
446    #[test]
447    fn lr_test() -> Result<()> {
448        let params = ParamsAdam {
449            lr: 0.004,
450            ..Default::default()
451        };
452        // Now use backprop to run a linear regression between samples and get the coefficients back.
453        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
454        let b = Var::new(0f32, &Device::Cpu)?;
455        let mut optim = Adam::new(vec![w.clone(), b.clone()], params)?;
456        assert_approx_eq!(0.004, optim.learning_rate());
457        optim.set_learning_rate(0.002);
458        assert_approx_eq!(0.002, optim.learning_rate());
459        Ok(())
460    }
461
462    #[test]
463    fn into_inner_test() -> Result<()> {
464        let params = ParamsAdam::default();
465        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
466        let b = Var::new(-2f32, &Device::Cpu)?;
467        let optim = Adam::new(vec![w.clone(), b.clone()], params)?;
468        let inner = optim.into_inner();
469        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
470        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
471        let params = ParamsAdam {
472            amsgrad: true,
473            ..Default::default()
474        };
475        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
476        let b = Var::new(-2f32, &Device::Cpu)?;
477        let n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
478        let inner = n_sgd.into_inner();
479        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
480        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
481        Ok(())
482    }
483
484    #[test]
485    fn params_test() -> Result<()> {
486        let params = ParamsAdam {
487            lr: 0.004,
488            ..Default::default()
489        };
490        // Now use backprop to run a linear regression between samples and get the coefficients back.
491        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
492        let b = Var::new(0f32, &Device::Cpu)?;
493        let mut optim = Adam::new(vec![w.clone(), b.clone()], params.clone())?;
494        assert_eq!(params, optim.params().clone());
495        let new_params = ParamsAdam {
496            lr: 0.002,
497            ..Default::default()
498        };
499        optim.set_params(new_params.clone());
500        assert_eq!(new_params, optim.params().clone());
501
502        let ams_params = ParamsAdam {
503            lr: 0.002,
504            amsgrad: true,
505            ..Default::default()
506        };
507        optim.set_params(ams_params);
508        // amsgrad cannot be changed once set
509        assert_eq!(new_params, optim.params().clone());
510        optim.set_betas(0.1, 0.1);
511        let final_params = ParamsAdam {
512            lr: 0.002,
513            beta_1: 0.1,
514            beta_2: 0.1,
515            ..Default::default()
516        };
517        assert_eq!(final_params, optim.params().clone());
518        Ok(())
519    }
520}