candle_optimisers/
radam.rs

1/*!
2RAdam optimiser
3
4Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)
5
6As decoupled weight decay is implemented, this can be used equivalent to the paper (which uses decoupled weight decay),
7or the PyTorch implementation (which does not)
8
9Pseudocode (including decoupling of weight decay):
10
11$$
12\\begin{aligned}
13    &\\rule{110mm}{0.4pt}                                                                 \\\\
14    &\\textbf{input}      : \\gamma \\text{ (lr)}, \\: \\beta_1, \\beta_2
15        \\text{ (betas)}, \\: \\theta_0 \\text{ (params)}, \\:f(\\theta) \\text{ (objective)}, \\:
16        \\lambda \\text{ (weightdecay)},                                                   \\\\
17    &\\hspace{13mm} \\epsilon \\text{ (epsilon)}                                            \\\\
18    &\\textbf{initialize} :  m_0 \\leftarrow 0 \\text{ ( first moment)},
19        v_0 \\leftarrow 0 \\text{ ( second moment)},                                       \\\\
20    &\\hspace{18mm} \\rho_{\\infty} \\leftarrow 2/(1-\\beta_2) -1                      \\\\[-1.ex]
21    &\\rule{110mm}{0.4pt}  \\\\
22    &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do}                         \\\\
23    &\\hspace{5mm}g_t           \\leftarrow   \\nabla_{\\theta} f_t (\\theta_{t-1})           \\\\
24    &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some}                        \\\\
25    &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled}                       \\\\
26    &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1}                    \\\\
27    &\\hspace{10mm}\\textbf{else}                                                              \\\\
28    &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda  \\theta_{t-1}                            \\\\
29    &\\hspace{5mm}m_t           \\leftarrow   \\beta_1 m_{t-1} + (1 - \\beta_1) g_t          \\\\
30    &\\hspace{5mm}v_t           \\leftarrow   \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t          \\\\
31    &\\hspace{5mm}\\widehat{m_t} \\leftarrow   m_t/\\big(1-\\beta_1^t \\big)                   \\\\
32    &\\hspace{5mm}\\rho_t \\leftarrow \\rho_{\\infty} -
33        2 t \\beta^t_2 /\\big(1-\\beta_2^t \\big)                                    \\\\[0.1.ex]
34    &\\hspace{5mm}\\textbf{if} \\: \\rho_t > 5                                               \\\\
35    &\\hspace{10mm} l_t \\leftarrow \\frac{\\sqrt{ (1-\\beta^t_2) }}{ \\sqrt{v_t} +\\epsilon  } \\\\
36    &\\hspace{10mm} r_t \\leftarrow
37    \\sqrt{\\frac{(\\rho_t-4)(\\rho_t-2)\\rho_{\\infty}}{(\\rho_{\\infty}-4)(\\rho_{\\infty}-2) \\rho_t}} \\\\
38    &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} r_t l_t        \\\\
39    &\\hspace{5mm}\\textbf{else}                                                           \\\\
40    &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t}                \\\\
41    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
42    &\\bf{return} \\:  \\theta_t                                                     \\\\[-1.ex]
43    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
44\\end{aligned}
45$$
46*/
47
48use candle_core::{Result, Var};
49use candle_nn::optim::Optimizer;
50
51use crate::{Decay, OptimParams};
52
53/// R Adam optimiser
54///
55/// Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)
56
57#[derive(Debug)]
58pub struct RAdam {
59    vars: Vec<VarRAdam>,
60    params: ParamsRAdam,
61    rho_inf: f64,
62    t: f64,
63}
64
65#[derive(Debug)]
66struct VarRAdam {
67    theta: Var,
68    m: Var,
69    v: Var,
70}
71
72/// Parameters for the RAdam optimiser
73#[derive(Clone, Debug, PartialEq, PartialOrd)]
74pub struct ParamsRAdam {
75    /// Learning rate
76    pub lr: f64,
77    /// Coefficient for moving average of first moment
78    pub beta_1: f64,
79    /// Coefficient for moving average of second moment
80    pub beta_2: f64,
81    /// Weight decay
82    pub weight_decay: Option<Decay>,
83    /// Term added to denominator to improve numerical stability
84    pub eps: f64,
85}
86
87impl Default for ParamsRAdam {
88    fn default() -> Self {
89        Self {
90            lr: 0.001,
91            beta_1: 0.9,
92            beta_2: 0.999,
93            eps: 1e-8,
94            weight_decay: None,
95        }
96    }
97}
98
99impl Optimizer for RAdam {
100    type Config = ParamsRAdam;
101
102    fn new(vars: Vec<Var>, params: ParamsRAdam) -> Result<Self> {
103        let vars = vars
104            .into_iter()
105            .filter(|var| var.dtype().is_float())
106            .map(|var| {
107                let dtype = var.dtype();
108                let shape = var.shape();
109                let device = var.device();
110                let m = Var::zeros(shape, dtype, device)?;
111                let v = Var::zeros(shape, dtype, device)?;
112                Ok(VarRAdam { theta: var, m, v })
113            })
114            .collect::<Result<Vec<VarRAdam>>>()?;
115        // // Err(SGDError::NoMomentum)?;
116        // let mut params = params;
117        // params.t = 0;
118        let rho_inf = 2. / (1. - params.beta_2) - 1.;
119        Ok(Self {
120            vars,
121            params,
122            rho_inf,
123            t: 1.,
124        })
125    }
126
127    fn learning_rate(&self) -> f64 {
128        self.params.lr
129    }
130
131    fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
132        // println!("prod {}", prod);
133        let rho_t = self.rho_inf
134            - 2. * self.t * self.params.beta_2.powf(self.t)
135                / (1. - self.params.beta_2.powf(self.t));
136
137        if let Some(wd) = self.params.weight_decay {
138            match wd {
139                Decay::WeightDecay(wd) => {
140                    for var in &self.vars {
141                        let theta = &var.theta;
142                        let m = &var.m;
143                        let v = &var.v;
144                        if let Some(grad) = grads.get(theta) {
145                            let grad = &(grad + (wd * theta.as_tensor())?)?;
146                            let m_next = ((self.params.beta_1 * m.as_tensor())?
147                                + ((1. - self.params.beta_1) * grad)?)?;
148                            let v_next = ((self.params.beta_2 * v.as_tensor())?
149                                + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
150                            let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
151
152                            let delta = if rho_t > 5. {
153                                let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
154                                    / (&v_next.sqrt()? + self.params.eps)?)?;
155                                let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
156                                    / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
157                                    .sqrt();
158                                (self.params.lr * r * (l * m_hat)?)?
159                            } else {
160                                (self.params.lr * m_hat)?
161                            };
162                            theta.set(&theta.sub(&(delta))?)?;
163                            m.set(&m_next)?;
164                            v.set(&v_next)?;
165                        }
166                    }
167                }
168                Decay::DecoupledWeightDecay(decay) => {
169                    for var in &self.vars {
170                        let theta = &var.theta;
171                        let m = &var.m;
172                        let v = &var.v;
173                        if let Some(grad) = grads.get(theta) {
174                            // decoupled weight decay step
175                            theta
176                                .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
177                            let m_next = ((self.params.beta_1 * m.as_tensor())?
178                                + ((1. - self.params.beta_1) * grad)?)?;
179                            let v_next = ((self.params.beta_2 * v.as_tensor())?
180                                + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
181                            let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
182
183                            let delta = if rho_t > 5. {
184                                let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
185                                    / (&v_next.sqrt()? + self.params.eps)?)?;
186                                let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
187                                    / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
188                                    .sqrt();
189                                (self.params.lr * r * (l * m_hat)?)?
190                            } else {
191                                (self.params.lr * m_hat)?
192                            };
193                            theta.set(&theta.sub(&(delta))?)?;
194                            m.set(&m_next)?;
195                            v.set(&v_next)?;
196                        }
197                    }
198                }
199            }
200        } else {
201            for var in &self.vars {
202                let theta = &var.theta;
203                let m = &var.m;
204                let v = &var.v;
205                if let Some(grad) = grads.get(theta) {
206                    let m_next = ((self.params.beta_1 * m.as_tensor())?
207                        + ((1. - self.params.beta_1) * grad)?)?;
208                    let v_next = ((self.params.beta_2 * v.as_tensor())?
209                        + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
210                    let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
211
212                    let delta = if rho_t > 5. {
213                        let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
214                            / (&v_next.sqrt()? + self.params.eps)?)?;
215                        let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
216                            / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
217                            .sqrt();
218                        (self.params.lr * r * (l * m_hat)?)?
219                    } else {
220                        (self.params.lr * m_hat)?
221                    };
222                    theta.set(&theta.sub(&(delta))?)?;
223                    m.set(&m_next)?;
224                    v.set(&v_next)?;
225                }
226            }
227        }
228
229        self.t += 1.;
230        Ok(())
231    }
232
233    fn set_learning_rate(&mut self, lr: f64) {
234        self.params.lr = lr;
235    }
236}
237
238impl OptimParams for RAdam {
239    fn params(&self) -> &Self::Config {
240        &self.params
241    }
242
243    fn set_params(&mut self, config: Self::Config) {
244        self.params = config;
245    }
246}
247
248impl RAdam {
249    /// Return the vars being optimised
250    #[must_use]
251    pub fn into_inner(self) -> Vec<Var> {
252        self.vars.into_iter().map(|v| v.theta).collect()
253    }
254
255    // pub fn push(&mut self, var: &Var) {
256    //     self.vars.push(var.clone());
257    // }
258}
259
260#[cfg(test)]
261mod tests {
262    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
263
264    use anyhow::Result;
265    use assert_approx_eq::assert_approx_eq;
266    use candle_core::{Device, Var};
267    use candle_nn::Optimizer;
268
269    use super::*;
270    #[test]
271    fn lr_test() -> Result<()> {
272        let params = ParamsRAdam {
273            lr: 0.004,
274            ..Default::default()
275        };
276        // Now use backprop to run a linear regression between samples and get the coefficients back.
277        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
278        let b = Var::new(0f32, &Device::Cpu)?;
279        let mut optim = RAdam::new(vec![w.clone(), b.clone()], params)?;
280        assert_approx_eq!(0.004, optim.learning_rate());
281        optim.set_learning_rate(0.002);
282        assert_approx_eq!(0.002, optim.learning_rate());
283        Ok(())
284    }
285
286    #[test]
287    fn into_inner_test() -> Result<()> {
288        let params = ParamsRAdam::default();
289        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
290        let b = Var::new(-2f32, &Device::Cpu)?;
291        let optim = RAdam::new(vec![w.clone(), b.clone()], params)?;
292        let inner = optim.into_inner();
293        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
294        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
295        Ok(())
296    }
297
298    #[test]
299    fn params_test() -> Result<()> {
300        let params = ParamsRAdam {
301            lr: 0.004,
302            ..Default::default()
303        };
304        // Now use backprop to run a linear regression between samples and get the coefficients back.
305        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
306        let b = Var::new(0f32, &Device::Cpu)?;
307        let mut optim = RAdam::new(vec![w.clone(), b.clone()], params.clone())?;
308        assert_eq!(params, optim.params().clone());
309        let new_params = ParamsRAdam {
310            lr: 0.002,
311            ..Default::default()
312        };
313        optim.set_params(new_params.clone());
314        assert_eq!(new_params, optim.params().clone());
315        Ok(())
316    }
317}