candle_optimisers/adadelta.rs
1/*!
2Adadelta optimiser
3
4Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
5
6Pseudocode (including decoupling of weight decay):
7$$
8\\begin{aligned}
9 &\\rule{110mm}{0.4pt} \\\\
10 &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)},
11 \\: f(\\theta) \\text{ (objective)}, \\: \\rho \\text{ (decay)},
12 \\: \\lambda \\text{ (weight decay)} \\\\
13 &\\textbf{initialize} : v_0 \\leftarrow 0 \\: \\text{ (square avg)},
14 \\: u_0 \\leftarrow 0 \\: \\text{ (accumulate variables)} \\\\[-1.ex]
15 &\\rule{110mm}{0.4pt} \\\\
16 &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
17 &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
18 &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
19 &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
20 &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
21 &\\hspace{10mm}\\textbf{else} \\\\
22 &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
23 &\\hspace{5mm} v_t \\leftarrow v_{t-1} \\rho + g^2_t (1 - \\rho) \\\\
24 &\\hspace{5mm}\\Delta x_t \\leftarrow \\frac{\\sqrt{u_{t-1} +
25 \\epsilon }}{ \\sqrt{v_t + \\epsilon} }g_t \\hspace{21mm} \\\\
26 &\\hspace{5mm} u_t \\leftarrow u_{t-1} \\rho +
27 \\Delta x^2_t (1 - \\rho) \\\\
28 &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\Delta x_t \\\\
29 &\\rule{110mm}{0.4pt} \\\\[-1.ex]
30 &\\bf{return} \\: \\theta_t \\\\[-1.ex]
31 &\\rule{110mm}{0.4pt} \\\\[-1.ex]
32 \\end{aligned}
33$$
34*/
35
36use candle_core::{Result, Var};
37use candle_nn::optim::Optimizer;
38
39use crate::{Decay, OptimParams};
40
41/// Adadelta optimiser
42///
43/// Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
44#[derive(Debug)]
45pub struct Adadelta {
46 vars: Vec<VarAdaDelta>,
47 params: ParamsAdaDelta,
48 // avg_acc: HashMap<TensorId, (Tensor, Tensor)>,
49}
50
51#[derive(Debug)]
52struct VarAdaDelta {
53 theta: Var,
54 v: Var,
55 u: Var,
56}
57
58/// Parameters for the Adadelta optimiser
59#[derive(Clone, Debug, PartialEq, PartialOrd)]
60pub struct ParamsAdaDelta {
61 /// Learning rate
62 pub lr: f64,
63 /// Decay
64 pub rho: f64,
65 /// Term added to the denominator to improve numerical stability
66 pub eps: f64,
67 /// Weight decay
68 pub weight_decay: Option<Decay>,
69}
70
71impl Default for ParamsAdaDelta {
72 fn default() -> Self {
73 Self {
74 lr: 1.0,
75 rho: 0.9,
76 weight_decay: None,
77 eps: 1e-6,
78 }
79 }
80}
81
82impl Optimizer for Adadelta {
83 type Config = ParamsAdaDelta;
84
85 fn new(vars: Vec<Var>, params: ParamsAdaDelta) -> Result<Self> {
86 let vars = vars
87 .into_iter()
88 .filter(|var| var.dtype().is_float())
89 .map(|var| {
90 let dtype = var.dtype();
91 let shape = var.shape();
92 let device = var.device();
93 let v = Var::zeros(shape, dtype, device)?;
94 let u = Var::zeros(shape, dtype, device)?;
95 Ok(VarAdaDelta { theta: var, v, u })
96 })
97 .collect::<Result<Vec<VarAdaDelta>>>()?;
98 // // Err(SGDError::NoMomentum)?;
99 // let mut params = params;
100 // params.t = 0;
101 Ok(Self {
102 vars,
103 params,
104 // avg_acc: HashMap::new(),
105 })
106 }
107
108 fn learning_rate(&self) -> f64 {
109 self.params.lr
110 }
111
112 fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
113 if let Some(decay) = self.params.weight_decay {
114 match decay {
115 Decay::WeightDecay(decay) => {
116 for var in &self.vars {
117 let theta = &var.theta;
118 let v = &var.v;
119 let u = &var.u;
120 if let Some(grad) = grads.get(theta) {
121 let grad = &(grad + (decay * theta.as_tensor())?)?;
122 let v_next = ((v.as_tensor() * self.params.rho)?
123 + (1. - self.params.rho) * grad.powf(2.)?)?;
124 let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
125 .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
126 * grad)?;
127 let u_next = ((u.as_tensor() * self.params.rho)?
128 + (1. - self.params.rho) * delta_x.powf(2.)?)?;
129 theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
130 v.set(&v_next)?;
131 u.set(&u_next)?;
132 }
133 }
134 }
135 Decay::DecoupledWeightDecay(decay) => {
136 for var in &self.vars {
137 let theta = &var.theta;
138 let v = &var.v;
139 let u = &var.u;
140 if let Some(grad) = grads.get(theta) {
141 // decoupled weight decay step
142 theta
143 .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
144 let v_next = ((v.as_tensor() * self.params.rho)?
145 + (1. - self.params.rho) * grad.powf(2.)?)?;
146 let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
147 .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
148 * grad)?;
149 let u_next = ((u.as_tensor() * self.params.rho)?
150 + (1. - self.params.rho) * delta_x.powf(2.)?)?;
151 theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
152 v.set(&v_next)?;
153 u.set(&u_next)?;
154 }
155 }
156 }
157 }
158 } else {
159 for var in &self.vars {
160 let theta = &var.theta;
161 let v = &var.v;
162 let u = &var.u;
163 if let Some(grad) = grads.get(theta) {
164 let v_next = ((v.as_tensor() * self.params.rho)?
165 + (1. - self.params.rho) * grad.powf(2.)?)?;
166 let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
167 .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
168 * grad)?;
169 let u_next = ((u.as_tensor() * self.params.rho)?
170 + (1. - self.params.rho) * delta_x.powf(2.)?)?;
171 theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
172 v.set(&v_next)?;
173 u.set(&u_next)?;
174 }
175 }
176 }
177
178 Ok(())
179 }
180
181 fn set_learning_rate(&mut self, lr: f64) {
182 self.params.lr = lr;
183 }
184}
185
186impl OptimParams for Adadelta {
187 fn params(&self) -> &Self::Config {
188 &self.params
189 }
190
191 fn set_params(&mut self, config: Self::Config) {
192 self.params = config;
193 }
194}
195
196impl Adadelta {
197 /// Return the vars being optimised
198 #[must_use]
199 pub fn into_inner(self) -> Vec<Var> {
200 self.vars.into_iter().map(|v| v.theta).collect()
201 }
202
203 // pub fn push(&mut self, var: &Var) {
204 // self.vars.push(var.clone());
205 // }
206}
207
208#[cfg(test)]
209mod tests {
210 // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
211
212 use anyhow::Result;
213 use assert_approx_eq::assert_approx_eq;
214 use candle_core::{Device, Var};
215 use candle_nn::Optimizer;
216
217 use super::*;
218 #[test]
219 fn lr_test() -> Result<()> {
220 let params = ParamsAdaDelta {
221 lr: 0.004,
222 ..Default::default()
223 };
224 // Now use backprop to run a linear regression between samples and get the coefficients back.
225 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
226 let b = Var::new(0f32, &Device::Cpu)?;
227 let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params)?;
228 assert_approx_eq!(0.004, optim.learning_rate());
229 optim.set_learning_rate(0.002);
230 assert_approx_eq!(0.002, optim.learning_rate());
231 Ok(())
232 }
233
234 #[test]
235 fn into_inner_test() -> Result<()> {
236 let params = ParamsAdaDelta::default();
237 let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
238 let b = Var::new(-2f32, &Device::Cpu)?;
239 let optim = Adadelta::new(vec![w.clone(), b.clone()], params)?;
240 let inner = optim.into_inner();
241 assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
242 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
243 Ok(())
244 }
245
246 #[test]
247 fn params_test() -> Result<()> {
248 let params = ParamsAdaDelta {
249 lr: 0.004,
250 ..Default::default()
251 };
252 // Now use backprop to run a linear regression between samples and get the coefficients back.
253 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
254 let b = Var::new(0f32, &Device::Cpu)?;
255 let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params.clone())?;
256 assert_eq!(params, optim.params().clone());
257 let new_params = ParamsAdaDelta {
258 lr: 0.002,
259 ..Default::default()
260 };
261 optim.set_params(new_params.clone());
262 assert_eq!(new_params, optim.params().clone());
263 Ok(())
264 }
265}