use candle_core::{Result, Tensor, Device};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AdamWConfig {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
}
impl Default for AdamWConfig {
fn default() -> Self {
Self {
lr: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
}
}
}
#[derive(Debug, Clone)]
struct ParamState {
m: Tensor,
v: Tensor,
step: usize,
}
pub struct AdamW {
config: AdamWConfig,
params: Vec<Tensor>,
states: HashMap<usize, ParamState>,
}
impl AdamW {
pub fn new(params: Vec<Tensor>, config: AdamWConfig) -> Result<Self> {
Ok(Self {
config,
params,
states: HashMap::new(),
})
}
pub fn step(&mut self, grads: &[Tensor]) -> Result<()> {
if grads.len() != self.params.len() {
return Err(candle_core::Error::Msg(format!(
"Expected {} gradients, got {}",
self.params.len(),
grads.len()
)));
}
for (i, (param, grad)) in self.params.iter_mut().zip(grads.iter()).enumerate() {
let state = self.states.entry(i).or_insert_with(|| {
let device = param.device();
let shape = param.shape();
ParamState {
m: Tensor::zeros(shape, param.dtype(), device).unwrap(),
v: Tensor::zeros(shape, param.dtype(), device).unwrap(),
step: 0,
}
});
state.step += 1;
state.m = ((state.m.clone() * self.config.beta1)?
+ (grad * (1.0 - self.config.beta1))?)?;
state.v = ((state.v.clone() * self.config.beta2)?
+ (grad.sqr()? * (1.0 - self.config.beta2))?)?;
let beta1_t = self.config.beta1.powi(state.step as i32);
let m_hat = (state.m.clone() / (1.0 - beta1_t))?;
let beta2_t = self.config.beta2.powi(state.step as i32);
let v_hat = (state.v.clone() / (1.0 - beta2_t))?;
let update = ((m_hat / (v_hat.sqrt()? + self.config.eps)?)? * self.config.lr)?;
let param_decayed = if self.config.weight_decay > 0.0 {
(param.clone() * (1.0 - self.config.lr * self.config.weight_decay))?
} else {
param.clone()
};
*param = (param_decayed - update)?;
}
Ok(())
}
pub fn zero_grad(&mut self) {
}
pub fn get_lr(&self) -> f64 {
self.config.lr
}
pub fn set_lr(&mut self, lr: f64) {
self.config.lr = lr;
}
pub fn params(&self) -> &[Tensor] {
&self.params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adamw_creation() -> Result<()> {
let device = Device::Cpu;
let param = Tensor::randn(0f32, 1.0, (10, 10), &device)?;
let config = AdamWConfig::default();
let optimizer = AdamW::new(vec![param], config)?;
assert_eq!(optimizer.get_lr(), 1e-3);
Ok(())
}
#[test]
fn test_adamw_step() -> Result<()> {
let device = Device::Cpu;
let param = Tensor::randn(0f32, 1.0, (10, 10), &device)?;
let grad = Tensor::ones((10, 10), param.dtype(), &device)?;
let config = AdamWConfig {
lr: 0.01,
..Default::default()
};
let mut optimizer = AdamW::new(vec![param.clone()], config)?;
optimizer.step(&[grad])?;
Ok(())
}
#[test]
fn test_adamw_lr_scheduling() -> Result<()> {
let device = Device::Cpu;
let param = Tensor::randn(0f32, 1.0, (10, 10), &device)?;
let config = AdamWConfig::default();
let mut optimizer = AdamW::new(vec![param], config)?;
assert_eq!(optimizer.get_lr(), 1e-3);
optimizer.set_lr(5e-4);
assert_eq!(optimizer.get_lr(), 5e-4);
Ok(())
}
}