candle_optimisers/radam.rs
1/*!
2RAdam optimiser
3
4Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)
5
6As decoupled weight decay is implemented, this can be used equivalent to the paper (which uses decoupled weight decay),
7or the PyTorch implementation (which does not)
8
9Pseudocode (including decoupling of weight decay):
10
11$$
12\\begin{aligned}
13 &\\rule{110mm}{0.4pt} \\\\
14 &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\beta_1, \\beta_2
15 \\text{ (betas)}, \\: \\theta_0 \\text{ (params)}, \\:f(\\theta) \\text{ (objective)}, \\:
16 \\lambda \\text{ (weightdecay)}, \\\\
17 &\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\
18 &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)},
19 v_0 \\leftarrow 0 \\text{ ( second moment)}, \\\\
20 &\\hspace{18mm} \\rho_{\\infty} \\leftarrow 2/(1-\\beta_2) -1 \\\\[-1.ex]
21 &\\rule{110mm}{0.4pt} \\\\
22 &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
23 &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
24 &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
25 &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
26 &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
27 &\\hspace{10mm}\\textbf{else} \\\\
28 &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
29 &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\
30 &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\
31 &\\hspace{5mm}\\widehat{m_t} \\leftarrow m_t/\\big(1-\\beta_1^t \\big) \\\\
32 &\\hspace{5mm}\\rho_t \\leftarrow \\rho_{\\infty} -
33 2 t \\beta^t_2 /\\big(1-\\beta_2^t \\big) \\\\[0.1.ex]
34 &\\hspace{5mm}\\textbf{if} \\: \\rho_t > 5 \\\\
35 &\\hspace{10mm} l_t \\leftarrow \\frac{\\sqrt{ (1-\\beta^t_2) }}{ \\sqrt{v_t} +\\epsilon } \\\\
36 &\\hspace{10mm} r_t \\leftarrow
37 \\sqrt{\\frac{(\\rho_t-4)(\\rho_t-2)\\rho_{\\infty}}{(\\rho_{\\infty}-4)(\\rho_{\\infty}-2) \\rho_t}} \\\\
38 &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} r_t l_t \\\\
39 &\\hspace{5mm}\\textbf{else} \\\\
40 &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} \\\\
41 &\\rule{110mm}{0.4pt} \\\\[-1.ex]
42 &\\bf{return} \\: \\theta_t \\\\[-1.ex]
43 &\\rule{110mm}{0.4pt} \\\\[-1.ex]
44\\end{aligned}
45$$
46*/
47
48use candle_core::{Result, Var};
49use candle_nn::optim::Optimizer;
50
51use crate::{Decay, OptimParams};
52
53/// R Adam optimiser
54///
55/// Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)
56
57#[derive(Debug)]
58pub struct RAdam {
59 vars: Vec<VarRAdam>,
60 params: ParamsRAdam,
61 rho_inf: f64,
62 t: f64,
63}
64
65#[derive(Debug)]
66struct VarRAdam {
67 theta: Var,
68 m: Var,
69 v: Var,
70}
71
72/// Parameters for the RAdam optimiser
73#[derive(Clone, Debug, PartialEq, PartialOrd)]
74pub struct ParamsRAdam {
75 /// Learning rate
76 pub lr: f64,
77 /// Coefficient for moving average of first moment
78 pub beta_1: f64,
79 /// Coefficient for moving average of second moment
80 pub beta_2: f64,
81 /// Weight decay
82 pub weight_decay: Option<Decay>,
83 /// Term added to denominator to improve numerical stability
84 pub eps: f64,
85}
86
87impl Default for ParamsRAdam {
88 fn default() -> Self {
89 Self {
90 lr: 0.001,
91 beta_1: 0.9,
92 beta_2: 0.999,
93 eps: 1e-8,
94 weight_decay: None,
95 }
96 }
97}
98
99impl Optimizer for RAdam {
100 type Config = ParamsRAdam;
101
102 fn new(vars: Vec<Var>, params: ParamsRAdam) -> Result<Self> {
103 let vars = vars
104 .into_iter()
105 .filter(|var| var.dtype().is_float())
106 .map(|var| {
107 let dtype = var.dtype();
108 let shape = var.shape();
109 let device = var.device();
110 let m = Var::zeros(shape, dtype, device)?;
111 let v = Var::zeros(shape, dtype, device)?;
112 Ok(VarRAdam { theta: var, m, v })
113 })
114 .collect::<Result<Vec<VarRAdam>>>()?;
115 // // Err(SGDError::NoMomentum)?;
116 // let mut params = params;
117 // params.t = 0;
118 let rho_inf = 2. / (1. - params.beta_2) - 1.;
119 Ok(Self {
120 vars,
121 params,
122 rho_inf,
123 t: 1.,
124 })
125 }
126
127 fn learning_rate(&self) -> f64 {
128 self.params.lr
129 }
130
131 fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
132 // println!("prod {}", prod);
133 let rho_t = self.rho_inf
134 - 2. * self.t * self.params.beta_2.powf(self.t)
135 / (1. - self.params.beta_2.powf(self.t));
136
137 if let Some(wd) = self.params.weight_decay {
138 match wd {
139 Decay::WeightDecay(wd) => {
140 for var in &self.vars {
141 let theta = &var.theta;
142 let m = &var.m;
143 let v = &var.v;
144 if let Some(grad) = grads.get(theta) {
145 let grad = &(grad + (wd * theta.as_tensor())?)?;
146 let m_next = ((self.params.beta_1 * m.as_tensor())?
147 + ((1. - self.params.beta_1) * grad)?)?;
148 let v_next = ((self.params.beta_2 * v.as_tensor())?
149 + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
150 let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
151
152 let delta = if rho_t > 5. {
153 let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
154 / (&v_next.sqrt()? + self.params.eps)?)?;
155 let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
156 / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
157 .sqrt();
158 (self.params.lr * r * (l * m_hat)?)?
159 } else {
160 (self.params.lr * m_hat)?
161 };
162 theta.set(&theta.sub(&(delta))?)?;
163 m.set(&m_next)?;
164 v.set(&v_next)?;
165 }
166 }
167 }
168 Decay::DecoupledWeightDecay(decay) => {
169 for var in &self.vars {
170 let theta = &var.theta;
171 let m = &var.m;
172 let v = &var.v;
173 if let Some(grad) = grads.get(theta) {
174 // decoupled weight decay step
175 theta
176 .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
177 let m_next = ((self.params.beta_1 * m.as_tensor())?
178 + ((1. - self.params.beta_1) * grad)?)?;
179 let v_next = ((self.params.beta_2 * v.as_tensor())?
180 + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
181 let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
182
183 let delta = if rho_t > 5. {
184 let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
185 / (&v_next.sqrt()? + self.params.eps)?)?;
186 let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
187 / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
188 .sqrt();
189 (self.params.lr * r * (l * m_hat)?)?
190 } else {
191 (self.params.lr * m_hat)?
192 };
193 theta.set(&theta.sub(&(delta))?)?;
194 m.set(&m_next)?;
195 v.set(&v_next)?;
196 }
197 }
198 }
199 }
200 } else {
201 for var in &self.vars {
202 let theta = &var.theta;
203 let m = &var.m;
204 let v = &var.v;
205 if let Some(grad) = grads.get(theta) {
206 let m_next = ((self.params.beta_1 * m.as_tensor())?
207 + ((1. - self.params.beta_1) * grad)?)?;
208 let v_next = ((self.params.beta_2 * v.as_tensor())?
209 + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
210 let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
211
212 let delta = if rho_t > 5. {
213 let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
214 / (&v_next.sqrt()? + self.params.eps)?)?;
215 let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
216 / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
217 .sqrt();
218 (self.params.lr * r * (l * m_hat)?)?
219 } else {
220 (self.params.lr * m_hat)?
221 };
222 theta.set(&theta.sub(&(delta))?)?;
223 m.set(&m_next)?;
224 v.set(&v_next)?;
225 }
226 }
227 }
228
229 self.t += 1.;
230 Ok(())
231 }
232
233 fn set_learning_rate(&mut self, lr: f64) {
234 self.params.lr = lr;
235 }
236}
237
238impl OptimParams for RAdam {
239 fn params(&self) -> &Self::Config {
240 &self.params
241 }
242
243 fn set_params(&mut self, config: Self::Config) {
244 self.params = config;
245 }
246}
247
248impl RAdam {
249 /// Return the vars being optimised
250 #[must_use]
251 pub fn into_inner(self) -> Vec<Var> {
252 self.vars.into_iter().map(|v| v.theta).collect()
253 }
254
255 // pub fn push(&mut self, var: &Var) {
256 // self.vars.push(var.clone());
257 // }
258}
259
260#[cfg(test)]
261mod tests {
262 // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
263
264 use anyhow::Result;
265 use assert_approx_eq::assert_approx_eq;
266 use candle_core::{Device, Var};
267 use candle_nn::Optimizer;
268
269 use super::*;
270 #[test]
271 fn lr_test() -> Result<()> {
272 let params = ParamsRAdam {
273 lr: 0.004,
274 ..Default::default()
275 };
276 // Now use backprop to run a linear regression between samples and get the coefficients back.
277 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
278 let b = Var::new(0f32, &Device::Cpu)?;
279 let mut optim = RAdam::new(vec![w.clone(), b.clone()], params)?;
280 assert_approx_eq!(0.004, optim.learning_rate());
281 optim.set_learning_rate(0.002);
282 assert_approx_eq!(0.002, optim.learning_rate());
283 Ok(())
284 }
285
286 #[test]
287 fn into_inner_test() -> Result<()> {
288 let params = ParamsRAdam::default();
289 let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
290 let b = Var::new(-2f32, &Device::Cpu)?;
291 let optim = RAdam::new(vec![w.clone(), b.clone()], params)?;
292 let inner = optim.into_inner();
293 assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
294 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
295 Ok(())
296 }
297
298 #[test]
299 fn params_test() -> Result<()> {
300 let params = ParamsRAdam {
301 lr: 0.004,
302 ..Default::default()
303 };
304 // Now use backprop to run a linear regression between samples and get the coefficients back.
305 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
306 let b = Var::new(0f32, &Device::Cpu)?;
307 let mut optim = RAdam::new(vec![w.clone(), b.clone()], params.clone())?;
308 assert_eq!(params, optim.params().clone());
309 let new_params = ParamsRAdam {
310 lr: 0.002,
311 ..Default::default()
312 };
313 optim.set_params(new_params.clone());
314 assert_eq!(new_params, optim.params().clone());
315 Ok(())
316 }
317}