1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13pub struct Adam {
31 params: Vec<Parameter>,
33 lr: f32,
35 beta1: f32,
37 beta2: f32,
39 eps: f32,
41 weight_decay: f32,
43 amsgrad: bool,
45 state: Vec<AdamState>,
47}
48
49#[derive(Debug, Clone)]
54struct AdamState {
55 exp_avg: Tensor<f32>,
57 exp_avg_sq: Tensor<f32>,
59 step: usize,
61}
62
63impl AdamState {
64 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
65 let size: usize = shape.iter().product();
66 let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
67 let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
68 if device.is_gpu() {
69 exp_avg = exp_avg.to_device(device.clone()).unwrap();
70 exp_avg_sq = exp_avg_sq.to_device(device).unwrap();
71 }
72 Self {
73 exp_avg,
74 exp_avg_sq,
75 step: 0,
76 }
77 }
78}
79
80impl Adam {
81 #[must_use]
83 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
84 Self::with_betas(params, lr, (0.9, 0.999))
85 }
86
87 #[must_use]
89 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
90 Self {
91 params,
92 lr,
93 beta1: betas.0,
94 beta2: betas.1,
95 eps: 1e-8,
96 weight_decay: 0.0,
97 amsgrad: false,
98 state: Vec::new(),
99 }
100 }
101
102 #[must_use]
104 pub fn with_options(
105 params: Vec<Parameter>,
106 lr: f32,
107 betas: (f32, f32),
108 eps: f32,
109 weight_decay: f32,
110 amsgrad: bool,
111 ) -> Self {
112 Self {
113 params,
114 lr,
115 beta1: betas.0,
116 beta2: betas.1,
117 eps,
118 weight_decay,
119 amsgrad,
120 state: Vec::new(),
121 }
122 }
123
124 #[must_use]
126 pub fn betas(mut self, betas: (f32, f32)) -> Self {
127 self.beta1 = betas.0;
128 self.beta2 = betas.1;
129 self
130 }
131
132 #[must_use]
134 pub fn eps(mut self, eps: f32) -> Self {
135 self.eps = eps;
136 self
137 }
138
139 #[must_use]
141 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
142 self.weight_decay = weight_decay;
143 self
144 }
145
146 #[must_use]
148 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
149 self.amsgrad = amsgrad;
150 self
151 }
152
153 fn ensure_state_initialized(&mut self) {
154 if self.state.is_empty() {
155 self.state = self
156 .params
157 .iter()
158 .map(|p| {
159 let data = p.data();
160 AdamState::new(data.shape(), data.device())
161 })
162 .collect();
163 }
164 }
165}
166
167impl Optimizer for Adam {
168 fn step(&mut self) {
169 self.ensure_state_initialized();
170
171 for (i, param) in self.params.iter().enumerate() {
172 if !param.requires_grad() {
173 continue;
174 }
175
176 let grad = match param.grad() {
177 Some(g) => g,
178 None => continue,
179 };
180
181 let state = &mut self.state[i];
182 state.step += 1;
183
184 let param_data = param.data();
185
186 #[cfg(feature = "cuda")]
188 if param_data.device().is_gpu() {
189 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
190 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
191
192 param_data.adam_step_inplace(
194 &grad,
195 &state.exp_avg,
196 &state.exp_avg_sq,
197 self.lr,
198 self.beta1,
199 self.beta2,
200 self.eps,
201 self.weight_decay,
202 bias_correction1,
203 bias_correction2,
204 );
205 continue;
207 }
208
209 let grad_vec = grad.to_vec();
211 let mut param_vec = param_data.to_vec();
212 let mut exp_avg_vec = state.exp_avg.to_vec();
213 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
214
215 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
216 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
217 let step_size = self.lr / bias_correction1;
218 let beta1 = self.beta1;
219 let beta2 = self.beta2;
220 let one_minus_beta1 = 1.0 - beta1;
221 let one_minus_beta2 = 1.0 - beta2;
222 let eps = self.eps;
223 let wd = self.weight_decay;
224
225 for i in 0..param_vec.len() {
226 let g = if wd == 0.0 {
227 grad_vec[i]
228 } else {
229 grad_vec[i] + wd * param_vec[i]
230 };
231 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
232 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
233 let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
234 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
235 }
236
237 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
238 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
239 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
240 }
241 }
242
243 fn zero_grad(&mut self) {
244 for param in &self.params {
245 param.zero_grad();
246 }
247 }
248
249 fn get_lr(&self) -> f32 {
250 self.lr
251 }
252
253 fn set_lr(&mut self, lr: f32) {
254 self.lr = lr;
255 }
256
257 fn parameters(&self) -> &[Parameter] {
258 &self.params
259 }
260}
261
262pub struct AdamW {
280 params: Vec<Parameter>,
282 lr: f32,
284 beta1: f32,
286 beta2: f32,
288 eps: f32,
290 weight_decay: f32,
292 amsgrad: bool,
294 state: Vec<AdamState>,
296}
297
298impl AdamW {
299 #[must_use]
301 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
302 Self::with_betas(params, lr, (0.9, 0.999))
303 }
304
305 #[must_use]
307 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
308 Self {
309 params,
310 lr,
311 beta1: betas.0,
312 beta2: betas.1,
313 eps: 1e-8,
314 weight_decay: 0.01, amsgrad: false,
316 state: Vec::new(),
317 }
318 }
319
320 #[must_use]
322 pub fn with_options(
323 params: Vec<Parameter>,
324 lr: f32,
325 betas: (f32, f32),
326 eps: f32,
327 weight_decay: f32,
328 amsgrad: bool,
329 ) -> Self {
330 Self {
331 params,
332 lr,
333 beta1: betas.0,
334 beta2: betas.1,
335 eps,
336 weight_decay,
337 amsgrad,
338 state: Vec::new(),
339 }
340 }
341
342 #[must_use]
344 pub fn betas(mut self, betas: (f32, f32)) -> Self {
345 self.beta1 = betas.0;
346 self.beta2 = betas.1;
347 self
348 }
349
350 #[must_use]
352 pub fn eps(mut self, eps: f32) -> Self {
353 self.eps = eps;
354 self
355 }
356
357 #[must_use]
359 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
360 self.weight_decay = weight_decay;
361 self
362 }
363
364 #[must_use]
366 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
367 self.amsgrad = amsgrad;
368 self
369 }
370
371 fn ensure_state_initialized(&mut self) {
372 if self.state.is_empty() {
373 self.state = self
374 .params
375 .iter()
376 .map(|p| {
377 let data = p.data();
378 AdamState::new(data.shape(), data.device())
379 })
380 .collect();
381 }
382 }
383}
384
385impl Optimizer for AdamW {
386 fn step(&mut self) {
387 self.ensure_state_initialized();
388
389 for (i, param) in self.params.iter().enumerate() {
390 if !param.requires_grad() {
391 continue;
392 }
393
394 let grad = match param.grad() {
395 Some(g) => g,
396 None => continue,
397 };
398
399 let state = &mut self.state[i];
400 state.step += 1;
401
402 let param_data = param.data();
403
404 #[cfg(feature = "cuda")]
410 if param_data.device().is_gpu() {
411 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
412 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
413
414 param_data.adam_step_inplace(
415 &grad,
416 &state.exp_avg,
417 &state.exp_avg_sq,
418 self.lr,
419 self.beta1,
420 self.beta2,
421 self.eps,
422 self.weight_decay,
423 bias_correction1,
424 bias_correction2,
425 );
426 continue;
427 }
428
429 let grad_vec = grad.to_vec();
431 let mut param_vec = param_data.to_vec();
432 let mut exp_avg_vec = state.exp_avg.to_vec();
433 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
434
435 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
436 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
437 let step_size = self.lr / bias_correction1;
438 let beta1 = self.beta1;
439 let beta2 = self.beta2;
440 let one_minus_beta1 = 1.0 - beta1;
441 let one_minus_beta2 = 1.0 - beta2;
442 let eps = self.eps;
443 let wd_factor = 1.0 - self.lr * self.weight_decay;
444 let has_wd = self.weight_decay != 0.0;
445
446 for i in 0..param_vec.len() {
447 if has_wd {
449 param_vec[i] *= wd_factor;
450 }
451 let g = grad_vec[i];
452 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
453 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
454 let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
455 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
456 }
457
458 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
459 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
460 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
461 }
462 }
463
464 fn zero_grad(&mut self) {
465 for param in &self.params {
466 param.zero_grad();
467 }
468 }
469
470 fn get_lr(&self) -> f32 {
471 self.lr
472 }
473
474 fn set_lr(&mut self, lr: f32) {
475 self.lr = lr;
476 }
477
478 fn parameters(&self) -> &[Parameter] {
479 &self.params
480 }
481}
482
483#[cfg(test)]
488mod tests {
489 use super::*;
490 use axonml_autograd::Variable;
491
492 #[test]
493 fn test_adam_creation() {
494 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
495 let param = Parameter::from_variable(var);
496 let optimizer = Adam::new(vec![param], 0.001);
497
498 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
499 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
500 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
501 }
502
503 #[test]
504 fn test_adam_step() {
505 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
506 let param = Parameter::from_variable(var);
507
508 param
510 .variable()
511 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
512
513 let mut optimizer = Adam::new(vec![param.clone()], 0.1);
514 optimizer.step();
515
516 let new_data = param.data().to_vec();
517 assert!((new_data[0] - 1.0).abs() > 1e-6);
519 }
520
521 #[test]
522 fn test_adamw_creation() {
523 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
524 let param = Parameter::from_variable(var);
525 let optimizer = AdamW::new(vec![param], 0.001);
526
527 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
528 }
529
530 #[test]
531 fn test_adam_builder_pattern() {
532 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
533 let param = Parameter::from_variable(var);
534
535 let optimizer = Adam::new(vec![param], 0.001)
536 .betas((0.95, 0.9999))
537 .eps(1e-7)
538 .weight_decay(0.01)
539 .amsgrad(true);
540
541 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
542 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
543 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
544 assert!(optimizer.amsgrad);
545 }
546}