1use super::{AdaptiveMomentumState, SimpleOptimizer};
2use crate::config::Config;
3use crate::optim::adaptor::OptimizerAdaptor;
4use crate::tensor::{Tensor, backend::AutodiffBackend};
5use crate::{
6 self as burn, LearningRate, grad_clipping::GradientClippingConfig, module::AutodiffModule,
7 record::Record,
8};
9use burn_tensor::{backend::Backend, ops::Device};
10
11#[cfg(not(feature = "std"))]
12use num_traits::Float;
13
14#[derive(Config)]
16pub struct AdamWConfig {
17 #[config(default = 0.9)]
19 beta_1: f32,
20 #[config(default = 0.999)]
22 beta_2: f32,
23 #[config(default = 1e-5)]
25 epsilon: f32,
26 #[config(default = 1e-4)]
28 weight_decay: f32,
29 grad_clipping: Option<GradientClippingConfig>,
31}
32
33#[derive(Clone)]
35pub struct AdamW {
36 momentum: AdaptiveMomentumW,
37 weight_decay: f32,
38}
39
40#[derive(Record, Clone, new)]
42pub struct AdamWState<B: Backend, const D: usize> {
43 pub momentum: AdaptiveMomentumState<B, D>,
45}
46
47impl<B: Backend> SimpleOptimizer<B> for AdamW {
48 type State<const D: usize> = AdamWState<B, D>;
49
50 fn step<const D: usize>(
52 &self,
53 lr: LearningRate,
55 tensor: Tensor<B, D>,
57 grad: Tensor<B, D>,
59 state: Option<Self::State<D>>,
61 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
62 let tensor_updated = tensor.clone() - tensor.mul_scalar(lr).mul_scalar(self.weight_decay);
63
64 let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
65
66 let state = AdamWState {
67 momentum: momentum_state,
68 };
69
70 (tensor_updated - raw_delta.mul_scalar(lr), Some(state))
71 }
72
73 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
74 state.momentum = state.momentum.to_device(device);
75 state
76 }
77}
78
79impl AdamWConfig {
80 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
86 let optim = AdamW {
87 momentum: AdaptiveMomentumW {
88 beta_1: self.beta_1,
89 beta_2: self.beta_2,
90 epsilon: self.epsilon,
91 },
92 weight_decay: self.weight_decay,
93 };
94
95 let mut optim = OptimizerAdaptor::from(optim);
96 if let Some(config) = &self.grad_clipping {
97 optim = optim.with_grad_clipping(config.init());
98 }
99 optim
100 }
101}
102
103#[derive(Clone)]
104struct AdaptiveMomentumW {
105 beta_1: f32,
106 beta_2: f32,
107 epsilon: f32,
108}
109
110impl AdaptiveMomentumW {
111 pub fn transform<B: Backend, const D: usize>(
112 &self,
113 grad: Tensor<B, D>,
114 state: Option<AdaptiveMomentumState<B, D>>,
115 ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
116 let state = if let Some(mut state) = state {
117 let factor = 1.0 - self.beta_1;
119 state.moment_1 = state
120 .moment_1
121 .mul_scalar(self.beta_1)
122 .add(grad.clone().mul_scalar(factor));
123
124 let factor = 1.0 - self.beta_2;
126 state.moment_2 = state
127 .moment_2
128 .mul_scalar(self.beta_2)
129 .add(grad.powi_scalar(2).mul_scalar(factor));
130
131 state.time += 1;
133
134 state
135 } else {
136 let factor = 1.0 - self.beta_1;
138 let moment_1 = grad.clone().mul_scalar(factor);
139
140 let factor = 1.0 - self.beta_2;
142 let moment_2 = grad.powi_scalar(2).mul_scalar(factor);
143
144 AdaptiveMomentumState::new(1, moment_1, moment_2)
145 };
146
147 let time: i32 = state.time as i32;
148
149 let moment_1_corrected = state
151 .moment_1
152 .clone()
153 .div_scalar(1f32 - self.beta_1.powi(time));
154
155 let moment_2_corrected = state
156 .moment_2
157 .clone()
158 .div_scalar(1f32 - self.beta_2.powi(time));
159
160 let update_delta =
162 moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
163
164 (
165 update_delta,
166 AdaptiveMomentumState::new(state.time, state.moment_1, state.moment_2),
167 )
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::module::{Module, Param};
175 use crate::optim::{GradientsParams, Optimizer};
176 use crate::tensor::{Distribution, Tensor, TensorData};
177 use crate::{TestAutodiffBackend, nn};
178 use burn_tensor::{Tolerance, ops::FloatElem};
179 type FT = FloatElem<TestAutodiffBackend>;
180
181 const LEARNING_RATE: LearningRate = 0.01;
182
183 #[test]
184 fn test_adamw_optimizer_save_load_state() {
185 let device = Default::default();
186 let linear = nn::LinearConfig::new(6, 6).init(&device);
187 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
188 let mut optimizer = create_adamw();
189 let grads = linear.forward(x).backward();
190 let grads = GradientsParams::from_grads(grads, &linear);
191 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
192
193 #[cfg(feature = "std")]
194 {
195 use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
196
197 BinFileRecorder::<FullPrecisionSettings>::default()
198 .record(
199 optimizer.to_record(),
200 std::env::temp_dir().as_path().join("test_optim_adamw"),
201 )
202 .unwrap();
203 }
204 #[cfg(not(feature = "std"))]
205 {
206 use crate::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
207
208 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
209 .record(optimizer.to_record(), ())
210 .unwrap();
211 assert!(!result.is_empty());
212 }
213
214 let state_optim_before = optimizer.to_record();
215 let state_optim_before_copy = optimizer.to_record();
216 let optimizer = create_adamw();
217 let optimizer = optimizer.load_record(state_optim_before_copy);
218 let state_optim_after = optimizer.to_record();
219
220 assert_eq!(state_optim_before.len(), state_optim_after.len());
221 }
222
223 #[test]
224 fn test_adamw_optimizer_with_numbers() {
225 let linear = given_linear_layer(
226 TensorData::from([
227 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
228 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
229 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
230 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
231 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
232 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
233 ]),
234 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
235 );
236 let device = Default::default();
237 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
238 [
239 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
240 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
241 ],
242 &device,
243 )
244 .require_grad();
245 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
246 [
247 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
248 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
249 ],
250 &device,
251 )
252 .require_grad();
253
254 let mut optimizer = AdamWConfig::new()
255 .with_epsilon(1e-8)
256 .with_beta_1(0.9)
257 .with_beta_2(0.999)
258 .with_weight_decay(0.5)
259 .init();
260
261 let grads = linear.forward(x_1).backward();
262 let grads = GradientsParams::from_grads(grads, &linear);
263 let linear = optimizer.step(LEARNING_RATE, linear, grads);
264
265 let grads = linear.forward(x_2).backward();
266 let grads = GradientsParams::from_grads(grads, &linear);
267 let linear = optimizer.step(LEARNING_RATE, linear, grads);
268
269 let state_updated = linear.into_record();
270 let weights_expected = TensorData::from([
271 [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],
272 [
273 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,
274 ],
275 [
276 -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,
277 ],
278 [
279 -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,
280 ],
281 [
282 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,
283 ],
284 [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],
285 ]);
286 let bias_expected = TensorData::from([
287 -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
288 ]);
289
290 let (weight_updated, bias_updated) = (
291 state_updated.weight.to_data(),
292 state_updated.bias.unwrap().to_data(),
293 );
294
295 let tolerance = Tolerance::absolute(1e-2);
296 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
297 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
298 }
299
300 #[test]
301 fn test_adam_optimizer_no_nan() {
302 let linear = given_linear_layer(
303 TensorData::from([
304 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
305 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
306 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
307 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
308 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
309 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
310 ]),
311 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
312 );
313
314 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
315 [
316 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
317 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
318 ],
319 &Default::default(),
320 )
321 .require_grad();
322
323 let mut optimizer = AdamWConfig::new()
324 .with_epsilon(1e-8)
325 .with_beta_1(0.9)
326 .with_beta_2(0.999)
327 .with_weight_decay(0.5)
328 .init();
329
330 let grads = linear.forward(x.clone()).backward();
331 let grads = GradientsParams::from_grads(grads, &linear);
332 let linear = optimizer.step(LEARNING_RATE, linear, grads);
333
334 let grads = linear.forward(x).backward();
335 let grads = GradientsParams::from_grads(grads, &linear);
336 let linear = optimizer.step(LEARNING_RATE, linear, grads);
337
338 let state_updated = linear.into_record();
339 assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
340 }
341
342 fn given_linear_layer(weight: TensorData, bias: TensorData) -> nn::Linear<TestAutodiffBackend> {
343 let device = Default::default();
344 let record = nn::LinearRecord {
345 weight: Param::from_data(weight, &device),
346 bias: Some(Param::from_data(bias, &device)),
347 };
348
349 nn::LinearConfig::new(6, 6)
350 .init(&device)
351 .load_record(record)
352 }
353
354 fn create_adamw()
355 -> OptimizerAdaptor<AdamW, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend> {
356 let config = AdamWConfig::new();
357 AdamW {
358 momentum: AdaptiveMomentumW {
359 beta_1: config.beta_1,
360 beta_2: config.beta_2,
361 epsilon: config.epsilon,
362 },
363 weight_decay: config.weight_decay,
364 }
365 .into()
366 }
367}