1use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22pub struct Adam {
40 params: Vec<Parameter>,
42 lr: f32,
44 beta1: f32,
46 beta2: f32,
48 eps: f32,
50 weight_decay: f32,
52 amsgrad: bool,
54 state: Vec<AdamState>,
56}
57
58#[derive(Debug, Clone)]
63struct AdamState {
64 exp_avg: Tensor<f32>,
66 exp_avg_sq: Tensor<f32>,
68 step: usize,
70}
71
72impl AdamState {
73 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
74 let size: usize = shape.iter().product();
75 let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
76 let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
77 if device.is_gpu() {
78 exp_avg = exp_avg.to_device(device).unwrap();
79 exp_avg_sq = exp_avg_sq.to_device(device).unwrap();
80 }
81 Self {
82 exp_avg,
83 exp_avg_sq,
84 step: 0,
85 }
86 }
87}
88
89impl Adam {
90 #[must_use]
92 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
93 Self::with_betas(params, lr, (0.9, 0.999))
94 }
95
96 #[must_use]
98 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
99 Self {
100 params,
101 lr,
102 beta1: betas.0,
103 beta2: betas.1,
104 eps: 1e-8,
105 weight_decay: 0.0,
106 amsgrad: false,
107 state: Vec::new(),
108 }
109 }
110
111 #[must_use]
113 pub fn with_options(
114 params: Vec<Parameter>,
115 lr: f32,
116 betas: (f32, f32),
117 eps: f32,
118 weight_decay: f32,
119 amsgrad: bool,
120 ) -> Self {
121 Self {
122 params,
123 lr,
124 beta1: betas.0,
125 beta2: betas.1,
126 eps,
127 weight_decay,
128 amsgrad,
129 state: Vec::new(),
130 }
131 }
132
133 #[must_use]
135 pub fn betas(mut self, betas: (f32, f32)) -> Self {
136 self.beta1 = betas.0;
137 self.beta2 = betas.1;
138 self
139 }
140
141 #[must_use]
143 pub fn eps(mut self, eps: f32) -> Self {
144 self.eps = eps;
145 self
146 }
147
148 #[must_use]
150 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
151 self.weight_decay = weight_decay;
152 self
153 }
154
155 #[must_use]
157 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
158 self.amsgrad = amsgrad;
159 self
160 }
161
162 fn ensure_state_initialized(&mut self) {
163 if self.state.is_empty() {
164 self.state = self
165 .params
166 .iter()
167 .map(|p| {
168 let data = p.data();
169 AdamState::new(data.shape(), data.device())
170 })
171 .collect();
172 }
173 }
174}
175
176impl Optimizer for Adam {
177 fn step(&mut self) {
178 self.ensure_state_initialized();
179
180 for (i, param) in self.params.iter().enumerate() {
181 if !param.requires_grad() {
182 continue;
183 }
184
185 let grad = match param.grad() {
186 Some(g) => g,
187 None => continue,
188 };
189
190 let state = &mut self.state[i];
191 state.step += 1;
192
193 let param_data = param.data();
194
195 #[cfg(feature = "cuda")]
197 if param_data.device().is_gpu() {
198 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
199 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
200
201 param_data.adam_step_inplace(
203 &grad,
204 &state.exp_avg,
205 &state.exp_avg_sq,
206 self.lr,
207 self.beta1,
208 self.beta2,
209 self.eps,
210 self.weight_decay,
211 bias_correction1,
212 bias_correction2,
213 );
214 continue;
216 }
217
218 let grad_vec = grad.to_vec();
220 let mut param_vec = param_data.to_vec();
221 let mut exp_avg_vec = state.exp_avg.to_vec();
222 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
223
224 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
225 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
226 let step_size = self.lr / bias_correction1;
227 let beta1 = self.beta1;
228 let beta2 = self.beta2;
229 let one_minus_beta1 = 1.0 - beta1;
230 let one_minus_beta2 = 1.0 - beta2;
231 let eps = self.eps;
232 let wd = self.weight_decay;
233
234 for i in 0..param_vec.len() {
235 let g = if wd == 0.0 {
236 grad_vec[i]
237 } else {
238 grad_vec[i] + wd * param_vec[i]
239 };
240 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
241 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
242 let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
243 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
244 }
245
246 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
247 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
248 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
249 }
250 }
251
252 fn zero_grad(&mut self) {
253 for param in &self.params {
254 param.zero_grad();
255 }
256 }
257
258 fn get_lr(&self) -> f32 {
259 self.lr
260 }
261
262 fn set_lr(&mut self, lr: f32) {
263 self.lr = lr;
264 }
265
266 fn parameters(&self) -> &[Parameter] {
267 &self.params
268 }
269}
270
271pub struct AdamW {
289 params: Vec<Parameter>,
291 lr: f32,
293 beta1: f32,
295 beta2: f32,
297 eps: f32,
299 weight_decay: f32,
301 amsgrad: bool,
303 state: Vec<AdamState>,
305}
306
307impl AdamW {
308 #[must_use]
310 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
311 Self::with_betas(params, lr, (0.9, 0.999))
312 }
313
314 #[must_use]
316 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
317 Self {
318 params,
319 lr,
320 beta1: betas.0,
321 beta2: betas.1,
322 eps: 1e-8,
323 weight_decay: 0.01, amsgrad: false,
325 state: Vec::new(),
326 }
327 }
328
329 #[must_use]
331 pub fn with_options(
332 params: Vec<Parameter>,
333 lr: f32,
334 betas: (f32, f32),
335 eps: f32,
336 weight_decay: f32,
337 amsgrad: bool,
338 ) -> Self {
339 Self {
340 params,
341 lr,
342 beta1: betas.0,
343 beta2: betas.1,
344 eps,
345 weight_decay,
346 amsgrad,
347 state: Vec::new(),
348 }
349 }
350
351 #[must_use]
353 pub fn betas(mut self, betas: (f32, f32)) -> Self {
354 self.beta1 = betas.0;
355 self.beta2 = betas.1;
356 self
357 }
358
359 #[must_use]
361 pub fn eps(mut self, eps: f32) -> Self {
362 self.eps = eps;
363 self
364 }
365
366 #[must_use]
368 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
369 self.weight_decay = weight_decay;
370 self
371 }
372
373 #[must_use]
375 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
376 self.amsgrad = amsgrad;
377 self
378 }
379
380 fn ensure_state_initialized(&mut self) {
381 if self.state.is_empty() {
382 self.state = self
383 .params
384 .iter()
385 .map(|p| {
386 let data = p.data();
387 AdamState::new(data.shape(), data.device())
388 })
389 .collect();
390 }
391 }
392}
393
394impl Optimizer for AdamW {
395 fn step(&mut self) {
396 self.ensure_state_initialized();
397
398 for (i, param) in self.params.iter().enumerate() {
399 if !param.requires_grad() {
400 continue;
401 }
402
403 let grad = match param.grad() {
404 Some(g) => g,
405 None => continue,
406 };
407
408 let state = &mut self.state[i];
409 state.step += 1;
410
411 let param_data = param.data();
412
413 #[cfg(feature = "cuda")]
419 if param_data.device().is_gpu() {
420 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
421 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
422
423 param_data.adam_step_inplace(
424 &grad,
425 &state.exp_avg,
426 &state.exp_avg_sq,
427 self.lr,
428 self.beta1,
429 self.beta2,
430 self.eps,
431 self.weight_decay,
432 bias_correction1,
433 bias_correction2,
434 );
435 continue;
436 }
437
438 let grad_vec = grad.to_vec();
440 let mut param_vec = param_data.to_vec();
441 let mut exp_avg_vec = state.exp_avg.to_vec();
442 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
443
444 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
445 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
446 let step_size = self.lr / bias_correction1;
447 let beta1 = self.beta1;
448 let beta2 = self.beta2;
449 let one_minus_beta1 = 1.0 - beta1;
450 let one_minus_beta2 = 1.0 - beta2;
451 let eps = self.eps;
452 let wd_factor = 1.0 - self.lr * self.weight_decay;
453 let has_wd = self.weight_decay != 0.0;
454
455 for i in 0..param_vec.len() {
456 if has_wd {
458 param_vec[i] *= wd_factor;
459 }
460 let g = grad_vec[i];
461 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
462 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
463 let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
464 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
465 }
466
467 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
468 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
469 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
470 }
471 }
472
473 fn zero_grad(&mut self) {
474 for param in &self.params {
475 param.zero_grad();
476 }
477 }
478
479 fn get_lr(&self) -> f32 {
480 self.lr
481 }
482
483 fn set_lr(&mut self, lr: f32) {
484 self.lr = lr;
485 }
486
487 fn parameters(&self) -> &[Parameter] {
488 &self.params
489 }
490}
491
492#[cfg(test)]
497mod tests {
498 use super::*;
499 use axonml_autograd::Variable;
500
501 #[test]
502 fn test_adam_creation() {
503 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
504 let param = Parameter::from_variable(var);
505 let optimizer = Adam::new(vec![param], 0.001);
506
507 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
508 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
509 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
510 }
511
512 #[test]
513 fn test_adam_step() {
514 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
515 let param = Parameter::from_variable(var);
516
517 param
519 .variable()
520 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
521
522 let mut optimizer = Adam::new(vec![param.clone()], 0.1);
523 optimizer.step();
524
525 let new_data = param.data().to_vec();
526 assert!((new_data[0] - 1.0).abs() > 1e-6);
528 }
529
530 #[test]
531 fn test_adamw_creation() {
532 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
533 let param = Parameter::from_variable(var);
534 let optimizer = AdamW::new(vec![param], 0.001);
535
536 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
537 }
538
539 #[test]
540 fn test_adam_builder_pattern() {
541 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
542 let param = Parameter::from_variable(var);
543
544 let optimizer = Adam::new(vec![param], 0.001)
545 .betas((0.95, 0.9999))
546 .eps(1e-7)
547 .weight_decay(0.01)
548 .amsgrad(true);
549
550 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
551 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
552 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
553 assert!(optimizer.amsgrad);
554 }
555}