1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::{Tensor, backend::AutodiffBackend};
5use burn::tensor::{backend::Backend, ops::Device};
6use burn::{module::AutodiffModule, record::Record};
7
8use super::{SimpleOptimizer, adaptor::OptimizerAdaptor};
9use crate::{LearningRate, grad_clipping::GradientClippingConfig};
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15#[derive(Config, Debug)]
20pub struct AdanConfig {
21 #[config(default = 0.98)]
23 beta_1: f32,
24 #[config(default = 0.92)]
26 beta_2: f32,
27 #[config(default = 0.99)]
29 beta_3: f32,
30 #[config(default = 1e-8)]
32 epsilon: f32,
33 #[config(default = 0.0)]
35 weight_decay: f32,
36 #[config(default = false)]
38 no_prox: bool,
39 grad_clipping: Option<GradientClippingConfig>,
41}
42
43#[derive(Clone)]
50pub struct Adan {
51 momentum: AdaptiveNesterovMomentum,
52 weight_decay: f32,
53 no_prox: bool,
54}
55
56#[derive(Record, Clone, new)]
58pub struct AdanState<B: Backend, const D: usize> {
59 pub momentum: AdaptiveNesterovMomentumState<B, D>,
61}
62
63impl<B: Backend> SimpleOptimizer<B> for Adan {
64 type State<const D: usize> = AdanState<B, D>;
65
66 fn step<const D: usize>(
67 &self,
68 lr: LearningRate,
69 tensor: Tensor<B, D>,
70 grad: Tensor<B, D>,
71 state: Option<Self::State<D>>,
72 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
73 let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));
74
75 let decay_rate = lr * (self.weight_decay as f64);
76 let delta = raw_delta.mul_scalar(lr);
77
78 let tensor_updated = if self.no_prox {
79 if decay_rate == 0.0 {
80 tensor - delta
81 } else {
82 tensor.mul_scalar(1.0 - decay_rate) - delta
83 }
84 } else {
85 let updated = tensor - delta;
86 if decay_rate == 0.0 {
87 updated
88 } else {
89 updated.div_scalar(1.0 + decay_rate)
90 }
91 };
92
93 (tensor_updated, Some(AdanState::new(momentum_state)))
94 }
95
96 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
97 state.momentum = state.momentum.to_device(device);
98 state
99 }
100}
101
102impl AdanConfig {
103 pub fn build(&self) -> Adan {
105 Adan {
106 momentum: AdaptiveNesterovMomentum {
107 beta_1: self.beta_1,
108 beta_2: self.beta_2,
109 beta_3: self.beta_3,
110 epsilon: self.epsilon,
111 },
112 weight_decay: self.weight_decay,
113 no_prox: self.no_prox,
114 }
115 }
116
117 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adan, M, B> {
123 let mut optim = OptimizerAdaptor::from(self.build());
124 if let Some(config) = &self.grad_clipping {
125 optim = optim.with_grad_clipping(config.init());
126 }
127 optim
128 }
129}
130
131#[derive(Record, Clone, new)]
133pub struct AdaptiveNesterovMomentumState<B: Backend, const D: usize> {
134 pub time: usize,
136 pub exp_avg: Tensor<B, D>,
138 pub exp_avg_sq: Tensor<B, D>,
140 pub exp_avg_diff: Tensor<B, D>,
142 pub neg_pre_grad: Tensor<B, D>,
144}
145
146#[derive(Clone)]
147struct AdaptiveNesterovMomentum {
148 beta_1: f32,
149 beta_2: f32,
150 beta_3: f32,
151 epsilon: f32,
152}
153
154impl AdaptiveNesterovMomentum {
155 pub fn transform<B: Backend, const D: usize>(
156 &self,
157 grad: Tensor<B, D>,
158 state: Option<AdaptiveNesterovMomentumState<B, D>>,
159 ) -> (Tensor<B, D>, AdaptiveNesterovMomentumState<B, D>) {
160 let state = if let Some(mut state) = state {
161 let grad_diff = state.neg_pre_grad.clone().add(grad.clone());
162 let grad_diff_sq = grad_diff
163 .clone()
164 .mul_scalar(self.beta_2)
165 .add(grad.clone())
166 .square();
167
168 state.exp_avg = state
169 .exp_avg
170 .mul_scalar(self.beta_1)
171 .add(grad.clone().mul_scalar(1.0 - self.beta_1));
172 state.exp_avg_diff = state
173 .exp_avg_diff
174 .mul_scalar(self.beta_2)
175 .add(grad_diff.mul_scalar(1.0 - self.beta_2));
176 state.exp_avg_sq = state
177 .exp_avg_sq
178 .mul_scalar(self.beta_3)
179 .add(grad_diff_sq.mul_scalar(1.0 - self.beta_3));
180 state.neg_pre_grad = grad.mul_scalar(-1.0);
181 state.time += 1;
182 state
183 } else {
184 AdaptiveNesterovMomentumState::new(
185 1,
186 grad.clone().mul_scalar(1.0 - self.beta_1),
187 grad.clone().square().mul_scalar(1.0 - self.beta_3),
188 grad.zeros_like(),
189 grad.clone().mul_scalar(-1.0),
190 )
191 };
192
193 let time = state.time as i32;
194 let denom = state
195 .exp_avg_sq
196 .clone()
197 .sqrt()
198 .div_scalar((1.0 - self.beta_3.powi(time)).sqrt())
199 .add_scalar(self.epsilon);
200 let update = state
201 .exp_avg
202 .clone()
203 .div_scalar(1.0 - self.beta_1.powi(time))
204 .div(denom.clone())
205 .add(
206 state
207 .exp_avg_diff
208 .clone()
209 .mul_scalar(self.beta_2)
210 .div_scalar(1.0 - self.beta_2.powi(time))
211 .div(denom),
212 );
213
214 (update, state)
215 }
216}
217
218impl<B: Backend, const D: usize> AdaptiveNesterovMomentumState<B, D> {
219 #[allow(clippy::wrong_self_convention)]
220 fn to_device(mut self, device: &B::Device) -> Self {
221 self.exp_avg = self.exp_avg.to_device(device);
222 self.exp_avg_sq = self.exp_avg_sq.to_device(device);
223 self.exp_avg_diff = self.exp_avg_diff.to_device(device);
224 self.neg_pre_grad = self.neg_pre_grad.to_device(device);
225 self
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::TestAutodiffBackend;
233 use crate::{GradientsParams, Optimizer};
234 use burn::module::{Module, Param};
235 use burn::tensor::{Distribution, Tensor, TensorData};
236 use burn::tensor::{Tolerance, ops::FloatElem};
237 use burn_nn::{Linear, LinearConfig, LinearRecord};
238
239 type FT = FloatElem<TestAutodiffBackend>;
240
241 const LEARNING_RATE: LearningRate = 0.01;
242
243 #[test]
244 fn test_adan_optimizer_save_load_state() {
245 let device = Default::default();
246 let linear = LinearConfig::new(6, 6).init(&device);
247 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
248 let mut optimizer = create_adan();
249 let grads = linear.forward(x).backward();
250 let grads = GradientsParams::from_grads(grads, &linear);
251 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
252
253 #[cfg(feature = "std")]
254 {
255 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
256
257 BinFileRecorder::<FullPrecisionSettings>::default()
258 .record(
259 optimizer.to_record(),
260 std::env::temp_dir().as_path().join("test_optim_adan"),
261 )
262 .unwrap();
263 }
264 #[cfg(not(feature = "std"))]
265 {
266 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
267
268 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
269 .record(optimizer.to_record(), ())
270 .unwrap();
271 assert!(!result.is_empty());
272 }
273
274 let state_optim_before = optimizer.to_record();
275 let state_optim_before_copy = optimizer.to_record();
276 let optimizer = create_adan();
277 let optimizer = optimizer.load_record(state_optim_before_copy);
278 let state_optim_after = optimizer.to_record();
279
280 assert_eq!(state_optim_before.len(), state_optim_after.len());
281 }
282
283 #[test]
284 fn test_adan_optimizer_with_numbers() {
285 let linear = given_linear_layer(
286 TensorData::from([
287 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
288 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
289 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
290 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
291 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
292 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
293 ]),
294 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
295 );
296 let device = Default::default();
297 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
298 [
299 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
300 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
301 ],
302 &device,
303 )
304 .require_grad();
305 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
306 [
307 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
308 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
309 ],
310 &device,
311 )
312 .require_grad();
313
314 let mut optimizer = AdanConfig::new()
315 .with_beta_1(0.98)
316 .with_beta_2(0.92)
317 .with_beta_3(0.99)
318 .with_epsilon(1e-8)
319 .with_weight_decay(0.02)
320 .init();
321
322 let grads = linear.forward(x_1).backward();
323 let grads = GradientsParams::from_grads(grads, &linear);
324 let linear = optimizer.step(LEARNING_RATE, linear, grads);
325
326 let grads = linear.forward(x_2).backward();
327 let grads = GradientsParams::from_grads(grads, &linear);
328 let linear = optimizer.step(LEARNING_RATE, linear, grads);
329
330 let state_updated = linear.into_record();
331 let weights_expected = TensorData::from([
332 [
333 -0.34034607,
334 0.11747075,
335 0.38426402,
336 0.29999772,
337 0.06599136,
338 0.04719888,
339 ],
340 [
341 0.0644293,
342 -0.031732224,
343 -0.37979296,
344 0.24165839,
345 0.18218218,
346 -0.30532277,
347 ],
348 [
349 -0.038910445,
350 0.01466812,
351 -0.31599957,
352 0.2283826,
353 -0.29780683,
354 0.2929568,
355 ],
356 [
357 -0.3178632,
358 -0.24129382,
359 -0.39133376,
360 -0.31796312,
361 -0.09605193,
362 0.14255258,
363 ],
364 [
365 0.31026322,
366 -0.23771758,
367 0.3519465,
368 -0.19243571,
369 0.35984334,
370 -0.049992695,
371 ],
372 [
373 -0.03577819,
374 -0.031879753,
375 0.10586514,
376 0.17213862,
377 0.009403733,
378 0.36326218,
379 ],
380 ]);
381 let bias_expected = TensorData::from([
382 -0.4103378,
383 0.06837065,
384 -0.116955206,
385 0.097558975,
386 0.11655137,
387 -0.006999196,
388 ]);
389
390 let (weight_updated, bias_updated) = (
391 state_updated.weight.to_data(),
392 state_updated.bias.unwrap().to_data(),
393 );
394
395 let tolerance = Tolerance::absolute(1e-5);
396 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
397 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
398 }
399
400 #[test]
401 fn test_adan_optimizer_no_nan() {
402 let linear = given_linear_layer(
403 TensorData::from([
404 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
405 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
406 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
407 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
408 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
409 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
410 ]),
411 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
412 );
413
414 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
415 [
416 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
417 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
418 ],
419 &Default::default(),
420 )
421 .require_grad();
422
423 let mut optimizer = AdanConfig::new()
424 .with_epsilon(1e-8)
425 .with_weight_decay(0.02)
426 .init();
427
428 let grads = linear.forward(x.clone()).backward();
429 let grads = GradientsParams::from_grads(grads, &linear);
430 let linear = optimizer.step(LEARNING_RATE, linear, grads);
431
432 let grads = linear.forward(x).backward();
433 let grads = GradientsParams::from_grads(grads, &linear);
434 let linear = optimizer.step(LEARNING_RATE, linear, grads);
435
436 let state_updated = linear.into_record();
437 assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
438 }
439
440 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
441 let device = Default::default();
442 let record = LinearRecord {
443 weight: Param::from_data(weight, &device),
444 bias: Some(Param::from_data(bias, &device)),
445 };
446
447 LinearConfig::new(6, 6).init(&device).load_record(record)
448 }
449
450 fn create_adan() -> OptimizerAdaptor<Adan, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
451 let config = AdanConfig::new();
452 Adan {
453 momentum: AdaptiveNesterovMomentum {
454 beta_1: config.beta_1,
455 beta_2: config.beta_2,
456 beta_3: config.beta_3,
457 epsilon: config.epsilon,
458 },
459 weight_decay: config.weight_decay,
460 no_prox: config.no_prox,
461 }
462 .into()
463 }
464}