burn_core/optim/
adam.rs

1use crate::{
2    self as burn, LearningRate, grad_clipping::GradientClippingConfig, module::AutodiffModule,
3    record::Record,
4};
5
6use super::{
7    SimpleOptimizer,
8    decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::config::Config;
11use crate::optim::adaptor::OptimizerAdaptor;
12use crate::tensor::{Tensor, backend::AutodiffBackend};
13use burn_tensor::{backend::Backend, ops::Device};
14
15#[cfg(not(feature = "std"))]
16use num_traits::Float;
17
18/// Adam configuration.
19#[derive(Config)]
20pub struct AdamConfig {
21    /// Parameter for Adam.
22    #[config(default = 0.9)]
23    beta_1: f32,
24    /// Parameter for Adam.
25    #[config(default = 0.999)]
26    beta_2: f32,
27    /// A value required for numerical stability.
28    #[config(default = 1e-5)]
29    epsilon: f32,
30    /// [Weight decay](WeightDecayConfig) config.
31    weight_decay: Option<WeightDecayConfig>,
32    /// [Gradient Clipping](GradientClippingConfig) config.
33    grad_clipping: Option<GradientClippingConfig>,
34}
35
36/// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
37#[derive(Clone)]
38pub struct Adam {
39    momentum: AdaptiveMomentum,
40    weight_decay: Option<WeightDecay>,
41}
42
43/// Adam state.
44#[derive(Record, Clone, new)]
45pub struct AdamState<B: Backend, const D: usize> {
46    /// The current adaptive momentum.
47    pub momentum: AdaptiveMomentumState<B, D>,
48}
49
50impl<B: Backend> SimpleOptimizer<B> for Adam {
51    type State<const D: usize> = AdamState<B, D>;
52
53    fn step<const D: usize>(
54        &self,
55        lr: LearningRate,
56        tensor: Tensor<B, D>,
57        mut grad: Tensor<B, D>,
58        state: Option<Self::State<D>>,
59    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
60        let mut state_momentum = None;
61
62        if let Some(state) = state {
63            state_momentum = Some(state.momentum);
64        }
65
66        if let Some(weight_decay) = &self.weight_decay {
67            grad = weight_decay.transform(grad, tensor.clone());
68        }
69
70        let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
71
72        let state = AdamState::new(state_momentum);
73        let delta = grad.mul_scalar(lr);
74
75        (tensor - delta, Some(state))
76    }
77
78    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
79        state.momentum = state.momentum.to_device(device);
80        state
81    }
82}
83
84impl AdamConfig {
85    /// Initialize Adam optimizer.
86    ///
87    /// # Returns
88    ///
89    /// Returns an optimizer that can be used to optimize a module.
90    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
91        let optim = Adam {
92            momentum: AdaptiveMomentum {
93                beta_1: self.beta_1,
94                beta_2: self.beta_2,
95                epsilon: self.epsilon,
96            },
97            weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
98        };
99
100        let mut optim = OptimizerAdaptor::from(optim);
101        if let Some(config) = &self.grad_clipping {
102            optim = optim.with_grad_clipping(config.init());
103        }
104        optim
105    }
106}
107
108/// Adaptive momentum state.
109#[derive(Record, new, Clone)]
110pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
111    /// The number of iterations aggregated.
112    pub time: usize,
113    /// The first order momentum.
114    pub moment_1: Tensor<B, D>,
115    /// The second order momentum.
116    pub moment_2: Tensor<B, D>,
117}
118
119#[derive(Clone)]
120struct AdaptiveMomentum {
121    beta_1: f32,
122    beta_2: f32,
123    epsilon: f32,
124}
125
126impl AdaptiveMomentum {
127    pub fn transform<B: Backend, const D: usize>(
128        &self,
129        grad: Tensor<B, D>,
130        momentum_state: Option<AdaptiveMomentumState<B, D>>,
131    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
132        let state = if let Some(mut state) = momentum_state {
133            let factor = 1.0 - self.beta_1;
134            state.moment_1 = state
135                .moment_1
136                .mul_scalar(self.beta_1)
137                .add(grad.clone().mul_scalar(factor));
138
139            let factor = 1.0 - self.beta_2;
140            state.moment_2 = state
141                .moment_2
142                .mul_scalar(self.beta_2)
143                .add(grad.powi_scalar(2).mul_scalar(factor));
144
145            state.time += 1;
146
147            state
148        } else {
149            let factor = 1.0 - self.beta_1;
150            let moment_1 = grad.clone().mul_scalar(factor);
151
152            let factor = 1.0 - self.beta_2;
153            let moment_2 = grad.powi_scalar(2).mul_scalar(factor);
154
155            AdaptiveMomentumState::new(1, moment_1, moment_2)
156        };
157
158        let time = state.time as i32;
159        let moment_1_corrected = state
160            .moment_1
161            .clone()
162            .div_scalar(1f32 - self.beta_1.powi(time));
163        let moment_2_corrected = state
164            .moment_2
165            .clone()
166            .div_scalar(1f32 - self.beta_2.powi(time));
167
168        let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
169
170        (grad, state)
171    }
172}
173
174impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
175    /// Move state to device.
176    ///
177    /// # Arguments
178    ///
179    /// * `device` - Device to move state to.
180    ///
181    /// # Returns
182    ///
183    /// Returns state moved to device.
184    pub fn to_device(mut self, device: &B::Device) -> Self {
185        self.moment_1 = self.moment_1.to_device(device);
186        self.moment_2 = self.moment_2.to_device(device);
187        self
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use burn_tensor::Tolerance;
194    use burn_tensor::ops::FloatElem;
195
196    use super::*;
197    use crate::module::{Module, Param};
198    use crate::optim::{GradientsParams, Optimizer};
199    use crate::tensor::{Distribution, Tensor, TensorData};
200    use crate::{TestAutodiffBackend, nn};
201
202    const LEARNING_RATE: LearningRate = 0.01;
203
204    #[test]
205    fn test_adam_optimizer_save_load_state() {
206        let device = Default::default();
207        let linear = nn::LinearConfig::new(6, 6).init(&device);
208        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
209        let mut optimizer = create_adam();
210        let grads = linear.forward(x).backward();
211        let grads = GradientsParams::from_grads(grads, &linear);
212        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
213
214        #[cfg(feature = "std")]
215        {
216            use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
217
218            BinFileRecorder::<FullPrecisionSettings>::default()
219                .record(
220                    optimizer.to_record(),
221                    std::env::temp_dir().as_path().join("test_optim_adam"),
222                )
223                .unwrap();
224        }
225        #[cfg(not(feature = "std"))]
226        {
227            use crate::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
228
229            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
230                .record(optimizer.to_record(), ())
231                .unwrap();
232            assert!(!result.is_empty());
233        }
234
235        let state_optim_before = optimizer.to_record();
236        let state_optim_before_copy = optimizer.to_record();
237        let optimizer = create_adam();
238        let optimizer = optimizer.load_record(state_optim_before_copy);
239        let state_optim_after = optimizer.to_record();
240
241        assert_eq!(state_optim_before.len(), state_optim_after.len());
242    }
243
244    #[test]
245    fn test_adam_optimizer_with_numbers() {
246        let device = Default::default();
247        let linear = given_linear_layer(
248            TensorData::from([
249                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
250                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
251                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
252                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
253                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
254                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
255            ]),
256            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
257        );
258        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
259            [
260                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
261                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
262            ],
263            &device,
264        )
265        .require_grad();
266        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
267            [
268                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
269                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
270            ],
271            &device,
272        )
273        .require_grad();
274
275        let mut optimizer = AdamConfig::new()
276            .with_epsilon(1e-8)
277            .with_beta_1(0.9)
278            .with_beta_2(0.999)
279            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
280            .init();
281
282        let grads = linear.forward(x_1).backward();
283        let grads = GradientsParams::from_grads(grads, &linear);
284        let linear = optimizer.step(LEARNING_RATE, linear, grads);
285
286        let grads = linear.forward(x_2).backward();
287        let grads = GradientsParams::from_grads(grads, &linear);
288        let linear = optimizer.step(LEARNING_RATE, linear, grads);
289
290        let state_updated = linear.into_record();
291        let weights_expected = TensorData::from([
292            [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
293            [
294                0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
295            ],
296            [
297                -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
298            ],
299            [
300                -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
301            ],
302            [
303                0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
304            ],
305            [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
306        ]);
307        let bias_expected = TensorData::from([
308            -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
309        ]);
310
311        let (weight_updated, bias_updated) = (
312            state_updated.weight.to_data(),
313            state_updated.bias.unwrap().to_data(),
314        );
315
316        type FT = FloatElem<TestAutodiffBackend>;
317        let tolerance = Tolerance::absolute(1e-2);
318        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
319        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
320    }
321
322    #[test]
323    fn test_adam_optimizer_no_nan() {
324        let linear = given_linear_layer(
325            TensorData::from([
326                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
327                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
328                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
329                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
330                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
331                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
332            ]),
333            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
334        );
335
336        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
337            [
338                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
339                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
340            ],
341            &Default::default(),
342        )
343        .require_grad();
344
345        let mut optimizer = AdamConfig::new()
346            .with_epsilon(1e-8)
347            .with_beta_1(0.9)
348            .with_beta_2(0.999)
349            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
350            .init();
351
352        let grads = linear.forward(x.clone()).backward();
353        let grads = GradientsParams::from_grads(grads, &linear);
354        let linear = optimizer.step(LEARNING_RATE, linear, grads);
355
356        let grads = linear.forward(x).backward();
357        let grads = GradientsParams::from_grads(grads, &linear);
358        let linear = optimizer.step(LEARNING_RATE, linear, grads);
359
360        let state_updated = linear.into_record();
361        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
362    }
363
364    fn given_linear_layer(weight: TensorData, bias: TensorData) -> nn::Linear<TestAutodiffBackend> {
365        let device = Default::default();
366        let record = nn::LinearRecord {
367            weight: Param::from_data(weight, &device),
368            bias: Some(Param::from_data(bias, &device)),
369        };
370
371        nn::LinearConfig::new(6, 6)
372            .init(&device)
373            .load_record(record)
374    }
375
376    fn create_adam() -> OptimizerAdaptor<Adam, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
377    {
378        let config = AdamConfig::new();
379        Adam {
380            momentum: AdaptiveMomentum {
381                beta_1: config.beta_1,
382                beta_2: config.beta_2,
383                epsilon: config.epsilon,
384            },
385            weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
386        }
387        .into()
388    }
389}