use crate::error::Result;
use crate::ops::FusedOptimizerOps;
use crate::optimizer::traits::Optimizer;
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AdamWConfig {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
}
impl Default for AdamWConfig {
fn default() -> Self {
Self {
lr: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
}
}
}
struct ParamState<R: Runtime> {
m: Tensor<R>,
v: Tensor<R>,
}
pub struct AdamW<R: Runtime> {
config: AdamWConfig,
state: HashMap<TensorId, ParamState<R>>,
timestep: u64,
}
impl<R: Runtime<DType = DType>> AdamW<R> {
pub fn new(config: AdamWConfig) -> Self {
Self {
config,
state: HashMap::new(),
timestep: 0,
}
}
#[allow(clippy::type_complexity)]
pub fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
self.timestep += 1;
let t = self.timestep;
let beta1 = self.config.beta1;
let beta2 = self.config.beta2;
let lr = self.config.lr;
let eps = self.config.eps;
let wd = self.config.weight_decay;
let bc1 = 1.0 - beta1.powi(t as i32);
let bc2 = 1.0 - beta2.powi(t as i32);
let step_size = lr * bc2.sqrt() / bc1;
let mut ids_with_grads: Vec<TensorId> = Vec::new();
for &id in params.keys() {
if grads.get(id).is_some() {
let param = params.get(&id).expect("iterating params.keys()");
self.state.entry(id).or_insert_with(|| {
let m = Tensor::<R>::zeros(param.shape(), param.dtype(), param.device());
let v = Tensor::<R>::zeros(param.shape(), param.dtype(), param.device());
ParamState { m, v }
});
ids_with_grads.push(id);
}
}
if ids_with_grads.is_empty() {
return Ok(());
}
let groups: Vec<(&Tensor<R>, &Tensor<R>, &Tensor<R>, &Tensor<R>)> = ids_with_grads
.iter()
.map(|id| {
let param = params
.get(id)
.expect("id came from ids_with_grads, which was built from params.keys()");
let grad = grads.get(*id).expect(
"id came from ids_with_grads, which was filtered to only ids with grads",
);
let state = self
.state
.get(id)
.expect("state was lazily initialized for all ids in ids_with_grads");
(param, grad, &state.m, &state.v)
})
.collect();
let results =
client.fused_multi_tensor_adamw(&groups, lr, beta1, beta2, eps, wd, step_size)?;
for (id, (new_param, new_m, new_v)) in ids_with_grads.iter().zip(results) {
let state_mut = self
.state
.get_mut(id)
.expect("state was initialized for all ids in ids_with_grads");
state_mut.m = new_m;
state_mut.v = new_v;
params.insert(*id, new_param);
}
Ok(())
}
pub fn timestep(&self) -> u64 {
self.timestep
}
pub fn config(&self) -> &AdamWConfig {
&self.config
}
pub fn reset(&mut self) {
self.state.clear();
self.timestep = 0;
}
pub fn set_lr(&mut self, lr: f64) {
self.config.lr = lr;
}
}
impl<R: Runtime<DType = DType>> Optimizer<R> for AdamW<R> {
fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
AdamW::step(self, client, params, grads)
}
fn set_lr(&mut self, lr: f64) {
AdamW::set_lr(self, lr);
}
fn lr(&self) -> f64 {
self.config.lr
}
fn reset(&mut self) {
AdamW::reset(self);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::autograd::{Var, backward};
use numr::autograd::{var_matmul, var_mean};
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_adamw_default_config() {
let config = AdamWConfig::default();
assert_eq!(config.lr, 1e-3);
assert_eq!(config.beta1, 0.9);
assert_eq!(config.beta2, 0.999);
assert_eq!(config.eps, 1e-8);
assert_eq!(config.weight_decay, 0.01);
}
#[test]
fn test_adamw_single_step() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device);
let w = Var::new(w_tensor, true);
let w_id = w.id();
let x = Tensor::<CpuRuntime>::from_slice(&[0.5f32, 0.5, 0.5, 0.5], &[2, 2], &device);
let x_var = Var::new(x, false);
let out = var_matmul(&w, &x_var, &client).unwrap();
let loss = var_mean(&out, &[0, 1], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let mut params = HashMap::new();
let w_data = w.tensor().clone();
params.insert(w_id, w_data);
let config = AdamWConfig::default();
let mut opt = AdamW::<CpuRuntime>::new(config);
opt.step(&client, &mut params, &grads).unwrap();
assert_eq!(opt.timestep(), 1);
let updated = params.get(&w_id).unwrap();
let updated_data = updated.to_vec::<f32>();
let original = vec![1.0f32, 2.0, 3.0, 4.0];
assert_ne!(updated_data, original, "params should change after step");
}
#[test]
fn test_adamw_multiple_steps_decrease_loss() {
let (client, device) = cpu_setup();
let target = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device);
let w_init = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[2, 2], &device);
let w_id = w_init.id();
let mut params = HashMap::new();
params.insert(w_id, w_init);
let config = AdamWConfig {
lr: 0.1,
weight_decay: 0.0, ..Default::default()
};
let mut opt = AdamW::<CpuRuntime>::new(config);
let mut first_loss = 0.0f64;
let mut last_loss = 0.0f64;
for i in 0..20 {
let w_tensor = params.get(&w_id).unwrap().clone();
let w = Var::with_id(w_tensor, w_id, true);
let t = Var::new(target.clone(), false);
let diff = numr::autograd::var_sub(&w, &t, &client).unwrap();
let sq = numr::autograd::var_mul(&diff, &diff, &client).unwrap();
let loss = var_mean(&sq, &[0, 1], false, &client).unwrap();
let loss_val = loss.tensor().to_vec::<f32>()[0] as f64;
if i == 0 {
first_loss = loss_val;
}
last_loss = loss_val;
let grads = backward(&loss, &client).unwrap();
opt.step(&client, &mut params, &grads).unwrap();
}
assert!(
last_loss < first_loss * 0.1,
"loss should decrease significantly: first={} last={}",
first_loss,
last_loss
);
let final_w = params.get(&w_id).unwrap().to_vec::<f32>();
assert!(
(final_w[0] - 1.0).abs() < 0.3,
"w[0] should approach 1.0, got {}",
final_w[0]
);
}
#[test]
fn test_adamw_weight_decay() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[5.0f32, 5.0], &[2], &device);
let w_id = w_tensor.id();
let mut grads = GradStore::new();
let zero_grad = Tensor::<CpuRuntime>::zeros(&[2], DType::F32, &device);
grads.insert(w_id, zero_grad);
let mut params = HashMap::new();
params.insert(w_id, w_tensor);
let config = AdamWConfig {
lr: 0.1,
weight_decay: 0.1,
..Default::default()
};
let mut opt = AdamW::<CpuRuntime>::new(config);
opt.step(&client, &mut params, &grads).unwrap();
let updated = params.get(&w_id).unwrap().to_vec::<f32>();
assert!(
updated[0] < 5.0,
"weight decay should shrink params, got {}",
updated[0]
);
}
#[test]
fn test_adamw_reset() {
let opt: AdamW<CpuRuntime> = AdamW::new(AdamWConfig::default());
assert_eq!(opt.timestep(), 0);
}
#[test]
fn test_adamw_set_lr() {
let mut opt: AdamW<CpuRuntime> = AdamW::new(AdamWConfig::default());
opt.set_lr(0.01);
assert_eq!(opt.config().lr, 0.01);
}
#[test]
fn test_adamw_skips_missing_grads() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let w_id = w_tensor.id();
let mut params = HashMap::new();
params.insert(w_id, w_tensor.clone());
let grads = GradStore::new();
let mut opt = AdamW::<CpuRuntime>::new(AdamWConfig::default());
opt.step(&client, &mut params, &grads).unwrap();
let updated = params.get(&w_id).unwrap().to_vec::<f32>();
assert_eq!(updated, vec![1.0, 2.0]);
}
}