Skip to main content

burn_optim/optim/
adamw.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::{Tensor, backend::AutodiffBackend};
5use burn::tensor::{backend::Backend, ops::Device};
6use burn::{module::AutodiffModule, record::Record};
7
8use super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor};
9use crate::{LearningRate, grad_clipping::GradientClippingConfig};
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15/// [`AdamW`] Configuration.
16#[derive(Config, Debug)]
17pub struct AdamWConfig {
18    /// Parameter for AdamW.
19    #[config(default = 0.9)]
20    beta_1: f32,
21    /// Parameter for AdamW.
22    #[config(default = 0.999)]
23    beta_2: f32,
24    /// A value required for numerical stability.
25    #[config(default = 1e-5)]
26    epsilon: f32,
27    /// Weight decay config.
28    #[config(default = 1e-4)]
29    weight_decay: f32,
30
31    /// Cautious weight decay config.
32    ///
33    /// See: <https://arxiv.org/abs/2510.12402>
34    #[config(default = false)]
35    cautious_weight_decay: bool,
36
37    /// Whether to use AMSGrad algorithm
38    #[config(default = false)]
39    amsgrad: bool,
40    /// [Gradient Clipping](GradientClippingConfig) config.
41    grad_clipping: Option<GradientClippingConfig>,
42}
43
44/// AdamW optimizer.
45///
46/// See:
47/// - [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
48/// - [Cautious Weight Decay, 2025](https://arxiv.org/abs/2510.12402)
49/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
50///
51/// Configured by [`AdamWConfig`].
52#[derive(Clone)]
53pub struct AdamW {
54    momentum: AdaptiveMomentumW,
55    weight_decay: f32,
56    cautious_weight_decay: bool,
57}
58
59/// AdamW state.
60#[derive(Record, Clone, new)]
61pub struct AdamWState<B: Backend, const D: usize> {
62    /// Th current adaptive momentum state.
63    pub momentum: AdaptiveMomentumState<B, D>,
64}
65
66impl<B: Backend> SimpleOptimizer<B> for AdamW {
67    type State<const D: usize> = AdamWState<B, D>;
68
69    /// A single optimization step for any tensor that represents the parameters of a model.
70    fn step<const D: usize>(
71        &self,
72        // Learning rate.
73        lr: LearningRate,
74        // Any tensor that represents the parameters of a model.
75        tensor: Tensor<B, D>,
76        // Gradient of the loss w.r.t. the parameters.
77        grad: Tensor<B, D>,
78        // State of the optimizer.
79        state: Option<Self::State<D>>,
80    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
81        let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
82
83        let decay_rate = lr * (self.weight_decay as f64);
84
85        let decayed_tensor = if decay_rate == 0.0 {
86            tensor.clone()
87        } else if self.cautious_weight_decay {
88            // Cautious weight decay.
89            // See: https://arxiv.org/abs/2510.12402
90            let tensor_pos = tensor.clone().greater_equal_elem(0.0);
91            let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0);
92            let differ = tensor_pos.not_equal(grad_pos);
93
94            // Zero out the decay where the decay is counter to the update direction.
95            tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0)
96        } else {
97            tensor.clone().mul_scalar(1.0 - decay_rate)
98        };
99
100        let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr);
101
102        let state = AdamWState {
103            momentum: momentum_state,
104        };
105
106        (tensor_updated, Some(state))
107    }
108
109    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
110        state.momentum = state.momentum.to_device(device);
111        state
112    }
113}
114
115impl AdamWConfig {
116    /// Build an [`AdamW`] from the config.
117    pub fn build(&self) -> AdamW {
118        AdamW {
119            momentum: AdaptiveMomentumW {
120                beta_1: self.beta_1,
121                beta_2: self.beta_2,
122                epsilon: self.epsilon,
123                amsgrad: self.amsgrad,
124            },
125            weight_decay: self.weight_decay,
126            cautious_weight_decay: self.cautious_weight_decay,
127        }
128    }
129
130    /// Initialize AdamW optimizer.
131    ///
132    /// # Returns
133    ///
134    /// Returns an optimizer that can be used to optimize a module.
135    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
136        let mut optim = OptimizerAdaptor::from(self.build());
137        if let Some(config) = &self.grad_clipping {
138            optim = optim.with_grad_clipping(config.init());
139        }
140        optim
141    }
142}
143
144#[derive(Clone)]
145struct AdaptiveMomentumW {
146    beta_1: f32,
147    beta_2: f32,
148    epsilon: f32,
149    amsgrad: bool,
150}
151
152impl AdaptiveMomentumW {
153    pub fn transform<B: Backend, const D: usize>(
154        &self,
155        grad: Tensor<B, D>,
156        state: Option<AdaptiveMomentumState<B, D>>,
157    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
158        let factor_1 = 1.0 - self.beta_1;
159        let factor_2 = 1.0 - self.beta_2;
160
161        let state = if let Some(mut state) = state {
162            // Update first moment estimate.
163            state.moment_1 = state
164                .moment_1
165                .mul_scalar(self.beta_1)
166                .add(grad.clone().mul_scalar(factor_1));
167
168            // Update second moment estimate.
169            state.moment_2 = state
170                .moment_2
171                .mul_scalar(self.beta_2)
172                .add(grad.square().mul_scalar(factor_2));
173
174            if self.amsgrad {
175                let max_v = state
176                    .max_moment_2
177                    .take()
178                    .unwrap_or_else(|| state.moment_2.clone());
179                state.max_moment_2 = Some(max_v.max_pair(state.moment_2.clone()));
180            }
181
182            // Update time.
183            state.time += 1;
184
185            state
186        } else {
187            // Initialize first moment estimate.
188            let moment_1 = grad.clone().mul_scalar(factor_1);
189
190            // Initialize second moment estimate.
191            let moment_2 = grad.square().mul_scalar(factor_2);
192            let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
193            AdaptiveMomentumState {
194                time: 1,
195                moment_1,
196                moment_2,
197                max_moment_2,
198            }
199        };
200
201        let time: i32 = state.time as i32;
202
203        // Compute bias-corrected first and second moment estimates.
204        let moment_1_corrected = state
205            .moment_1
206            .clone()
207            .div_scalar(1f32 - self.beta_1.powi(time));
208
209        let v_to_use = if self.amsgrad {
210            state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
211        } else {
212            &state.moment_2
213        };
214
215        let moment_2_corrected = v_to_use.clone().div_scalar(1f32 - self.beta_2.powi(time));
216
217        let update_delta =
218            moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
219
220        (update_delta, state)
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::TestAutodiffBackend;
228    use crate::{GradientsParams, Optimizer};
229    use burn::module::{Module, Param};
230    use burn::tensor::{Distribution, Tensor, TensorData};
231    use burn::tensor::{Tolerance, ops::FloatElem};
232    use burn_nn::{Linear, LinearConfig, LinearRecord};
233
234    type FT = FloatElem<TestAutodiffBackend>;
235
236    const LEARNING_RATE: LearningRate = 0.01;
237
238    #[test]
239    fn test_adamw_optimizer_save_load_state() {
240        let device = Default::default();
241        let linear = LinearConfig::new(6, 6).init(&device);
242        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
243        let mut optimizer = create_adamw();
244        let grads = linear.forward(x).backward();
245        let grads = GradientsParams::from_grads(grads, &linear);
246        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
247
248        #[cfg(feature = "std")]
249        {
250            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
251
252            BinFileRecorder::<FullPrecisionSettings>::default()
253                .record(
254                    optimizer.to_record(),
255                    std::env::temp_dir().as_path().join("test_optim_adamw"),
256                )
257                .unwrap();
258        }
259        #[cfg(not(feature = "std"))]
260        {
261            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
262
263            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
264                .record(optimizer.to_record(), ())
265                .unwrap();
266            assert!(!result.is_empty());
267        }
268
269        let state_optim_before = optimizer.to_record();
270        let state_optim_before_copy = optimizer.to_record();
271        let optimizer = create_adamw();
272        let optimizer = optimizer.load_record(state_optim_before_copy);
273        let state_optim_after = optimizer.to_record();
274
275        assert_eq!(state_optim_before.len(), state_optim_after.len());
276    }
277    #[test]
278    fn test_adamw_optimizer_with_amsgrad_50_steps() {
279        let device = Default::default();
280        let mut linear = given_linear_layer(
281            TensorData::from([
282                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
283                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
284                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
285                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
286                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
287                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
288            ]),
289            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
290        );
291
292        let mut optimizer = AdamWConfig::new()
293            .with_epsilon(1e-8)
294            .with_beta_1(0.9)
295            .with_beta_2(0.999)
296            .with_amsgrad(true)
297            .with_weight_decay(0.5)
298            .init();
299
300        for i in 1..=50 {
301            let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
302                .mul_scalar(i as f32 * 0.1)
303                .require_grad();
304
305            let grads = linear.forward(x).backward();
306            let grads = GradientsParams::from_grads(grads, &linear);
307            linear = optimizer.step(LEARNING_RATE, linear, grads);
308        }
309
310        let state_updated = linear.into_record();
311        let weight_updated = state_updated.weight.to_data();
312        let bias_updated = state_updated.bias.unwrap().to_data();
313
314        let weights_expected = TensorData::from([
315            [
316                -0.7822558283805847,
317                -0.42578864097595215,
318                -0.21805696189403534,
319                -0.28366872668266296,
320                -0.46587175130844116,
321                -0.4805040955543518,
322            ],
323            [
324                -0.4722539782524109,
325                -0.5471276640892029,
326                -0.8181359767913818,
327                -0.33425918221473694,
328                -0.3805687427520752,
329                -0.7601516842842102,
330            ],
331            [
332                -0.5475167632102966,
333                -0.5057991743087769,
334                -0.763265073299408,
335                -0.3393959403038025,
336                -0.7490996718406677,
337                -0.28911691904067993,
338            ],
339            [
340                -0.7646660208702087,
341                -0.7050473093986511,
342                -0.8218720555305481,
343                -0.7647438049316406,
344                -0.5919585227966309,
345                -0.40617525577545166,
346            ],
347            [
348                -0.27588561177253723,
349                -0.7025567889213562,
350                -0.24343004822731018,
351                -0.6672990918159485,
352                -0.23728127777576447,
353                -0.556389570236206,
354            ],
355            [
356                -0.5451040267944336,
357                -0.5420684814453125,
358                -0.4348171353340149,
359                -0.3832150399684906,
360                -0.5099242925643921,
361                -0.23440153896808624,
362            ],
363        ]);
364        let bias_expected = TensorData::from([
365            -0.7473056316375732,
366            -0.3745720386505127,
367            -0.5188710689544678,
368            -0.35184532403945923,
369            -0.33705732226371765,
370            -0.4332566559314728,
371        ]);
372
373        type FT = FloatElem<TestAutodiffBackend>;
374        let tolerance = Tolerance::absolute(1e-5);
375        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
376        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
377    }
378    #[test]
379    fn test_adamw_optimizer_with_numbers() {
380        let linear = given_linear_layer(
381            TensorData::from([
382                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
383                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
384                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
385                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
386                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
387                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
388            ]),
389            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
390        );
391        let device = Default::default();
392        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
393            [
394                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
395                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
396            ],
397            &device,
398        )
399        .require_grad();
400        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
401            [
402                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
403                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
404            ],
405            &device,
406        )
407        .require_grad();
408
409        let mut optimizer = AdamWConfig::new()
410            .with_epsilon(1e-8)
411            .with_beta_1(0.9)
412            .with_beta_2(0.999)
413            .with_weight_decay(0.5)
414            .init();
415
416        let grads = linear.forward(x_1).backward();
417        let grads = GradientsParams::from_grads(grads, &linear);
418        let linear = optimizer.step(LEARNING_RATE, linear, grads);
419
420        let grads = linear.forward(x_2).backward();
421        let grads = GradientsParams::from_grads(grads, &linear);
422        let linear = optimizer.step(LEARNING_RATE, linear, grads);
423
424        let state_updated = linear.into_record();
425        let weights_expected = TensorData::from([
426            [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
427            [
428                0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
429            ],
430            [
431                -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
432            ],
433            [
434                -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
435            ],
436            [
437                0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
438            ],
439            [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
440        ]);
441        let bias_expected = TensorData::from([
442            -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
443        ]);
444
445        let (weight_updated, bias_updated) = (
446            state_updated.weight.to_data(),
447            state_updated.bias.unwrap().to_data(),
448        );
449
450        let tolerance = Tolerance::absolute(1e-2);
451        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
452        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
453    }
454
455    #[test]
456    fn test_adamw_optimizer_with_numbers_cautious() {
457        let linear = given_linear_layer(
458            TensorData::from([
459                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
460                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
461                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
462                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
463                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
464                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
465            ]),
466            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
467        );
468        let device = Default::default();
469        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
470            [
471                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
472                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
473            ],
474            &device,
475        )
476        .require_grad();
477        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
478            [
479                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
480                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085],
481            ],
482            &device,
483        )
484        .require_grad();
485
486        let mut optimizer = AdamWConfig::new()
487            .with_cautious_weight_decay(true)
488            .with_epsilon(1e-8)
489            .with_beta_1(0.9)
490            .with_beta_2(0.999)
491            .with_weight_decay(0.5)
492            .init();
493
494        let grads = linear.forward(x_1).backward();
495        let grads = GradientsParams::from_grads(grads, &linear);
496        let linear = optimizer.step(LEARNING_RATE, linear, grads);
497
498        let grads = linear.forward(x_2).backward();
499        let grads = GradientsParams::from_grads(grads, &linear);
500        let linear = optimizer.step(LEARNING_RATE, linear, grads);
501
502        let state_updated = linear.into_record();
503        let weights_expected = TensorData::from([
504            [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
505            [
506                0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
507            ],
508            [
509                -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
510            ],
511            [
512                -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
513            ],
514            [
515                0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
516            ],
517            [
518                -0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332,
519            ],
520        ]);
521        let bias_expected = TensorData::from([
522            -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
523        ]);
524
525        let (weight_updated, bias_updated) = (
526            state_updated.weight.to_data(),
527            state_updated.bias.unwrap().to_data(),
528        );
529
530        let tolerance = Tolerance::absolute(1e-2);
531        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
532        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
533    }
534
535    #[test]
536    fn test_adam_optimizer_no_nan() {
537        let linear = given_linear_layer(
538            TensorData::from([
539                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
540                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
541                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
542                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
543                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
544                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
545            ]),
546            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
547        );
548
549        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
550            [
551                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
552                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
553            ],
554            &Default::default(),
555        )
556        .require_grad();
557
558        let mut optimizer = AdamWConfig::new()
559            .with_epsilon(1e-8)
560            .with_beta_1(0.9)
561            .with_beta_2(0.999)
562            .with_weight_decay(0.5)
563            .init();
564
565        let grads = linear.forward(x.clone()).backward();
566        let grads = GradientsParams::from_grads(grads, &linear);
567        let linear = optimizer.step(LEARNING_RATE, linear, grads);
568
569        let grads = linear.forward(x).backward();
570        let grads = GradientsParams::from_grads(grads, &linear);
571        let linear = optimizer.step(LEARNING_RATE, linear, grads);
572
573        let state_updated = linear.into_record();
574        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
575    }
576
577    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
578        let device = Default::default();
579        let record = LinearRecord {
580            weight: Param::from_data(weight, &device),
581            bias: Some(Param::from_data(bias, &device)),
582        };
583
584        LinearConfig::new(6, 6).init(&device).load_record(record)
585    }
586
587    fn create_adamw() -> OptimizerAdaptor<AdamW, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
588        let config = AdamWConfig::new();
589        AdamW {
590            momentum: AdaptiveMomentumW {
591                beta_1: config.beta_1,
592                beta_2: config.beta_2,
593                epsilon: config.epsilon,
594                amsgrad: config.amsgrad,
595            },
596            weight_decay: config.weight_decay,
597            cautious_weight_decay: false,
598        }
599        .into()
600    }
601}