1use crate::{
2 self as burn, LearningRate, grad_clipping::GradientClippingConfig, module::AutodiffModule,
3 record::Record,
4};
5
6use super::{
7 SimpleOptimizer,
8 decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::config::Config;
11use crate::optim::adaptor::OptimizerAdaptor;
12use crate::tensor::{Tensor, backend::AutodiffBackend};
13use burn_tensor::{backend::Backend, ops::Device};
14
15#[cfg(not(feature = "std"))]
16use num_traits::Float;
17
18#[derive(Config)]
20pub struct AdamConfig {
21 #[config(default = 0.9)]
23 beta_1: f32,
24 #[config(default = 0.999)]
26 beta_2: f32,
27 #[config(default = 1e-5)]
29 epsilon: f32,
30 weight_decay: Option<WeightDecayConfig>,
32 grad_clipping: Option<GradientClippingConfig>,
34}
35
36#[derive(Clone)]
38pub struct Adam {
39 momentum: AdaptiveMomentum,
40 weight_decay: Option<WeightDecay>,
41}
42
43#[derive(Record, Clone, new)]
45pub struct AdamState<B: Backend, const D: usize> {
46 pub momentum: AdaptiveMomentumState<B, D>,
48}
49
50impl<B: Backend> SimpleOptimizer<B> for Adam {
51 type State<const D: usize> = AdamState<B, D>;
52
53 fn step<const D: usize>(
54 &self,
55 lr: LearningRate,
56 tensor: Tensor<B, D>,
57 mut grad: Tensor<B, D>,
58 state: Option<Self::State<D>>,
59 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
60 let mut state_momentum = None;
61
62 if let Some(state) = state {
63 state_momentum = Some(state.momentum);
64 }
65
66 if let Some(weight_decay) = &self.weight_decay {
67 grad = weight_decay.transform(grad, tensor.clone());
68 }
69
70 let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
71
72 let state = AdamState::new(state_momentum);
73 let delta = grad.mul_scalar(lr);
74
75 (tensor - delta, Some(state))
76 }
77
78 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
79 state.momentum = state.momentum.to_device(device);
80 state
81 }
82}
83
84impl AdamConfig {
85 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
91 let optim = Adam {
92 momentum: AdaptiveMomentum {
93 beta_1: self.beta_1,
94 beta_2: self.beta_2,
95 epsilon: self.epsilon,
96 },
97 weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
98 };
99
100 let mut optim = OptimizerAdaptor::from(optim);
101 if let Some(config) = &self.grad_clipping {
102 optim = optim.with_grad_clipping(config.init());
103 }
104 optim
105 }
106}
107
108#[derive(Record, new, Clone)]
110pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
111 pub time: usize,
113 pub moment_1: Tensor<B, D>,
115 pub moment_2: Tensor<B, D>,
117}
118
119#[derive(Clone)]
120struct AdaptiveMomentum {
121 beta_1: f32,
122 beta_2: f32,
123 epsilon: f32,
124}
125
126impl AdaptiveMomentum {
127 pub fn transform<B: Backend, const D: usize>(
128 &self,
129 grad: Tensor<B, D>,
130 momentum_state: Option<AdaptiveMomentumState<B, D>>,
131 ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
132 let state = if let Some(mut state) = momentum_state {
133 let factor = 1.0 - self.beta_1;
134 state.moment_1 = state
135 .moment_1
136 .mul_scalar(self.beta_1)
137 .add(grad.clone().mul_scalar(factor));
138
139 let factor = 1.0 - self.beta_2;
140 state.moment_2 = state
141 .moment_2
142 .mul_scalar(self.beta_2)
143 .add(grad.powi_scalar(2).mul_scalar(factor));
144
145 state.time += 1;
146
147 state
148 } else {
149 let factor = 1.0 - self.beta_1;
150 let moment_1 = grad.clone().mul_scalar(factor);
151
152 let factor = 1.0 - self.beta_2;
153 let moment_2 = grad.powi_scalar(2).mul_scalar(factor);
154
155 AdaptiveMomentumState::new(1, moment_1, moment_2)
156 };
157
158 let time = state.time as i32;
159 let moment_1_corrected = state
160 .moment_1
161 .clone()
162 .div_scalar(1f32 - self.beta_1.powi(time));
163 let moment_2_corrected = state
164 .moment_2
165 .clone()
166 .div_scalar(1f32 - self.beta_2.powi(time));
167
168 let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
169
170 (grad, state)
171 }
172}
173
174impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
175 pub fn to_device(mut self, device: &B::Device) -> Self {
185 self.moment_1 = self.moment_1.to_device(device);
186 self.moment_2 = self.moment_2.to_device(device);
187 self
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use burn_tensor::Tolerance;
194 use burn_tensor::ops::FloatElem;
195
196 use super::*;
197 use crate::module::{Module, Param};
198 use crate::optim::{GradientsParams, Optimizer};
199 use crate::tensor::{Distribution, Tensor, TensorData};
200 use crate::{TestAutodiffBackend, nn};
201
202 const LEARNING_RATE: LearningRate = 0.01;
203
204 #[test]
205 fn test_adam_optimizer_save_load_state() {
206 let device = Default::default();
207 let linear = nn::LinearConfig::new(6, 6).init(&device);
208 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
209 let mut optimizer = create_adam();
210 let grads = linear.forward(x).backward();
211 let grads = GradientsParams::from_grads(grads, &linear);
212 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
213
214 #[cfg(feature = "std")]
215 {
216 use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
217
218 BinFileRecorder::<FullPrecisionSettings>::default()
219 .record(
220 optimizer.to_record(),
221 std::env::temp_dir().as_path().join("test_optim_adam"),
222 )
223 .unwrap();
224 }
225 #[cfg(not(feature = "std"))]
226 {
227 use crate::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
228
229 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
230 .record(optimizer.to_record(), ())
231 .unwrap();
232 assert!(!result.is_empty());
233 }
234
235 let state_optim_before = optimizer.to_record();
236 let state_optim_before_copy = optimizer.to_record();
237 let optimizer = create_adam();
238 let optimizer = optimizer.load_record(state_optim_before_copy);
239 let state_optim_after = optimizer.to_record();
240
241 assert_eq!(state_optim_before.len(), state_optim_after.len());
242 }
243
244 #[test]
245 fn test_adam_optimizer_with_numbers() {
246 let device = Default::default();
247 let linear = given_linear_layer(
248 TensorData::from([
249 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
250 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
251 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
252 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
253 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
254 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
255 ]),
256 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
257 );
258 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
259 [
260 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
261 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
262 ],
263 &device,
264 )
265 .require_grad();
266 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
267 [
268 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
269 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
270 ],
271 &device,
272 )
273 .require_grad();
274
275 let mut optimizer = AdamConfig::new()
276 .with_epsilon(1e-8)
277 .with_beta_1(0.9)
278 .with_beta_2(0.999)
279 .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
280 .init();
281
282 let grads = linear.forward(x_1).backward();
283 let grads = GradientsParams::from_grads(grads, &linear);
284 let linear = optimizer.step(LEARNING_RATE, linear, grads);
285
286 let grads = linear.forward(x_2).backward();
287 let grads = GradientsParams::from_grads(grads, &linear);
288 let linear = optimizer.step(LEARNING_RATE, linear, grads);
289
290 let state_updated = linear.into_record();
291 let weights_expected = TensorData::from([
292 [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
293 [
294 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
295 ],
296 [
297 -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
298 ],
299 [
300 -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
301 ],
302 [
303 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
304 ],
305 [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
306 ]);
307 let bias_expected = TensorData::from([
308 -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
309 ]);
310
311 let (weight_updated, bias_updated) = (
312 state_updated.weight.to_data(),
313 state_updated.bias.unwrap().to_data(),
314 );
315
316 type FT = FloatElem<TestAutodiffBackend>;
317 let tolerance = Tolerance::absolute(1e-2);
318 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
319 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
320 }
321
322 #[test]
323 fn test_adam_optimizer_no_nan() {
324 let linear = given_linear_layer(
325 TensorData::from([
326 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
327 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
328 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
329 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
330 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
331 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
332 ]),
333 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
334 );
335
336 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
337 [
338 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
339 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
340 ],
341 &Default::default(),
342 )
343 .require_grad();
344
345 let mut optimizer = AdamConfig::new()
346 .with_epsilon(1e-8)
347 .with_beta_1(0.9)
348 .with_beta_2(0.999)
349 .with_weight_decay(Some(WeightDecayConfig::new(0.5)))
350 .init();
351
352 let grads = linear.forward(x.clone()).backward();
353 let grads = GradientsParams::from_grads(grads, &linear);
354 let linear = optimizer.step(LEARNING_RATE, linear, grads);
355
356 let grads = linear.forward(x).backward();
357 let grads = GradientsParams::from_grads(grads, &linear);
358 let linear = optimizer.step(LEARNING_RATE, linear, grads);
359
360 let state_updated = linear.into_record();
361 assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
362 }
363
364 fn given_linear_layer(weight: TensorData, bias: TensorData) -> nn::Linear<TestAutodiffBackend> {
365 let device = Default::default();
366 let record = nn::LinearRecord {
367 weight: Param::from_data(weight, &device),
368 bias: Some(Param::from_data(bias, &device)),
369 };
370
371 nn::LinearConfig::new(6, 6)
372 .init(&device)
373 .load_record(record)
374 }
375
376 fn create_adam() -> OptimizerAdaptor<Adam, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
377 {
378 let config = AdamConfig::new();
379 Adam {
380 momentum: AdaptiveMomentum {
381 beta_1: config.beta_1,
382 beta_2: config.beta_2,
383 epsilon: config.epsilon,
384 },
385 weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
386 }
387 .into()
388 }
389}