Skip to main content

burn_optim/optim/
adam.rs

1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8
9use super::{
10    SimpleOptimizer,
11    adaptor::OptimizerAdaptor,
12    decay::{WeightDecay, WeightDecayConfig},
13};
14use crate::{LearningRate, grad_clipping::GradientClippingConfig};
15
16#[cfg(not(feature = "std"))]
17#[allow(unused_imports)]
18use num_traits::Float as _;
19
20/// Adam configuration.
21#[derive(Config, Debug)]
22pub struct AdamConfig {
23    /// Parameter for Adam.
24    #[config(default = 0.9)]
25    beta_1: f32,
26    /// Parameter for Adam.
27    #[config(default = 0.999)]
28    beta_2: f32,
29    /// A value required for numerical stability.
30    #[config(default = 1e-5)]
31    epsilon: f32,
32    /// Whether to use AMSGrad algorithm
33    #[config(default = false)]
34    amsgrad: bool,
35    /// [Weight decay](WeightDecayConfig) config.
36    weight_decay: Option<WeightDecayConfig>,
37    /// [Gradient Clipping](GradientClippingConfig) config.
38    grad_clipping: Option<GradientClippingConfig>,
39}
40
41/// Adam optimizer.
42///
43/// See:
44/// - [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
45/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
46#[derive(Clone)]
47pub struct Adam {
48    momentum: AdaptiveMomentum,
49    weight_decay: Option<WeightDecay>,
50}
51
52/// Adam state.
53#[derive(Record, Clone, new)]
54pub struct AdamState<B: Backend, const D: usize> {
55    /// The current adaptive momentum.
56    pub momentum: AdaptiveMomentumState<B, D>,
57}
58
59impl<B: Backend> SimpleOptimizer<B> for Adam {
60    type State<const D: usize> = AdamState<B, D>;
61
62    fn step<const D: usize>(
63        &self,
64        lr: LearningRate,
65        tensor: Tensor<B, D>,
66        mut grad: Tensor<B, D>,
67        state: Option<Self::State<D>>,
68    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
69        let mut state_momentum = None;
70
71        if let Some(state) = state {
72            state_momentum = Some(state.momentum);
73        }
74
75        if let Some(weight_decay) = &self.weight_decay {
76            grad = weight_decay.transform(grad, tensor.clone());
77        }
78
79        let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
80
81        let state = AdamState::new(state_momentum);
82        let delta = grad.mul_scalar(lr);
83
84        (tensor - delta, Some(state))
85    }
86
87    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
88        state.momentum = state.momentum.to_device(device);
89        state
90    }
91}
92
93impl AdamConfig {
94    /// Build an [`Adam`] from the config.
95    pub fn build(&self) -> Adam {
96        Adam {
97            momentum: AdaptiveMomentum {
98                beta_1: self.beta_1,
99                beta_2: self.beta_2,
100                epsilon: self.epsilon,
101                amsgrad: self.amsgrad,
102            },
103            weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
104        }
105    }
106
107    /// Initialize Adam optimizer.
108    ///
109    /// # Returns
110    ///
111    /// Returns an optimizer that can be used to optimize a module.
112    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
113        let mut optim = OptimizerAdaptor::from(self.build());
114        if let Some(config) = &self.grad_clipping {
115            optim = optim.with_grad_clipping(config.init());
116        }
117        optim
118    }
119}
120
121/// Adaptive momentum state.
122#[derive(Record, new, Clone)]
123pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
124    /// The number of iterations aggregated.
125    pub time: usize,
126    /// The first order momentum.
127    pub moment_1: Tensor<B, D>,
128    /// The second order momentum.
129    pub moment_2: Tensor<B, D>,
130    /// Max of second  order momentum (for AMSGrad)
131    #[new(default)]
132    pub max_moment_2: Option<Tensor<B, D>>,
133}
134
135#[derive(Clone)]
136struct AdaptiveMomentum {
137    beta_1: f32,
138    beta_2: f32,
139    epsilon: f32,
140    amsgrad: bool,
141}
142
143impl AdaptiveMomentum {
144    pub fn transform<B: Backend, const D: usize>(
145        &self,
146        grad: Tensor<B, D>,
147        momentum_state: Option<AdaptiveMomentumState<B, D>>,
148    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
149        let state = if let Some(mut state) = momentum_state {
150            let factor = 1.0 - self.beta_1;
151            state.moment_1 = state
152                .moment_1
153                .mul_scalar(self.beta_1)
154                .add(grad.clone().mul_scalar(factor));
155
156            let factor = 1.0 - self.beta_2;
157            state.moment_2 = state
158                .moment_2
159                .mul_scalar(self.beta_2)
160                .add(grad.square().mul_scalar(factor));
161            if self.amsgrad {
162                let max_v = state
163                    .max_moment_2
164                    .take()
165                    .unwrap_or_else(|| state.moment_2.clone());
166
167                let new_max = max_v.max_pair(state.moment_2.clone());
168                state.max_moment_2 = Some(new_max);
169            }
170
171            state.time += 1;
172
173            state
174        } else {
175            let factor = 1.0 - self.beta_1;
176            let moment_1 = grad.clone().mul_scalar(factor);
177
178            let factor = 1.0 - self.beta_2;
179            let moment_2 = grad.square().mul_scalar(factor);
180            let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
181            AdaptiveMomentumState {
182                time: 1,
183                moment_1,
184                moment_2,
185                max_moment_2,
186            }
187        };
188
189        let time = state.time as i32;
190        let bias_correction2_sqrt = (1.0 - self.beta_2.powi(time)).sqrt();
191        let combined_factor = bias_correction2_sqrt / (1.0 - self.beta_1.powi(time));
192
193        let v_to_use = if self.amsgrad {
194            state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
195        } else {
196            &state.moment_2
197        };
198
199        let grad = state.moment_1.clone().mul_scalar(combined_factor).div(
200            v_to_use
201                .clone()
202                .sqrt()
203                .add_scalar(self.epsilon * bias_correction2_sqrt),
204        );
205        (grad, state)
206    }
207}
208
209impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
210    /// Move state to device.
211    ///
212    /// # Arguments
213    ///
214    /// * `device` - Device to move state to.
215    ///
216    /// # Returns
217    ///
218    /// Returns state moved to device.
219    pub fn to_device(mut self, device: &B::Device) -> Self {
220        self.moment_1 = self.moment_1.to_device(device);
221        self.moment_2 = self.moment_2.to_device(device);
222        self.max_moment_2 = self.max_moment_2.map(|tensor| tensor.to_device(device));
223        self
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use burn::tensor::Tolerance;
230    use burn::tensor::ops::FloatElem;
231
232    use super::*;
233    use crate::TestAutodiffBackend;
234    use crate::{GradientsParams, Optimizer};
235    use burn::module::{Module, Param};
236    use burn::tensor::{Distribution, Tensor, TensorData};
237    use burn_nn::{Linear, LinearConfig, LinearRecord};
238
239    const LEARNING_RATE: LearningRate = 0.01;
240
241    #[test]
242    fn test_adam_optimizer_save_load_state() {
243        let device = Default::default();
244        let linear = LinearConfig::new(6, 6).init(&device);
245        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
246        let mut optimizer = create_adam();
247        let grads = linear.forward(x).backward();
248        let grads = GradientsParams::from_grads(grads, &linear);
249        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
250
251        #[cfg(feature = "std")]
252        {
253            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
254
255            BinFileRecorder::<FullPrecisionSettings>::default()
256                .record(
257                    optimizer.to_record(),
258                    std::env::temp_dir().as_path().join("test_optim_adam"),
259                )
260                .unwrap();
261        }
262        #[cfg(not(feature = "std"))]
263        {
264            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
265
266            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
267                .record(optimizer.to_record(), ())
268                .unwrap();
269            assert!(!result.is_empty());
270        }
271
272        let state_optim_before = optimizer.to_record();
273        let state_optim_before_copy = optimizer.to_record();
274        let optimizer = create_adam();
275        let optimizer = optimizer.load_record(state_optim_before_copy);
276        let state_optim_after = optimizer.to_record();
277
278        assert_eq!(state_optim_before.len(), state_optim_after.len());
279    }
280    #[test]
281    fn test_adam_optimizer_with_amsgrad_50_steps() {
282        let device = Default::default();
283        let mut linear = given_linear_layer(
284            TensorData::from([
285                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
286                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
287                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
288                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
289                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
290                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
291            ]),
292            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
293        );
294
295        let mut optimizer = AdamConfig::new()
296            .with_epsilon(1e-8)
297            .with_beta_1(0.9)
298            .with_beta_2(0.999)
299            .with_amsgrad(true)
300            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
301            .init();
302
303        for i in 1..=50 {
304            let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
305                .mul_scalar(i as f32 * 0.1)
306                .require_grad();
307
308            let grads = linear.forward(x).backward();
309            let grads = GradientsParams::from_grads(grads, &linear);
310            linear = optimizer.step(LEARNING_RATE, linear, grads);
311        }
312
313        let state_updated = linear.into_record();
314        let weight_updated = state_updated.weight.to_data();
315        let bias_updated = state_updated.bias.unwrap().to_data();
316
317        let weights_expected = TensorData::from([
318            [
319                -0.9125810265541077,
320                -0.45855265855789185,
321                -0.1915993094444275,
322                -0.2759990692138672,
323                -0.5099529027938843,
324                -0.5287043452262878,
325            ],
326            [
327                -0.5181325674057007,
328                -0.6139854788780212,
329                -0.9574727416038513,
330                -0.34102925658226013,
331                -0.400514155626297,
332                -0.8847861886024475,
333            ],
334            [
335                -0.614483118057251,
336                -0.5611032247543335,
337                -0.8887064456939697,
338                -0.34762972593307495,
339                -0.8708556890487671,
340                -0.2830044627189636,
341            ],
342            [
343                -0.8904699683189392,
344                -0.8151527643203735,
345                -0.9621278643608093,
346                -0.8905676603317261,
347                -0.671261191368103,
348                -0.4333854615688324,
349            ],
350            [
351                -0.26599061489105225,
352                -0.8119961023330688,
353                -0.22424538433551788,
354                -0.7672406435012817,
355                -0.2163349837064743,
356                -0.6258266568183899,
357            ],
358            [
359                -0.611397922039032,
360                -0.6075160503387451,
361                -0.4701341986656189,
362                -0.4039117991924286,
363                -0.5663845539093018,
364                -0.21262989938259125,
365            ],
366        ]);
367        let bias_expected = TensorData::from([
368            -0.8817203044891357,
369            -0.4038999378681183,
370            -0.5889149308204651,
371            -0.37475723028182983,
372            -0.3557940721511841,
373            -0.47914788126945496,
374        ]);
375
376        type FT = FloatElem<TestAutodiffBackend>;
377        let tolerance = Tolerance::absolute(1e-5);
378        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
379        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
380    }
381    #[test]
382    fn test_adam_optimizer_with_numbers() {
383        let device = Default::default();
384        let linear = given_linear_layer(
385            TensorData::from([
386                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
387                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
388                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
389                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
390                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
391                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
392            ]),
393            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
394        );
395        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
396            [
397                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
398                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
399            ],
400            &device,
401        )
402        .require_grad();
403        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
404            [
405                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
406                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
407            ],
408            &device,
409        )
410        .require_grad();
411
412        let mut optimizer = AdamConfig::new()
413            .with_epsilon(1e-8)
414            .with_beta_1(0.9)
415            .with_beta_2(0.999)
416            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
417            .init();
418
419        let grads = linear.forward(x_1).backward();
420        let grads = GradientsParams::from_grads(grads, &linear);
421        let linear = optimizer.step(LEARNING_RATE, linear, grads);
422
423        let grads = linear.forward(x_2).backward();
424        let grads = GradientsParams::from_grads(grads, &linear);
425        let linear = optimizer.step(LEARNING_RATE, linear, grads);
426
427        let state_updated = linear.into_record();
428        let weights_expected = TensorData::from([
429            [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
430            [
431                0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
432            ],
433            [
434                -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
435            ],
436            [
437                -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
438            ],
439            [
440                0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
441            ],
442            [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
443        ]);
444        let bias_expected = TensorData::from([
445            -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
446        ]);
447
448        let (weight_updated, bias_updated) = (
449            state_updated.weight.to_data(),
450            state_updated.bias.unwrap().to_data(),
451        );
452
453        type FT = FloatElem<TestAutodiffBackend>;
454        let tolerance = Tolerance::absolute(1e-2);
455        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
456        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
457    }
458
459    #[test]
460    fn test_adam_optimizer_no_nan() {
461        let linear = given_linear_layer(
462            TensorData::from([
463                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
464                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
465                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
466                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
467                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
468                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
469            ]),
470            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
471        );
472
473        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
474            [
475                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
476                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
477            ],
478            &Default::default(),
479        )
480        .require_grad();
481
482        let mut optimizer = AdamConfig::new()
483            .with_epsilon(1e-8)
484            .with_beta_1(0.9)
485            .with_beta_2(0.999)
486            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
487            .init();
488
489        let grads = linear.forward(x.clone()).backward();
490        let grads = GradientsParams::from_grads(grads, &linear);
491        let linear = optimizer.step(LEARNING_RATE, linear, grads);
492
493        let grads = linear.forward(x).backward();
494        let grads = GradientsParams::from_grads(grads, &linear);
495        let linear = optimizer.step(LEARNING_RATE, linear, grads);
496
497        let state_updated = linear.into_record();
498        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
499    }
500
501    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
502        let device = Default::default();
503        let record = LinearRecord {
504            weight: Param::from_data(weight, &device),
505            bias: Some(Param::from_data(bias, &device)),
506        };
507
508        LinearConfig::new(6, 6).init(&device).load_record(record)
509    }
510
511    fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
512        let config = AdamConfig::new();
513        Adam {
514            momentum: AdaptiveMomentum {
515                beta_1: config.beta_1,
516                beta_2: config.beta_2,
517                epsilon: config.epsilon,
518                amsgrad: config.amsgrad,
519            },
520            weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
521        }
522        .into()
523    }
524}