candle_optimisers/
adadelta.rs

1/*!
2Adadelta optimiser
3
4Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
5
6Pseudocode (including decoupling of weight decay):
7$$
8\\begin{aligned}
9            &\\rule{110mm}{0.4pt}                                                                 \\\\
10            &\\textbf{input}      : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)},
11                \\: f(\\theta) \\text{ (objective)}, \\: \\rho \\text{ (decay)},
12                \\: \\lambda \\text{ (weight decay)}                                                \\\\
13            &\\textbf{initialize} :  v_0  \\leftarrow 0 \\: \\text{ (square avg)},
14                \\: u_0 \\leftarrow 0 \\: \\text{ (accumulate variables)}                     \\\\[-1.ex]
15            &\\rule{110mm}{0.4pt}                                                                 \\\\
16            &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do}                     \\\\
17            &\\hspace{5mm}g_t           \\leftarrow   \\nabla_{\\theta} f_t (\\theta_{t-1})          \\\\
18            &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some}                        \\\\
19            &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled}                       \\\\
20            &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1}                    \\\\
21            &\\hspace{10mm}\\textbf{else}                                                              \\\\
22            &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda  \\theta_{t-1}                            \\\\
23            &\\hspace{5mm} v_t      \\leftarrow v_{t-1} \\rho + g^2_t (1 - \\rho)                    \\\\
24            &\\hspace{5mm}\\Delta x_t    \\leftarrow   \\frac{\\sqrt{u_{t-1} +
25                \\epsilon }}{ \\sqrt{v_t + \\epsilon}  }g_t \\hspace{21mm}                           \\\\
26            &\\hspace{5mm} u_t  \\leftarrow   u_{t-1}  \\rho +
27                 \\Delta x^2_t  (1 - \\rho)                                                        \\\\
28            &\\hspace{5mm}\\theta_t      \\leftarrow   \\theta_{t-1} - \\gamma  \\Delta x_t            \\\\
29            &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
30            &\\bf{return} \\:  \\theta_t                                                     \\\\[-1.ex]
31            &\\rule{110mm}{0.4pt}                                                          \\\\[-1.ex]
32       \\end{aligned}
33$$
34*/
35
36use candle_core::{Result, Var};
37use candle_nn::optim::Optimizer;
38
39use crate::{Decay, OptimParams};
40
41/// Adadelta optimiser
42///
43/// Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
44#[derive(Debug)]
45pub struct Adadelta {
46    vars: Vec<VarAdaDelta>,
47    params: ParamsAdaDelta,
48    // avg_acc: HashMap<TensorId, (Tensor, Tensor)>,
49}
50
51#[derive(Debug)]
52struct VarAdaDelta {
53    theta: Var,
54    v: Var,
55    u: Var,
56}
57
58/// Parameters for the Adadelta optimiser
59#[derive(Clone, Debug, PartialEq, PartialOrd)]
60pub struct ParamsAdaDelta {
61    /// Learning rate
62    pub lr: f64,
63    /// Decay
64    pub rho: f64,
65    /// Term added to the denominator to improve numerical stability
66    pub eps: f64,
67    /// Weight decay
68    pub weight_decay: Option<Decay>,
69}
70
71impl Default for ParamsAdaDelta {
72    fn default() -> Self {
73        Self {
74            lr: 1.0,
75            rho: 0.9,
76            weight_decay: None,
77            eps: 1e-6,
78        }
79    }
80}
81
82impl Optimizer for Adadelta {
83    type Config = ParamsAdaDelta;
84
85    fn new(vars: Vec<Var>, params: ParamsAdaDelta) -> Result<Self> {
86        let vars = vars
87            .into_iter()
88            .filter(|var| var.dtype().is_float())
89            .map(|var| {
90                let dtype = var.dtype();
91                let shape = var.shape();
92                let device = var.device();
93                let v = Var::zeros(shape, dtype, device)?;
94                let u = Var::zeros(shape, dtype, device)?;
95                Ok(VarAdaDelta { theta: var, v, u })
96            })
97            .collect::<Result<Vec<VarAdaDelta>>>()?;
98        // // Err(SGDError::NoMomentum)?;
99        // let mut params = params;
100        // params.t = 0;
101        Ok(Self {
102            vars,
103            params,
104            // avg_acc: HashMap::new(),
105        })
106    }
107
108    fn learning_rate(&self) -> f64 {
109        self.params.lr
110    }
111
112    fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
113        if let Some(decay) = self.params.weight_decay {
114            match decay {
115                Decay::WeightDecay(decay) => {
116                    for var in &self.vars {
117                        let theta = &var.theta;
118                        let v = &var.v;
119                        let u = &var.u;
120                        if let Some(grad) = grads.get(theta) {
121                            let grad = &(grad + (decay * theta.as_tensor())?)?;
122                            let v_next = ((v.as_tensor() * self.params.rho)?
123                                + (1. - self.params.rho) * grad.powf(2.)?)?;
124                            let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
125                                .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
126                                * grad)?;
127                            let u_next = ((u.as_tensor() * self.params.rho)?
128                                + (1. - self.params.rho) * delta_x.powf(2.)?)?;
129                            theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
130                            v.set(&v_next)?;
131                            u.set(&u_next)?;
132                        }
133                    }
134                }
135                Decay::DecoupledWeightDecay(decay) => {
136                    for var in &self.vars {
137                        let theta = &var.theta;
138                        let v = &var.v;
139                        let u = &var.u;
140                        if let Some(grad) = grads.get(theta) {
141                            // decoupled weight decay step
142                            theta
143                                .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
144                            let v_next = ((v.as_tensor() * self.params.rho)?
145                                + (1. - self.params.rho) * grad.powf(2.)?)?;
146                            let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
147                                .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
148                                * grad)?;
149                            let u_next = ((u.as_tensor() * self.params.rho)?
150                                + (1. - self.params.rho) * delta_x.powf(2.)?)?;
151                            theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
152                            v.set(&v_next)?;
153                            u.set(&u_next)?;
154                        }
155                    }
156                }
157            }
158        } else {
159            for var in &self.vars {
160                let theta = &var.theta;
161                let v = &var.v;
162                let u = &var.u;
163                if let Some(grad) = grads.get(theta) {
164                    let v_next = ((v.as_tensor() * self.params.rho)?
165                        + (1. - self.params.rho) * grad.powf(2.)?)?;
166                    let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
167                        .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
168                        * grad)?;
169                    let u_next = ((u.as_tensor() * self.params.rho)?
170                        + (1. - self.params.rho) * delta_x.powf(2.)?)?;
171                    theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
172                    v.set(&v_next)?;
173                    u.set(&u_next)?;
174                }
175            }
176        }
177
178        Ok(())
179    }
180
181    fn set_learning_rate(&mut self, lr: f64) {
182        self.params.lr = lr;
183    }
184}
185
186impl OptimParams for Adadelta {
187    fn params(&self) -> &Self::Config {
188        &self.params
189    }
190
191    fn set_params(&mut self, config: Self::Config) {
192        self.params = config;
193    }
194}
195
196impl Adadelta {
197    /// Return the vars being optimised
198    #[must_use]
199    pub fn into_inner(self) -> Vec<Var> {
200        self.vars.into_iter().map(|v| v.theta).collect()
201    }
202
203    // pub fn push(&mut self, var: &Var) {
204    //     self.vars.push(var.clone());
205    // }
206}
207
208#[cfg(test)]
209mod tests {
210    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
211
212    use anyhow::Result;
213    use assert_approx_eq::assert_approx_eq;
214    use candle_core::{Device, Var};
215    use candle_nn::Optimizer;
216
217    use super::*;
218    #[test]
219    fn lr_test() -> Result<()> {
220        let params = ParamsAdaDelta {
221            lr: 0.004,
222            ..Default::default()
223        };
224        // Now use backprop to run a linear regression between samples and get the coefficients back.
225        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
226        let b = Var::new(0f32, &Device::Cpu)?;
227        let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params)?;
228        assert_approx_eq!(0.004, optim.learning_rate());
229        optim.set_learning_rate(0.002);
230        assert_approx_eq!(0.002, optim.learning_rate());
231        Ok(())
232    }
233
234    #[test]
235    fn into_inner_test() -> Result<()> {
236        let params = ParamsAdaDelta::default();
237        let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
238        let b = Var::new(-2f32, &Device::Cpu)?;
239        let optim = Adadelta::new(vec![w.clone(), b.clone()], params)?;
240        let inner = optim.into_inner();
241        assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
242        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
243        Ok(())
244    }
245
246    #[test]
247    fn params_test() -> Result<()> {
248        let params = ParamsAdaDelta {
249            lr: 0.004,
250            ..Default::default()
251        };
252        // Now use backprop to run a linear regression between samples and get the coefficients back.
253        let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
254        let b = Var::new(0f32, &Device::Cpu)?;
255        let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params.clone())?;
256        assert_eq!(params, optim.params().clone());
257        let new_params = ParamsAdaDelta {
258            lr: 0.002,
259            ..Default::default()
260        };
261        optim.set_params(new_params.clone());
262        assert_eq!(new_params, optim.params().clone());
263        Ok(())
264    }
265}