1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8
9use super::{
10 SimpleOptimizer,
11 adaptor::OptimizerAdaptor,
12 decay::{WeightDecay, WeightDecayConfig},
13};
14use crate::{LearningRate, grad_clipping::GradientClippingConfig};
15
16#[cfg(not(feature = "std"))]
17#[allow(unused_imports)]
18use num_traits::Float as _;
19
20#[derive(Config, Debug)]
22pub struct AdamConfig {
23 #[config(default = 0.9)]
25 beta_1: f32,
26 #[config(default = 0.999)]
28 beta_2: f32,
29 #[config(default = 1e-5)]
31 epsilon: f32,
32 weight_decay: Option<WeightDecayConfig>,
34 grad_clipping: Option<GradientClippingConfig>,
36}
37
38#[derive(Clone)]
40pub struct Adam {
41 momentum: AdaptiveMomentum,
42 weight_decay: Option<WeightDecay>,
43}
44
45#[derive(Record, Clone, new)]
47pub struct AdamState<B: Backend, const D: usize> {
48 pub momentum: AdaptiveMomentumState<B, D>,
50}
51
52impl<B: Backend> SimpleOptimizer<B> for Adam {
53 type State<const D: usize> = AdamState<B, D>;
54
55 fn step<const D: usize>(
56 &self,
57 lr: LearningRate,
58 tensor: Tensor<B, D>,
59 mut grad: Tensor<B, D>,
60 state: Option<Self::State<D>>,
61 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
62 let mut state_momentum = None;
63
64 if let Some(state) = state {
65 state_momentum = Some(state.momentum);
66 }
67
68 if let Some(weight_decay) = &self.weight_decay {
69 grad = weight_decay.transform(grad, tensor.clone());
70 }
71
72 let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
73
74 let state = AdamState::new(state_momentum);
75 let delta = grad.mul_scalar(lr);
76
77 (tensor - delta, Some(state))
78 }
79
80 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
81 state.momentum = state.momentum.to_device(device);
82 state
83 }
84}
85
86impl AdamConfig {
87 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
93 let optim = Adam {
94 momentum: AdaptiveMomentum {
95 beta_1: self.beta_1,
96 beta_2: self.beta_2,
97 epsilon: self.epsilon,
98 },
99 weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
100 };
101
102 let mut optim = OptimizerAdaptor::from(optim);
103 if let Some(config) = &self.grad_clipping {
104 optim = optim.with_grad_clipping(config.init());
105 }
106 optim
107 }
108}
109
110#[derive(Record, new, Clone)]
112pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
113 pub time: usize,
115 pub moment_1: Tensor<B, D>,
117 pub moment_2: Tensor<B, D>,
119}
120
121#[derive(Clone)]
122struct AdaptiveMomentum {
123 beta_1: f32,
124 beta_2: f32,
125 epsilon: f32,
126}
127
128impl AdaptiveMomentum {
129 pub fn transform<B: Backend, const D: usize>(
130 &self,
131 grad: Tensor<B, D>,
132 momentum_state: Option<AdaptiveMomentumState<B, D>>,
133 ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
134 let state = if let Some(mut state) = momentum_state {
135 let factor = 1.0 - self.beta_1;
136 state.moment_1 = state
137 .moment_1
138 .mul_scalar(self.beta_1)
139 .add(grad.clone().mul_scalar(factor));
140
141 let factor = 1.0 - self.beta_2;
142 state.moment_2 = state
143 .moment_2
144 .mul_scalar(self.beta_2)
145 .add(grad.powi_scalar(2).mul_scalar(factor));
146
147 state.time += 1;
148
149 state
150 } else {
151 let factor = 1.0 - self.beta_1;
152 let moment_1 = grad.clone().mul_scalar(factor);
153
154 let factor = 1.0 - self.beta_2;
155 let moment_2 = grad.powi_scalar(2).mul_scalar(factor);
156
157 AdaptiveMomentumState::new(1, moment_1, moment_2)
158 };
159
160 let time = state.time as i32;
161 let moment_1_corrected = state
162 .moment_1
163 .clone()
164 .div_scalar(1f32 - self.beta_1.powi(time));
165 let moment_2_corrected = state
166 .moment_2
167 .clone()
168 .div_scalar(1f32 - self.beta_2.powi(time));
169
170 let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
171
172 (grad, state)
173 }
174}
175
176impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
177 pub fn to_device(mut self, device: &B::Device) -> Self {
187 self.moment_1 = self.moment_1.to_device(device);
188 self.moment_2 = self.moment_2.to_device(device);
189 self
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use burn::tensor::Tolerance;
196 use burn::tensor::ops::FloatElem;
197
198 use super::*;
199 use crate::TestAutodiffBackend;
200 use crate::{GradientsParams, Optimizer};
201 use burn::module::{Module, Param};
202 use burn::tensor::{Distribution, Tensor, TensorData};
203 use burn_nn::{Linear, LinearConfig, LinearRecord};
204
205 const LEARNING_RATE: LearningRate = 0.01;
206
207 #[test]
208 fn test_adam_optimizer_save_load_state() {
209 let device = Default::default();
210 let linear = LinearConfig::new(6, 6).init(&device);
211 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
212 let mut optimizer = create_adam();
213 let grads = linear.forward(x).backward();
214 let grads = GradientsParams::from_grads(grads, &linear);
215 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
216
217 #[cfg(feature = "std")]
218 {
219 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
220
221 BinFileRecorder::<FullPrecisionSettings>::default()
222 .record(
223 optimizer.to_record(),
224 std::env::temp_dir().as_path().join("test_optim_adam"),
225 )
226 .unwrap();
227 }
228 #[cfg(not(feature = "std"))]
229 {
230 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
231
232 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
233 .record(optimizer.to_record(), ())
234 .unwrap();
235 assert!(!result.is_empty());
236 }
237
238 let state_optim_before = optimizer.to_record();
239 let state_optim_before_copy = optimizer.to_record();
240 let optimizer = create_adam();
241 let optimizer = optimizer.load_record(state_optim_before_copy);
242 let state_optim_after = optimizer.to_record();
243
244 assert_eq!(state_optim_before.len(), state_optim_after.len());
245 }
246
247 #[test]
248 fn test_adam_optimizer_with_numbers() {
249 let device = Default::default();
250 let linear = given_linear_layer(
251 TensorData::from([
252 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
253 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
254 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
255 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
256 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
257 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
258 ]),
259 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
260 );
261 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
262 [
263 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
264 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
265 ],
266 &device,
267 )
268 .require_grad();
269 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
270 [
271 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
272 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
273 ],
274 &device,
275 )
276 .require_grad();
277
278 let mut optimizer = AdamConfig::new()
279 .with_epsilon(1e-8)
280 .with_beta_1(0.9)
281 .with_beta_2(0.999)
282 .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
283 .init();
284
285 let grads = linear.forward(x_1).backward();
286 let grads = GradientsParams::from_grads(grads, &linear);
287 let linear = optimizer.step(LEARNING_RATE, linear, grads);
288
289 let grads = linear.forward(x_2).backward();
290 let grads = GradientsParams::from_grads(grads, &linear);
291 let linear = optimizer.step(LEARNING_RATE, linear, grads);
292
293 let state_updated = linear.into_record();
294 let weights_expected = TensorData::from([
295 [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
296 [
297 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
298 ],
299 [
300 -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
301 ],
302 [
303 -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
304 ],
305 [
306 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
307 ],
308 [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
309 ]);
310 let bias_expected = TensorData::from([
311 -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
312 ]);
313
314 let (weight_updated, bias_updated) = (
315 state_updated.weight.to_data(),
316 state_updated.bias.unwrap().to_data(),
317 );
318
319 type FT = FloatElem<TestAutodiffBackend>;
320 let tolerance = Tolerance::absolute(1e-2);
321 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
322 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
323 }
324
325 #[test]
326 fn test_adam_optimizer_no_nan() {
327 let linear = given_linear_layer(
328 TensorData::from([
329 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
330 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
331 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
332 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
333 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
334 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
335 ]),
336 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
337 );
338
339 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
340 [
341 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
342 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
343 ],
344 &Default::default(),
345 )
346 .require_grad();
347
348 let mut optimizer = AdamConfig::new()
349 .with_epsilon(1e-8)
350 .with_beta_1(0.9)
351 .with_beta_2(0.999)
352 .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
353 .init();
354
355 let grads = linear.forward(x.clone()).backward();
356 let grads = GradientsParams::from_grads(grads, &linear);
357 let linear = optimizer.step(LEARNING_RATE, linear, grads);
358
359 let grads = linear.forward(x).backward();
360 let grads = GradientsParams::from_grads(grads, &linear);
361 let linear = optimizer.step(LEARNING_RATE, linear, grads);
362
363 let state_updated = linear.into_record();
364 assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
365 }
366
367 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
368 let device = Default::default();
369 let record = LinearRecord {
370 weight: Param::from_data(weight, &device),
371 bias: Some(Param::from_data(bias, &device)),
372 };
373
374 LinearConfig::new(6, 6).init(&device).load_record(record)
375 }
376
377 fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
378 let config = AdamConfig::new();
379 Adam {
380 momentum: AdaptiveMomentum {
381 beta_1: config.beta_1,
382 beta_2: config.beta_2,
383 epsilon: config.epsilon,
384 },
385 weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
386 }
387 .into()
388 }
389}