use burn_core as burn;
use burn::{module::AutodiffModule, record::Record};
use burn::config::Config;
use burn::tensor::{Tensor, backend::AutodiffBackend};
use burn::tensor::{backend::Backend, ops::Device};
use super::{
SimpleOptimizer,
adaptor::OptimizerAdaptor,
decay::{WeightDecay, WeightDecayConfig},
};
use crate::{LearningRate, grad_clipping::GradientClippingConfig};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
#[derive(Config, Debug)]
pub struct AdamConfig {
#[config(default = 0.9)]
beta_1: f32,
#[config(default = 0.999)]
beta_2: f32,
#[config(default = 1e-5)]
epsilon: f32,
#[config(default = false)]
amsgrad: bool,
weight_decay: Option<WeightDecayConfig>,
grad_clipping: Option<GradientClippingConfig>,
}
#[derive(Clone)]
pub struct Adam {
momentum: AdaptiveMomentum,
weight_decay: Option<WeightDecay>,
}
#[derive(Record, Clone, new)]
pub struct AdamState<B: Backend, const D: usize> {
pub momentum: AdaptiveMomentumState<B, D>,
}
impl<B: Backend> SimpleOptimizer<B> for Adam {
type State<const D: usize> = AdamState<B, D>;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
mut grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let mut state_momentum = None;
if let Some(state) = state {
state_momentum = Some(state.momentum);
}
if let Some(weight_decay) = &self.weight_decay {
grad = weight_decay.transform(grad, tensor.clone());
}
let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
let state = AdamState::new(state_momentum);
let delta = grad.mul_scalar(lr);
(tensor - delta, Some(state))
}
fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
state.momentum = state.momentum.to_device(device);
state
}
}
impl AdamConfig {
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
let optim = Adam {
momentum: AdaptiveMomentum {
beta_1: self.beta_1,
beta_2: self.beta_2,
epsilon: self.epsilon,
amsgrad: self.amsgrad,
},
weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
};
let mut optim = OptimizerAdaptor::from(optim);
if let Some(config) = &self.grad_clipping {
optim = optim.with_grad_clipping(config.init());
}
optim
}
}
#[derive(Record, new, Clone)]
pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
pub time: usize,
pub moment_1: Tensor<B, D>,
pub moment_2: Tensor<B, D>,
#[new(default)]
pub max_moment_2: Option<Tensor<B, D>>,
}
#[derive(Clone)]
struct AdaptiveMomentum {
beta_1: f32,
beta_2: f32,
epsilon: f32,
amsgrad: bool,
}
impl AdaptiveMomentum {
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
momentum_state: Option<AdaptiveMomentumState<B, D>>,
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
let state = if let Some(mut state) = momentum_state {
let factor = 1.0 - self.beta_1;
state.moment_1 = state
.moment_1
.mul_scalar(self.beta_1)
.add(grad.clone().mul_scalar(factor));
let factor = 1.0 - self.beta_2;
state.moment_2 = state
.moment_2
.mul_scalar(self.beta_2)
.add(grad.square().mul_scalar(factor));
if self.amsgrad {
let max_v = state
.max_moment_2
.take()
.unwrap_or_else(|| state.moment_2.clone());
let new_max = max_v.max_pair(state.moment_2.clone());
state.max_moment_2 = Some(new_max);
}
state.time += 1;
state
} else {
let factor = 1.0 - self.beta_1;
let moment_1 = grad.clone().mul_scalar(factor);
let factor = 1.0 - self.beta_2;
let moment_2 = grad.square().mul_scalar(factor);
let max_moment_2 = self.amsgrad.then(|| moment_2.clone());
AdaptiveMomentumState {
time: 1,
moment_1,
moment_2,
max_moment_2,
}
};
let time = state.time as i32;
let bias_correction2_sqrt = (1.0 - self.beta_2.powi(time)).sqrt();
let combined_factor = bias_correction2_sqrt / (1.0 - self.beta_1.powi(time));
let v_to_use = if self.amsgrad {
state.max_moment_2.as_ref().unwrap_or(&state.moment_2)
} else {
&state.moment_2
};
let grad = state.moment_1.clone().mul_scalar(combined_factor).div(
v_to_use
.clone()
.sqrt()
.add_scalar(self.epsilon * bias_correction2_sqrt),
);
(grad, state)
}
}
impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
pub fn to_device(mut self, device: &B::Device) -> Self {
self.moment_1 = self.moment_1.to_device(device);
self.moment_2 = self.moment_2.to_device(device);
self.max_moment_2 = self.max_moment_2.map(|tensor| tensor.to_device(device));
self
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Tolerance;
use burn::tensor::ops::FloatElem;
use super::*;
use crate::TestAutodiffBackend;
use crate::{GradientsParams, Optimizer};
use burn::module::{Module, Param};
use burn::tensor::{Distribution, Tensor, TensorData};
use burn_nn::{Linear, LinearConfig, LinearRecord};
const LEARNING_RATE: LearningRate = 0.01;
#[test]
fn test_adam_optimizer_save_load_state() {
let device = Default::default();
let linear = LinearConfig::new(6, 6).init(&device);
let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
let mut optimizer = create_adam();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
#[cfg(feature = "std")]
{
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
BinFileRecorder::<FullPrecisionSettings>::default()
.record(
optimizer.to_record(),
std::env::temp_dir().as_path().join("test_optim_adam"),
)
.unwrap();
}
#[cfg(not(feature = "std"))]
{
use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
let result = BinBytesRecorder::<FullPrecisionSettings>::default()
.record(optimizer.to_record(), ())
.unwrap();
assert!(!result.is_empty());
}
let state_optim_before = optimizer.to_record();
let state_optim_before_copy = optimizer.to_record();
let optimizer = create_adam();
let optimizer = optimizer.load_record(state_optim_before_copy);
let state_optim_after = optimizer.to_record();
assert_eq!(state_optim_before.len(), state_optim_after.len());
}
#[test]
fn test_adam_optimizer_with_amsgrad_50_steps() {
let device = Default::default();
let mut linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let mut optimizer = AdamConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_amsgrad(true)
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
.init();
for i in 1..=50 {
let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)
.mul_scalar(i as f32 * 0.1)
.require_grad();
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
linear = optimizer.step(LEARNING_RATE, linear, grads);
}
let state_updated = linear.into_record();
let weight_updated = state_updated.weight.to_data();
let bias_updated = state_updated.bias.unwrap().to_data();
let weights_expected = TensorData::from([
[
-0.9125810265541077,
-0.45855265855789185,
-0.1915993094444275,
-0.2759990692138672,
-0.5099529027938843,
-0.5287043452262878,
],
[
-0.5181325674057007,
-0.6139854788780212,
-0.9574727416038513,
-0.34102925658226013,
-0.400514155626297,
-0.8847861886024475,
],
[
-0.614483118057251,
-0.5611032247543335,
-0.8887064456939697,
-0.34762972593307495,
-0.8708556890487671,
-0.2830044627189636,
],
[
-0.8904699683189392,
-0.8151527643203735,
-0.9621278643608093,
-0.8905676603317261,
-0.671261191368103,
-0.4333854615688324,
],
[
-0.26599061489105225,
-0.8119961023330688,
-0.22424538433551788,
-0.7672406435012817,
-0.2163349837064743,
-0.6258266568183899,
],
[
-0.611397922039032,
-0.6075160503387451,
-0.4701341986656189,
-0.4039117991924286,
-0.5663845539093018,
-0.21262989938259125,
],
]);
let bias_expected = TensorData::from([
-0.8817203044891357,
-0.4038999378681183,
-0.5889149308204651,
-0.37475723028182983,
-0.3557940721511841,
-0.47914788126945496,
]);
type FT = FloatElem<TestAutodiffBackend>;
let tolerance = Tolerance::absolute(1e-5);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
}
#[test]
fn test_adam_optimizer_with_numbers() {
let device = Default::default();
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
],
&device,
)
.require_grad();
let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&device,
)
.require_grad();
let mut optimizer = AdamConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
.init();
let grads = linear.forward(x_1).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x_2).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
let weights_expected = TensorData::from([
[-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],
[
0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,
],
[
-0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,
],
[
-0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,
],
[
0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,
],
[-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],
]);
let bias_expected = TensorData::from([
-0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,
]);
let (weight_updated, bias_updated) = (
state_updated.weight.to_data(),
state_updated.bias.unwrap().to_data(),
);
type FT = FloatElem<TestAutodiffBackend>;
let tolerance = Tolerance::absolute(1e-2);
bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
}
#[test]
fn test_adam_optimizer_no_nan() {
let linear = given_linear_layer(
TensorData::from([
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
]),
TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
);
let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
[
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
],
&Default::default(),
)
.require_grad();
let mut optimizer = AdamConfig::new()
.with_epsilon(1e-8)
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_weight_decay(Some(WeightDecayConfig::new(0.5)))
.init();
let grads = linear.forward(x.clone()).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let grads = linear.forward(x).backward();
let grads = GradientsParams::from_grads(grads, &linear);
let linear = optimizer.step(LEARNING_RATE, linear, grads);
let state_updated = linear.into_record();
assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());
}
fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
let device = Default::default();
let record = LinearRecord {
weight: Param::from_data(weight, &device),
bias: Some(Param::from_data(bias, &device)),
};
LinearConfig::new(6, 6).init(&device).load_record(record)
}
fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
let config = AdamConfig::new();
Adam {
momentum: AdaptiveMomentum {
beta_1: config.beta_1,
beta_2: config.beta_2,
epsilon: config.epsilon,
amsgrad: config.amsgrad,
},
weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
}
.into()
}
}