candle_optimisers/esgd.rs
1/*!
2 Stochastic Gradient Descent
3
4 This incoporates Nesterov and classical momentum as well as weight decay and decoupled weight decay
5 (as described as SGDW in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101))
6
7$$
8\\begin{aligned}
9 &\\rule{110mm}{0.4pt} \\\\
10 &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta)
11 \\text{ (objective)}, \\: \\lambda \\text{ (weight decay)}, \\\\
12 &\\hspace{13mm} \\:\\mu \\text{ (momentum)}, \\:\\tau \\text{ (dampening)} \\\\[-1.ex]
13 &\\rule{110mm}{0.4pt} \\\\
14 &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
15 &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
16 &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
17 &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
18 &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
19 &\\hspace{10mm}\\textbf{else} \\\\
20 &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
21 &\\hspace{5mm}\\textbf{if} \\: \\mu \\textbf{ is } \\text{Some} \\\\
22 &\\hspace{10mm}\\textbf{if} \\: t>1 \\\\
23 &\\hspace{15mm} b_t \\leftarrow \\mu b_{t-1} + (1-\\tau)g_{t} \\\\
24 &\\hspace{10mm}\\textbf{else} \\\\
25 &\\hspace{15mm} b_t \\leftarrow g_{t} \\\\
26 &\\hspace{10mm}\\textbf{if} \\: \\textit{nesterov} \\\\
27 &\\hspace{15mm} g_t \\leftarrow g_t + \\mu b_t \\\\
28 &\\hspace{10mm}\\textbf{else} \\\\
29 &\\hspace{15mm} g_t \\leftarrow b_t \\\\
30 &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma g_t \\\\
31 &\\rule{110mm}{0.4pt} \\\\[-1.ex]
32 &\\bf{return} \\: \\theta_t \\\\[-1.ex]
33 &\\rule{110mm}{0.4pt} \\\\[-1.ex]
34\\end{aligned}
35$$
36
37*/
38
39use candle_core::{Result, Var};
40use candle_nn::optim::Optimizer;
41
42use crate::{Decay, Momentum, OptimParams};
43
44/// Optimizer for Stochastic Gradient Descent with momentum.
45#[derive(Debug)]
46pub struct SGD {
47 vars: Vec<VarSGD>,
48 params: ParamsSGD,
49}
50
51#[derive(Debug)]
52struct VarSGD {
53 theta: Var,
54 b: Option<Var>,
55}
56
57/// Parameters for SGD
58#[derive(Clone, Debug, PartialEq, PartialOrd)]
59pub struct ParamsSGD {
60 /// Learning rate
61 pub lr: f64,
62 /// Weight decay
63 pub weight_decay: Option<Decay>,
64 /// Momentum
65 pub momentum: Option<Momentum>,
66 /// Dampening
67 pub dampening: f64,
68}
69
70impl Default for ParamsSGD {
71 fn default() -> Self {
72 Self {
73 lr: 0.1,
74 weight_decay: None,
75 momentum: None, //Momentum::Classical(0.1)
76 dampening: 0.0,
77 // nesterov: false,
78 }
79 }
80}
81
82impl Optimizer for SGD {
83 type Config = ParamsSGD;
84
85 fn new(vars: Vec<Var>, params: ParamsSGD) -> Result<Self> {
86 let vars = vars
87 .into_iter()
88 .filter(|var| var.dtype().is_float())
89 .map(|var| VarSGD {
90 theta: var,
91 b: None,
92 })
93 .collect::<Vec<VarSGD>>();
94 // Err(SGDError::NoMomentum)?;
95 Ok(Self { vars, params })
96 }
97
98 fn learning_rate(&self) -> f64 {
99 self.params.lr
100 }
101
102 #[allow(clippy::too_many_lines)]
103 fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
104 if let Some(momentum) = self.params.momentum {
105 match momentum {
106 Momentum::Classical(momentum) => {
107 if let Some(decay) = self.params.weight_decay {
108 match decay {
109 Decay::WeightDecay(decay) => {
110 for var in &mut self.vars {
111 let theta = &var.theta;
112 // let prev_step = var.b;
113 if let Some(grad) = grads.get(theta) {
114 let grad = &(grad + (decay * theta.as_tensor())?)?;
115 if let Some(prev_step) = &(var.b) {
116 // println!("Exists");
117 // bt←μbt−1+(1−τ)gt
118 let bt = ((prev_step.as_tensor() * momentum)?
119 + (1. - self.params.dampening) * (grad))?;
120
121 // if not nesterov gt = bt
122 theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
123 // println!("Momentum {}", bt);
124 prev_step.set(&bt)?;
125 } else {
126 // println!("Doesn't Exist");
127 // bt←μbt−1+(1−τ)gt
128 // if there is no history bt = gt = grad with no weight_decay
129 let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
130
131 // if not nesterov gt = bt
132 theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
133 // println!("Momentum {}", bt);
134 var.b = Some(Var::from_tensor(&bt)?);
135 }
136 }
137 }
138 }
139 Decay::DecoupledWeightDecay(decay) => {
140 for var in &mut self.vars {
141 let theta = &var.theta;
142 // let prev_step = var.b;
143 if let Some(grad) = grads.get(theta) {
144 // decoupled weight decay step
145 theta.set(
146 &(theta.as_tensor()
147 * self.params.lr.mul_add(-decay, 1.))?,
148 )?;
149 if let Some(prev_step) = &(var.b) {
150 // println!("Exists");
151 // bt←μbt−1+(1−τ)gt
152 let bt = ((prev_step.as_tensor() * momentum)?
153 + (1. - self.params.dampening) * (grad))?;
154
155 // if not nesterov gt = bt
156 theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
157 // println!("Momentum {}", bt);
158 prev_step.set(&bt)?;
159 } else {
160 // println!("Doesn't Exist");
161 // bt←μbt−1+(1−τ)gt
162 // if there is no history bt = gt = grad with no weight_decay
163 let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
164
165 // if not nesterov gt = bt
166 theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
167 // println!("Momentum {}", bt);
168 var.b = Some(Var::from_tensor(&bt)?);
169 }
170 }
171 }
172 }
173 }
174 } else {
175 for var in &mut self.vars {
176 let theta = &var.theta;
177 // let prev_step = var.b;
178 if let Some(grad) = grads.get(theta) {
179 if let Some(prev_step) = &(var.b) {
180 // println!("Exists");
181 // bt←μbt−1+(1−τ)gt
182 let bt = ((prev_step.as_tensor() * momentum)?
183 + (1. - self.params.dampening) * (grad))?;
184
185 // if not nesterov gt = bt
186 theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
187 // println!("Momentum {}", bt);
188 prev_step.set(&bt)?;
189 } else {
190 // println!("Doesn't Exist");
191 // bt←μbt−1+(1−τ)gt
192 // if there is no history bt = gt = grad with no weight_decay
193 let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
194
195 // if not nesterov gt = bt
196 theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
197 // println!("Momentum {}", bt);
198 var.b = Some(Var::from_tensor(&bt)?);
199 }
200 }
201 }
202 }
203 }
204 Momentum::Nesterov(momentum) => {
205 if let Some(decay) = self.params.weight_decay {
206 match decay {
207 Decay::WeightDecay(decay) => {
208 for var in &mut self.vars {
209 let theta = &var.theta;
210 // let prev_step = var.b;
211 if let Some(grad) = grads.get(theta) {
212 let grad = &(grad + (decay * theta.as_tensor())?)?;
213 if let Some(prev_step) = &(var.b) {
214 // println!("Exists");
215 // bt←μbt−1+(1−τ)gt
216 let bt = ((prev_step.as_tensor() * momentum)?
217 + (1. - self.params.dampening) * (grad))?;
218
219 let gt = (grad + (momentum * &bt)?)?;
220 // println!("Momentum {}", bt);
221 prev_step.set(&bt)?;
222 theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
223 } else {
224 // println!("Doesn't Exist");
225 // bt←μbt−1+(1−τ)gt
226 // if there is no history bt = gt = grad with no weight_decay
227 let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
228
229 let gt = (grad + (momentum * &bt)?)?;
230 // println!("Momentum {}", bt);
231 var.b = Some(Var::from_tensor(&bt)?);
232 theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
233 }
234 }
235 }
236 }
237 Decay::DecoupledWeightDecay(decay) => {
238 for var in &mut self.vars {
239 let theta = &var.theta;
240 // let prev_step = var.b;
241 if let Some(grad) = grads.get(theta) {
242 // decoupled weight decay step
243 theta.set(
244 &(theta.as_tensor()
245 * self.params.lr.mul_add(-decay, 1.))?,
246 )?;
247 if let Some(prev_step) = &(var.b) {
248 // println!("Exists");
249 // bt←μbt−1+(1−τ)gt
250 let bt = ((prev_step.as_tensor() * momentum)?
251 + (1. - self.params.dampening) * (grad))?;
252
253 let gt = (grad + (momentum * &bt)?)?;
254 // println!("Momentum {}", bt);
255 prev_step.set(&bt)?;
256 theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
257 } else {
258 // println!("Doesn't Exist");
259 // bt←μbt−1+(1−τ)gt
260 // if there is no history bt = gt = grad with no weight_decay
261 let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
262
263 let gt = (grad + (momentum * &bt)?)?;
264 // println!("Momentum {}", bt);
265 var.b = Some(Var::from_tensor(&bt)?);
266 theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
267 }
268 }
269 }
270 }
271 }
272 } else {
273 for var in &mut self.vars {
274 let theta = &var.theta;
275 // let prev_step = var.b;
276 if let Some(grad) = grads.get(theta) {
277 if let Some(prev_step) = &(var.b) {
278 // println!("Exists");
279 // bt←μbt−1+(1−τ)gt
280 let bt = ((prev_step.as_tensor() * momentum)?
281 + (1. - self.params.dampening) * (grad))?;
282
283 let gt = (grad + (momentum * &bt)?)?;
284 // println!("Momentum {}", bt);
285 prev_step.set(&bt)?;
286 theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
287 } else {
288 // println!("Doesn't Exist");
289 // bt←μbt−1+(1−τ)gt
290 // if there is no history bt = gt = grad with no weight_decay
291 let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
292
293 let gt = (grad + (momentum * &bt)?)?;
294 // println!("Momentum {}", bt);
295 var.b = Some(Var::from_tensor(&bt)?);
296 theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
297 }
298 }
299 }
300 }
301 }
302 }
303 } else if let Some(decay) = self.params.weight_decay {
304 // These should be the same up to numeric precision
305 // For SGD with no momentum decoupled weight decay and L2 reg are equivalent
306 match decay {
307 Decay::WeightDecay(decay) => {
308 for var in &mut self.vars {
309 let theta = &var.theta;
310 // let prev_step = var.b;
311 if let Some(grad) = grads.get(theta) {
312 let grad = &(grad + (decay * theta.as_tensor())?)?; // weight decay grad
313 theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta
314 }
315 }
316 }
317 Decay::DecoupledWeightDecay(decay) => {
318 for var in &mut self.vars {
319 let theta = &var.theta;
320 // let prev_step = var.b;
321 if let Some(grad) = grads.get(theta) {
322 theta
323 .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
324 theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta based on grad
325 }
326 }
327 }
328 }
329 } else {
330 for var in &mut self.vars {
331 let theta = &var.theta;
332 // let prev_step = var.b;
333 if let Some(grad) = grads.get(theta) {
334 theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta based on grad
335 }
336 }
337 }
338
339 Ok(())
340 }
341
342 fn set_learning_rate(&mut self, lr: f64) {
343 self.params.lr = lr;
344 }
345}
346
347impl OptimParams for SGD {
348 fn params(&self) -> &Self::Config {
349 &self.params
350 }
351
352 fn set_params(&mut self, config: Self::Config) {
353 self.params = config;
354 }
355}
356
357impl SGD {
358 /// Return the vars being optimised
359 #[must_use]
360 pub fn into_inner(self) -> Vec<Var> {
361 self.vars.into_iter().map(|v| v.theta).collect()
362 }
363
364 // pub fn push(&mut self, var: &Var) {
365 // self.vars.push(var.clone());
366 // }
367}
368
369#[cfg(test)]
370mod tests {
371 // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
372
373 use anyhow::Result;
374 use assert_approx_eq::assert_approx_eq;
375 use candle_core::{Device, Var};
376 use candle_nn::Optimizer;
377
378 use super::*;
379 #[test]
380 fn lr_test() -> Result<()> {
381 let params = ParamsSGD {
382 lr: 0.004,
383 ..Default::default()
384 };
385 // Now use backprop to run a linear regression between samples and get the coefficients back.
386 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
387 let b = Var::new(0f32, &Device::Cpu)?;
388 let mut optim = SGD::new(vec![w.clone(), b.clone()], params)?;
389 assert_approx_eq!(0.004, optim.learning_rate());
390 optim.set_learning_rate(0.002);
391 assert_approx_eq!(0.002, optim.learning_rate());
392 Ok(())
393 }
394
395 #[test]
396 fn into_inner_test() -> Result<()> {
397 let params = ParamsSGD::default();
398 let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
399 let b = Var::new(-2f32, &Device::Cpu)?;
400 let optim = SGD::new(vec![w.clone(), b.clone()], params)?;
401 let inner = optim.into_inner();
402 assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
403 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
404 Ok(())
405 }
406
407 #[test]
408 fn params_test() -> Result<()> {
409 let params = ParamsSGD {
410 lr: 0.004,
411 ..Default::default()
412 };
413 // Now use backprop to run a linear regression between samples and get the coefficients back.
414 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
415 let b = Var::new(0f32, &Device::Cpu)?;
416 let mut optim = SGD::new(vec![w.clone(), b.clone()], params.clone())?;
417 assert_eq!(params, optim.params().clone());
418 let new_params = ParamsSGD {
419 lr: 0.002,
420 ..Default::default()
421 };
422 optim.set_params(new_params.clone());
423 assert_eq!(new_params, optim.params().clone());
424 Ok(())
425 }
426}