1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::{Tensor, backend::AutodiffBackend};
5use burn::tensor::{backend::Backend, ops::Device};
6use burn::{module::AutodiffModule, record::Record};
7
8use super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor};
9use crate::{LearningRate, grad_clipping::GradientClippingConfig};
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15#[derive(Config, Debug)]
17pub struct AdamWConfig {
18 #[config(default = 0.9)]
20 beta_1: f32,
21 #[config(default = 0.999)]
23 beta_2: f32,
24 #[config(default = 1e-5)]
26 epsilon: f32,
27 #[config(default = 1e-4)]
29 weight_decay: f32,
30
31 #[config(default = false)]
35 cautious_weight_decay: bool,
36
37 grad_clipping: Option<GradientClippingConfig>,
39}
40
41#[derive(Clone)]
49pub struct AdamW {
50 momentum: AdaptiveMomentumW,
51 weight_decay: f32,
52 cautious_weight_decay: bool,
53}
54
55#[derive(Record, Clone, new)]
57pub struct AdamWState<B: Backend, const D: usize> {
58 pub momentum: AdaptiveMomentumState<B, D>,
60}
61
62impl<B: Backend> SimpleOptimizer<B> for AdamW {
63 type State<const D: usize> = AdamWState<B, D>;
64
65 fn step<const D: usize>(
67 &self,
68 lr: LearningRate,
70 tensor: Tensor<B, D>,
72 grad: Tensor<B, D>,
74 state: Option<Self::State<D>>,
76 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
77 let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
78
79 let decay_rate = lr * (self.weight_decay as f64);
80
81 let decayed_tensor = if decay_rate == 0.0 {
82 tensor.clone()
83 } else if self.cautious_weight_decay {
84 let tensor_pos = tensor.clone().greater_equal_elem(0.0);
87 let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0);
88 let differ = tensor_pos.not_equal(grad_pos);
89
90 tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0)
92 } else {
93 tensor.clone().mul_scalar(1.0 - decay_rate)
94 };
95
96 let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr);
97
98 let state = AdamWState {
99 momentum: momentum_state,
100 };
101
102 (tensor_updated, Some(state))
103 }
104
105 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
106 state.momentum = state.momentum.to_device(device);
107 state
108 }
109}
110
111impl AdamWConfig {
112 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
118 let optim = AdamW {
119 momentum: AdaptiveMomentumW {
120 beta_1: self.beta_1,
121 beta_2: self.beta_2,
122 epsilon: self.epsilon,
123 },
124 weight_decay: self.weight_decay,
125 cautious_weight_decay: self.cautious_weight_decay,
126 };
127
128 let mut optim = OptimizerAdaptor::from(optim);
129 if let Some(config) = &self.grad_clipping {
130 optim = optim.with_grad_clipping(config.init());
131 }
132 optim
133 }
134}
135
136#[derive(Clone)]
137struct AdaptiveMomentumW {
138 beta_1: f32,
139 beta_2: f32,
140 epsilon: f32,
141}
142
143impl AdaptiveMomentumW {
144 pub fn transform<B: Backend, const D: usize>(
145 &self,
146 grad: Tensor<B, D>,
147 state: Option<AdaptiveMomentumState<B, D>>,
148 ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
149 let factor_1 = 1.0 - self.beta_1;
150 let factor_2 = 1.0 - self.beta_2;
151
152 let state = if let Some(mut state) = state {
153 state.moment_1 = state
155 .moment_1
156 .mul_scalar(self.beta_1)
157 .add(grad.clone().mul_scalar(factor_1));
158
159 state.moment_2 = state
161 .moment_2
162 .mul_scalar(self.beta_2)
163 .add(grad.powi_scalar(2).mul_scalar(factor_2));
164
165 state.time += 1;
167
168 state
169 } else {
170 let moment_1 = grad.clone().mul_scalar(factor_1);
172
173 let moment_2 = grad.powi_scalar(2).mul_scalar(factor_2);
175
176 AdaptiveMomentumState::new(1, moment_1, moment_2)
177 };
178
179 let time: i32 = state.time as i32;
180
181 let moment_1_corrected = state
183 .moment_1
184 .clone()
185 .div_scalar(1f32 - self.beta_1.powi(time));
186
187 let moment_2_corrected = state
188 .moment_2
189 .clone()
190 .div_scalar(1f32 - self.beta_2.powi(time));
191
192 let update_delta =
194 moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
195
196 (
197 update_delta,
198 AdaptiveMomentumState::new(state.time, state.moment_1, state.moment_2),
199 )
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::TestAutodiffBackend;
207 use crate::{GradientsParams, Optimizer};
208 use burn::module::{Module, Param};
209 use burn::tensor::{Distribution, Tensor, TensorData};
210 use burn::tensor::{Tolerance, ops::FloatElem};
211 use burn_nn::{Linear, LinearConfig, LinearRecord};
212
213 type FT = FloatElem<TestAutodiffBackend>;
214
215 const LEARNING_RATE: LearningRate = 0.01;
216
217 #[test]
218 fn test_adamw_optimizer_save_load_state() {
219 let device = Default::default();
220 let linear = LinearConfig::new(6, 6).init(&device);
221 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
222 let mut optimizer = create_adamw();
223 let grads = linear.forward(x).backward();
224 let grads = GradientsParams::from_grads(grads, &linear);
225 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
226
227 #[cfg(feature = "std")]
228 {
229 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
230
231 BinFileRecorder::<FullPrecisionSettings>::default()
232 .record(
233 optimizer.to_record(),
234 std::env::temp_dir().as_path().join("test_optim_adamw"),
235 )
236 .unwrap();
237 }
238 #[cfg(not(feature = "std"))]
239 {
240 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
241
242 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
243 .record(optimizer.to_record(), ())
244 .unwrap();
245 assert!(!result.is_empty());
246 }
247
248 let state_optim_before = optimizer.to_record();
249 let state_optim_before_copy = optimizer.to_record();
250 let optimizer = create_adamw();
251 let optimizer = optimizer.load_record(state_optim_before_copy);
252 let state_optim_after = optimizer.to_record();
253
254 assert_eq!(state_optim_before.len(), state_optim_after.len());
255 }
256
257 #[test]
258 fn test_adamw_optimizer_with_numbers() {
259 let linear = given_linear_layer(
260 TensorData::from([
261 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
262 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
263 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
264 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
265 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
266 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
267 ]),
268 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
269 );
270 let device = Default::default();
271 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
272 [
273 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
274 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
275 ],
276 &device,
277 )
278 .require_grad();
279 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
280 [
281 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
282 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
283 ],
284 &device,
285 )
286 .require_grad();
287
288 let mut optimizer = AdamWConfig::new()
289 .with_epsilon(1e-8)
290 .with_beta_1(0.9)
291 .with_beta_2(0.999)
292 .with_weight_decay(0.5)
293 .init();
294
295 let grads = linear.forward(x_1).backward();
296 let grads = GradientsParams::from_grads(grads, &linear);
297 let linear = optimizer.step(LEARNING_RATE, linear, grads);
298
299 let grads = linear.forward(x_2).backward();
300 let grads = GradientsParams::from_grads(grads, &linear);
301 let linear = optimizer.step(LEARNING_RATE, linear, grads);
302
303 let state_updated = linear.into_record();
304 let weights_expected = TensorData::from([
305 [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
306 [
307 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
308 ],
309 [
310 -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
311 ],
312 [
313 -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
314 ],
315 [
316 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
317 ],
318 [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
319 ]);
320 let bias_expected = TensorData::from([
321 -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
322 ]);
323
324 let (weight_updated, bias_updated) = (
325 state_updated.weight.to_data(),
326 state_updated.bias.unwrap().to_data(),
327 );
328
329 let tolerance = Tolerance::absolute(1e-2);
330 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
331 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
332 }
333
334 #[test]
335 fn test_adamw_optimizer_with_numbers_cautious() {
336 let linear = given_linear_layer(
337 TensorData::from([
338 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
339 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
340 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
341 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
342 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
343 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
344 ]),
345 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
346 );
347 let device = Default::default();
348 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
349 [
350 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
351 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
352 ],
353 &device,
354 )
355 .require_grad();
356 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
357 [
358 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
359 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085],
360 ],
361 &device,
362 )
363 .require_grad();
364
365 let mut optimizer = AdamWConfig::new()
366 .with_cautious_weight_decay(true)
367 .with_epsilon(1e-8)
368 .with_beta_1(0.9)
369 .with_beta_2(0.999)
370 .with_weight_decay(0.5)
371 .init();
372
373 let grads = linear.forward(x_1).backward();
374 let grads = GradientsParams::from_grads(grads, &linear);
375 let linear = optimizer.step(LEARNING_RATE, linear, grads);
376
377 let grads = linear.forward(x_2).backward();
378 let grads = GradientsParams::from_grads(grads, &linear);
379 let linear = optimizer.step(LEARNING_RATE, linear, grads);
380
381 let state_updated = linear.into_record();
382 let weights_expected = TensorData::from([
383 [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
384 [
385 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
386 ],
387 [
388 -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
389 ],
390 [
391 -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
392 ],
393 [
394 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
395 ],
396 [
397 -0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332,
398 ],
399 ]);
400 let bias_expected = TensorData::from([
401 -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
402 ]);
403
404 let (weight_updated, bias_updated) = (
405 state_updated.weight.to_data(),
406 state_updated.bias.unwrap().to_data(),
407 );
408
409 let tolerance = Tolerance::absolute(1e-2);
410 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
411 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
412 }
413
414 #[test]
415 fn test_adam_optimizer_no_nan() {
416 let linear = given_linear_layer(
417 TensorData::from([
418 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
419 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
420 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
421 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
422 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
423 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
424 ]),
425 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
426 );
427
428 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
429 [
430 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
431 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
432 ],
433 &Default::default(),
434 )
435 .require_grad();
436
437 let mut optimizer = AdamWConfig::new()
438 .with_epsilon(1e-8)
439 .with_beta_1(0.9)
440 .with_beta_2(0.999)
441 .with_weight_decay(0.5)
442 .init();
443
444 let grads = linear.forward(x.clone()).backward();
445 let grads = GradientsParams::from_grads(grads, &linear);
446 let linear = optimizer.step(LEARNING_RATE, linear, grads);
447
448 let grads = linear.forward(x).backward();
449 let grads = GradientsParams::from_grads(grads, &linear);
450 let linear = optimizer.step(LEARNING_RATE, linear, grads);
451
452 let state_updated = linear.into_record();
453 assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
454 }
455
456 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
457 let device = Default::default();
458 let record = LinearRecord {
459 weight: Param::from_data(weight, &device),
460 bias: Some(Param::from_data(bias, &device)),
461 };
462
463 LinearConfig::new(6, 6).init(&device).load_record(record)
464 }
465
466 fn create_adamw() -> OptimizerAdaptor<AdamW, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
467 let config = AdamWConfig::new();
468 AdamW {
469 momentum: AdaptiveMomentumW {
470 beta_1: config.beta_1,
471 beta_2: config.beta_2,
472 epsilon: config.epsilon,
473 },
474 weight_decay: config.weight_decay,
475 cautious_weight_decay: false,
476 }
477 .into()
478 }
479}