candle_optimisers/
nadam.rs1use candle_core::{Result, Var};
42use candle_nn::optim::Optimizer;
43
44use crate::{Decay, OptimParams};
45
46#[derive(Debug)]
50pub struct NAdam {
51 vars: Vec<VarNAdam>,
52 params: ParamsNAdam,
53 mu_t: f64,
54 mu_t2: f64,
55 prod: f64,
56 prod2: f64,
57 t: f64,
58}
59
60#[derive(Debug)]
61struct VarNAdam {
62 theta: Var,
63 m: Var,
64 v: Var,
65}
66
67#[derive(Clone, Debug, PartialEq, PartialOrd)]
69pub struct ParamsNAdam {
70 pub lr: f64,
72 pub beta_1: f64,
74 pub beta_2: f64,
76 pub eps: f64,
78 pub weight_decay: Option<Decay>,
80 pub momentum_decay: f64,
82}
83
84impl Default for ParamsNAdam {
85 fn default() -> Self {
86 Self {
87 lr: 0.002,
88 beta_1: 0.9,
89 beta_2: 0.999,
90 eps: 1e-8,
91 weight_decay: None,
92 momentum_decay: 0.004,
93 }
94 }
95}
96
97impl Optimizer for NAdam {
98 type Config = ParamsNAdam;
99
100 fn new(vars: Vec<Var>, params: ParamsNAdam) -> Result<Self> {
101 let vars = vars
102 .into_iter()
103 .filter(|var| var.dtype().is_float())
104 .map(|var| {
105 let dtype = var.dtype();
106 let shape = var.shape();
107 let device = var.device();
108 let m = Var::zeros(shape, dtype, device)?;
109 let v = Var::zeros(shape, dtype, device)?;
110 Ok(VarNAdam { theta: var, m, v })
111 })
112 .collect::<Result<Vec<VarNAdam>>>()?;
113 let t = 1.;
117 let mu_t2 = params.beta_1 * 0.5f64.mul_add(-(0.96_f64.powf(t * params.momentum_decay)), 1.);
118 Ok(Self {
119 vars,
120 params,
121 t: 1.,
122 mu_t: 1.,
123 mu_t2,
124 prod: 1.,
125 prod2: mu_t2,
126 })
127 }
128
129 fn learning_rate(&self) -> f64 {
130 self.params.lr
131 }
132
133 fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
134 let mu_t = self.mu_t2;
135 let mu_t2 = self.params.beta_1
136 * 0.5f64.mul_add(
137 -(0.96_f64.powf((self.t + 1.) * self.params.momentum_decay)),
138 1.,
139 );
140 let prod = self.prod2;
141 let prod2 = prod * mu_t2;
142 self.mu_t = mu_t;
143 self.mu_t2 = mu_t2;
144 self.prod = prod;
145 self.prod2 = prod2;
146 if let Some(decay) = self.params.weight_decay {
149 match decay {
150 Decay::WeightDecay(decay) => {
151 for var in &self.vars {
152 let theta = &var.theta;
153 let m = &var.m;
154 let v = &var.v;
155 if let Some(grad) = grads.get(theta) {
156 let grad = &(grad + (decay * theta.as_tensor())?)?;
157 let m_next = ((self.params.beta_1 * m.as_tensor())?
158 + ((1. - self.params.beta_1) * grad)?)?;
159 let v_next = ((self.params.beta_2 * v.as_tensor())?
160 + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
161 let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
162 + (((1. - mu_t) / (1. - prod)) * grad)?)?;
163 let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
164 let delta = (m_hat * self.params.lr)?
165 .div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
166 theta.set(&theta.sub(&(delta))?)?;
167 m.set(&m_next)?;
168 v.set(&v_next)?;
169 }
170 }
171 }
172 Decay::DecoupledWeightDecay(decay) => {
173 for var in &self.vars {
174 let theta = &var.theta;
175 let m = &var.m;
176 let v = &var.v;
177 if let Some(grad) = grads.get(theta) {
178 theta
179 .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
180 let m_next = ((self.params.beta_1 * m.as_tensor())?
181 + ((1. - self.params.beta_1) * grad)?)?;
182 let v_next = ((self.params.beta_2 * v.as_tensor())?
183 + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
184 let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
185 + (((1. - mu_t) / (1. - prod)) * grad)?)?;
186 let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
187 let delta = (m_hat * self.params.lr)?
188 .div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
189 theta.set(&theta.sub(&(delta))?)?;
190 m.set(&m_next)?;
191 v.set(&v_next)?;
192 }
193 }
194 }
195 }
196 } else {
197 for var in &self.vars {
198 let theta = &var.theta;
199 let m = &var.m;
200 let v = &var.v;
201 if let Some(grad) = grads.get(theta) {
202 let m_next = ((self.params.beta_1 * m.as_tensor())?
203 + ((1. - self.params.beta_1) * grad)?)?;
204 let v_next = ((self.params.beta_2 * v.as_tensor())?
205 + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
206 let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
207 + (((1. - mu_t) / (1. - prod)) * grad)?)?;
208 let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
209 let delta =
210 (m_hat * self.params.lr)?.div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
211 theta.set(&theta.sub(&(delta))?)?;
212 m.set(&m_next)?;
213 v.set(&v_next)?;
214 }
215 }
216 }
217
218 self.t += 1.;
219 Ok(())
220 }
221
222 fn set_learning_rate(&mut self, lr: f64) {
223 self.params.lr = lr;
224 }
225}
226
227impl OptimParams for NAdam {
228 fn params(&self) -> &Self::Config {
229 &self.params
230 }
231
232 fn set_params(&mut self, config: Self::Config) {
233 self.params = config;
234 }
235}
236
237impl NAdam {
238 #[must_use]
240 pub fn into_inner(self) -> Vec<Var> {
241 self.vars.into_iter().map(|v| v.theta).collect()
242 }
243
244 }
248
249#[cfg(test)]
250mod tests {
251 use anyhow::Result;
254 use assert_approx_eq::assert_approx_eq;
255 use candle_core::{Device, Var};
256 use candle_nn::Optimizer;
257
258 use super::*;
259 #[test]
260 fn lr_test() -> Result<()> {
261 let params = ParamsNAdam {
262 lr: 0.004,
263 ..Default::default()
264 };
265 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
267 let b = Var::new(0f32, &Device::Cpu)?;
268 let mut optim = NAdam::new(vec![w.clone(), b.clone()], params)?;
269 assert_approx_eq!(0.004, optim.learning_rate());
270 optim.set_learning_rate(0.002);
271 assert_approx_eq!(0.002, optim.learning_rate());
272 Ok(())
273 }
274
275 #[test]
276 fn into_inner_test() -> Result<()> {
277 let params = ParamsNAdam::default();
278 let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
279 let b = Var::new(-2f32, &Device::Cpu)?;
280 let optim = NAdam::new(vec![w.clone(), b.clone()], params)?;
281 let inner = optim.into_inner();
282 assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
283 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
284 Ok(())
285 }
286
287 #[test]
288 fn params_test() -> Result<()> {
289 let params = ParamsNAdam {
290 lr: 0.004,
291 ..Default::default()
292 };
293 let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
295 let b = Var::new(0f32, &Device::Cpu)?;
296 let mut optim = NAdam::new(vec![w.clone(), b.clone()], params.clone())?;
297 assert_eq!(params, optim.params().clone());
298 let new_params = ParamsNAdam {
299 lr: 0.002,
300 ..Default::default()
301 };
302 optim.set_params(new_params.clone());
303 assert_eq!(new_params, optim.params().clone());
304 Ok(())
305 }
306}