1use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22pub struct Adam {
40 params: Vec<Parameter>,
42 lr: f32,
44 beta1: f32,
46 beta2: f32,
48 eps: f32,
50 weight_decay: f32,
52 amsgrad: bool,
54 state: Vec<AdamState>,
56}
57
58#[derive(Debug, Clone)]
63struct AdamState {
64 exp_avg: Tensor<f32>,
66 exp_avg_sq: Tensor<f32>,
68 max_exp_avg_sq: Option<Tensor<f32>>,
70 step: usize,
72}
73
74impl AdamState {
75 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
76 let size: usize = shape.iter().product();
77 let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
78 let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
79 if device.is_gpu() {
80 exp_avg = exp_avg.to_device(device).expect("device transfer failed");
81 exp_avg_sq = exp_avg_sq.to_device(device).expect("device transfer failed");
82 }
83 Self {
84 exp_avg,
85 exp_avg_sq,
86 max_exp_avg_sq: None, step: 0,
88 }
89 }
90}
91
92impl Adam {
93 #[must_use]
95 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
96 Self::with_betas(params, lr, (0.9, 0.999))
97 }
98
99 #[must_use]
101 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
102 Self {
103 params,
104 lr,
105 beta1: betas.0,
106 beta2: betas.1,
107 eps: 1e-8,
108 weight_decay: 0.0,
109 amsgrad: false,
110 state: Vec::new(),
111 }
112 }
113
114 #[must_use]
116 pub fn with_options(
117 params: Vec<Parameter>,
118 lr: f32,
119 betas: (f32, f32),
120 eps: f32,
121 weight_decay: f32,
122 amsgrad: bool,
123 ) -> Self {
124 Self {
125 params,
126 lr,
127 beta1: betas.0,
128 beta2: betas.1,
129 eps,
130 weight_decay,
131 amsgrad,
132 state: Vec::new(),
133 }
134 }
135
136 #[must_use]
138 pub fn betas(mut self, betas: (f32, f32)) -> Self {
139 self.beta1 = betas.0;
140 self.beta2 = betas.1;
141 self
142 }
143
144 #[must_use]
146 pub fn eps(mut self, eps: f32) -> Self {
147 self.eps = eps;
148 self
149 }
150
151 #[must_use]
153 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
154 self.weight_decay = weight_decay;
155 self
156 }
157
158 #[must_use]
160 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
161 self.amsgrad = amsgrad;
162 self
163 }
164
165 fn ensure_state_initialized(&mut self) {
166 if self.state.is_empty() {
167 self.state = self
168 .params
169 .iter()
170 .map(|p| {
171 let data = p.data();
172 AdamState::new(data.shape(), data.device())
173 })
174 .collect();
175 }
176 }
177}
178
179impl Optimizer for Adam {
180 fn step(&mut self) {
181 self.ensure_state_initialized();
182
183 for (i, param) in self.params.iter().enumerate() {
184 if !param.requires_grad() {
185 continue;
186 }
187
188 let grad = match param.grad() {
189 Some(g) => g,
190 None => continue,
191 };
192
193 let state = &mut self.state[i];
194 state.step += 1;
195
196 let param_data = param.data();
197
198 #[cfg(feature = "cuda")]
200 if param_data.device().is_gpu() {
201 let grad = if !grad.device().is_gpu() {
204 grad.to_device(param_data.device())
205 .expect("Adam: failed to migrate CPU gradient to GPU")
206 } else {
207 grad
208 };
209 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
210 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
211
212 param_data.adam_step_inplace(
214 &grad,
215 &state.exp_avg,
216 &state.exp_avg_sq,
217 self.lr,
218 self.beta1,
219 self.beta2,
220 self.eps,
221 self.weight_decay,
222 bias_correction1,
223 bias_correction2,
224 );
225 continue;
227 }
228
229 let grad_vec = grad.to_vec();
231 let mut param_vec = param_data.to_vec();
232 let mut exp_avg_vec = state.exp_avg.to_vec();
233 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
234
235 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
236 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
237 let step_size = self.lr / bias_correction1;
238 let beta1 = self.beta1;
239 let beta2 = self.beta2;
240 let one_minus_beta1 = 1.0 - beta1;
241 let one_minus_beta2 = 1.0 - beta2;
242 let eps = self.eps;
243 let wd = self.weight_decay;
244
245 let mut max_sq_vec = if self.amsgrad {
247 state
248 .max_exp_avg_sq
249 .as_ref()
250 .map(|t| t.to_vec())
251 .unwrap_or_else(|| vec![0.0f32; param_vec.len()])
252 } else {
253 Vec::new()
254 };
255
256 for i in 0..param_vec.len() {
257 let g = if wd == 0.0 {
258 grad_vec[i]
259 } else {
260 grad_vec[i] + wd * param_vec[i]
261 };
262 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
263 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
264
265 let v_hat = if self.amsgrad {
266 max_sq_vec[i] = max_sq_vec[i].max(exp_avg_sq_vec[i]);
267 max_sq_vec[i] / bias_correction2
268 } else {
269 exp_avg_sq_vec[i] / bias_correction2
270 };
271
272 let denom = v_hat.sqrt() + eps;
273 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
274 }
275
276 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).expect("tensor creation failed");
277 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).expect("tensor creation failed");
278 if self.amsgrad {
279 state.max_exp_avg_sq = Some(Tensor::from_vec(max_sq_vec, param_data.shape()).expect("tensor creation failed"));
280 }
281 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).expect("tensor creation failed"));
282 }
283 }
284
285 fn zero_grad(&mut self) {
286 for param in &self.params {
287 param.zero_grad();
288 }
289 }
290
291 fn get_lr(&self) -> f32 {
292 self.lr
293 }
294
295 fn set_lr(&mut self, lr: f32) {
296 self.lr = lr;
297 }
298
299 fn parameters(&self) -> &[Parameter] {
300 &self.params
301 }
302}
303
304pub struct AdamW {
322 params: Vec<Parameter>,
324 lr: f32,
326 beta1: f32,
328 beta2: f32,
330 eps: f32,
332 weight_decay: f32,
334 amsgrad: bool,
336 state: Vec<AdamState>,
338}
339
340impl AdamW {
341 #[must_use]
343 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
344 Self::with_betas(params, lr, (0.9, 0.999))
345 }
346
347 #[must_use]
349 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
350 Self {
351 params,
352 lr,
353 beta1: betas.0,
354 beta2: betas.1,
355 eps: 1e-8,
356 weight_decay: 0.01, amsgrad: false,
358 state: Vec::new(),
359 }
360 }
361
362 #[must_use]
364 pub fn with_options(
365 params: Vec<Parameter>,
366 lr: f32,
367 betas: (f32, f32),
368 eps: f32,
369 weight_decay: f32,
370 amsgrad: bool,
371 ) -> Self {
372 Self {
373 params,
374 lr,
375 beta1: betas.0,
376 beta2: betas.1,
377 eps,
378 weight_decay,
379 amsgrad,
380 state: Vec::new(),
381 }
382 }
383
384 #[must_use]
386 pub fn betas(mut self, betas: (f32, f32)) -> Self {
387 self.beta1 = betas.0;
388 self.beta2 = betas.1;
389 self
390 }
391
392 #[must_use]
394 pub fn eps(mut self, eps: f32) -> Self {
395 self.eps = eps;
396 self
397 }
398
399 #[must_use]
401 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
402 self.weight_decay = weight_decay;
403 self
404 }
405
406 #[must_use]
408 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
409 self.amsgrad = amsgrad;
410 self
411 }
412
413 fn ensure_state_initialized(&mut self) {
414 if self.state.is_empty() {
415 self.state = self
416 .params
417 .iter()
418 .map(|p| {
419 let data = p.data();
420 AdamState::new(data.shape(), data.device())
421 })
422 .collect();
423 }
424 }
425}
426
427impl Optimizer for AdamW {
428 fn step(&mut self) {
429 self.ensure_state_initialized();
430
431 for (i, param) in self.params.iter().enumerate() {
432 if !param.requires_grad() {
433 continue;
434 }
435
436 let grad = match param.grad() {
437 Some(g) => g,
438 None => continue,
439 };
440
441 let state = &mut self.state[i];
442 state.step += 1;
443
444 let param_data = param.data();
445
446 #[cfg(feature = "cuda")]
448 if param_data.device().is_gpu() {
449 let grad = if !grad.device().is_gpu() {
451 grad.to_device(param_data.device())
452 .expect("AdamW: failed to migrate CPU gradient to GPU")
453 } else {
454 grad
455 };
456
457 if self.weight_decay > 0.0 {
461 let decay_factor = 1.0 - self.lr * self.weight_decay;
462 let decayed = param_data.mul_scalar(decay_factor);
463 param.update_data(decayed);
464 }
465
466 let param_data = param.data();
468
469 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
470 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
471
472 param_data.adam_step_inplace(
474 &grad,
475 &state.exp_avg,
476 &state.exp_avg_sq,
477 self.lr,
478 self.beta1,
479 self.beta2,
480 self.eps,
481 0.0, bias_correction1,
483 bias_correction2,
484 );
485 continue;
486 }
487
488 let grad_vec = grad.to_vec();
490 let mut param_vec = param_data.to_vec();
491 let mut exp_avg_vec = state.exp_avg.to_vec();
492 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
493
494 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
495 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
496 let step_size = self.lr / bias_correction1;
497 let beta1 = self.beta1;
498 let beta2 = self.beta2;
499 let one_minus_beta1 = 1.0 - beta1;
500 let one_minus_beta2 = 1.0 - beta2;
501 let eps = self.eps;
502 let wd_factor = 1.0 - self.lr * self.weight_decay;
503 let has_wd = self.weight_decay != 0.0;
504
505 for i in 0..param_vec.len() {
506 if has_wd {
508 param_vec[i] *= wd_factor;
509 }
510 let g = grad_vec[i];
511 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
512 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
513 let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
514 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
515 }
516
517 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
518 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
519 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
520 }
521 }
522
523 fn zero_grad(&mut self) {
524 for param in &self.params {
525 param.zero_grad();
526 }
527 }
528
529 fn get_lr(&self) -> f32 {
530 self.lr
531 }
532
533 fn set_lr(&mut self, lr: f32) {
534 self.lr = lr;
535 }
536
537 fn parameters(&self) -> &[Parameter] {
538 &self.params
539 }
540}
541
542#[cfg(test)]
547mod tests {
548 use super::*;
549 use axonml_autograd::Variable;
550
551 #[test]
552 fn test_adam_creation() {
553 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
554 let param = Parameter::from_variable(var);
555 let optimizer = Adam::new(vec![param], 0.001);
556
557 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
558 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
559 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
560 }
561
562 #[test]
563 fn test_adam_step() {
564 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
565 let param = Parameter::from_variable(var);
566
567 param
569 .variable()
570 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
571
572 let mut optimizer = Adam::new(vec![param.clone()], 0.1);
573 optimizer.step();
574
575 let new_data = param.data().to_vec();
576 assert!((new_data[0] - 1.0).abs() > 1e-6);
578 }
579
580 #[test]
581 fn test_adamw_creation() {
582 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
583 let param = Parameter::from_variable(var);
584 let optimizer = AdamW::new(vec![param], 0.001);
585
586 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
587 }
588
589 #[test]
590 fn test_adam_builder_pattern() {
591 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
592 let param = Parameter::from_variable(var);
593
594 let optimizer = Adam::new(vec![param], 0.001)
595 .betas((0.95, 0.9999))
596 .eps(1e-7)
597 .weight_decay(0.01)
598 .amsgrad(true);
599
600 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
601 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
602 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
603 assert!(optimizer.amsgrad);
604 }
605}