Skip to main content

burn_optim/optim/
adan.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::{Tensor, backend::AutodiffBackend};
5use burn::tensor::{backend::Backend, ops::Device};
6use burn::{module::AutodiffModule, record::Record};
7
8use super::{SimpleOptimizer, adaptor::OptimizerAdaptor};
9use crate::{LearningRate, grad_clipping::GradientClippingConfig};
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15/// [`Adan`] Configuration.
16///
17/// See:
18/// - [Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models](https://arxiv.org/abs/2208.06677).
19#[derive(Config, Debug)]
20pub struct AdanConfig {
21    /// Parameter for the first moment.
22    #[config(default = 0.98)]
23    beta_1: f32,
24    /// Parameter for the gradient-difference momentum.
25    #[config(default = 0.92)]
26    beta_2: f32,
27    /// Parameter for the second moment.
28    #[config(default = 0.99)]
29    beta_3: f32,
30    /// A value required for numerical stability.
31    #[config(default = 1e-8)]
32    epsilon: f32,
33    /// Weight decay factor.
34    #[config(default = 0.0)]
35    weight_decay: f32,
36    /// Disable proximal weight decay and use the decoupled update instead.
37    #[config(default = false)]
38    no_prox: bool,
39    /// [Gradient Clipping](GradientClippingConfig) config.
40    grad_clipping: Option<GradientClippingConfig>,
41}
42
43/// Adan optimizer.
44///
45/// See:
46/// - [Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models](https://arxiv.org/abs/2208.06677).
47///
48/// Configured by [`AdanConfig`].
49#[derive(Clone)]
50pub struct Adan {
51    momentum: AdaptiveNesterovMomentum,
52    weight_decay: f32,
53    no_prox: bool,
54}
55
56/// Adan state.
57#[derive(Record, Clone, new)]
58pub struct AdanState<B: Backend, const D: usize> {
59    /// The current adaptive Nesterov momentum state.
60    pub momentum: AdaptiveNesterovMomentumState<B, D>,
61}
62
63impl<B: Backend> SimpleOptimizer<B> for Adan {
64    type State<const D: usize> = AdanState<B, D>;
65
66    fn step<const D: usize>(
67        &self,
68        lr: LearningRate,
69        tensor: Tensor<B, D>,
70        grad: Tensor<B, D>,
71        state: Option<Self::State<D>>,
72    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
73        let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
74
75        let decay_rate = lr * (self.weight_decay as f64);
76        let delta = raw_delta.mul_scalar(lr);
77
78        let tensor_updated = if self.no_prox {
79            if decay_rate == 0.0 {
80                tensor - delta
81            } else {
82                tensor.mul_scalar(1.0 - decay_rate) - delta
83            }
84        } else {
85            let updated = tensor - delta;
86            if decay_rate == 0.0 {
87                updated
88            } else {
89                updated.div_scalar(1.0 + decay_rate)
90            }
91        };
92
93        (tensor_updated, Some(AdanState::new(momentum_state)))
94    }
95
96    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
97        state.momentum = state.momentum.to_device(device);
98        state
99    }
100}
101
102impl AdanConfig {
103    /// Build an [`Adan`] from the config.
104    pub fn build(&self) -> Adan {
105        Adan {
106            momentum: AdaptiveNesterovMomentum {
107                beta_1: self.beta_1,
108                beta_2: self.beta_2,
109                beta_3: self.beta_3,
110                epsilon: self.epsilon,
111            },
112            weight_decay: self.weight_decay,
113            no_prox: self.no_prox,
114        }
115    }
116
117    /// Initialize Adan optimizer.
118    ///
119    /// # Returns
120    ///
121    /// Returns an optimizer that can be used to optimize a module.
122    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adan, M, B> {
123        let mut optim = OptimizerAdaptor::from(self.build());
124        if let Some(config) = &self.grad_clipping {
125            optim = optim.with_grad_clipping(config.init());
126        }
127        optim
128    }
129}
130
131/// Adaptive Nesterov momentum state.
132#[derive(Record, Clone, new)]
133pub struct AdaptiveNesterovMomentumState<B: Backend, const D: usize> {
134    /// The number of iterations aggregated.
135    pub time: usize,
136    /// The first order momentum.
137    pub exp_avg: Tensor<B, D>,
138    /// The gradient-difference weighted second order momentum.
139    pub exp_avg_sq: Tensor<B, D>,
140    /// The gradient-difference momentum.
141    pub exp_avg_diff: Tensor<B, D>,
142    /// The negated previous gradient.
143    pub neg_pre_grad: Tensor<B, D>,
144}
145
146#[derive(Clone)]
147struct AdaptiveNesterovMomentum {
148    beta_1: f32,
149    beta_2: f32,
150    beta_3: f32,
151    epsilon: f32,
152}
153
154impl AdaptiveNesterovMomentum {
155    pub fn transform<B: Backend, const D: usize>(
156        &self,
157        grad: Tensor<B, D>,
158        state: Option<AdaptiveNesterovMomentumState<B, D>>,
159    ) -> (Tensor<B, D>, AdaptiveNesterovMomentumState<B, D>) {
160        let state = if let Some(mut state) = state {
161            let grad_diff = state.neg_pre_grad.clone().add(grad.clone());
162            let grad_diff_sq = grad_diff
163                .clone()
164                .mul_scalar(self.beta_2)
165                .add(grad.clone())
166                .square();
167
168            state.exp_avg = state
169                .exp_avg
170                .mul_scalar(self.beta_1)
171                .add(grad.clone().mul_scalar(1.0 - self.beta_1));
172            state.exp_avg_diff = state
173                .exp_avg_diff
174                .mul_scalar(self.beta_2)
175                .add(grad_diff.mul_scalar(1.0 - self.beta_2));
176            state.exp_avg_sq = state
177                .exp_avg_sq
178                .mul_scalar(self.beta_3)
179                .add(grad_diff_sq.mul_scalar(1.0 - self.beta_3));
180            state.neg_pre_grad = grad.mul_scalar(-1.0);
181            state.time += 1;
182            state
183        } else {
184            AdaptiveNesterovMomentumState::new(
185                1,
186                grad.clone().mul_scalar(1.0 - self.beta_1),
187                grad.clone().square().mul_scalar(1.0 - self.beta_3),
188                grad.zeros_like(),
189                grad.clone().mul_scalar(-1.0),
190            )
191        };
192
193        let time = state.time as i32;
194        let denom = state
195            .exp_avg_sq
196            .clone()
197            .sqrt()
198            .div_scalar((1.0 - self.beta_3.powi(time)).sqrt())
199            .add_scalar(self.epsilon);
200        let update = state
201            .exp_avg
202            .clone()
203            .div_scalar(1.0 - self.beta_1.powi(time))
204            .div(denom.clone())
205            .add(
206                state
207                    .exp_avg_diff
208                    .clone()
209                    .mul_scalar(self.beta_2)
210                    .div_scalar(1.0 - self.beta_2.powi(time))
211                    .div(denom),
212            );
213
214        (update, state)
215    }
216}
217
218impl<B: Backend, const D: usize> AdaptiveNesterovMomentumState<B, D> {
219    #[allow(clippy::wrong_self_convention)]
220    fn to_device(mut self, device: &B::Device) -> Self {
221        self.exp_avg = self.exp_avg.to_device(device);
222        self.exp_avg_sq = self.exp_avg_sq.to_device(device);
223        self.exp_avg_diff = self.exp_avg_diff.to_device(device);
224        self.neg_pre_grad = self.neg_pre_grad.to_device(device);
225        self
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::TestAutodiffBackend;
233    use crate::{GradientsParams, Optimizer};
234    use burn::module::{Module, Param};
235    use burn::tensor::{Distribution, Tensor, TensorData};
236    use burn::tensor::{Tolerance, ops::FloatElem};
237    use burn_nn::{Linear, LinearConfig, LinearRecord};
238
239    type FT = FloatElem<TestAutodiffBackend>;
240
241    const LEARNING_RATE: LearningRate = 0.01;
242
243    #[test]
244    fn test_adan_optimizer_save_load_state() {
245        let device = Default::default();
246        let linear = LinearConfig::new(6, 6).init(&device);
247        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
248        let mut optimizer = create_adan();
249        let grads = linear.forward(x).backward();
250        let grads = GradientsParams::from_grads(grads, &linear);
251        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
252
253        #[cfg(feature = "std")]
254        {
255            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
256
257            BinFileRecorder::<FullPrecisionSettings>::default()
258                .record(
259                    optimizer.to_record(),
260                    std::env::temp_dir().as_path().join("test_optim_adan"),
261                )
262                .unwrap();
263        }
264        #[cfg(not(feature = "std"))]
265        {
266            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
267
268            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
269                .record(optimizer.to_record(), ())
270                .unwrap();
271            assert!(!result.is_empty());
272        }
273
274        let state_optim_before = optimizer.to_record();
275        let state_optim_before_copy = optimizer.to_record();
276        let optimizer = create_adan();
277        let optimizer = optimizer.load_record(state_optim_before_copy);
278        let state_optim_after = optimizer.to_record();
279
280        assert_eq!(state_optim_before.len(), state_optim_after.len());
281    }
282
283    #[test]
284    fn test_adan_optimizer_with_numbers() {
285        let linear = given_linear_layer(
286            TensorData::from([
287                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
288                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
289                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
290                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
291                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
292                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
293            ]),
294            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
295        );
296        let device = Default::default();
297        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
298            [
299                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
300                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
301            ],
302            &device,
303        )
304        .require_grad();
305        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
306            [
307                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
308                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
309            ],
310            &device,
311        )
312        .require_grad();
313
314        let mut optimizer = AdanConfig::new()
315            .with_beta_1(0.98)
316            .with_beta_2(0.92)
317            .with_beta_3(0.99)
318            .with_epsilon(1e-8)
319            .with_weight_decay(0.02)
320            .init();
321
322        let grads = linear.forward(x_1).backward();
323        let grads = GradientsParams::from_grads(grads, &linear);
324        let linear = optimizer.step(LEARNING_RATE, linear, grads);
325
326        let grads = linear.forward(x_2).backward();
327        let grads = GradientsParams::from_grads(grads, &linear);
328        let linear = optimizer.step(LEARNING_RATE, linear, grads);
329
330        let state_updated = linear.into_record();
331        let weights_expected = TensorData::from([
332            [
333                -0.34034607,
334                0.11747075,
335                0.38426402,
336                0.29999772,
337                0.06599136,
338                0.04719888,
339            ],
340            [
341                0.0644293,
342                -0.031732224,
343                -0.37979296,
344                0.24165839,
345                0.18218218,
346                -0.30532277,
347            ],
348            [
349                -0.038910445,
350                0.01466812,
351                -0.31599957,
352                0.2283826,
353                -0.29780683,
354                0.2929568,
355            ],
356            [
357                -0.3178632,
358                -0.24129382,
359                -0.39133376,
360                -0.31796312,
361                -0.09605193,
362                0.14255258,
363            ],
364            [
365                0.31026322,
366                -0.23771758,
367                0.3519465,
368                -0.19243571,
369                0.35984334,
370                -0.049992695,
371            ],
372            [
373                -0.03577819,
374                -0.031879753,
375                0.10586514,
376                0.17213862,
377                0.009403733,
378                0.36326218,
379            ],
380        ]);
381        let bias_expected = TensorData::from([
382            -0.4103378,
383            0.06837065,
384            -0.116955206,
385            0.097558975,
386            0.11655137,
387            -0.006999196,
388        ]);
389
390        let (weight_updated, bias_updated) = (
391            state_updated.weight.to_data(),
392            state_updated.bias.unwrap().to_data(),
393        );
394
395        let tolerance = Tolerance::absolute(1e-5);
396        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
397        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
398    }
399
400    #[test]
401    fn test_adan_optimizer_no_nan() {
402        let linear = given_linear_layer(
403            TensorData::from([
404                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
405                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
406                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
407                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
408                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
409                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
410            ]),
411            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
412        );
413
414        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
415            [
416                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
417                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
418            ],
419            &Default::default(),
420        )
421        .require_grad();
422
423        let mut optimizer = AdanConfig::new()
424            .with_epsilon(1e-8)
425            .with_weight_decay(0.02)
426            .init();
427
428        let grads = linear.forward(x.clone()).backward();
429        let grads = GradientsParams::from_grads(grads, &linear);
430        let linear = optimizer.step(LEARNING_RATE, linear, grads);
431
432        let grads = linear.forward(x).backward();
433        let grads = GradientsParams::from_grads(grads, &linear);
434        let linear = optimizer.step(LEARNING_RATE, linear, grads);
435
436        let state_updated = linear.into_record();
437        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
438    }
439
440    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
441        let device = Default::default();
442        let record = LinearRecord {
443            weight: Param::from_data(weight, &device),
444            bias: Some(Param::from_data(bias, &device)),
445        };
446
447        LinearConfig::new(6, 6).init(&device).load_record(record)
448    }
449
450    fn create_adan() -> OptimizerAdaptor<Adan, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
451        let config = AdanConfig::new();
452        Adan {
453            momentum: AdaptiveNesterovMomentum {
454                beta_1: config.beta_1,
455                beta_2: config.beta_2,
456                beta_3: config.beta_3,
457                epsilon: config.epsilon,
458            },
459            weight_decay: config.weight_decay,
460            no_prox: config.no_prox,
461        }
462        .into()
463    }
464}