1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8
9use super::{
10 SimpleOptimizer,
11 adaptor::OptimizerAdaptor,
12 decay::{WeightDecay, WeightDecayConfig},
13};
14use crate::{LearningRate, grad_clipping::GradientClippingConfig};
15
16#[cfg(not(feature = "std"))]
17#[allow(unused_imports)]
18use num_traits::Float as _;
19
20#[derive(Config, Debug)]
22pub struct AdamConfig {
23 #[config(default = 0.9)]
25 beta_1: f32,
26 #[config(default = 0.999)]
28 beta_2: f32,
29 #[config(default = 1e-5)]
31 epsilon: f32,
32 #[config(default = false)]
34 amsgrad: bool,
35 weight_decay: Option<WeightDecayConfig>,
37 grad_clipping: Option<GradientClippingConfig>,
39}
40
41#[derive(Clone)]
47pub struct Adam {
48 momentum: AdaptiveMomentum,
49 weight_decay: Option<WeightDecay>,
50}
51
52#[derive(Record, Clone, new)]
54pub struct AdamState<B: Backend, const D: usize> {
55 pub momentum: AdaptiveMomentumState<B, D>,
57}
58
59impl<B: Backend> SimpleOptimizer<B> for Adam {
60 type State<const D: usize> = AdamState<B, D>;
61
62 fn step<const D: usize>(
63 &self,
64 lr: LearningRate,
65 tensor: Tensor<B, D>,
66 mut grad: Tensor<B, D>,
67 state: Option<Self::State<D>>,
68 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
69 let mut state_momentum = None;
70
71 if let Some(state) = state {
72 state_momentum = Some(state.momentum);
73 }
74
75 if let Some(weight_decay) = &self.weight_decay {
76 grad = weight_decay.transform(grad, tensor.clone());
77 }
78
79 let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
80
81 let state = AdamState::new(state_momentum);
82 let delta = grad.mul_scalar(lr);
83
84 (tensor - delta, Some(state))
85 }
86
87 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
88 state.momentum = state.momentum.to_device(device);
89 state
90 }
91}
92
93impl AdamConfig {
94 pub fn build(&self) -> Adam {
96 Adam {
97 momentum: AdaptiveMomentum {
98 beta_1: self.beta_1,
99 beta_2: self.beta_2,
100 epsilon: self.epsilon,
101 amsgrad: self.amsgrad,
102 },
103 weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
104 }
105 }
106
107 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
113 let mut optim = OptimizerAdaptor::from(self.build());
114 if let Some(config) = &self.grad_clipping {
115 optim = optim.with_grad_clipping(config.init());
116 }
117 optim
118 }
119}
120
121#[derive(Record, new, Clone)]
123pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
124 pub time: usize,
126 pub moment_1: Tensor<B, D>,
128 pub moment_2: Tensor<B, D>,
130 #[new(default)]
132 pub max_moment_2: Option<Tensor<B, D>>,
133}
134
135#[derive(Clone)]
136struct AdaptiveMomentum {
137 beta_1: f32,
138 beta_2: f32,
139 epsilon: f32,
140 amsgrad: bool,
141}
142
143impl AdaptiveMomentum {
144 pub fn transform<B: Backend, const D: usize>(
145 &self,
146 grad: Tensor<B, D>,
147 momentum_state: Option<AdaptiveMomentumState<B, D>>,
148 ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
149 let state = if let Some(mut state) = momentum_state {
150 let factor = 1.0 - self.beta_1;
151 state.moment_1 = state
152 .moment_1
153 .mul_scalar(self.beta_1)
154 .add(grad.clone().mul_scalar(factor));
155
156 let factor = 1.0 - self.beta_2;
157 state.moment_2 = state
158 .moment_2
159 .mul_scalar(self.beta_2)
160 .add(grad.square().mul_scalar(factor));
161 if self.amsgrad {
162 let max_v = state
163 .max_moment_2
164 .take()
165 .unwrap_or_else(|| state.moment_2.clone());
166
167 let new_max = max_v.max_pair(state.moment_2.clone());
168 state.max_moment_2 = Some(new_max);
169 }
170
171 state.time += 1;
172
173 state
174 } else {
175 let factor = 1.0 - self.beta_1;
176 let moment_1 = grad.clone().mul_scalar(factor);
177
178 let factor = 1.0 - self.beta_2;
179 let moment_2 = grad.square().mul_scalar(factor);
180 let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
181 AdaptiveMomentumState {
182 time: 1,
183 moment_1,
184 moment_2,
185 max_moment_2,
186 }
187 };
188
189 let time = state.time as i32;
190 let bias_correction2_sqrt = (1.0 - self.beta_2.powi(time)).sqrt();
191 let combined_factor = bias_correction2_sqrt / (1.0 - self.beta_1.powi(time));
192
193 let v_to_use = if self.amsgrad {
194 state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
195 } else {
196 &state.moment_2
197 };
198
199 let grad = state.moment_1.clone().mul_scalar(combined_factor).div(
200 v_to_use
201 .clone()
202 .sqrt()
203 .add_scalar(self.epsilon * bias_correction2_sqrt),
204 );
205 (grad, state)
206 }
207}
208
209impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
210 pub fn to_device(mut self, device: &B::Device) -> Self {
220 self.moment_1 = self.moment_1.to_device(device);
221 self.moment_2 = self.moment_2.to_device(device);
222 self.max_moment_2 = self.max_moment_2.map(|tensor| tensor.to_device(device));
223 self
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use burn::tensor::Tolerance;
230 use burn::tensor::ops::FloatElem;
231
232 use super::*;
233 use crate::TestAutodiffBackend;
234 use crate::{GradientsParams, Optimizer};
235 use burn::module::{Module, Param};
236 use burn::tensor::{Distribution, Tensor, TensorData};
237 use burn_nn::{Linear, LinearConfig, LinearRecord};
238
239 const LEARNING_RATE: LearningRate = 0.01;
240
241 #[test]
242 fn test_adam_optimizer_save_load_state() {
243 let device = Default::default();
244 let linear = LinearConfig::new(6, 6).init(&device);
245 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
246 let mut optimizer = create_adam();
247 let grads = linear.forward(x).backward();
248 let grads = GradientsParams::from_grads(grads, &linear);
249 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
250
251 #[cfg(feature = "std")]
252 {
253 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
254
255 BinFileRecorder::<FullPrecisionSettings>::default()
256 .record(
257 optimizer.to_record(),
258 std::env::temp_dir().as_path().join("test_optim_adam"),
259 )
260 .unwrap();
261 }
262 #[cfg(not(feature = "std"))]
263 {
264 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
265
266 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
267 .record(optimizer.to_record(), ())
268 .unwrap();
269 assert!(!result.is_empty());
270 }
271
272 let state_optim_before = optimizer.to_record();
273 let state_optim_before_copy = optimizer.to_record();
274 let optimizer = create_adam();
275 let optimizer = optimizer.load_record(state_optim_before_copy);
276 let state_optim_after = optimizer.to_record();
277
278 assert_eq!(state_optim_before.len(), state_optim_after.len());
279 }
280 #[test]
281 fn test_adam_optimizer_with_amsgrad_50_steps() {
282 let device = Default::default();
283 let mut linear = given_linear_layer(
284 TensorData::from([
285 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
286 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
287 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
288 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
289 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
290 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
291 ]),
292 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
293 );
294
295 let mut optimizer = AdamConfig::new()
296 .with_epsilon(1e-8)
297 .with_beta_1(0.9)
298 .with_beta_2(0.999)
299 .with_amsgrad(true)
300 .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
301 .init();
302
303 for i in 1..=50 {
304 let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
305 .mul_scalar(i as f32 * 0.1)
306 .require_grad();
307
308 let grads = linear.forward(x).backward();
309 let grads = GradientsParams::from_grads(grads, &linear);
310 linear = optimizer.step(LEARNING_RATE, linear, grads);
311 }
312
313 let state_updated = linear.into_record();
314 let weight_updated = state_updated.weight.to_data();
315 let bias_updated = state_updated.bias.unwrap().to_data();
316
317 let weights_expected = TensorData::from([
318 [
319 -0.9125810265541077,
320 -0.45855265855789185,
321 -0.1915993094444275,
322 -0.2759990692138672,
323 -0.5099529027938843,
324 -0.5287043452262878,
325 ],
326 [
327 -0.5181325674057007,
328 -0.6139854788780212,
329 -0.9574727416038513,
330 -0.34102925658226013,
331 -0.400514155626297,
332 -0.8847861886024475,
333 ],
334 [
335 -0.614483118057251,
336 -0.5611032247543335,
337 -0.8887064456939697,
338 -0.34762972593307495,
339 -0.8708556890487671,
340 -0.2830044627189636,
341 ],
342 [
343 -0.8904699683189392,
344 -0.8151527643203735,
345 -0.9621278643608093,
346 -0.8905676603317261,
347 -0.671261191368103,
348 -0.4333854615688324,
349 ],
350 [
351 -0.26599061489105225,
352 -0.8119961023330688,
353 -0.22424538433551788,
354 -0.7672406435012817,
355 -0.2163349837064743,
356 -0.6258266568183899,
357 ],
358 [
359 -0.611397922039032,
360 -0.6075160503387451,
361 -0.4701341986656189,
362 -0.4039117991924286,
363 -0.5663845539093018,
364 -0.21262989938259125,
365 ],
366 ]);
367 let bias_expected = TensorData::from([
368 -0.8817203044891357,
369 -0.4038999378681183,
370 -0.5889149308204651,
371 -0.37475723028182983,
372 -0.3557940721511841,
373 -0.47914788126945496,
374 ]);
375
376 type FT = FloatElem<TestAutodiffBackend>;
377 let tolerance = Tolerance::absolute(1e-5);
378 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
379 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
380 }
381 #[test]
382 fn test_adam_optimizer_with_numbers() {
383 let device = Default::default();
384 let linear = given_linear_layer(
385 TensorData::from([
386 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
387 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
388 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
389 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
390 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
391 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
392 ]),
393 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
394 );
395 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
396 [
397 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
398 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
399 ],
400 &device,
401 )
402 .require_grad();
403 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
404 [
405 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
406 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
407 ],
408 &device,
409 )
410 .require_grad();
411
412 let mut optimizer = AdamConfig::new()
413 .with_epsilon(1e-8)
414 .with_beta_1(0.9)
415 .with_beta_2(0.999)
416 .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
417 .init();
418
419 let grads = linear.forward(x_1).backward();
420 let grads = GradientsParams::from_grads(grads, &linear);
421 let linear = optimizer.step(LEARNING_RATE, linear, grads);
422
423 let grads = linear.forward(x_2).backward();
424 let grads = GradientsParams::from_grads(grads, &linear);
425 let linear = optimizer.step(LEARNING_RATE, linear, grads);
426
427 let state_updated = linear.into_record();
428 let weights_expected = TensorData::from([
429 [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
430 [
431 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
432 ],
433 [
434 -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
435 ],
436 [
437 -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
438 ],
439 [
440 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
441 ],
442 [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
443 ]);
444 let bias_expected = TensorData::from([
445 -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
446 ]);
447
448 let (weight_updated, bias_updated) = (
449 state_updated.weight.to_data(),
450 state_updated.bias.unwrap().to_data(),
451 );
452
453 type FT = FloatElem<TestAutodiffBackend>;
454 let tolerance = Tolerance::absolute(1e-2);
455 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
456 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
457 }
458
459 #[test]
460 fn test_adam_optimizer_no_nan() {
461 let linear = given_linear_layer(
462 TensorData::from([
463 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
464 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
465 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
466 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
467 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
468 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
469 ]),
470 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
471 );
472
473 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
474 [
475 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
476 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
477 ],
478 &Default::default(),
479 )
480 .require_grad();
481
482 let mut optimizer = AdamConfig::new()
483 .with_epsilon(1e-8)
484 .with_beta_1(0.9)
485 .with_beta_2(0.999)
486 .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
487 .init();
488
489 let grads = linear.forward(x.clone()).backward();
490 let grads = GradientsParams::from_grads(grads, &linear);
491 let linear = optimizer.step(LEARNING_RATE, linear, grads);
492
493 let grads = linear.forward(x).backward();
494 let grads = GradientsParams::from_grads(grads, &linear);
495 let linear = optimizer.step(LEARNING_RATE, linear, grads);
496
497 let state_updated = linear.into_record();
498 assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
499 }
500
501 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
502 let device = Default::default();
503 let record = LinearRecord {
504 weight: Param::from_data(weight, &device),
505 bias: Some(Param::from_data(bias, &device)),
506 };
507
508 LinearConfig::new(6, 6).init(&device).load_record(record)
509 }
510
511 fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
512 let config = AdamConfig::new();
513 Adam {
514 momentum: AdaptiveMomentum {
515 beta_1: config.beta_1,
516 beta_2: config.beta_2,
517 epsilon: config.epsilon,
518 amsgrad: config.amsgrad,
519 },
520 weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
521 }
522 .into()
523 }
524}