burn_optim/optim/
adam.rs

1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8
9use super::{
10    SimpleOptimizer,
11    adaptor::OptimizerAdaptor,
12    decay::{WeightDecay, WeightDecayConfig},
13};
14use crate::{LearningRate, grad_clipping::GradientClippingConfig};
15
16#[cfg(not(feature = "std"))]
17#[allow(unused_imports)]
18use num_traits::Float as _;
19
20/// Adam configuration.
21#[derive(Config, Debug)]
22pub struct AdamConfig {
23    /// Parameter for Adam.
24    #[config(default = 0.9)]
25    beta_1: f32,
26    /// Parameter for Adam.
27    #[config(default = 0.999)]
28    beta_2: f32,
29    /// A value required for numerical stability.
30    #[config(default = 1e-5)]
31    epsilon: f32,
32    /// [Weight decay](WeightDecayConfig) config.
33    weight_decay: Option<WeightDecayConfig>,
34    /// [Gradient Clipping](GradientClippingConfig) config.
35    grad_clipping: Option<GradientClippingConfig>,
36}
37
38/// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
39#[derive(Clone)]
40pub struct Adam {
41    momentum: AdaptiveMomentum,
42    weight_decay: Option<WeightDecay>,
43}
44
45/// Adam state.
46#[derive(Record, Clone, new)]
47pub struct AdamState<B: Backend, const D: usize> {
48    /// The current adaptive momentum.
49    pub momentum: AdaptiveMomentumState<B, D>,
50}
51
52impl<B: Backend> SimpleOptimizer<B> for Adam {
53    type State<const D: usize> = AdamState<B, D>;
54
55    fn step<const D: usize>(
56        &self,
57        lr: LearningRate,
58        tensor: Tensor<B, D>,
59        mut grad: Tensor<B, D>,
60        state: Option<Self::State<D>>,
61    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
62        let mut state_momentum = None;
63
64        if let Some(state) = state {
65            state_momentum = Some(state.momentum);
66        }
67
68        if let Some(weight_decay) = &self.weight_decay {
69            grad = weight_decay.transform(grad, tensor.clone());
70        }
71
72        let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
73
74        let state = AdamState::new(state_momentum);
75        let delta = grad.mul_scalar(lr);
76
77        (tensor - delta, Some(state))
78    }
79
80    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
81        state.momentum = state.momentum.to_device(device);
82        state
83    }
84}
85
86impl AdamConfig {
87    /// Initialize Adam optimizer.
88    ///
89    /// # Returns
90    ///
91    /// Returns an optimizer that can be used to optimize a module.
92    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
93        let optim = Adam {
94            momentum: AdaptiveMomentum {
95                beta_1: self.beta_1,
96                beta_2: self.beta_2,
97                epsilon: self.epsilon,
98            },
99            weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
100        };
101
102        let mut optim = OptimizerAdaptor::from(optim);
103        if let Some(config) = &self.grad_clipping {
104            optim = optim.with_grad_clipping(config.init());
105        }
106        optim
107    }
108}
109
110/// Adaptive momentum state.
111#[derive(Record, new, Clone)]
112pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
113    /// The number of iterations aggregated.
114    pub time: usize,
115    /// The first order momentum.
116    pub moment_1: Tensor<B, D>,
117    /// The second order momentum.
118    pub moment_2: Tensor<B, D>,
119}
120
121#[derive(Clone)]
122struct AdaptiveMomentum {
123    beta_1: f32,
124    beta_2: f32,
125    epsilon: f32,
126}
127
128impl AdaptiveMomentum {
129    pub fn transform<B: Backend, const D: usize>(
130        &self,
131        grad: Tensor<B, D>,
132        momentum_state: Option<AdaptiveMomentumState<B, D>>,
133    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
134        let state = if let Some(mut state) = momentum_state {
135            let factor = 1.0 - self.beta_1;
136            state.moment_1 = state
137                .moment_1
138                .mul_scalar(self.beta_1)
139                .add(grad.clone().mul_scalar(factor));
140
141            let factor = 1.0 - self.beta_2;
142            state.moment_2 = state
143                .moment_2
144                .mul_scalar(self.beta_2)
145                .add(grad.powi_scalar(2).mul_scalar(factor));
146
147            state.time += 1;
148
149            state
150        } else {
151            let factor = 1.0 - self.beta_1;
152            let moment_1 = grad.clone().mul_scalar(factor);
153
154            let factor = 1.0 - self.beta_2;
155            let moment_2 = grad.powi_scalar(2).mul_scalar(factor);
156
157            AdaptiveMomentumState::new(1, moment_1, moment_2)
158        };
159
160        let time = state.time as i32;
161        let moment_1_corrected = state
162            .moment_1
163            .clone()
164            .div_scalar(1f32 - self.beta_1.powi(time));
165        let moment_2_corrected = state
166            .moment_2
167            .clone()
168            .div_scalar(1f32 - self.beta_2.powi(time));
169
170        let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
171
172        (grad, state)
173    }
174}
175
176impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
177    /// Move state to device.
178    ///
179    /// # Arguments
180    ///
181    /// * `device` - Device to move state to.
182    ///
183    /// # Returns
184    ///
185    /// Returns state moved to device.
186    pub fn to_device(mut self, device: &B::Device) -> Self {
187        self.moment_1 = self.moment_1.to_device(device);
188        self.moment_2 = self.moment_2.to_device(device);
189        self
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use burn::tensor::Tolerance;
196    use burn::tensor::ops::FloatElem;
197
198    use super::*;
199    use crate::TestAutodiffBackend;
200    use crate::{GradientsParams, Optimizer};
201    use burn::module::{Module, Param};
202    use burn::tensor::{Distribution, Tensor, TensorData};
203    use burn_nn::{Linear, LinearConfig, LinearRecord};
204
205    const LEARNING_RATE: LearningRate = 0.01;
206
207    #[test]
208    fn test_adam_optimizer_save_load_state() {
209        let device = Default::default();
210        let linear = LinearConfig::new(6, 6).init(&device);
211        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
212        let mut optimizer = create_adam();
213        let grads = linear.forward(x).backward();
214        let grads = GradientsParams::from_grads(grads, &linear);
215        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
216
217        #[cfg(feature = "std")]
218        {
219            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
220
221            BinFileRecorder::<FullPrecisionSettings>::default()
222                .record(
223                    optimizer.to_record(),
224                    std::env::temp_dir().as_path().join("test_optim_adam"),
225                )
226                .unwrap();
227        }
228        #[cfg(not(feature = "std"))]
229        {
230            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
231
232            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
233                .record(optimizer.to_record(), ())
234                .unwrap();
235            assert!(!result.is_empty());
236        }
237
238        let state_optim_before = optimizer.to_record();
239        let state_optim_before_copy = optimizer.to_record();
240        let optimizer = create_adam();
241        let optimizer = optimizer.load_record(state_optim_before_copy);
242        let state_optim_after = optimizer.to_record();
243
244        assert_eq!(state_optim_before.len(), state_optim_after.len());
245    }
246
247    #[test]
248    fn test_adam_optimizer_with_numbers() {
249        let device = Default::default();
250        let linear = given_linear_layer(
251            TensorData::from([
252                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
253                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
254                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
255                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
256                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
257                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
258            ]),
259            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
260        );
261        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
262            [
263                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
264                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
265            ],
266            &device,
267        )
268        .require_grad();
269        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
270            [
271                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
272                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
273            ],
274            &device,
275        )
276        .require_grad();
277
278        let mut optimizer = AdamConfig::new()
279            .with_epsilon(1e-8)
280            .with_beta_1(0.9)
281            .with_beta_2(0.999)
282            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
283            .init();
284
285        let grads = linear.forward(x_1).backward();
286        let grads = GradientsParams::from_grads(grads, &linear);
287        let linear = optimizer.step(LEARNING_RATE, linear, grads);
288
289        let grads = linear.forward(x_2).backward();
290        let grads = GradientsParams::from_grads(grads, &linear);
291        let linear = optimizer.step(LEARNING_RATE, linear, grads);
292
293        let state_updated = linear.into_record();
294        let weights_expected = TensorData::from([
295            [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
296            [
297                0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
298            ],
299            [
300                -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
301            ],
302            [
303                -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
304            ],
305            [
306                0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
307            ],
308            [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
309        ]);
310        let bias_expected = TensorData::from([
311            -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
312        ]);
313
314        let (weight_updated, bias_updated) = (
315            state_updated.weight.to_data(),
316            state_updated.bias.unwrap().to_data(),
317        );
318
319        type FT = FloatElem<TestAutodiffBackend>;
320        let tolerance = Tolerance::absolute(1e-2);
321        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
322        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
323    }
324
325    #[test]
326    fn test_adam_optimizer_no_nan() {
327        let linear = given_linear_layer(
328            TensorData::from([
329                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
330                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
331                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
332                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
333                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
334                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
335            ]),
336            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
337        );
338
339        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
340            [
341                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
342                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
343            ],
344            &Default::default(),
345        )
346        .require_grad();
347
348        let mut optimizer = AdamConfig::new()
349            .with_epsilon(1e-8)
350            .with_beta_1(0.9)
351            .with_beta_2(0.999)
352            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
353            .init();
354
355        let grads = linear.forward(x.clone()).backward();
356        let grads = GradientsParams::from_grads(grads, &linear);
357        let linear = optimizer.step(LEARNING_RATE, linear, grads);
358
359        let grads = linear.forward(x).backward();
360        let grads = GradientsParams::from_grads(grads, &linear);
361        let linear = optimizer.step(LEARNING_RATE, linear, grads);
362
363        let state_updated = linear.into_record();
364        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
365    }
366
367    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
368        let device = Default::default();
369        let record = LinearRecord {
370            weight: Param::from_data(weight, &device),
371            bias: Some(Param::from_data(bias, &device)),
372        };
373
374        LinearConfig::new(6, 6).init(&device).load_record(record)
375    }
376
377    fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
378        let config = AdamConfig::new();
379        Adam {
380            momentum: AdaptiveMomentum {
381                beta_1: config.beta_1,
382                beta_2: config.beta_2,
383                epsilon: config.epsilon,
384            },
385            weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
386        }
387        .into()
388    }
389}