1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13pub struct Adam {
31 params: Vec<Parameter>,
33 lr: f32,
35 beta1: f32,
37 beta2: f32,
39 eps: f32,
41 weight_decay: f32,
43 amsgrad: bool,
45 state: Vec<AdamState>,
47}
48
49#[derive(Debug, Clone)]
51struct AdamState {
52 exp_avg: Vec<f32>,
54 exp_avg_sq: Vec<f32>,
56 max_exp_avg_sq: Option<Vec<f32>>,
58 step: usize,
60}
61
62impl AdamState {
63 fn new(size: usize, amsgrad: bool) -> Self {
64 Self {
65 exp_avg: vec![0.0; size],
66 exp_avg_sq: vec![0.0; size],
67 max_exp_avg_sq: if amsgrad { Some(vec![0.0; size]) } else { None },
68 step: 0,
69 }
70 }
71}
72
73impl Adam {
74 #[must_use]
76 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
77 Self::with_betas(params, lr, (0.9, 0.999))
78 }
79
80 #[must_use]
82 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
83 Self {
84 params,
85 lr,
86 beta1: betas.0,
87 beta2: betas.1,
88 eps: 1e-8,
89 weight_decay: 0.0,
90 amsgrad: false,
91 state: Vec::new(),
92 }
93 }
94
95 #[must_use]
97 pub fn with_options(
98 params: Vec<Parameter>,
99 lr: f32,
100 betas: (f32, f32),
101 eps: f32,
102 weight_decay: f32,
103 amsgrad: bool,
104 ) -> Self {
105 Self {
106 params,
107 lr,
108 beta1: betas.0,
109 beta2: betas.1,
110 eps,
111 weight_decay,
112 amsgrad,
113 state: Vec::new(),
114 }
115 }
116
117 #[must_use]
119 pub fn betas(mut self, betas: (f32, f32)) -> Self {
120 self.beta1 = betas.0;
121 self.beta2 = betas.1;
122 self
123 }
124
125 #[must_use]
127 pub fn eps(mut self, eps: f32) -> Self {
128 self.eps = eps;
129 self
130 }
131
132 #[must_use]
134 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
135 self.weight_decay = weight_decay;
136 self
137 }
138
139 #[must_use]
141 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
142 self.amsgrad = amsgrad;
143 self
144 }
145
146 fn ensure_state_initialized(&mut self) {
147 if self.state.is_empty() {
148 self.state = self
149 .params
150 .iter()
151 .map(|p| AdamState::new(p.numel(), self.amsgrad))
152 .collect();
153 }
154 }
155}
156
157impl Optimizer for Adam {
158 fn step(&mut self) {
159 self.ensure_state_initialized();
160
161 for (i, param) in self.params.iter().enumerate() {
162 if !param.requires_grad() {
163 continue;
164 }
165
166 let grad = match param.grad() {
167 Some(g) => g,
168 None => continue,
169 };
170
171 let grad_vec = grad.to_vec();
172 let state = &mut self.state[i];
173 state.step += 1;
174
175 let param_data = param.data();
176 let mut param_vec = param_data.to_vec();
177
178 let grad_vec: Vec<f32> = if self.weight_decay == 0.0 {
180 grad_vec
181 } else {
182 grad_vec
183 .iter()
184 .zip(param_vec.iter())
185 .map(|(g, p)| g + self.weight_decay * p)
186 .collect()
187 };
188
189 for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
191 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
192 }
193
194 for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
196 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
197 }
198
199 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
201 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
202
203 let step_size = self.lr / bias_correction1;
205
206 if self.amsgrad {
208 let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
210 for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
211 *max_v = max_v.max(*v);
212 }
213 for (p, (m, max_v)) in param_vec
214 .iter_mut()
215 .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
216 {
217 let denom = (max_v / bias_correction2).sqrt() + self.eps;
218 *p -= step_size * m / denom;
219 }
220 } else {
221 for (p, (m, v)) in param_vec
223 .iter_mut()
224 .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
225 {
226 let denom = (v / bias_correction2).sqrt() + self.eps;
227 *p -= step_size * m / denom;
228 }
229 }
230
231 let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
232 param.update_data(update);
233 }
234 }
235
236 fn zero_grad(&mut self) {
237 for param in &self.params {
238 param.zero_grad();
239 }
240 }
241
242 fn get_lr(&self) -> f32 {
243 self.lr
244 }
245
246 fn set_lr(&mut self, lr: f32) {
247 self.lr = lr;
248 }
249
250 fn parameters(&self) -> &[Parameter] {
251 &self.params
252 }
253}
254
255pub struct AdamW {
273 params: Vec<Parameter>,
275 lr: f32,
277 beta1: f32,
279 beta2: f32,
281 eps: f32,
283 weight_decay: f32,
285 amsgrad: bool,
287 state: Vec<AdamState>,
289}
290
291impl AdamW {
292 #[must_use]
294 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
295 Self::with_betas(params, lr, (0.9, 0.999))
296 }
297
298 #[must_use]
300 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
301 Self {
302 params,
303 lr,
304 beta1: betas.0,
305 beta2: betas.1,
306 eps: 1e-8,
307 weight_decay: 0.01, amsgrad: false,
309 state: Vec::new(),
310 }
311 }
312
313 #[must_use]
315 pub fn with_options(
316 params: Vec<Parameter>,
317 lr: f32,
318 betas: (f32, f32),
319 eps: f32,
320 weight_decay: f32,
321 amsgrad: bool,
322 ) -> Self {
323 Self {
324 params,
325 lr,
326 beta1: betas.0,
327 beta2: betas.1,
328 eps,
329 weight_decay,
330 amsgrad,
331 state: Vec::new(),
332 }
333 }
334
335 #[must_use]
337 pub fn betas(mut self, betas: (f32, f32)) -> Self {
338 self.beta1 = betas.0;
339 self.beta2 = betas.1;
340 self
341 }
342
343 #[must_use]
345 pub fn eps(mut self, eps: f32) -> Self {
346 self.eps = eps;
347 self
348 }
349
350 #[must_use]
352 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
353 self.weight_decay = weight_decay;
354 self
355 }
356
357 #[must_use]
359 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
360 self.amsgrad = amsgrad;
361 self
362 }
363
364 fn ensure_state_initialized(&mut self) {
365 if self.state.is_empty() {
366 self.state = self
367 .params
368 .iter()
369 .map(|p| AdamState::new(p.numel(), self.amsgrad))
370 .collect();
371 }
372 }
373}
374
375impl Optimizer for AdamW {
376 fn step(&mut self) {
377 self.ensure_state_initialized();
378
379 for (i, param) in self.params.iter().enumerate() {
380 if !param.requires_grad() {
381 continue;
382 }
383
384 let grad = match param.grad() {
385 Some(g) => g,
386 None => continue,
387 };
388
389 let grad_vec = grad.to_vec();
390 let state = &mut self.state[i];
391 state.step += 1;
392
393 let param_data = param.data();
394 let mut param_vec = param_data.to_vec();
395
396 if self.weight_decay != 0.0 {
398 for p in &mut param_vec {
399 *p *= 1.0 - self.lr * self.weight_decay;
400 }
401 }
402
403 for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
405 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
406 }
407
408 for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
410 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
411 }
412
413 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
415 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
416
417 let step_size = self.lr / bias_correction1;
419
420 if self.amsgrad {
422 let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
423 for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
424 *max_v = max_v.max(*v);
425 }
426 for (p, (m, max_v)) in param_vec
427 .iter_mut()
428 .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
429 {
430 let denom = (max_v / bias_correction2).sqrt() + self.eps;
431 *p -= step_size * m / denom;
432 }
433 } else {
434 for (p, (m, v)) in param_vec
435 .iter_mut()
436 .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
437 {
438 let denom = (v / bias_correction2).sqrt() + self.eps;
439 *p -= step_size * m / denom;
440 }
441 }
442
443 let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
444 param.update_data(update);
445 }
446 }
447
448 fn zero_grad(&mut self) {
449 for param in &self.params {
450 param.zero_grad();
451 }
452 }
453
454 fn get_lr(&self) -> f32 {
455 self.lr
456 }
457
458 fn set_lr(&mut self, lr: f32) {
459 self.lr = lr;
460 }
461
462 fn parameters(&self) -> &[Parameter] {
463 &self.params
464 }
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474 use axonml_autograd::Variable;
475
476 #[test]
477 fn test_adam_creation() {
478 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
479 let param = Parameter::from_variable(var);
480 let optimizer = Adam::new(vec![param], 0.001);
481
482 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
483 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
484 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
485 }
486
487 #[test]
488 fn test_adam_step() {
489 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
490 let param = Parameter::from_variable(var);
491
492 param
494 .variable()
495 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
496
497 let mut optimizer = Adam::new(vec![param.clone()], 0.1);
498 optimizer.step();
499
500 let new_data = param.data().to_vec();
501 assert!((new_data[0] - 1.0).abs() > 1e-6);
503 }
504
505 #[test]
506 fn test_adamw_creation() {
507 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
508 let param = Parameter::from_variable(var);
509 let optimizer = AdamW::new(vec![param], 0.001);
510
511 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
512 }
513
514 #[test]
515 fn test_adam_builder_pattern() {
516 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
517 let param = Parameter::from_variable(var);
518
519 let optimizer = Adam::new(vec![param], 0.001)
520 .betas((0.95, 0.9999))
521 .eps(1e-7)
522 .weight_decay(0.01)
523 .amsgrad(true);
524
525 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
526 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
527 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
528 assert!(optimizer.amsgrad);
529 }
530}