candle_optimisers/
nadam.rs

1/*!
2NAdam optimiser: Adam with Nesterov momentum
3
4Described in [Incorporating Nesterov Momentum into Adam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ)
5
6Pseudocode (including decoupling of weight decay):
7
8$$
9\\begin{aligned}
10    &\\rule{110mm}{0.4pt}                                                                 \\\\
11    &\\textbf{input}      : \\gamma_t \\text{ (lr)}, \\: \\beta_1,\\beta_2 \\text{ (betas)},
12        \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) \\text{ (objective)}                   \\\\
13    &\\hspace{12mm} \\: \\lambda \\text{ (weight decay)}, \\:\\psi \\text{ (momentum decay)}    \\\\
14    &\\textbf{initialize} :  m_0 \\leftarrow 0 \\text{ ( first moment)},
15        v_0 \\leftarrow 0 \\text{ ( second moment)}                                 \\\\[-1.ex]
16    &\\rule{110mm}{0.4pt}                                                                 \\\\
17    &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do}                         \\\\
18    &\\hspace{5mm}g_t           \\leftarrow   \\nabla_{\\theta} f_t (\\theta_{t-1})           \\\\
19    &\\hspace{5mm} \\theta_t \\leftarrow \\theta_{t-1}                                       \\\\
20    &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some}                        \\\\
21    &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled}                       \\\\
22    &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1}                    \\\\
23    &\\hspace{10mm}\\textbf{else}                                                              \\\\
24    &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda  \\theta_{t-1}                            \\\\
25    &\\hspace{5mm} \\mu_t \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2}  0.96^{t \\psi} \\big)     \\\\
26    &\\hspace{5mm} \\mu_{t+1} \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2} 0.96^{(t+1)\\psi}\\big)\\\\
27    &\\hspace{5mm}m_t           \\leftarrow   \\beta_1 m_{t-1} + (1 - \\beta_1) g_t          \\\\
28    &\\hspace{5mm}v_t           \\leftarrow   \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t          \\\\
29    &\\hspace{5mm}\\widehat{m_t} \\leftarrow \\mu_{t+1} m_t/(1-\\prod_{i=1}^{t+1}\\mu_i)\\\\[-1.ex]
30    & \\hspace{11mm} + (1-\\mu_t) g_t /(1-\\prod_{i=1}^{t} \\mu_{i})                         \\\\
31    &\\hspace{5mm}\\widehat{v_t} \\leftarrow   v_t/\\big(1-\\beta_2^t \\big)                   \\\\
32    &\\hspace{5mm}\\theta_t \\leftarrow \\theta_t - \\gamma \\widehat{m_t}/
33        \\big(\\sqrt{\\widehat{v_t}} + \\epsilon \\big)                                       \\\\
34    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
35    &\\bf{return} \\:  \\theta_t                                                     \\\\[-1.ex]
36    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
37\\end{aligned}
38$$
39*/
40
41use candle_core::{Result, Var};
42use candle_nn::optim::Optimizer;
43
44use crate::{Decay, OptimParams};
45
46/// Adam optimiser with Nesterov momentum
47///
48/// Described in <https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ>
49#[derive(Debug)]
50pub struct NAdam {
51    vars: Vec<VarNAdam>,
52    params: ParamsNAdam,
53    mu_t: f64,
54    mu_t2: f64,
55    prod: f64,
56    prod2: f64,
57    t: f64,
58}
59
60#[derive(Debug)]
61struct VarNAdam {
62    theta: Var,
63    m: Var,
64    v: Var,
65}
66
67/// Parameters for The NAdam optimiser
68#[derive(Clone, Debug, PartialEq, PartialOrd)]
69pub struct ParamsNAdam {
70    /// Learning rate
71    pub lr: f64,
72    /// Coefficient for moving average of first moment
73    pub beta_1: f64,
74    /// Coefficient for moving average of second moment
75    pub beta_2: f64,
76    /// Term added to denominator to improve numerical stability
77    pub eps: f64,
78    /// Weight decay
79    pub weight_decay: Option<Decay>,
80    /// Momentum decay
81    pub momentum_decay: f64,
82}
83
84impl Default for ParamsNAdam {
85    fn default() -> Self {
86        Self {
87            lr: 0.002,
88            beta_1: 0.9,
89            beta_2: 0.999,
90            eps: 1e-8,
91            weight_decay: None,
92            momentum_decay: 0.004,
93        }
94    }
95}
96
97impl Optimizer for NAdam {
98    type Config = ParamsNAdam;
99
100    fn new(vars: Vec<Var>, params: ParamsNAdam) -> Result<Self> {
101        let vars = vars
102            .into_iter()
103            .filter(|var| var.dtype().is_float())
104            .map(|var| {
105                let dtype = var.dtype();
106                let shape = var.shape();
107                let device = var.device();
108                let m = Var::zeros(shape, dtype, device)?;
109                let v = Var::zeros(shape, dtype, device)?;
110                Ok(VarNAdam { theta: var, m, v })
111            })
112            .collect::<Result<Vec<VarNAdam>>>()?;
113        // // Err(SGDError::NoMomentum)?;
114        // let mut params = params;
115        // params.t = 0;
116        let t = 1.;
117        let mu_t2 = params.beta_1 * 0.5f64.mul_add(-(0.96_f64.powf(t * params.momentum_decay)), 1.);
118        Ok(Self {
119            vars,
120            params,
121            t: 1.,
122            mu_t: 1.,
123            mu_t2,
124            prod: 1.,
125            prod2: mu_t2,
126        })
127    }
128
129    fn learning_rate(&self) -> f64 {
130        self.params.lr
131    }
132
133    fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
134        let mu_t = self.mu_t2;
135        let mu_t2 = self.params.beta_1
136            * 0.5f64.mul_add(
137                -(0.96_f64.powf((self.t + 1.) * self.params.momentum_decay)),
138                1.,
139            );
140        let prod = self.prod2;
141        let prod2 = prod * mu_t2;
142        self.mu_t = mu_t;
143        self.mu_t2 = mu_t2;
144        self.prod = prod;
145        self.prod2 = prod2;
146        // println!("prod {}", prod);
147
148        if let Some(decay) = self.params.weight_decay {
149            match decay {
150                Decay::WeightDecay(decay) => {
151                    for var in &self.vars {
152                        let theta = &var.theta;
153                        let m = &var.m;
154                        let v = &var.v;
155                        if let Some(grad) = grads.get(theta) {
156                            let grad = &(grad + (decay * theta.as_tensor())?)?;
157                            let m_next = ((self.params.beta_1 * m.as_tensor())?
158                                + ((1. - self.params.beta_1) * grad)?)?;
159                            let v_next = ((self.params.beta_2 * v.as_tensor())?
160                                + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
161                            let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
162                                + (((1. - mu_t) / (1. - prod)) * grad)?)?;
163                            let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
164                            let delta = (m_hat * self.params.lr)?
165                                .div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
166                            theta.set(&theta.sub(&(delta))?)?;
167                            m.set(&m_next)?;
168                            v.set(&v_next)?;
169                        }
170                    }
171                }
172                Decay::DecoupledWeightDecay(decay) => {
173                    for var in &self.vars {
174                        let theta = &var.theta;
175                        let m = &var.m;
176                        let v = &var.v;
177                        if let Some(grad) = grads.get(theta) {
178                            theta
179                                .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
180                            let m_next = ((self.params.beta_1 * m.as_tensor())?
181                                + ((1. - self.params.beta_1) * grad)?)?;
182                            let v_next = ((self.params.beta_2 * v.as_tensor())?
183                                + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
184                            let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
185                                + (((1. - mu_t) / (1. - prod)) * grad)?)?;
186                            let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
187                            let delta = (m_hat * self.params.lr)?
188                                .div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
189                            theta.set(&theta.sub(&(delta))?)?;
190                            m.set(&m_next)?;
191                            v.set(&v_next)?;
192                        }
193                    }
194                }
195            }
196        } else {
197            for var in &self.vars {
198                let theta = &var.theta;
199                let m = &var.m;
200                let v = &var.v;
201                if let Some(grad) = grads.get(theta) {
202                    let m_next = ((self.params.beta_1 * m.as_tensor())?
203                        + ((1. - self.params.beta_1) * grad)?)?;
204                    let v_next = ((self.params.beta_2 * v.as_tensor())?
205                        + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
206                    let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
207                        + (((1. - mu_t) / (1. - prod)) * grad)?)?;
208                    let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
209                    let delta =
210                        (m_hat * self.params.lr)?.div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
211                    theta.set(&theta.sub(&(delta))?)?;
212                    m.set(&m_next)?;
213                    v.set(&v_next)?;
214                }
215            }
216        }
217
218        self.t += 1.;
219        Ok(())
220    }
221
222    fn set_learning_rate(&mut self, lr: f64) {
223        self.params.lr = lr;
224    }
225}
226
227impl OptimParams for NAdam {
228    fn params(&self) -> &Self::Config {
229        &self.params
230    }
231
232    fn set_params(&mut self, config: Self::Config) {
233        self.params = config;
234    }
235}
236
237impl NAdam {
238    /// Return the vars being optimised
239    #[must_use]
240    pub fn into_inner(self) -> Vec<Var> {
241        self.vars.into_iter().map(|v| v.theta).collect()
242    }
243
244    // pub fn push(&mut self, var: &Var) {
245    //     self.vars.push(var.clone());
246    // }
247}
248
249#[cfg(test)]
250mod tests {
251    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
252
253    use anyhow::Result;
254    use assert_approx_eq::assert_approx_eq;
255    use candle_core::{Device, Var};
256    use candle_nn::Optimizer;
257
258    use super::*;
259    #[test]
260    fn lr_test() -> Result<()> {
261        let params = ParamsNAdam {
262            lr: 0.004,
263            ..Default::default()
264        };
265        // Now use backprop to run a linear regression between samples and get the coefficients back.
266        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
267        let b = Var::new(0f32, &Device::Cpu)?;
268        let mut optim = NAdam::new(vec![w.clone(), b.clone()], params)?;
269        assert_approx_eq!(0.004, optim.learning_rate());
270        optim.set_learning_rate(0.002);
271        assert_approx_eq!(0.002, optim.learning_rate());
272        Ok(())
273    }
274
275    #[test]
276    fn into_inner_test() -> Result<()> {
277        let params = ParamsNAdam::default();
278        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
279        let b = Var::new(-2f32, &Device::Cpu)?;
280        let optim = NAdam::new(vec![w.clone(), b.clone()], params)?;
281        let inner = optim.into_inner();
282        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
283        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
284        Ok(())
285    }
286
287    #[test]
288    fn params_test() -> Result<()> {
289        let params = ParamsNAdam {
290            lr: 0.004,
291            ..Default::default()
292        };
293        // Now use backprop to run a linear regression between samples and get the coefficients back.
294        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
295        let b = Var::new(0f32, &Device::Cpu)?;
296        let mut optim = NAdam::new(vec![w.clone(), b.clone()], params.clone())?;
297        assert_eq!(params, optim.params().clone());
298        let new_params = ParamsNAdam {
299            lr: 0.002,
300            ..Default::default()
301        };
302        optim.set_params(new_params.clone());
303        assert_eq!(new_params, optim.params().clone());
304        Ok(())
305    }
306}