use crate::optim::Optimizer;
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::ops::Add;
pub struct AdamW {
learning_rate: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
amsgrad: bool,
step: usize,
state: HashMap<usize, AdamWState>,
}
struct AdamWState {
momentum: Tensor<f32>, velocity: Tensor<f32>, max_velocity: Option<Tensor<f32>>, }
impl AdamW {
pub fn new(
learning_rate: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
amsgrad: bool,
) -> Self {
assert!(learning_rate > 0.0, "Learning rate must be positive");
assert!(beta1 >= 0.0 && beta1 < 1.0, "Beta1 must be in [0, 1)");
assert!(beta2 >= 0.0 && beta2 < 1.0, "Beta2 must be in [0, 1)");
assert!(eps > 0.0, "Epsilon must be positive");
assert!(weight_decay >= 0.0, "Weight decay must be non-negative");
Self {
learning_rate,
beta1,
beta2,
eps,
weight_decay,
amsgrad,
step: 0,
state: HashMap::new(),
}
}
pub fn default_params(learning_rate: f32) -> Self {
Self::new(learning_rate, 0.9, 0.999, 1e-8, 0.01, false)
}
pub fn with_weight_decay(learning_rate: f32, weight_decay: f32) -> Self {
Self::new(learning_rate, 0.9, 0.999, 1e-8, weight_decay, false)
}
pub fn with_amsgrad(learning_rate: f32, weight_decay: f32) -> Self {
Self::new(learning_rate, 0.9, 0.999, 1e-8, weight_decay, true)
}
pub fn get_state(&self, param_id: usize) -> Option<(&Tensor<f32>, &Tensor<f32>)> {
self.state
.get(¶m_id)
.map(|s| (&s.momentum, &s.velocity))
}
fn get_or_create_state(&mut self, param_id: usize, shape: &[usize]) -> &mut AdamWState {
self.state.entry(param_id).or_insert_with(|| AdamWState {
momentum: Tensor::zeros(shape),
velocity: Tensor::zeros(shape),
max_velocity: if self.amsgrad {
Some(Tensor::zeros(shape))
} else {
None
},
})
}
}
impl Optimizer for AdamW {
fn step(&mut self, param: &Tensor<f32>, grad: &Tensor<f32>) {
self.step += 1;
let param_id = param as *const _ as usize;
let beta1 = self.beta1;
let beta2 = self.beta2;
let eps = self.eps;
let weight_decay = self.weight_decay;
let learning_rate = self.learning_rate;
let amsgrad = self.amsgrad;
let step = self.step;
let state = self.get_or_create_state(param_id, param.shape());
state.momentum = (&(&state.momentum * beta1)) + (&(grad * (1.0 - beta1)));
let grad_squared = grad * grad;
state.velocity = (&(&state.velocity * beta2)) + (&(&grad_squared * (1.0 - beta2)));
let bias_correction1 = 1.0 - beta1.powi(step as i32);
let bias_correction2 = 1.0 - beta2.powi(step as i32);
let momentum_corrected = &state.momentum / bias_correction1;
let velocity_corrected = if amsgrad {
if let Some(ref mut max_vel) = state.max_velocity {
let current_velocity = &state.velocity / bias_correction2;
let max_data = max_vel.as_slice().unwrap();
let curr_data = current_velocity.as_slice().unwrap();
let mut new_max = Vec::with_capacity(max_data.len());
for (m, c) in max_data.iter().zip(curr_data.iter()) {
new_max.push(m.max(*c));
}
*max_vel = Tensor::from_vec(new_max, max_vel.shape().to_vec());
max_vel.clone()
} else {
&state.velocity / bias_correction2
}
} else {
&state.velocity / bias_correction2
};
let velocity_sqrt = velocity_corrected.sqrt();
let update = &momentum_corrected / &(&velocity_sqrt + eps);
let param_with_decay = if weight_decay > 0.0 {
param * (1.0 - learning_rate * weight_decay)
} else {
param.clone()
};
let new_param = ¶m_with_decay - &(&update * learning_rate);
unsafe {
std::ptr::copy_nonoverlapping(
new_param.as_slice().unwrap().as_ptr(),
param.as_slice().unwrap().as_ptr() as *mut f32,
param.as_slice().unwrap().len(),
);
}
}
fn zero_grad(&mut self) {
}
fn learning_rate(&self) -> f32 {
self.learning_rate
}
fn set_learning_rate(&mut self, lr: f32) {
assert!(lr > 0.0, "Learning rate must be positive");
self.learning_rate = lr;
}
fn state_dict(&self) -> HashMap<String, f32> {
let mut state = HashMap::new();
state.insert("learning_rate".to_string(), self.learning_rate);
state.insert("beta1".to_string(), self.beta1);
state.insert("beta2".to_string(), self.beta2);
state.insert("eps".to_string(), self.eps);
state.insert("weight_decay".to_string(), self.weight_decay);
state.insert("amsgrad".to_string(), if self.amsgrad { 1.0 } else { 0.0 });
state.insert("step".to_string(), self.step as f32);
state
}
fn load_state_dict(&mut self, state: HashMap<String, f32>) {
if let Some(&lr) = state.get("learning_rate") {
self.learning_rate = lr;
}
if let Some(&beta1) = state.get("beta1") {
self.beta1 = beta1;
}
if let Some(&beta2) = state.get("beta2") {
self.beta2 = beta2;
}
if let Some(&eps) = state.get("eps") {
self.eps = eps;
}
if let Some(&wd) = state.get("weight_decay") {
self.weight_decay = wd;
}
if let Some(&amsgrad) = state.get("amsgrad") {
self.amsgrad = amsgrad > 0.5;
}
if let Some(&step) = state.get("step") {
self.step = step as usize;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adamw_creation() {
let optimizer = AdamW::default_params(0.001);
assert_eq!(optimizer.learning_rate, 0.001);
assert_eq!(optimizer.weight_decay, 0.01);
}
#[test]
fn test_adamw_with_weight_decay() {
let optimizer = AdamW::with_weight_decay(0.001, 0.1);
assert_eq!(optimizer.weight_decay, 0.1);
}
#[test]
fn test_adamw_step() {
let mut optimizer = AdamW::default_params(0.1);
let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let grad = Tensor::from_vec(vec![0.1, 0.2, 0.3], vec![3]);
optimizer.step(¶m, &grad);
let param_data = param.as_slice().unwrap();
assert!(param_data[0] < 1.0); assert!(param_data[1] < 2.0);
assert!(param_data[2] < 3.0);
}
#[test]
fn test_adamw_amsgrad() {
let mut optimizer = AdamW::with_amsgrad(0.001, 0.01);
assert!(optimizer.amsgrad);
let param = Tensor::randn(&[10]);
let grad = Tensor::randn(&[10]);
optimizer.step(¶m, &grad);
let param_id = ¶m as *const _ as usize;
assert!(optimizer
.state
.get(¶m_id)
.unwrap()
.max_velocity
.is_some());
}
}