candle_optimisers/
adagrad.rs

1/*!
2Adagrad optimiser
3
4Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html)
5
6Pseudocode (including decoupling of weight decay):
7
8$$
9\\begin{aligned}
10            &\\rule{110mm}{0.4pt}                                                                 \\\\
11            &\\textbf{input}      : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta)
12                \\text{ (objective)}, \\: \\lambda \\text{ (weight decay)},                          \\\\
13            &\\hspace{12mm}    \\tau \\text{ (initial accumulator value)}, \\: \\eta\\text{ (lr decay)}\\\\
14            &\\textbf{initialize} :  statesum_0 \\leftarrow 0                             \\\\[-1.ex]
15            &\\rule{110mm}{0.4pt}                                                                 \\\\
16            &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do}                         \\\\
17            &\\hspace{5mm}g_t           \\leftarrow   \\nabla_{\\theta} f_t (\\theta_{t-1})           \\\\
18            &\\hspace{5mm} \\tilde{\\gamma}    \\leftarrow \\gamma / (1 +(t-1) \\eta)                  \\\\
19            &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some}                        \\\\
20            &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled}                       \\\\
21            &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1}                    \\\\
22            &\\hspace{10mm}\\textbf{else}                                                              \\\\
23            &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda  \\theta_{t-1}                            \\\\
24            &\\hspace{5mm}statesum_t  \\leftarrow  statesum_{t-1} + g^2_t                      \\\\
25            &\\hspace{5mm}\\theta_t \\leftarrow
26                \\theta_{t-1}- \\tilde{\\gamma} \\frac{g_t}{\\sqrt{statesum_t}+\\epsilon}            \\\\
27            &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
28            &\\bf{return} \\:  \\theta_t                                                     \\\\[-1.ex]
29            &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
30       \\end{aligned}
31$$
32
33
34
35*/
36
37use candle_core::{Result, Var};
38use candle_nn::optim::Optimizer;
39
40use crate::{Decay, OptimParams};
41
42/// Adagrad optimiser
43///
44/// Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html)
45#[derive(Debug)]
46pub struct Adagrad {
47    vars: Vec<VarAdaGrad>,
48    params: ParamsAdaGrad,
49    t: f64,
50}
51
52#[derive(Debug)]
53struct VarAdaGrad {
54    theta: Var,
55    sum: Var,
56}
57
58/// Parameters for the Adagrad optimiser
59#[derive(Clone, Debug, PartialEq, PartialOrd)]
60pub struct ParamsAdaGrad {
61    /// Learning rate
62    pub lr: f64,
63    /// Learning rate decay
64    pub lr_decay: f64,
65    /// Initial value of accumulator
66    pub initial_acc: f64,
67    /// weight decay
68    pub weight_decay: Option<Decay>,
69    /// term added to the denominator to improve numerical stability
70    pub eps: f64,
71}
72
73impl Default for ParamsAdaGrad {
74    fn default() -> Self {
75        Self {
76            lr: 0.01,
77            lr_decay: 0.0,
78            initial_acc: 0.0,
79            weight_decay: None,
80            eps: 1e-10,
81        }
82    }
83}
84
85impl Optimizer for Adagrad {
86    type Config = ParamsAdaGrad;
87
88    fn new(vars: Vec<Var>, params: ParamsAdaGrad) -> Result<Self> {
89        let vars = vars
90            .into_iter()
91            .filter(|var| var.dtype().is_float())
92            .map(|var| {
93                let dtype = var.dtype();
94                let shape = var.shape();
95                let device = var.device();
96                let sum = Var::zeros(shape, dtype, device)?;
97                Ok(VarAdaGrad { theta: var, sum })
98            })
99            .collect::<Result<Vec<VarAdaGrad>>>()?;
100        // // Err(SGDError::NoMomentum)?;
101        // let mut params = params;
102        // params.t = 0;
103        Ok(Self {
104            vars,
105            t: 0.,
106            params,
107        })
108    }
109
110    fn learning_rate(&self) -> f64 {
111        self.params.lr
112    }
113
114    fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
115        if let Some(decay) = self.params.weight_decay {
116            match decay {
117                Decay::WeightDecay(decay) => {
118                    for var in &self.vars {
119                        let theta = &var.theta;
120                        let sum = &var.sum;
121                        if let Some(grad) = grads.get(theta) {
122                            let gamma_tilde =
123                                self.params.lr / self.t.mul_add(self.params.lr_decay, 1.);
124                            let grad = &(grad + (decay * theta.as_tensor())?)?;
125                            let current_sum = (sum.as_tensor() + grad.powf(2.)?)?;
126                            let change = (gamma_tilde
127                                * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?;
128                            sum.set(&current_sum)?;
129                            theta.set(&theta.sub(&change)?)?;
130                        }
131                    }
132                }
133                Decay::DecoupledWeightDecay(decay) => {
134                    for var in &self.vars {
135                        let theta = &var.theta;
136                        let sum = &var.sum;
137                        if let Some(grad) = grads.get(theta) {
138                            // decoupled weight decay step
139                            theta
140                                .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
141                            let gamma_tilde =
142                                self.params.lr / self.t.mul_add(self.params.lr_decay, 1.);
143                            let current_sum = (sum.as_tensor() + grad.powf(2.)?)?;
144                            let change = (gamma_tilde
145                                * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?;
146                            sum.set(&current_sum)?;
147                            theta.set(&theta.sub(&change)?)?;
148                        }
149                    }
150                }
151            }
152        } else {
153            for var in &self.vars {
154                let theta = &var.theta;
155                let sum = &var.sum;
156                if let Some(grad) = grads.get(theta) {
157                    let gamma_tilde = self.params.lr / self.t.mul_add(self.params.lr_decay, 1.);
158                    let current_sum = (sum.as_tensor() + grad.powf(2.)?)?;
159                    let change =
160                        (gamma_tilde * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?;
161                    sum.set(&current_sum)?;
162                    theta.set(&theta.sub(&change)?)?;
163                }
164            }
165        }
166        self.t += 1.;
167        Ok(())
168    }
169
170    fn set_learning_rate(&mut self, lr: f64) {
171        self.params.lr = lr;
172    }
173}
174
175impl OptimParams for Adagrad {
176    fn params(&self) -> &Self::Config {
177        &self.params
178    }
179
180    fn set_params(&mut self, config: Self::Config) {
181        self.params = config;
182    }
183}
184
185impl Adagrad {
186    /// Return the vars being optimised
187    #[must_use]
188    pub fn into_inner(self) -> Vec<Var> {
189        self.vars.into_iter().map(|v| v.theta).collect()
190    }
191
192    // pub fn push(&mut self, var: &Var) {
193    //     self.vars.push(var.clone());
194    // }
195}
196
197#[cfg(test)]
198mod tests {
199    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
200
201    use anyhow::Result;
202    use assert_approx_eq::assert_approx_eq;
203    use candle_core::{Device, Var};
204    use candle_nn::Optimizer;
205
206    use super::*;
207    #[test]
208    fn lr_test() -> Result<()> {
209        let params = ParamsAdaGrad {
210            lr: 0.004,
211            ..Default::default()
212        };
213        // Now use backprop to run a linear regression between samples and get the coefficients back.
214        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
215        let b = Var::new(0f32, &Device::Cpu)?;
216        let mut optim = Adagrad::new(vec![w.clone(), b.clone()], params)?;
217        assert_approx_eq!(0.004, optim.learning_rate());
218        optim.set_learning_rate(0.002);
219        assert_approx_eq!(0.002, optim.learning_rate());
220        Ok(())
221    }
222
223    #[test]
224    fn into_inner_test() -> Result<()> {
225        let params = ParamsAdaGrad::default();
226        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
227        let b = Var::new(-2f32, &Device::Cpu)?;
228        let optim = Adagrad::new(vec![w.clone(), b.clone()], params)?;
229        let inner = optim.into_inner();
230        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
231        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
232        Ok(())
233    }
234
235    #[test]
236    fn params_test() -> Result<()> {
237        let params = ParamsAdaGrad {
238            lr: 0.004,
239            ..Default::default()
240        };
241        // Now use backprop to run a linear regression between samples and get the coefficients back.
242        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
243        let b = Var::new(0f32, &Device::Cpu)?;
244        let mut optim = Adagrad::new(vec![w.clone(), b.clone()], params.clone())?;
245        assert_eq!(params, optim.params().clone());
246        let new_params = ParamsAdaGrad {
247            lr: 0.002,
248            ..Default::default()
249        };
250        optim.set_params(new_params.clone());
251        assert_eq!(new_params, optim.params().clone());
252        Ok(())
253    }
254}