1use candle_core::{Result, Var};
51use candle_nn::optim::Optimizer;
52use log::warn;
53
54use crate::{Decay, OptimParams};
55
56trait AdamInner {
57 fn new(vars: Vec<Var>) -> Result<Self>
58 where
59 Self: Sized;
60 fn into_inner(self) -> Vec<Var>;
61 fn inner_step(
62 &self,
63 params: &ParamsAdam,
64 grads: &candle_core::backprop::GradStore,
65 t: f64,
66 ) -> Result<()>;
67}
68
69#[derive(Debug)]
78pub struct Adam {
79 vars: VarAdam,
80 params: ParamsAdam,
81 t: f64,
82}
83
84#[derive(Debug)]
85struct VarAdamBase {
86 theta: Var,
87 m: Var,
88 v: Var,
89}
90
91#[derive(Debug)]
92struct VecAdamBase(Vec<VarAdamBase>);
93
94impl AdamInner for VecAdamBase {
95 fn new(vars: Vec<Var>) -> Result<Self>
96 where
97 Self: Sized,
98 {
99 Ok(VecAdamBase(
100 vars.into_iter()
101 .filter(|var| var.dtype().is_float())
102 .map(|var| {
103 let dtype = var.dtype();
104 let shape = var.shape();
105 let device = var.device();
106 let m = Var::zeros(shape, dtype, device)?;
107 let v = Var::zeros(shape, dtype, device)?;
108 Ok(VarAdamBase { theta: var, m, v })
109 })
110 .collect::<Result<Vec<VarAdamBase>>>()?,
111 ))
112 }
113
114 fn into_inner(self) -> Vec<Var> {
115 self.0.into_iter().map(|var| var.theta).collect()
116 }
117
118 fn inner_step(
119 &self,
120 params: &ParamsAdam,
121 grads: &candle_core::backprop::GradStore,
122 t: f64,
123 ) -> Result<()> {
124 if let Some(decay) = params.weight_decay {
125 match decay {
126 Decay::WeightDecay(decay) => {
127 for var in &self.0 {
128 let theta = &var.theta;
129 let m = &var.m;
130 let v = &var.v;
131 if let Some(grad) = grads.get(theta) {
132 let grad = &(grad + (decay * theta.as_tensor())?)?;
133 let m_next = ((params.beta_1 * m.as_tensor())?
134 + ((1. - params.beta_1) * grad)?)?;
135 let v_next = ((params.beta_2 * v.as_tensor())?
136 + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
137 let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
138 let v_hat = (&v_next / (1. - params.beta_2.powf(t)))?;
139 let delta =
140 (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
141 theta.set(&theta.sub(&(delta))?)?;
142 m.set(&m_next)?;
143 v.set(&v_next)?;
144 }
145 }
146 }
147 Decay::DecoupledWeightDecay(decay) => {
148 for var in &self.0 {
149 let theta = &var.theta;
150 let m = &var.m;
151 let v = &var.v;
152 if let Some(grad) = grads.get(theta) {
153 theta.set(&(theta.as_tensor() * params.lr.mul_add(-decay, 1.))?)?;
154 let m_next = ((params.beta_1 * m.as_tensor())?
155 + ((1. - params.beta_1) * grad)?)?;
156 let v_next = ((params.beta_2 * v.as_tensor())?
157 + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
158 let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
159 let v_hat = (&v_next / (1. - params.beta_2.powf(t)))?;
160 let delta =
161 (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
162 theta.set(&theta.sub(&(delta))?)?;
163 m.set(&m_next)?;
164 v.set(&v_next)?;
165 }
166 }
167 }
168 }
169 } else {
170 for var in &self.0 {
171 let theta = &var.theta;
172 let m = &var.m;
173 let v = &var.v;
174 if let Some(grad) = grads.get(theta) {
175 let m_next =
176 ((params.beta_1 * m.as_tensor())? + ((1. - params.beta_1) * grad)?)?;
177 let v_next = ((params.beta_2 * v.as_tensor())?
178 + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
179 let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
180 let v_hat = (&v_next / (1. - params.beta_2.powf(t)))?;
181 let delta = (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
182 theta.set(&theta.sub(&(delta))?)?;
183 m.set(&m_next)?;
184 v.set(&v_next)?;
185 }
186 }
187 }
188 Ok(())
189 }
190}
191
192#[derive(Debug)]
193struct VarAdamAmsgrad {
194 theta: Var,
195 m: Var,
196 v: Var,
197 vmax: Var,
198}
199
200#[derive(Debug)]
201struct VecAdamAmsgrad(Vec<VarAdamAmsgrad>);
202
203impl AdamInner for VecAdamAmsgrad {
204 fn new(vars: Vec<Var>) -> Result<Self>
205 where
206 Self: Sized,
207 {
208 Ok(VecAdamAmsgrad(
209 vars.into_iter()
210 .filter(|var| var.dtype().is_float())
211 .map(|var| {
212 let dtype = var.dtype();
213 let shape = var.shape();
214 let device = var.device();
215 let m = Var::zeros(shape, dtype, device)?;
216 let v = Var::zeros(shape, dtype, device)?;
217 let vmax = Var::zeros(shape, dtype, device)?;
218 Ok(VarAdamAmsgrad {
219 theta: var,
220 m,
221 v,
222 vmax,
223 })
224 })
225 .collect::<Result<Vec<VarAdamAmsgrad>>>()?,
226 ))
227 }
228
229 fn into_inner(self) -> Vec<Var> {
230 self.0.into_iter().map(|var| var.theta).collect()
231 }
232
233 fn inner_step(
234 &self,
235 params: &ParamsAdam,
236 grads: &candle_core::backprop::GradStore,
237 t: f64,
238 ) -> Result<()> {
239 if let Some(decay) = params.weight_decay {
240 match decay {
241 Decay::WeightDecay(decay) => {
242 for var in &self.0 {
243 let theta = &var.theta;
244 let m = &var.m;
245 let v = &var.v;
246 let vmax = &var.vmax;
247 if let Some(grad) = grads.get(theta) {
248 let grad = &(grad + (decay * theta.as_tensor())?)?;
249 let m_next = ((params.beta_1 * m.as_tensor())?
250 + ((1. - params.beta_1) * grad)?)?;
251 let v_next = ((params.beta_2 * v.as_tensor())?
252 + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
253 let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
254 let vmax_next = vmax.maximum(&v_next)?;
255 let v_hat = (&vmax_next / (1. - params.beta_2.powf(t)))?;
256 let delta =
257 (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
258 theta.set(&theta.sub(&(delta))?)?;
259 m.set(&m_next)?;
260 v.set(&v_next)?;
261 vmax.set(&vmax_next)?;
262 }
263 }
264 }
265 Decay::DecoupledWeightDecay(decay) => {
266 for var in &self.0 {
267 let theta = &var.theta;
268 let m = &var.m;
269 let v = &var.v;
270 let vmax = &var.vmax;
271 if let Some(grad) = grads.get(theta) {
272 theta.set(&(theta.as_tensor() * params.lr.mul_add(-decay, 1.))?)?;
273 let m_next = ((params.beta_1 * m.as_tensor())?
274 + ((1. - params.beta_1) * grad)?)?;
275 let v_next = ((params.beta_2 * v.as_tensor())?
276 + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
277 let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
278 let vmax_next = vmax.maximum(&v_next)?;
279 let v_hat = (&vmax_next / (1. - params.beta_2.powf(t)))?;
280 let delta =
281 (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
282 theta.set(&theta.sub(&(delta))?)?;
283 m.set(&m_next)?;
284 v.set(&v_next)?;
285 vmax.set(&vmax_next)?;
286 }
287 }
288 }
289 }
290 } else {
291 for var in &self.0 {
292 let theta = &var.theta;
293 let m = &var.m;
294 let v = &var.v;
295 let vmax = &var.vmax;
296 if let Some(grad) = grads.get(theta) {
297 let m_next =
298 ((params.beta_1 * m.as_tensor())? + ((1. - params.beta_1) * grad)?)?;
299 let v_next = ((params.beta_2 * v.as_tensor())?
300 + ((1. - params.beta_2) * grad.powf(2.)?)?)?;
301 let m_hat = (&m_next / (1. - (params.beta_1).powf(t)))?;
302 let vmax_next = vmax.maximum(&v_next)?;
303 let v_hat = (&vmax_next / (1. - params.beta_2.powf(t)))?;
304 let delta = (m_hat * params.lr)?.div(&(v_hat.powf(0.5)? + params.eps)?)?;
305 theta.set(&theta.sub(&(delta))?)?;
306 m.set(&m_next)?;
307 v.set(&v_next)?;
308 vmax.set(&vmax_next)?;
309 }
310 }
311 }
312 Ok(())
313 }
314}
315
316#[derive(Debug)]
317enum VarAdam {
318 VecAdamBase(VecAdamBase),
319 VecAdamAmsgrad(VecAdamAmsgrad),
320}
321
322#[allow(clippy::module_name_repetitions)]
324#[derive(Clone, Debug, PartialEq, PartialOrd)]
325pub struct ParamsAdam {
326 pub lr: f64,
328 pub beta_1: f64,
330 pub beta_2: f64,
332 pub eps: f64,
334 pub weight_decay: Option<Decay>,
336 pub amsgrad: bool,
338}
339
340impl Default for ParamsAdam {
341 fn default() -> Self {
342 Self {
343 lr: 0.001,
344 beta_1: 0.9,
345 beta_2: 0.999,
346 eps: 1e-8,
347 weight_decay: None,
348 amsgrad: false,
349 }
351 }
352}
353
354impl Optimizer for Adam {
355 type Config = ParamsAdam;
356
357 fn new(vars: Vec<Var>, params: ParamsAdam) -> Result<Self> {
358 if params.amsgrad {
359 let vars = VarAdam::VecAdamAmsgrad(VecAdamAmsgrad::new(vars)?);
360 Ok(Self {
361 vars,
362 params,
363 t: 1.,
364 })
365 } else {
366 let vars = VarAdam::VecAdamBase(VecAdamBase::new(vars)?);
367 Ok(Self {
368 vars,
369 params,
370 t: 1.,
371 })
372 }
373 }
374
375 fn learning_rate(&self) -> f64 {
376 self.params.lr
377 }
378
379 fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
380 match &self.vars {
381 VarAdam::VecAdamBase(vars) => vars.inner_step(&self.params, grads, self.t)?,
382 VarAdam::VecAdamAmsgrad(vars) => vars.inner_step(&self.params, grads, self.t)?,
383 }
384 self.t += 1.;
385 Ok(())
386 }
387
388 fn set_learning_rate(&mut self, lr: f64) {
389 self.params.lr = lr;
390 }
391}
392
393impl OptimParams for Adam {
394 fn params(&self) -> &Self::Config {
395 &self.params
396 }
397
398 fn set_params(&mut self, config: Self::Config) {
405 let ams_grad = self.params.amsgrad;
406 if ams_grad == config.amsgrad {
407 self.params = config;
408 } else {
409 warn!("AMSGrad cannot be changed once set");
410 let mut config = config;
411 config.amsgrad = ams_grad;
412 self.params = config;
413 }
414 }
415}
416
417impl Adam {
418 #[must_use]
420 pub fn into_inner(self) -> Vec<Var> {
421 match self.vars {
422 VarAdam::VecAdamBase(vars) => vars.into_inner(),
423 VarAdam::VecAdamAmsgrad(vars) => vars.into_inner(),
424 }
425 }
426
427 pub fn set_betas(&mut self, beta_1: f64, beta_2: f64) {
431 self.params.beta_1 = beta_1;
432 self.params.beta_2 = beta_2;
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use anyhow::Result;
441 use assert_approx_eq::assert_approx_eq;
442 use candle_core::{Device, Var};
443 use candle_nn::Optimizer;
444
445 use super::*;
446 #[test]
447 fn lr_test() -> Result<()> {
448 let params = ParamsAdam {
449 lr: 0.004,
450 ..Default::default()
451 };
452 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
454 let b = Var::new(0f32, &Device::Cpu)?;
455 let mut optim = Adam::new(vec![w.clone(), b.clone()], params)?;
456 assert_approx_eq!(0.004, optim.learning_rate());
457 optim.set_learning_rate(0.002);
458 assert_approx_eq!(0.002, optim.learning_rate());
459 Ok(())
460 }
461
462 #[test]
463 fn into_inner_test() -> Result<()> {
464 let params = ParamsAdam::default();
465 let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
466 let b = Var::new(-2f32, &Device::Cpu)?;
467 let optim = Adam::new(vec![w.clone(), b.clone()], params)?;
468 let inner = optim.into_inner();
469 assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
470 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
471 let params = ParamsAdam {
472 amsgrad: true,
473 ..Default::default()
474 };
475 let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
476 let b = Var::new(-2f32, &Device::Cpu)?;
477 let n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
478 let inner = n_sgd.into_inner();
479 assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
480 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
481 Ok(())
482 }
483
484 #[test]
485 fn params_test() -> Result<()> {
486 let params = ParamsAdam {
487 lr: 0.004,
488 ..Default::default()
489 };
490 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
492 let b = Var::new(0f32, &Device::Cpu)?;
493 let mut optim = Adam::new(vec![w.clone(), b.clone()], params.clone())?;
494 assert_eq!(params, optim.params().clone());
495 let new_params = ParamsAdam {
496 lr: 0.002,
497 ..Default::default()
498 };
499 optim.set_params(new_params.clone());
500 assert_eq!(new_params, optim.params().clone());
501
502 let ams_params = ParamsAdam {
503 lr: 0.002,
504 amsgrad: true,
505 ..Default::default()
506 };
507 optim.set_params(ams_params);
508 assert_eq!(new_params, optim.params().clone());
510 optim.set_betas(0.1, 0.1);
511 let final_params = ParamsAdam {
512 lr: 0.002,
513 beta_1: 0.1,
514 beta_2: 0.1,
515 ..Default::default()
516 };
517 assert_eq!(final_params, optim.params().clone());
518 Ok(())
519 }
520}