use axonml_core;
use axonml_nn::Parameter;
use axonml_tensor::Tensor;
use crate::optimizer::Optimizer;
#[derive(Debug, Clone)]
struct LambState {
exp_avg: Tensor<f32>,
exp_avg_sq: Tensor<f32>,
step: usize,
}
impl LambState {
fn new(shape: &[usize], device: axonml_core::Device) -> Self {
let size: usize = shape.iter().product();
let mut exp_avg =
Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
let mut exp_avg_sq =
Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
if device.is_gpu() {
exp_avg = exp_avg.to_device(device).expect("device transfer failed");
exp_avg_sq = exp_avg_sq
.to_device(device)
.expect("device transfer failed");
}
Self {
exp_avg,
exp_avg_sq,
step: 0,
}
}
}
pub struct LAMB {
params: Vec<Parameter>,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
bias_correction: bool,
state: Vec<LambState>,
}
impl LAMB {
#[must_use]
pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
Self {
params,
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-6,
weight_decay: 0.0,
bias_correction: true,
state: Vec::new(),
}
}
#[must_use]
pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
Self {
params,
lr,
beta1: betas.0,
beta2: betas.1,
eps: 1e-6,
weight_decay: 0.0,
bias_correction: true,
state: Vec::new(),
}
}
#[must_use]
pub fn with_options(
params: Vec<Parameter>,
lr: f32,
betas: (f32, f32),
eps: f32,
weight_decay: f32,
) -> Self {
Self {
params,
lr,
beta1: betas.0,
beta2: betas.1,
eps,
weight_decay,
bias_correction: true,
state: Vec::new(),
}
}
#[must_use]
pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
self.beta1 = beta1;
self.beta2 = beta2;
self
}
#[must_use]
pub fn eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
#[must_use]
pub fn weight_decay(mut self, weight_decay: f32) -> Self {
self.weight_decay = weight_decay;
self
}
#[must_use]
pub fn bias_correction(mut self, enabled: bool) -> Self {
self.bias_correction = enabled;
self
}
fn ensure_state_initialized(&mut self) {
if self.state.is_empty() {
self.state = self
.params
.iter()
.map(|p| {
let data = p.data();
LambState::new(data.shape(), data.device())
})
.collect();
}
}
}
impl Optimizer for LAMB {
fn step(&mut self) {
self.ensure_state_initialized();
for (i, param) in self.params.iter().enumerate() {
if !param.requires_grad() {
continue;
}
let grad = match param.grad() {
Some(g) => g,
None => continue,
};
let state = &mut self.state[i];
state.step += 1;
let param_data = param.data();
state.exp_avg = state
.exp_avg
.mul_scalar(self.beta1)
.add(&grad.mul_scalar(1.0 - self.beta1))
.unwrap();
let grad_sq = grad.mul(&grad).unwrap();
state.exp_avg_sq = state
.exp_avg_sq
.mul_scalar(self.beta2)
.add(&grad_sq.mul_scalar(1.0 - self.beta2))
.unwrap();
let (bias_correction1, bias_correction2) = if self.bias_correction {
(
1.0 - self.beta1.powi(state.step as i32),
1.0 - self.beta2.powi(state.step as i32),
)
} else {
(1.0, 1.0)
};
let m_hat = state.exp_avg.mul_scalar(1.0 / bias_correction1);
let v_hat = state.exp_avg_sq.mul_scalar(1.0 / bias_correction2);
let adam_update = m_hat.div(&v_hat.sqrt().add_scalar(self.eps)).unwrap();
let update = if self.weight_decay > 0.0 {
adam_update
.add(¶m_data.mul_scalar(self.weight_decay))
.unwrap()
} else {
adam_update
};
let weight_norm_sq = param_data.mul(¶m_data).unwrap().sum();
let update_norm_sq = update.mul(&update).unwrap().sum();
let weight_norm = weight_norm_sq.to_vec()[0].sqrt();
let update_norm = update_norm_sq.to_vec()[0].sqrt();
let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
weight_norm / update_norm
} else {
1.0
};
let effective_lr = self.lr * trust_ratio;
let new_param = param_data.sub(&update.mul_scalar(effective_lr)).unwrap();
param.update_data(new_param);
}
}
fn zero_grad(&mut self) {
for param in &self.params {
param.zero_grad();
}
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
fn parameters(&self) -> &[Parameter] {
&self.params
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_autograd::Variable;
#[test]
fn test_lamb_creation() {
let var = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
true,
);
let param = Parameter::from_variable(var);
let optimizer = LAMB::new(vec![param], 0.001);
assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
}
#[test]
fn test_lamb_step() {
let var = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
true,
);
let param = Parameter::from_variable(var);
param
.variable()
.set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
optimizer.step();
let new_data = param.data().to_vec();
assert!((new_data[0] - 1.0).abs() > 1e-6);
}
#[test]
fn test_lamb_with_weight_decay() {
let var = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
true,
);
let param = Parameter::from_variable(var);
param
.variable()
.set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
optimizer.step();
let new_data = param.data().to_vec();
assert!((new_data[0] - 1.0).abs() > 1e-6);
}
#[test]
fn test_lamb_builder_pattern() {
let var = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
true,
);
let param = Parameter::from_variable(var);
let optimizer = LAMB::new(vec![param], 0.001)
.betas(0.95, 0.9999)
.eps(1e-7)
.weight_decay(0.01);
assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
assert!((optimizer.eps - 1e-7).abs() < 1e-9);
assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
}
#[test]
fn test_lamb_trust_ratio() {
let var = Variable::new(
Tensor::from_vec(vec![3.0, 4.0], &[2]).expect("tensor creation failed"),
true,
);
let param = Parameter::from_variable(var);
param
.variable()
.set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).expect("tensor creation failed"));
let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
let old_data = param.data().to_vec();
optimizer.step();
let new_data = param.data().to_vec();
assert!((new_data[0] - old_data[0]).abs() > 1e-6);
assert!((new_data[1] - old_data[1]).abs() > 1e-6);
}
#[test]
fn test_lamb_zero_grad() {
let var = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
true,
);
let param = Parameter::from_variable(var);
param
.variable()
.set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
assert!(param.grad().is_some());
optimizer.zero_grad();
}
#[test]
fn test_l2_norm_via_tensor() {
let t = Tensor::from_vec(vec![3.0f32, 4.0], &[2]).expect("tensor creation failed");
let norm_sq = t.mul(&t).unwrap().sum();
let norm = norm_sq.to_vec()[0].sqrt();
assert!((norm - 5.0).abs() < 1e-6);
}
}