use axonml_nn::Parameter;
use axonml_tensor::Tensor;
use crate::optimizer::Optimizer;
pub struct Adam {
params: Vec<Parameter>,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
amsgrad: bool,
state: Vec<AdamState>,
}
#[derive(Debug, Clone)]
struct AdamState {
exp_avg: Tensor<f32>,
exp_avg_sq: Tensor<f32>,
step: usize,
}
impl AdamState {
fn new(shape: &[usize], device: axonml_core::Device) -> Self {
let size: usize = shape.iter().product();
let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
if device.is_gpu() {
exp_avg = exp_avg.to_device(device).unwrap();
exp_avg_sq = exp_avg_sq.to_device(device).unwrap();
}
Self {
exp_avg,
exp_avg_sq,
step: 0,
}
}
}
impl Adam {
#[must_use]
pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
Self::with_betas(params, lr, (0.9, 0.999))
}
#[must_use]
pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
Self {
params,
lr,
beta1: betas.0,
beta2: betas.1,
eps: 1e-8,
weight_decay: 0.0,
amsgrad: false,
state: Vec::new(),
}
}
#[must_use]
pub fn with_options(
params: Vec<Parameter>,
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
amsgrad: bool,
) -> Self {
Self {
params,
lr,
beta1: betas.0,
beta2: betas.1,
eps,
weight_decay,
amsgrad,
state: Vec::new(),
}
}
#[must_use]
pub fn betas(mut self, betas: (f32, f32)) -> Self {
self.beta1 = betas.0;
self.beta2 = betas.1;
self
}
#[must_use]
pub fn eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
#[must_use]
pub fn weight_decay(mut self, weight_decay: f32) -> Self {
self.weight_decay = weight_decay;
self
}
#[must_use]
pub fn amsgrad(mut self, amsgrad: bool) -> Self {
self.amsgrad = amsgrad;
self
}
fn ensure_state_initialized(&mut self) {
if self.state.is_empty() {
self.state = self
.params
.iter()
.map(|p| {
let data = p.data();
AdamState::new(data.shape(), data.device())
})
.collect();
}
}
}
impl Optimizer for Adam {
fn step(&mut self) {
self.ensure_state_initialized();
for (i, param) in self.params.iter().enumerate() {
if !param.requires_grad() {
continue;
}
let grad = match param.grad() {
Some(g) => g,
None => continue,
};
let state = &mut self.state[i];
state.step += 1;
let param_data = param.data();
#[cfg(feature = "cuda")]
if param_data.device().is_gpu() {
let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
param_data.adam_step_inplace(
&grad,
&state.exp_avg,
&state.exp_avg_sq,
self.lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
bias_correction1,
bias_correction2,
);
continue;
}
let grad_vec = grad.to_vec();
let mut param_vec = param_data.to_vec();
let mut exp_avg_vec = state.exp_avg.to_vec();
let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
let step_size = self.lr / bias_correction1;
let beta1 = self.beta1;
let beta2 = self.beta2;
let one_minus_beta1 = 1.0 - beta1;
let one_minus_beta2 = 1.0 - beta2;
let eps = self.eps;
let wd = self.weight_decay;
for i in 0..param_vec.len() {
let g = if wd == 0.0 {
grad_vec[i]
} else {
grad_vec[i] + wd * param_vec[i]
};
exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
param_vec[i] -= step_size * exp_avg_vec[i] / denom;
}
state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
}
}
fn zero_grad(&mut self) {
for param in &self.params {
param.zero_grad();
}
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
fn parameters(&self) -> &[Parameter] {
&self.params
}
}
pub struct AdamW {
params: Vec<Parameter>,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
amsgrad: bool,
state: Vec<AdamState>,
}
impl AdamW {
#[must_use]
pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
Self::with_betas(params, lr, (0.9, 0.999))
}
#[must_use]
pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
Self {
params,
lr,
beta1: betas.0,
beta2: betas.1,
eps: 1e-8,
weight_decay: 0.01, amsgrad: false,
state: Vec::new(),
}
}
#[must_use]
pub fn with_options(
params: Vec<Parameter>,
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
amsgrad: bool,
) -> Self {
Self {
params,
lr,
beta1: betas.0,
beta2: betas.1,
eps,
weight_decay,
amsgrad,
state: Vec::new(),
}
}
#[must_use]
pub fn betas(mut self, betas: (f32, f32)) -> Self {
self.beta1 = betas.0;
self.beta2 = betas.1;
self
}
#[must_use]
pub fn eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
#[must_use]
pub fn weight_decay(mut self, weight_decay: f32) -> Self {
self.weight_decay = weight_decay;
self
}
#[must_use]
pub fn amsgrad(mut self, amsgrad: bool) -> Self {
self.amsgrad = amsgrad;
self
}
fn ensure_state_initialized(&mut self) {
if self.state.is_empty() {
self.state = self
.params
.iter()
.map(|p| {
let data = p.data();
AdamState::new(data.shape(), data.device())
})
.collect();
}
}
}
impl Optimizer for AdamW {
fn step(&mut self) {
self.ensure_state_initialized();
for (i, param) in self.params.iter().enumerate() {
if !param.requires_grad() {
continue;
}
let grad = match param.grad() {
Some(g) => g,
None => continue,
};
let state = &mut self.state[i];
state.step += 1;
let param_data = param.data();
#[cfg(feature = "cuda")]
if param_data.device().is_gpu() {
let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
param_data.adam_step_inplace(
&grad,
&state.exp_avg,
&state.exp_avg_sq,
self.lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
bias_correction1,
bias_correction2,
);
continue;
}
let grad_vec = grad.to_vec();
let mut param_vec = param_data.to_vec();
let mut exp_avg_vec = state.exp_avg.to_vec();
let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
let step_size = self.lr / bias_correction1;
let beta1 = self.beta1;
let beta2 = self.beta2;
let one_minus_beta1 = 1.0 - beta1;
let one_minus_beta2 = 1.0 - beta2;
let eps = self.eps;
let wd_factor = 1.0 - self.lr * self.weight_decay;
let has_wd = self.weight_decay != 0.0;
for i in 0..param_vec.len() {
if has_wd {
param_vec[i] *= wd_factor;
}
let g = grad_vec[i];
exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
param_vec[i] -= step_size * exp_avg_vec[i] / denom;
}
state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
}
}
fn zero_grad(&mut self) {
for param in &self.params {
param.zero_grad();
}
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
fn parameters(&self) -> &[Parameter] {
&self.params
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_autograd::Variable;
#[test]
fn test_adam_creation() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
let optimizer = Adam::new(vec![param], 0.001);
assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
}
#[test]
fn test_adam_step() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
param
.variable()
.set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
let mut optimizer = Adam::new(vec![param.clone()], 0.1);
optimizer.step();
let new_data = param.data().to_vec();
assert!((new_data[0] - 1.0).abs() > 1e-6);
}
#[test]
fn test_adamw_creation() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
let optimizer = AdamW::new(vec![param], 0.001);
assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
}
#[test]
fn test_adam_builder_pattern() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
let optimizer = Adam::new(vec![param], 0.001)
.betas((0.95, 0.9999))
.eps(1e-7)
.weight_decay(0.01)
.amsgrad(true);
assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
assert!((optimizer.eps - 1e-7).abs() < 1e-9);
assert!(optimizer.amsgrad);
}
}