candle_optimisers/
esgd.rs

1/*!
2   Stochastic Gradient Descent
3
4   This incoporates Nesterov and classical momentum as well as weight decay and decoupled weight decay
5   (as described as SGDW in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101))
6
7$$
8\\begin{aligned}
9    &\\rule{110mm}{0.4pt}                                                                 \\\\
10     &\\textbf{input}      : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta)
11        \\text{ (objective)}, \\: \\lambda \\text{ (weight decay)},                          \\\\
12   &\\hspace{13mm} \\:\\mu \\text{ (momentum)}, \\:\\tau \\text{ (dampening)}          \\\\[-1.ex]
13    &\\rule{110mm}{0.4pt}                                                                 \\\\
14    &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do}                         \\\\
15    &\\hspace{5mm}g_t           \\leftarrow   \\nabla_{\\theta} f_t (\\theta_{t-1})           \\\\
16    &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some}                        \\\\
17    &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled}                       \\\\
18    &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1}                    \\\\
19    &\\hspace{10mm}\\textbf{else}                                                              \\\\
20    &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda  \\theta_{t-1}                            \\\\
21    &\\hspace{5mm}\\textbf{if} \\: \\mu \\textbf{ is } \\text{Some}                        \\\\
22    &\\hspace{10mm}\\textbf{if} \\: t>1                      \\\\
23    &\\hspace{15mm} b_t \\leftarrow \\mu b_{t-1} + (1-\\tau)g_{t}                   \\\\
24    &\\hspace{10mm}\\textbf{else}                                                              \\\\
25    &\\hspace{15mm} b_t \\leftarrow g_{t}                                    \\\\
26     &\\hspace{10mm}\\textbf{if} \\: \\textit{nesterov}                       \\\\
27    &\\hspace{15mm} g_t \\leftarrow g_t + \\mu b_t                   \\\\
28    &\\hspace{10mm}\\textbf{else}                                                              \\\\
29    &\\hspace{15mm} g_t \\leftarrow b_t                           \\\\
30    &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma g_t \\\\
31    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
32    &\\bf{return} \\:  \\theta_t                                                     \\\\[-1.ex]
33    &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
34\\end{aligned}
35$$
36
37*/
38
39use candle_core::{Result, Var};
40use candle_nn::optim::Optimizer;
41
42use crate::{Decay, Momentum, OptimParams};
43
44/// Optimizer for Stochastic Gradient Descent with momentum.
45#[derive(Debug)]
46pub struct SGD {
47    vars: Vec<VarSGD>,
48    params: ParamsSGD,
49}
50
51#[derive(Debug)]
52struct VarSGD {
53    theta: Var,
54    b: Option<Var>,
55}
56
57/// Parameters for SGD
58#[derive(Clone, Debug, PartialEq, PartialOrd)]
59pub struct ParamsSGD {
60    /// Learning rate
61    pub lr: f64,
62    /// Weight decay
63    pub weight_decay: Option<Decay>,
64    /// Momentum
65    pub momentum: Option<Momentum>,
66    /// Dampening
67    pub dampening: f64,
68}
69
70impl Default for ParamsSGD {
71    fn default() -> Self {
72        Self {
73            lr: 0.1,
74            weight_decay: None,
75            momentum: None, //Momentum::Classical(0.1)
76            dampening: 0.0,
77            // nesterov: false,
78        }
79    }
80}
81
82impl Optimizer for SGD {
83    type Config = ParamsSGD;
84
85    fn new(vars: Vec<Var>, params: ParamsSGD) -> Result<Self> {
86        let vars = vars
87            .into_iter()
88            .filter(|var| var.dtype().is_float())
89            .map(|var| VarSGD {
90                theta: var,
91                b: None,
92            })
93            .collect::<Vec<VarSGD>>();
94        // Err(SGDError::NoMomentum)?;
95        Ok(Self { vars, params })
96    }
97
98    fn learning_rate(&self) -> f64 {
99        self.params.lr
100    }
101
102    #[allow(clippy::too_many_lines)]
103    fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
104        if let Some(momentum) = self.params.momentum {
105            match momentum {
106                Momentum::Classical(momentum) => {
107                    if let Some(decay) = self.params.weight_decay {
108                        match decay {
109                            Decay::WeightDecay(decay) => {
110                                for var in &mut self.vars {
111                                    let theta = &var.theta;
112                                    // let prev_step = var.b;
113                                    if let Some(grad) = grads.get(theta) {
114                                        let grad = &(grad + (decay * theta.as_tensor())?)?;
115                                        if let Some(prev_step) = &(var.b) {
116                                            // println!("Exists");
117                                            // bt​←μbt−1​+(1−τ)gt
118                                            let bt = ((prev_step.as_tensor() * momentum)?
119                                                + (1. - self.params.dampening) * (grad))?;
120
121                                            // if not nesterov gt = bt
122                                            theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
123                                            // println!("Momentum {}", bt);
124                                            prev_step.set(&bt)?;
125                                        } else {
126                                            // println!("Doesn't Exist");
127                                            // bt​←μbt−1​+(1−τ)gt
128                                            // if there is no history bt = gt = grad with no weight_decay
129                                            let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
130
131                                            // if not nesterov gt = bt
132                                            theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
133                                            // println!("Momentum {}", bt);
134                                            var.b = Some(Var::from_tensor(&bt)?);
135                                        }
136                                    }
137                                }
138                            }
139                            Decay::DecoupledWeightDecay(decay) => {
140                                for var in &mut self.vars {
141                                    let theta = &var.theta;
142                                    // let prev_step = var.b;
143                                    if let Some(grad) = grads.get(theta) {
144                                        // decoupled weight decay step
145                                        theta.set(
146                                            &(theta.as_tensor()
147                                                * self.params.lr.mul_add(-decay, 1.))?,
148                                        )?;
149                                        if let Some(prev_step) = &(var.b) {
150                                            // println!("Exists");
151                                            // bt​←μbt−1​+(1−τ)gt
152                                            let bt = ((prev_step.as_tensor() * momentum)?
153                                                + (1. - self.params.dampening) * (grad))?;
154
155                                            // if not nesterov gt = bt
156                                            theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
157                                            // println!("Momentum {}", bt);
158                                            prev_step.set(&bt)?;
159                                        } else {
160                                            // println!("Doesn't Exist");
161                                            // bt​←μbt−1​+(1−τ)gt
162                                            // if there is no history bt = gt = grad with no weight_decay
163                                            let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
164
165                                            // if not nesterov gt = bt
166                                            theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
167                                            // println!("Momentum {}", bt);
168                                            var.b = Some(Var::from_tensor(&bt)?);
169                                        }
170                                    }
171                                }
172                            }
173                        }
174                    } else {
175                        for var in &mut self.vars {
176                            let theta = &var.theta;
177                            // let prev_step = var.b;
178                            if let Some(grad) = grads.get(theta) {
179                                if let Some(prev_step) = &(var.b) {
180                                    // println!("Exists");
181                                    // bt​←μbt−1​+(1−τ)gt
182                                    let bt = ((prev_step.as_tensor() * momentum)?
183                                        + (1. - self.params.dampening) * (grad))?;
184
185                                    // if not nesterov gt = bt
186                                    theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
187                                    // println!("Momentum {}", bt);
188                                    prev_step.set(&bt)?;
189                                } else {
190                                    // println!("Doesn't Exist");
191                                    // bt​←μbt−1​+(1−τ)gt
192                                    // if there is no history bt = gt = grad with no weight_decay
193                                    let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
194
195                                    // if not nesterov gt = bt
196                                    theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
197                                    // println!("Momentum {}", bt);
198                                    var.b = Some(Var::from_tensor(&bt)?);
199                                }
200                            }
201                        }
202                    }
203                }
204                Momentum::Nesterov(momentum) => {
205                    if let Some(decay) = self.params.weight_decay {
206                        match decay {
207                            Decay::WeightDecay(decay) => {
208                                for var in &mut self.vars {
209                                    let theta = &var.theta;
210                                    // let prev_step = var.b;
211                                    if let Some(grad) = grads.get(theta) {
212                                        let grad = &(grad + (decay * theta.as_tensor())?)?;
213                                        if let Some(prev_step) = &(var.b) {
214                                            // println!("Exists");
215                                            // bt​←μbt−1​+(1−τ)gt
216                                            let bt = ((prev_step.as_tensor() * momentum)?
217                                                + (1. - self.params.dampening) * (grad))?;
218
219                                            let gt = (grad + (momentum * &bt)?)?;
220                                            // println!("Momentum {}", bt);
221                                            prev_step.set(&bt)?;
222                                            theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
223                                        } else {
224                                            // println!("Doesn't Exist");
225                                            // bt​←μbt−1​+(1−τ)gt
226                                            // if there is no history bt = gt = grad with no weight_decay
227                                            let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
228
229                                            let gt = (grad + (momentum * &bt)?)?;
230                                            // println!("Momentum {}", bt);
231                                            var.b = Some(Var::from_tensor(&bt)?);
232                                            theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
233                                        }
234                                    }
235                                }
236                            }
237                            Decay::DecoupledWeightDecay(decay) => {
238                                for var in &mut self.vars {
239                                    let theta = &var.theta;
240                                    // let prev_step = var.b;
241                                    if let Some(grad) = grads.get(theta) {
242                                        // decoupled weight decay step
243                                        theta.set(
244                                            &(theta.as_tensor()
245                                                * self.params.lr.mul_add(-decay, 1.))?,
246                                        )?;
247                                        if let Some(prev_step) = &(var.b) {
248                                            // println!("Exists");
249                                            // bt​←μbt−1​+(1−τ)gt
250                                            let bt = ((prev_step.as_tensor() * momentum)?
251                                                + (1. - self.params.dampening) * (grad))?;
252
253                                            let gt = (grad + (momentum * &bt)?)?;
254                                            // println!("Momentum {}", bt);
255                                            prev_step.set(&bt)?;
256                                            theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
257                                        } else {
258                                            // println!("Doesn't Exist");
259                                            // bt​←μbt−1​+(1−τ)gt
260                                            // if there is no history bt = gt = grad with no weight_decay
261                                            let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
262
263                                            let gt = (grad + (momentum * &bt)?)?;
264                                            // println!("Momentum {}", bt);
265                                            var.b = Some(Var::from_tensor(&bt)?);
266                                            theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
267                                        }
268                                    }
269                                }
270                            }
271                        }
272                    } else {
273                        for var in &mut self.vars {
274                            let theta = &var.theta;
275                            // let prev_step = var.b;
276                            if let Some(grad) = grads.get(theta) {
277                                if let Some(prev_step) = &(var.b) {
278                                    // println!("Exists");
279                                    // bt​←μbt−1​+(1−τ)gt
280                                    let bt = ((prev_step.as_tensor() * momentum)?
281                                        + (1. - self.params.dampening) * (grad))?;
282
283                                    let gt = (grad + (momentum * &bt)?)?;
284                                    // println!("Momentum {}", bt);
285                                    prev_step.set(&bt)?;
286                                    theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
287                                } else {
288                                    // println!("Doesn't Exist");
289                                    // bt​←μbt−1​+(1−τ)gt
290                                    // if there is no history bt = gt = grad with no weight_decay
291                                    let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap
292
293                                    let gt = (grad + (momentum * &bt)?)?;
294                                    // println!("Momentum {}", bt);
295                                    var.b = Some(Var::from_tensor(&bt)?);
296                                    theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
297                                }
298                            }
299                        }
300                    }
301                }
302            }
303        } else if let Some(decay) = self.params.weight_decay {
304            // These should be the same up to numeric precision
305            // For SGD with no momentum decoupled weight decay and L2 reg are equivalent
306            match decay {
307                Decay::WeightDecay(decay) => {
308                    for var in &mut self.vars {
309                        let theta = &var.theta;
310                        // let prev_step = var.b;
311                        if let Some(grad) = grads.get(theta) {
312                            let grad = &(grad + (decay * theta.as_tensor())?)?; // weight decay grad
313                            theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta
314                        }
315                    }
316                }
317                Decay::DecoupledWeightDecay(decay) => {
318                    for var in &mut self.vars {
319                        let theta = &var.theta;
320                        // let prev_step = var.b;
321                        if let Some(grad) = grads.get(theta) {
322                            theta
323                                .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
324                            theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta based on grad
325                        }
326                    }
327                }
328            }
329        } else {
330            for var in &mut self.vars {
331                let theta = &var.theta;
332                // let prev_step = var.b;
333                if let Some(grad) = grads.get(theta) {
334                    theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta based on grad
335                }
336            }
337        }
338
339        Ok(())
340    }
341
342    fn set_learning_rate(&mut self, lr: f64) {
343        self.params.lr = lr;
344    }
345}
346
347impl OptimParams for SGD {
348    fn params(&self) -> &Self::Config {
349        &self.params
350    }
351
352    fn set_params(&mut self, config: Self::Config) {
353        self.params = config;
354    }
355}
356
357impl SGD {
358    /// Return the vars being optimised
359    #[must_use]
360    pub fn into_inner(self) -> Vec<Var> {
361        self.vars.into_iter().map(|v| v.theta).collect()
362    }
363
364    // pub fn push(&mut self, var: &Var) {
365    //     self.vars.push(var.clone());
366    // }
367}
368
369#[cfg(test)]
370mod tests {
371    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
372
373    use anyhow::Result;
374    use assert_approx_eq::assert_approx_eq;
375    use candle_core::{Device, Var};
376    use candle_nn::Optimizer;
377
378    use super::*;
379    #[test]
380    fn lr_test() -> Result<()> {
381        let params = ParamsSGD {
382            lr: 0.004,
383            ..Default::default()
384        };
385        // Now use backprop to run a linear regression between samples and get the coefficients back.
386        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
387        let b = Var::new(0f32, &Device::Cpu)?;
388        let mut optim = SGD::new(vec![w.clone(), b.clone()], params)?;
389        assert_approx_eq!(0.004, optim.learning_rate());
390        optim.set_learning_rate(0.002);
391        assert_approx_eq!(0.002, optim.learning_rate());
392        Ok(())
393    }
394
395    #[test]
396    fn into_inner_test() -> Result<()> {
397        let params = ParamsSGD::default();
398        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
399        let b = Var::new(-2f32, &Device::Cpu)?;
400        let optim = SGD::new(vec![w.clone(), b.clone()], params)?;
401        let inner = optim.into_inner();
402        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
403        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
404        Ok(())
405    }
406
407    #[test]
408    fn params_test() -> Result<()> {
409        let params = ParamsSGD {
410            lr: 0.004,
411            ..Default::default()
412        };
413        // Now use backprop to run a linear regression between samples and get the coefficients back.
414        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
415        let b = Var::new(0f32, &Device::Cpu)?;
416        let mut optim = SGD::new(vec![w.clone(), b.clone()], params.clone())?;
417        assert_eq!(params, optim.params().clone());
418        let new_params = ParamsSGD {
419            lr: 0.002,
420            ..Default::default()
421        };
422        optim.set_params(new_params.clone());
423        assert_eq!(new_params, optim.params().clone());
424        Ok(())
425    }
426}