1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::{Tensor, backend::AutodiffBackend};
5use burn::tensor::{backend::Backend, ops::Device};
6use burn::{module::AutodiffModule, record::Record};
7
8use super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor};
9use crate::{LearningRate, grad_clipping::GradientClippingConfig};
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15#[derive(Config, Debug)]
17pub struct AdamWConfig {
18 #[config(default = 0.9)]
20 beta_1: f32,
21 #[config(default = 0.999)]
23 beta_2: f32,
24 #[config(default = 1e-5)]
26 epsilon: f32,
27 #[config(default = 1e-4)]
29 weight_decay: f32,
30
31 #[config(default = false)]
35 cautious_weight_decay: bool,
36
37 #[config(default = false)]
39 amsgrad: bool,
40 grad_clipping: Option<GradientClippingConfig>,
42}
43
44#[derive(Clone)]
53pub struct AdamW {
54 momentum: AdaptiveMomentumW,
55 weight_decay: f32,
56 cautious_weight_decay: bool,
57}
58
59#[derive(Record, Clone, new)]
61pub struct AdamWState<B: Backend, const D: usize> {
62 pub momentum: AdaptiveMomentumState<B, D>,
64}
65
66impl<B: Backend> SimpleOptimizer<B> for AdamW {
67 type State<const D: usize> = AdamWState<B, D>;
68
69 fn step<const D: usize>(
71 &self,
72 lr: LearningRate,
74 tensor: Tensor<B, D>,
76 grad: Tensor<B, D>,
78 state: Option<Self::State<D>>,
80 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
81 let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
82
83 let decay_rate = lr * (self.weight_decay as f64);
84
85 let decayed_tensor = if decay_rate == 0.0 {
86 tensor.clone()
87 } else if self.cautious_weight_decay {
88 let tensor_pos = tensor.clone().greater_equal_elem(0.0);
91 let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0);
92 let differ = tensor_pos.not_equal(grad_pos);
93
94 tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0)
96 } else {
97 tensor.clone().mul_scalar(1.0 - decay_rate)
98 };
99
100 let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr);
101
102 let state = AdamWState {
103 momentum: momentum_state,
104 };
105
106 (tensor_updated, Some(state))
107 }
108
109 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
110 state.momentum = state.momentum.to_device(device);
111 state
112 }
113}
114
115impl AdamWConfig {
116 pub fn build(&self) -> AdamW {
118 AdamW {
119 momentum: AdaptiveMomentumW {
120 beta_1: self.beta_1,
121 beta_2: self.beta_2,
122 epsilon: self.epsilon,
123 amsgrad: self.amsgrad,
124 },
125 weight_decay: self.weight_decay,
126 cautious_weight_decay: self.cautious_weight_decay,
127 }
128 }
129
130 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
136 let mut optim = OptimizerAdaptor::from(self.build());
137 if let Some(config) = &self.grad_clipping {
138 optim = optim.with_grad_clipping(config.init());
139 }
140 optim
141 }
142}
143
144#[derive(Clone)]
145struct AdaptiveMomentumW {
146 beta_1: f32,
147 beta_2: f32,
148 epsilon: f32,
149 amsgrad: bool,
150}
151
152impl AdaptiveMomentumW {
153 pub fn transform<B: Backend, const D: usize>(
154 &self,
155 grad: Tensor<B, D>,
156 state: Option<AdaptiveMomentumState<B, D>>,
157 ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
158 let factor_1 = 1.0 - self.beta_1;
159 let factor_2 = 1.0 - self.beta_2;
160
161 let state = if let Some(mut state) = state {
162 state.moment_1 = state
164 .moment_1
165 .mul_scalar(self.beta_1)
166 .add(grad.clone().mul_scalar(factor_1));
167
168 state.moment_2 = state
170 .moment_2
171 .mul_scalar(self.beta_2)
172 .add(grad.square().mul_scalar(factor_2));
173
174 if self.amsgrad {
175 let max_v = state
176 .max_moment_2
177 .take()
178 .unwrap_or_else(|| state.moment_2.clone());
179 state.max_moment_2 = Some(max_v.max_pair(state.moment_2.clone()));
180 }
181
182 state.time += 1;
184
185 state
186 } else {
187 let moment_1 = grad.clone().mul_scalar(factor_1);
189
190 let moment_2 = grad.square().mul_scalar(factor_2);
192 let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
193 AdaptiveMomentumState {
194 time: 1,
195 moment_1,
196 moment_2,
197 max_moment_2,
198 }
199 };
200
201 let time: i32 = state.time as i32;
202
203 let moment_1_corrected = state
205 .moment_1
206 .clone()
207 .div_scalar(1f32 - self.beta_1.powi(time));
208
209 let v_to_use = if self.amsgrad {
210 state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
211 } else {
212 &state.moment_2
213 };
214
215 let moment_2_corrected = v_to_use.clone().div_scalar(1f32 - self.beta_2.powi(time));
216
217 let update_delta =
218 moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
219
220 (update_delta, state)
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::TestAutodiffBackend;
228 use crate::{GradientsParams, Optimizer};
229 use burn::module::{Module, Param};
230 use burn::tensor::{Distribution, Tensor, TensorData};
231 use burn::tensor::{Tolerance, ops::FloatElem};
232 use burn_nn::{Linear, LinearConfig, LinearRecord};
233
234 type FT = FloatElem<TestAutodiffBackend>;
235
236 const LEARNING_RATE: LearningRate = 0.01;
237
238 #[test]
239 fn test_adamw_optimizer_save_load_state() {
240 let device = Default::default();
241 let linear = LinearConfig::new(6, 6).init(&device);
242 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
243 let mut optimizer = create_adamw();
244 let grads = linear.forward(x).backward();
245 let grads = GradientsParams::from_grads(grads, &linear);
246 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
247
248 #[cfg(feature = "std")]
249 {
250 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
251
252 BinFileRecorder::<FullPrecisionSettings>::default()
253 .record(
254 optimizer.to_record(),
255 std::env::temp_dir().as_path().join("test_optim_adamw"),
256 )
257 .unwrap();
258 }
259 #[cfg(not(feature = "std"))]
260 {
261 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
262
263 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
264 .record(optimizer.to_record(), ())
265 .unwrap();
266 assert!(!result.is_empty());
267 }
268
269 let state_optim_before = optimizer.to_record();
270 let state_optim_before_copy = optimizer.to_record();
271 let optimizer = create_adamw();
272 let optimizer = optimizer.load_record(state_optim_before_copy);
273 let state_optim_after = optimizer.to_record();
274
275 assert_eq!(state_optim_before.len(), state_optim_after.len());
276 }
277 #[test]
278 fn test_adamw_optimizer_with_amsgrad_50_steps() {
279 let device = Default::default();
280 let mut linear = given_linear_layer(
281 TensorData::from([
282 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
283 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
284 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
285 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
286 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
287 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
288 ]),
289 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
290 );
291
292 let mut optimizer = AdamWConfig::new()
293 .with_epsilon(1e-8)
294 .with_beta_1(0.9)
295 .with_beta_2(0.999)
296 .with_amsgrad(true)
297 .with_weight_decay(0.5)
298 .init();
299
300 for i in 1..=50 {
301 let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
302 .mul_scalar(i as f32 * 0.1)
303 .require_grad();
304
305 let grads = linear.forward(x).backward();
306 let grads = GradientsParams::from_grads(grads, &linear);
307 linear = optimizer.step(LEARNING_RATE, linear, grads);
308 }
309
310 let state_updated = linear.into_record();
311 let weight_updated = state_updated.weight.to_data();
312 let bias_updated = state_updated.bias.unwrap().to_data();
313
314 let weights_expected = TensorData::from([
315 [
316 -0.7822558283805847,
317 -0.42578864097595215,
318 -0.21805696189403534,
319 -0.28366872668266296,
320 -0.46587175130844116,
321 -0.4805040955543518,
322 ],
323 [
324 -0.4722539782524109,
325 -0.5471276640892029,
326 -0.8181359767913818,
327 -0.33425918221473694,
328 -0.3805687427520752,
329 -0.7601516842842102,
330 ],
331 [
332 -0.5475167632102966,
333 -0.5057991743087769,
334 -0.763265073299408,
335 -0.3393959403038025,
336 -0.7490996718406677,
337 -0.28911691904067993,
338 ],
339 [
340 -0.7646660208702087,
341 -0.7050473093986511,
342 -0.8218720555305481,
343 -0.7647438049316406,
344 -0.5919585227966309,
345 -0.40617525577545166,
346 ],
347 [
348 -0.27588561177253723,
349 -0.7025567889213562,
350 -0.24343004822731018,
351 -0.6672990918159485,
352 -0.23728127777576447,
353 -0.556389570236206,
354 ],
355 [
356 -0.5451040267944336,
357 -0.5420684814453125,
358 -0.4348171353340149,
359 -0.3832150399684906,
360 -0.5099242925643921,
361 -0.23440153896808624,
362 ],
363 ]);
364 let bias_expected = TensorData::from([
365 -0.7473056316375732,
366 -0.3745720386505127,
367 -0.5188710689544678,
368 -0.35184532403945923,
369 -0.33705732226371765,
370 -0.4332566559314728,
371 ]);
372
373 type FT = FloatElem<TestAutodiffBackend>;
374 let tolerance = Tolerance::absolute(1e-5);
375 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
376 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
377 }
378 #[test]
379 fn test_adamw_optimizer_with_numbers() {
380 let linear = given_linear_layer(
381 TensorData::from([
382 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
383 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
384 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
385 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
386 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
387 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
388 ]),
389 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
390 );
391 let device = Default::default();
392 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
393 [
394 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
395 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
396 ],
397 &device,
398 )
399 .require_grad();
400 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
401 [
402 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
403 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
404 ],
405 &device,
406 )
407 .require_grad();
408
409 let mut optimizer = AdamWConfig::new()
410 .with_epsilon(1e-8)
411 .with_beta_1(0.9)
412 .with_beta_2(0.999)
413 .with_weight_decay(0.5)
414 .init();
415
416 let grads = linear.forward(x_1).backward();
417 let grads = GradientsParams::from_grads(grads, &linear);
418 let linear = optimizer.step(LEARNING_RATE, linear, grads);
419
420 let grads = linear.forward(x_2).backward();
421 let grads = GradientsParams::from_grads(grads, &linear);
422 let linear = optimizer.step(LEARNING_RATE, linear, grads);
423
424 let state_updated = linear.into_record();
425 let weights_expected = TensorData::from([
426 [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
427 [
428 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
429 ],
430 [
431 -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
432 ],
433 [
434 -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
435 ],
436 [
437 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
438 ],
439 [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
440 ]);
441 let bias_expected = TensorData::from([
442 -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
443 ]);
444
445 let (weight_updated, bias_updated) = (
446 state_updated.weight.to_data(),
447 state_updated.bias.unwrap().to_data(),
448 );
449
450 let tolerance = Tolerance::absolute(1e-2);
451 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
452 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
453 }
454
455 #[test]
456 fn test_adamw_optimizer_with_numbers_cautious() {
457 let linear = given_linear_layer(
458 TensorData::from([
459 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
460 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
461 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
462 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
463 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
464 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
465 ]),
466 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
467 );
468 let device = Default::default();
469 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
470 [
471 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
472 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
473 ],
474 &device,
475 )
476 .require_grad();
477 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
478 [
479 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
480 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085],
481 ],
482 &device,
483 )
484 .require_grad();
485
486 let mut optimizer = AdamWConfig::new()
487 .with_cautious_weight_decay(true)
488 .with_epsilon(1e-8)
489 .with_beta_1(0.9)
490 .with_beta_2(0.999)
491 .with_weight_decay(0.5)
492 .init();
493
494 let grads = linear.forward(x_1).backward();
495 let grads = GradientsParams::from_grads(grads, &linear);
496 let linear = optimizer.step(LEARNING_RATE, linear, grads);
497
498 let grads = linear.forward(x_2).backward();
499 let grads = GradientsParams::from_grads(grads, &linear);
500 let linear = optimizer.step(LEARNING_RATE, linear, grads);
501
502 let state_updated = linear.into_record();
503 let weights_expected = TensorData::from([
504 [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
505 [
506 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
507 ],
508 [
509 -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
510 ],
511 [
512 -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
513 ],
514 [
515 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
516 ],
517 [
518 -0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332,
519 ],
520 ]);
521 let bias_expected = TensorData::from([
522 -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
523 ]);
524
525 let (weight_updated, bias_updated) = (
526 state_updated.weight.to_data(),
527 state_updated.bias.unwrap().to_data(),
528 );
529
530 let tolerance = Tolerance::absolute(1e-2);
531 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
532 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
533 }
534
535 #[test]
536 fn test_adam_optimizer_no_nan() {
537 let linear = given_linear_layer(
538 TensorData::from([
539 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
540 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
541 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
542 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
543 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
544 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
545 ]),
546 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
547 );
548
549 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
550 [
551 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
552 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
553 ],
554 &Default::default(),
555 )
556 .require_grad();
557
558 let mut optimizer = AdamWConfig::new()
559 .with_epsilon(1e-8)
560 .with_beta_1(0.9)
561 .with_beta_2(0.999)
562 .with_weight_decay(0.5)
563 .init();
564
565 let grads = linear.forward(x.clone()).backward();
566 let grads = GradientsParams::from_grads(grads, &linear);
567 let linear = optimizer.step(LEARNING_RATE, linear, grads);
568
569 let grads = linear.forward(x).backward();
570 let grads = GradientsParams::from_grads(grads, &linear);
571 let linear = optimizer.step(LEARNING_RATE, linear, grads);
572
573 let state_updated = linear.into_record();
574 assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
575 }
576
577 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
578 let device = Default::default();
579 let record = LinearRecord {
580 weight: Param::from_data(weight, &device),
581 bias: Some(Param::from_data(bias, &device)),
582 };
583
584 LinearConfig::new(6, 6).init(&device).load_record(record)
585 }
586
587 fn create_adamw() -> OptimizerAdaptor<AdamW, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
588 let config = AdamWConfig::new();
589 AdamW {
590 momentum: AdaptiveMomentumW {
591 beta_1: config.beta_1,
592 beta_2: config.beta_2,
593 epsilon: config.epsilon,
594 amsgrad: config.amsgrad,
595 },
596 weight_decay: config.weight_decay,
597 cautious_weight_decay: false,
598 }
599 .into()
600 }
601}