burn_optim/optim/
adamw.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::{Tensor, backend::AutodiffBackend};
5use burn::tensor::{backend::Backend, ops::Device};
6use burn::{module::AutodiffModule, record::Record};
7
8use super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor};
9use crate::{LearningRate, grad_clipping::GradientClippingConfig};
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15/// [`AdamW`] Configuration.
16#[derive(Config, Debug)]
17pub struct AdamWConfig {
18    /// Parameter for AdamW.
19    #[config(default = 0.9)]
20    beta_1: f32,
21    /// Parameter for AdamW.
22    #[config(default = 0.999)]
23    beta_2: f32,
24    /// A value required for numerical stability.
25    #[config(default = 1e-5)]
26    epsilon: f32,
27    /// Weight decay config.
28    #[config(default = 1e-4)]
29    weight_decay: f32,
30
31    /// Cautious weight decay config.
32    ///
33    /// See: <https://arxiv.org/abs/2510.12402>
34    #[config(default = false)]
35    cautious_weight_decay: bool,
36
37    /// [Gradient Clipping](GradientClippingConfig) config.
38    grad_clipping: Option<GradientClippingConfig>,
39}
40
41/// AdamW optimizer.
42///
43/// See:
44/// - [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
45/// - [Cautious Weight Decay, 2025](https://arxiv.org/abs/2510.12402)
46///
47/// Configured by [`AdamWConfig`].
48#[derive(Clone)]
49pub struct AdamW {
50    momentum: AdaptiveMomentumW,
51    weight_decay: f32,
52    cautious_weight_decay: bool,
53}
54
55/// AdamW state.
56#[derive(Record, Clone, new)]
57pub struct AdamWState<B: Backend, const D: usize> {
58    /// Th current adaptive momentum state.
59    pub momentum: AdaptiveMomentumState<B, D>,
60}
61
62impl<B: Backend> SimpleOptimizer<B> for AdamW {
63    type State<const D: usize> = AdamWState<B, D>;
64
65    /// A single optimization step for any tensor that represents the parameters of a model.
66    fn step<const D: usize>(
67        &self,
68        // Learning rate.
69        lr: LearningRate,
70        // Any tensor that represents the parameters of a model.
71        tensor: Tensor<B, D>,
72        // Gradient of the loss w.r.t. the parameters.
73        grad: Tensor<B, D>,
74        // State of the optimizer.
75        state: Option<Self::State<D>>,
76    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
77        let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
78
79        let decay_rate = lr * (self.weight_decay as f64);
80
81        let decayed_tensor = if decay_rate == 0.0 {
82            tensor.clone()
83        } else if self.cautious_weight_decay {
84            // Cautious weight decay.
85            // See: https://arxiv.org/abs/2510.12402
86            let tensor_pos = tensor.clone().greater_equal_elem(0.0);
87            let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0);
88            let differ = tensor_pos.not_equal(grad_pos);
89
90            // Zero out the decay where the decay is counter to the update direction.
91            tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0)
92        } else {
93            tensor.clone().mul_scalar(1.0 - decay_rate)
94        };
95
96        let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr);
97
98        let state = AdamWState {
99            momentum: momentum_state,
100        };
101
102        (tensor_updated, Some(state))
103    }
104
105    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
106        state.momentum = state.momentum.to_device(device);
107        state
108    }
109}
110
111impl AdamWConfig {
112    /// Initialize AdamW optimizer.
113    ///
114    /// # Returns
115    ///
116    /// Returns an optimizer that can be used to optimize a module.
117    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
118        let optim = AdamW {
119            momentum: AdaptiveMomentumW {
120                beta_1: self.beta_1,
121                beta_2: self.beta_2,
122                epsilon: self.epsilon,
123            },
124            weight_decay: self.weight_decay,
125            cautious_weight_decay: self.cautious_weight_decay,
126        };
127
128        let mut optim = OptimizerAdaptor::from(optim);
129        if let Some(config) = &self.grad_clipping {
130            optim = optim.with_grad_clipping(config.init());
131        }
132        optim
133    }
134}
135
136#[derive(Clone)]
137struct AdaptiveMomentumW {
138    beta_1: f32,
139    beta_2: f32,
140    epsilon: f32,
141}
142
143impl AdaptiveMomentumW {
144    pub fn transform<B: Backend, const D: usize>(
145        &self,
146        grad: Tensor<B, D>,
147        state: Option<AdaptiveMomentumState<B, D>>,
148    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
149        let factor_1 = 1.0 - self.beta_1;
150        let factor_2 = 1.0 - self.beta_2;
151
152        let state = if let Some(mut state) = state {
153            // Update first moment estimate.
154            state.moment_1 = state
155                .moment_1
156                .mul_scalar(self.beta_1)
157                .add(grad.clone().mul_scalar(factor_1));
158
159            // Update second moment estimate.
160            state.moment_2 = state
161                .moment_2
162                .mul_scalar(self.beta_2)
163                .add(grad.powi_scalar(2).mul_scalar(factor_2));
164
165            // Update time.
166            state.time += 1;
167
168            state
169        } else {
170            // Initialize first moment estimate.
171            let moment_1 = grad.clone().mul_scalar(factor_1);
172
173            // Initialize second moment estimate.
174            let moment_2 = grad.powi_scalar(2).mul_scalar(factor_2);
175
176            AdaptiveMomentumState::new(1, moment_1, moment_2)
177        };
178
179        let time: i32 = state.time as i32;
180
181        // Compute bias-corrected first and second moment estimates.
182        let moment_1_corrected = state
183            .moment_1
184            .clone()
185            .div_scalar(1f32 - self.beta_1.powi(time));
186
187        let moment_2_corrected = state
188            .moment_2
189            .clone()
190            .div_scalar(1f32 - self.beta_2.powi(time));
191
192        // Compute update delta. This still needs to be scaled by the learning rate.
193        let update_delta =
194            moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
195
196        (
197            update_delta,
198            AdaptiveMomentumState::new(state.time, state.moment_1, state.moment_2),
199        )
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::TestAutodiffBackend;
207    use crate::{GradientsParams, Optimizer};
208    use burn::module::{Module, Param};
209    use burn::tensor::{Distribution, Tensor, TensorData};
210    use burn::tensor::{Tolerance, ops::FloatElem};
211    use burn_nn::{Linear, LinearConfig, LinearRecord};
212
213    type FT = FloatElem<TestAutodiffBackend>;
214
215    const LEARNING_RATE: LearningRate = 0.01;
216
217    #[test]
218    fn test_adamw_optimizer_save_load_state() {
219        let device = Default::default();
220        let linear = LinearConfig::new(6, 6).init(&device);
221        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
222        let mut optimizer = create_adamw();
223        let grads = linear.forward(x).backward();
224        let grads = GradientsParams::from_grads(grads, &linear);
225        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
226
227        #[cfg(feature = "std")]
228        {
229            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
230
231            BinFileRecorder::<FullPrecisionSettings>::default()
232                .record(
233                    optimizer.to_record(),
234                    std::env::temp_dir().as_path().join("test_optim_adamw"),
235                )
236                .unwrap();
237        }
238        #[cfg(not(feature = "std"))]
239        {
240            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
241
242            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
243                .record(optimizer.to_record(), ())
244                .unwrap();
245            assert!(!result.is_empty());
246        }
247
248        let state_optim_before = optimizer.to_record();
249        let state_optim_before_copy = optimizer.to_record();
250        let optimizer = create_adamw();
251        let optimizer = optimizer.load_record(state_optim_before_copy);
252        let state_optim_after = optimizer.to_record();
253
254        assert_eq!(state_optim_before.len(), state_optim_after.len());
255    }
256
257    #[test]
258    fn test_adamw_optimizer_with_numbers() {
259        let linear = given_linear_layer(
260            TensorData::from([
261                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
262                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
263                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
264                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
265                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
266                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
267            ]),
268            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
269        );
270        let device = Default::default();
271        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
272            [
273                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
274                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
275            ],
276            &device,
277        )
278        .require_grad();
279        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
280            [
281                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
282                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
283            ],
284            &device,
285        )
286        .require_grad();
287
288        let mut optimizer = AdamWConfig::new()
289            .with_epsilon(1e-8)
290            .with_beta_1(0.9)
291            .with_beta_2(0.999)
292            .with_weight_decay(0.5)
293            .init();
294
295        let grads = linear.forward(x_1).backward();
296        let grads = GradientsParams::from_grads(grads, &linear);
297        let linear = optimizer.step(LEARNING_RATE, linear, grads);
298
299        let grads = linear.forward(x_2).backward();
300        let grads = GradientsParams::from_grads(grads, &linear);
301        let linear = optimizer.step(LEARNING_RATE, linear, grads);
302
303        let state_updated = linear.into_record();
304        let weights_expected = TensorData::from([
305            [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
306            [
307                0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
308            ],
309            [
310                -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
311            ],
312            [
313                -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
314            ],
315            [
316                0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
317            ],
318            [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
319        ]);
320        let bias_expected = TensorData::from([
321            -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
322        ]);
323
324        let (weight_updated, bias_updated) = (
325            state_updated.weight.to_data(),
326            state_updated.bias.unwrap().to_data(),
327        );
328
329        let tolerance = Tolerance::absolute(1e-2);
330        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
331        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
332    }
333
334    #[test]
335    fn test_adamw_optimizer_with_numbers_cautious() {
336        let linear = given_linear_layer(
337            TensorData::from([
338                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
339                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
340                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
341                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
342                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
343                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
344            ]),
345            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
346        );
347        let device = Default::default();
348        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
349            [
350                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
351                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
352            ],
353            &device,
354        )
355        .require_grad();
356        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
357            [
358                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
359                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085],
360            ],
361            &device,
362        )
363        .require_grad();
364
365        let mut optimizer = AdamWConfig::new()
366            .with_cautious_weight_decay(true)
367            .with_epsilon(1e-8)
368            .with_beta_1(0.9)
369            .with_beta_2(0.999)
370            .with_weight_decay(0.5)
371            .init();
372
373        let grads = linear.forward(x_1).backward();
374        let grads = GradientsParams::from_grads(grads, &linear);
375        let linear = optimizer.step(LEARNING_RATE, linear, grads);
376
377        let grads = linear.forward(x_2).backward();
378        let grads = GradientsParams::from_grads(grads, &linear);
379        let linear = optimizer.step(LEARNING_RATE, linear, grads);
380
381        let state_updated = linear.into_record();
382        let weights_expected = TensorData::from([
383            [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
384            [
385                0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
386            ],
387            [
388                -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
389            ],
390            [
391                -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
392            ],
393            [
394                0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
395            ],
396            [
397                -0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332,
398            ],
399        ]);
400        let bias_expected = TensorData::from([
401            -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
402        ]);
403
404        let (weight_updated, bias_updated) = (
405            state_updated.weight.to_data(),
406            state_updated.bias.unwrap().to_data(),
407        );
408
409        let tolerance = Tolerance::absolute(1e-2);
410        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
411        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
412    }
413
414    #[test]
415    fn test_adam_optimizer_no_nan() {
416        let linear = given_linear_layer(
417            TensorData::from([
418                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
419                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
420                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
421                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
422                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
423                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
424            ]),
425            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
426        );
427
428        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
429            [
430                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
431                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
432            ],
433            &Default::default(),
434        )
435        .require_grad();
436
437        let mut optimizer = AdamWConfig::new()
438            .with_epsilon(1e-8)
439            .with_beta_1(0.9)
440            .with_beta_2(0.999)
441            .with_weight_decay(0.5)
442            .init();
443
444        let grads = linear.forward(x.clone()).backward();
445        let grads = GradientsParams::from_grads(grads, &linear);
446        let linear = optimizer.step(LEARNING_RATE, linear, grads);
447
448        let grads = linear.forward(x).backward();
449        let grads = GradientsParams::from_grads(grads, &linear);
450        let linear = optimizer.step(LEARNING_RATE, linear, grads);
451
452        let state_updated = linear.into_record();
453        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
454    }
455
456    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
457        let device = Default::default();
458        let record = LinearRecord {
459            weight: Param::from_data(weight, &device),
460            bias: Some(Param::from_data(bias, &device)),
461        };
462
463        LinearConfig::new(6, 6).init(&device).load_record(record)
464    }
465
466    fn create_adamw() -> OptimizerAdaptor<AdamW, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
467        let config = AdamWConfig::new();
468        AdamW {
469            momentum: AdaptiveMomentumW {
470                beta_1: config.beta_1,
471                beta_2: config.beta_2,
472                epsilon: config.epsilon,
473            },
474            weight_decay: config.weight_decay,
475            cautious_weight_decay: false,
476        }
477        .into()
478    }
479}