burn_core/optim/
adamw.rs

1use super::{AdaptiveMomentumState, SimpleOptimizer};
2use crate::config::Config;
3use crate::optim::adaptor::OptimizerAdaptor;
4use crate::tensor::{Tensor, backend::AutodiffBackend};
5use crate::{
6    self as burn, LearningRate, grad_clipping::GradientClippingConfig, module::AutodiffModule,
7    record::Record,
8};
9use burn_tensor::{backend::Backend, ops::Device};
10
11#[cfg(not(feature = "std"))]
12use num_traits::Float;
13
14/// AdamW configuration.
15#[derive(Config)]
16pub struct AdamWConfig {
17    /// Parameter for AdamW.
18    #[config(default = 0.9)]
19    beta_1: f32,
20    /// Parameter for AdamW.
21    #[config(default = 0.999)]
22    beta_2: f32,
23    /// A value required for numerical stability.
24    #[config(default = 1e-5)]
25    epsilon: f32,
26    /// Weight decay config.
27    #[config(default = 1e-4)]
28    weight_decay: f32,
29    /// [Gradient Clipping](GradientClippingConfig) config.
30    grad_clipping: Option<GradientClippingConfig>,
31}
32
33/// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
34#[derive(Clone)]
35pub struct AdamW {
36    momentum: AdaptiveMomentumW,
37    weight_decay: f32,
38}
39
40/// AdamW state.
41#[derive(Record, Clone, new)]
42pub struct AdamWState<B: Backend, const D: usize> {
43    /// Th current adaptive momentum state.
44    pub momentum: AdaptiveMomentumState<B, D>,
45}
46
47impl<B: Backend> SimpleOptimizer<B> for AdamW {
48    type State<const D: usize> = AdamWState<B, D>;
49
50    /// A single optimization step for any tensor that represents the parameters of a model.
51    fn step<const D: usize>(
52        &self,
53        // Learning rate.
54        lr: LearningRate,
55        // Any tensor that represents the parameters of a model.
56        tensor: Tensor<B, D>,
57        // Gradient of the loss w.r.t. the parameters.
58        grad: Tensor<B, D>,
59        // State of the optimizer.
60        state: Option<Self::State<D>>,
61    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
62        let tensor_updated = tensor.clone() - tensor.mul_scalar(lr).mul_scalar(self.weight_decay);
63
64        let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
65
66        let state = AdamWState {
67            momentum: momentum_state,
68        };
69
70        (tensor_updated - raw_delta.mul_scalar(lr), Some(state))
71    }
72
73    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
74        state.momentum = state.momentum.to_device(device);
75        state
76    }
77}
78
79impl AdamWConfig {
80    /// Initialize AdamW optimizer.
81    ///
82    /// # Returns
83    ///
84    /// Returns an optimizer that can be used to optimize a module.
85    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
86        let optim = AdamW {
87            momentum: AdaptiveMomentumW {
88                beta_1: self.beta_1,
89                beta_2: self.beta_2,
90                epsilon: self.epsilon,
91            },
92            weight_decay: self.weight_decay,
93        };
94
95        let mut optim = OptimizerAdaptor::from(optim);
96        if let Some(config) = &self.grad_clipping {
97            optim = optim.with_grad_clipping(config.init());
98        }
99        optim
100    }
101}
102
103#[derive(Clone)]
104struct AdaptiveMomentumW {
105    beta_1: f32,
106    beta_2: f32,
107    epsilon: f32,
108}
109
110impl AdaptiveMomentumW {
111    pub fn transform<B: Backend, const D: usize>(
112        &self,
113        grad: Tensor<B, D>,
114        state: Option<AdaptiveMomentumState<B, D>>,
115    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
116        let state = if let Some(mut state) = state {
117            // Update first moment estimate.
118            let factor = 1.0 - self.beta_1;
119            state.moment_1 = state
120                .moment_1
121                .mul_scalar(self.beta_1)
122                .add(grad.clone().mul_scalar(factor));
123
124            // Update second moment estimate.
125            let factor = 1.0 - self.beta_2;
126            state.moment_2 = state
127                .moment_2
128                .mul_scalar(self.beta_2)
129                .add(grad.powi_scalar(2).mul_scalar(factor));
130
131            // Update time.
132            state.time += 1;
133
134            state
135        } else {
136            // Initialize first moment estimate.
137            let factor = 1.0 - self.beta_1;
138            let moment_1 = grad.clone().mul_scalar(factor);
139
140            // Initialize second moment estimate.
141            let factor = 1.0 - self.beta_2;
142            let moment_2 = grad.powi_scalar(2).mul_scalar(factor);
143
144            AdaptiveMomentumState::new(1, moment_1, moment_2)
145        };
146
147        let time: i32 = state.time as i32;
148
149        // Compute bias-corrected first and second moment estimates.
150        let moment_1_corrected = state
151            .moment_1
152            .clone()
153            .div_scalar(1f32 - self.beta_1.powi(time));
154
155        let moment_2_corrected = state
156            .moment_2
157            .clone()
158            .div_scalar(1f32 - self.beta_2.powi(time));
159
160        // Compute update delta. This still needs to be scaled by the learning rate.
161        let update_delta =
162            moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
163
164        (
165            update_delta,
166            AdaptiveMomentumState::new(state.time, state.moment_1, state.moment_2),
167        )
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::module::{Module, Param};
175    use crate::optim::{GradientsParams, Optimizer};
176    use crate::tensor::{Distribution, Tensor, TensorData};
177    use crate::{TestAutodiffBackend, nn};
178    use burn_tensor::{Tolerance, ops::FloatElem};
179    type FT = FloatElem<TestAutodiffBackend>;
180
181    const LEARNING_RATE: LearningRate = 0.01;
182
183    #[test]
184    fn test_adamw_optimizer_save_load_state() {
185        let device = Default::default();
186        let linear = nn::LinearConfig::new(6, 6).init(&device);
187        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
188        let mut optimizer = create_adamw();
189        let grads = linear.forward(x).backward();
190        let grads = GradientsParams::from_grads(grads, &linear);
191        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
192
193        #[cfg(feature = "std")]
194        {
195            use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
196
197            BinFileRecorder::<FullPrecisionSettings>::default()
198                .record(
199                    optimizer.to_record(),
200                    std::env::temp_dir().as_path().join("test_optim_adamw"),
201                )
202                .unwrap();
203        }
204        #[cfg(not(feature = "std"))]
205        {
206            use crate::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
207
208            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
209                .record(optimizer.to_record(), ())
210                .unwrap();
211            assert!(!result.is_empty());
212        }
213
214        let state_optim_before = optimizer.to_record();
215        let state_optim_before_copy = optimizer.to_record();
216        let optimizer = create_adamw();
217        let optimizer = optimizer.load_record(state_optim_before_copy);
218        let state_optim_after = optimizer.to_record();
219
220        assert_eq!(state_optim_before.len(), state_optim_after.len());
221    }
222
223    #[test]
224    fn test_adamw_optimizer_with_numbers() {
225        let linear = given_linear_layer(
226            TensorData::from([
227                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
228                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
229                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
230                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
231                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
232                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
233            ]),
234            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
235        );
236        let device = Default::default();
237        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
238            [
239                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
240                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
241            ],
242            &device,
243        )
244        .require_grad();
245        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
246            [
247                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
248                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
249            ],
250            &device,
251        )
252        .require_grad();
253
254        let mut optimizer = AdamWConfig::new()
255            .with_epsilon(1e-8)
256            .with_beta_1(0.9)
257            .with_beta_2(0.999)
258            .with_weight_decay(0.5)
259            .init();
260
261        let grads = linear.forward(x_1).backward();
262        let grads = GradientsParams::from_grads(grads, &linear);
263        let linear = optimizer.step(LEARNING_RATE, linear, grads);
264
265        let grads = linear.forward(x_2).backward();
266        let grads = GradientsParams::from_grads(grads, &linear);
267        let linear = optimizer.step(LEARNING_RATE, linear, grads);
268
269        let state_updated = linear.into_record();
270        let weights_expected = TensorData::from([
271            [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
272            [
273                0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
274            ],
275            [
276                -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
277            ],
278            [
279                -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
280            ],
281            [
282                0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
283            ],
284            [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
285        ]);
286        let bias_expected = TensorData::from([
287            -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
288        ]);
289
290        let (weight_updated, bias_updated) = (
291            state_updated.weight.to_data(),
292            state_updated.bias.unwrap().to_data(),
293        );
294
295        let tolerance = Tolerance::absolute(1e-2);
296        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
297        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
298    }
299
300    #[test]
301    fn test_adam_optimizer_no_nan() {
302        let linear = given_linear_layer(
303            TensorData::from([
304                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
305                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
306                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
307                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
308                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
309                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
310            ]),
311            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
312        );
313
314        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
315            [
316                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
317                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
318            ],
319            &Default::default(),
320        )
321        .require_grad();
322
323        let mut optimizer = AdamWConfig::new()
324            .with_epsilon(1e-8)
325            .with_beta_1(0.9)
326            .with_beta_2(0.999)
327            .with_weight_decay(0.5)
328            .init();
329
330        let grads = linear.forward(x.clone()).backward();
331        let grads = GradientsParams::from_grads(grads, &linear);
332        let linear = optimizer.step(LEARNING_RATE, linear, grads);
333
334        let grads = linear.forward(x).backward();
335        let grads = GradientsParams::from_grads(grads, &linear);
336        let linear = optimizer.step(LEARNING_RATE, linear, grads);
337
338        let state_updated = linear.into_record();
339        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
340    }
341
342    fn given_linear_layer(weight: TensorData, bias: TensorData) -> nn::Linear<TestAutodiffBackend> {
343        let device = Default::default();
344        let record = nn::LinearRecord {
345            weight: Param::from_data(weight, &device),
346            bias: Some(Param::from_data(bias, &device)),
347        };
348
349        nn::LinearConfig::new(6, 6)
350            .init(&device)
351            .load_record(record)
352    }
353
354    fn create_adamw()
355    -> OptimizerAdaptor<AdamW, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend> {
356        let config = AdamWConfig::new();
357        AdamW {
358            momentum: AdaptiveMomentumW {
359                beta_1: config.beta_1,
360                beta_2: config.beta_2,
361                epsilon: config.epsilon,
362            },
363            weight_decay: config.weight_decay,
364        }
365        .into()
366    }
367}