1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13pub struct Adam {
31 params: Vec<Parameter>,
33 lr: f32,
35 beta1: f32,
37 beta2: f32,
39 eps: f32,
41 weight_decay: f32,
43 amsgrad: bool,
45 state: Vec<AdamState>,
47}
48
49#[derive(Debug, Clone)]
51struct AdamState {
52 exp_avg: Vec<f32>,
54 exp_avg_sq: Vec<f32>,
56 max_exp_avg_sq: Option<Vec<f32>>,
58 step: usize,
60}
61
62impl AdamState {
63 fn new(size: usize, amsgrad: bool) -> Self {
64 Self {
65 exp_avg: vec![0.0; size],
66 exp_avg_sq: vec![0.0; size],
67 max_exp_avg_sq: if amsgrad { Some(vec![0.0; size]) } else { None },
68 step: 0,
69 }
70 }
71}
72
73impl Adam {
74 #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
76 Self::with_betas(params, lr, (0.9, 0.999))
77 }
78
79 #[must_use] pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
81 Self {
82 params,
83 lr,
84 beta1: betas.0,
85 beta2: betas.1,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 state: Vec::new(),
90 }
91 }
92
93 #[must_use] pub fn with_options(
95 params: Vec<Parameter>,
96 lr: f32,
97 betas: (f32, f32),
98 eps: f32,
99 weight_decay: f32,
100 amsgrad: bool,
101 ) -> Self {
102 Self {
103 params,
104 lr,
105 beta1: betas.0,
106 beta2: betas.1,
107 eps,
108 weight_decay,
109 amsgrad,
110 state: Vec::new(),
111 }
112 }
113
114 #[must_use] pub fn betas(mut self, betas: (f32, f32)) -> Self {
116 self.beta1 = betas.0;
117 self.beta2 = betas.1;
118 self
119 }
120
121 #[must_use] pub fn eps(mut self, eps: f32) -> Self {
123 self.eps = eps;
124 self
125 }
126
127 #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
129 self.weight_decay = weight_decay;
130 self
131 }
132
133 #[must_use] pub fn amsgrad(mut self, amsgrad: bool) -> Self {
135 self.amsgrad = amsgrad;
136 self
137 }
138
139 fn ensure_state_initialized(&mut self) {
140 if self.state.is_empty() {
141 self.state = self
142 .params
143 .iter()
144 .map(|p| AdamState::new(p.numel(), self.amsgrad))
145 .collect();
146 }
147 }
148}
149
150impl Optimizer for Adam {
151 fn step(&mut self) {
152 self.ensure_state_initialized();
153
154 for (i, param) in self.params.iter().enumerate() {
155 if !param.requires_grad() {
156 continue;
157 }
158
159 let grad = match param.grad() {
160 Some(g) => g,
161 None => continue,
162 };
163
164 let grad_vec = grad.to_vec();
165 let state = &mut self.state[i];
166 state.step += 1;
167
168 let param_data = param.data();
169 let mut param_vec = param_data.to_vec();
170
171 let grad_vec: Vec<f32> = if self.weight_decay == 0.0 {
173 grad_vec
174 } else {
175 grad_vec
176 .iter()
177 .zip(param_vec.iter())
178 .map(|(g, p)| g + self.weight_decay * p)
179 .collect()
180 };
181
182 for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
184 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
185 }
186
187 for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
189 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
190 }
191
192 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
194 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
195
196 let step_size = self.lr / bias_correction1;
198
199 if self.amsgrad {
201 let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
203 for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
204 *max_v = max_v.max(*v);
205 }
206 for (p, (m, max_v)) in param_vec
207 .iter_mut()
208 .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
209 {
210 let denom = (max_v / bias_correction2).sqrt() + self.eps;
211 *p -= step_size * m / denom;
212 }
213 } else {
214 for (p, (m, v)) in param_vec
216 .iter_mut()
217 .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
218 {
219 let denom = (v / bias_correction2).sqrt() + self.eps;
220 *p -= step_size * m / denom;
221 }
222 }
223
224 let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
225 param.update_data(update);
226 }
227 }
228
229 fn zero_grad(&mut self) {
230 for param in &self.params {
231 param.zero_grad();
232 }
233 }
234
235 fn get_lr(&self) -> f32 {
236 self.lr
237 }
238
239 fn set_lr(&mut self, lr: f32) {
240 self.lr = lr;
241 }
242
243 fn parameters(&self) -> &[Parameter] {
244 &self.params
245 }
246}
247
248pub struct AdamW {
266 params: Vec<Parameter>,
268 lr: f32,
270 beta1: f32,
272 beta2: f32,
274 eps: f32,
276 weight_decay: f32,
278 amsgrad: bool,
280 state: Vec<AdamState>,
282}
283
284impl AdamW {
285 #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
287 Self::with_betas(params, lr, (0.9, 0.999))
288 }
289
290 #[must_use] pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
292 Self {
293 params,
294 lr,
295 beta1: betas.0,
296 beta2: betas.1,
297 eps: 1e-8,
298 weight_decay: 0.01, amsgrad: false,
300 state: Vec::new(),
301 }
302 }
303
304 #[must_use] pub fn with_options(
306 params: Vec<Parameter>,
307 lr: f32,
308 betas: (f32, f32),
309 eps: f32,
310 weight_decay: f32,
311 amsgrad: bool,
312 ) -> Self {
313 Self {
314 params,
315 lr,
316 beta1: betas.0,
317 beta2: betas.1,
318 eps,
319 weight_decay,
320 amsgrad,
321 state: Vec::new(),
322 }
323 }
324
325 #[must_use] pub fn betas(mut self, betas: (f32, f32)) -> Self {
327 self.beta1 = betas.0;
328 self.beta2 = betas.1;
329 self
330 }
331
332 #[must_use] pub fn eps(mut self, eps: f32) -> Self {
334 self.eps = eps;
335 self
336 }
337
338 #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
340 self.weight_decay = weight_decay;
341 self
342 }
343
344 #[must_use] pub fn amsgrad(mut self, amsgrad: bool) -> Self {
346 self.amsgrad = amsgrad;
347 self
348 }
349
350 fn ensure_state_initialized(&mut self) {
351 if self.state.is_empty() {
352 self.state = self
353 .params
354 .iter()
355 .map(|p| AdamState::new(p.numel(), self.amsgrad))
356 .collect();
357 }
358 }
359}
360
361impl Optimizer for AdamW {
362 fn step(&mut self) {
363 self.ensure_state_initialized();
364
365 for (i, param) in self.params.iter().enumerate() {
366 if !param.requires_grad() {
367 continue;
368 }
369
370 let grad = match param.grad() {
371 Some(g) => g,
372 None => continue,
373 };
374
375 let grad_vec = grad.to_vec();
376 let state = &mut self.state[i];
377 state.step += 1;
378
379 let param_data = param.data();
380 let mut param_vec = param_data.to_vec();
381
382 if self.weight_decay != 0.0 {
384 for p in &mut param_vec {
385 *p *= 1.0 - self.lr * self.weight_decay;
386 }
387 }
388
389 for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
391 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
392 }
393
394 for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
396 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
397 }
398
399 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
401 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
402
403 let step_size = self.lr / bias_correction1;
405
406 if self.amsgrad {
408 let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
409 for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
410 *max_v = max_v.max(*v);
411 }
412 for (p, (m, max_v)) in param_vec
413 .iter_mut()
414 .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
415 {
416 let denom = (max_v / bias_correction2).sqrt() + self.eps;
417 *p -= step_size * m / denom;
418 }
419 } else {
420 for (p, (m, v)) in param_vec
421 .iter_mut()
422 .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
423 {
424 let denom = (v / bias_correction2).sqrt() + self.eps;
425 *p -= step_size * m / denom;
426 }
427 }
428
429 let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
430 param.update_data(update);
431 }
432 }
433
434 fn zero_grad(&mut self) {
435 for param in &self.params {
436 param.zero_grad();
437 }
438 }
439
440 fn get_lr(&self) -> f32 {
441 self.lr
442 }
443
444 fn set_lr(&mut self, lr: f32) {
445 self.lr = lr;
446 }
447
448 fn parameters(&self) -> &[Parameter] {
449 &self.params
450 }
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460 use axonml_autograd::Variable;
461
462 #[test]
463 fn test_adam_creation() {
464 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
465 let param = Parameter::from_variable(var);
466 let optimizer = Adam::new(vec![param], 0.001);
467
468 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
469 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
470 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
471 }
472
473 #[test]
474 fn test_adam_step() {
475 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
476 let param = Parameter::from_variable(var);
477
478 param
480 .variable()
481 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
482
483 let mut optimizer = Adam::new(vec![param.clone()], 0.1);
484 optimizer.step();
485
486 let new_data = param.data().to_vec();
487 assert!((new_data[0] - 1.0).abs() > 1e-6);
489 }
490
491 #[test]
492 fn test_adamw_creation() {
493 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
494 let param = Parameter::from_variable(var);
495 let optimizer = AdamW::new(vec![param], 0.001);
496
497 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
498 }
499
500 #[test]
501 fn test_adam_builder_pattern() {
502 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
503 let param = Parameter::from_variable(var);
504
505 let optimizer = Adam::new(vec![param], 0.001)
506 .betas((0.95, 0.9999))
507 .eps(1e-7)
508 .weight_decay(0.01)
509 .amsgrad(true);
510
511 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
512 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
513 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
514 assert!(optimizer.amsgrad);
515 }
516}