use crate::error::Result;
use crate::ops::FusedOptimizerOps;
use crate::optimizer::traits::Optimizer;
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SgdConfig {
pub lr: f64,
pub momentum: f64,
pub weight_decay: f64,
pub dampening: f64,
pub nesterov: bool,
}
impl Default for SgdConfig {
fn default() -> Self {
Self {
lr: 0.01,
momentum: 0.0,
weight_decay: 0.0,
dampening: 0.0,
nesterov: false,
}
}
}
pub struct Sgd<R: Runtime> {
config: SgdConfig,
velocity: HashMap<TensorId, Tensor<R>>,
}
impl<R: Runtime<DType = DType>> Sgd<R> {
pub fn new(config: SgdConfig) -> Self {
Self {
config,
velocity: HashMap::new(),
}
}
pub fn config(&self) -> &SgdConfig {
&self.config
}
}
impl<R: Runtime<DType = DType>> Optimizer<R> for Sgd<R> {
fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
let lr = self.config.lr;
let momentum = self.config.momentum;
let wd = self.config.weight_decay;
let dampening = self.config.dampening;
let nesterov = self.config.nesterov;
let param_ids: Vec<TensorId> = params.keys().copied().collect();
for id in param_ids {
let grad = match grads.get(id) {
Some(g) => g,
None => continue,
};
let param = params.get(&id).expect("id collected from params.keys()");
let momentum_buf = self.velocity.get(&id);
let (new_param, new_buf) = client.fused_sgd_step(
param,
grad,
momentum_buf,
lr,
momentum,
dampening,
wd,
nesterov,
)?;
if momentum > 0.0 {
self.velocity.insert(id, new_buf);
}
params.insert(id, new_param);
}
Ok(())
}
fn set_lr(&mut self, lr: f64) {
self.config.lr = lr;
}
fn lr(&self) -> f64 {
self.config.lr
}
fn reset(&mut self) {
self.velocity.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::autograd::{Var, backward, var_mean, var_mul, var_sub};
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_sgd_default_config() {
let config = SgdConfig::default();
assert_eq!(config.lr, 0.01);
assert_eq!(config.momentum, 0.0);
assert_eq!(config.weight_decay, 0.0);
assert_eq!(config.dampening, 0.0);
assert!(!config.nesterov);
}
#[test]
fn test_sgd_vanilla_step() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device);
let w_id = w_tensor.id();
let grad = Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2, 0.3, 0.4], &[2, 2], &device);
let mut grads = GradStore::new();
grads.insert(w_id, grad);
let mut params = HashMap::new();
params.insert(w_id, w_tensor);
let config = SgdConfig {
lr: 0.1,
..Default::default()
};
let mut opt = Sgd::<CpuRuntime>::new(config);
opt.step(&client, &mut params, &grads).unwrap();
let updated = params.get(&w_id).unwrap().to_vec::<f32>();
assert!((updated[0] - 0.99).abs() < 1e-6);
assert!((updated[1] - 1.98).abs() < 1e-6);
assert!((updated[2] - 2.97).abs() < 1e-6);
assert!((updated[3] - 3.96).abs() < 1e-6);
}
#[test]
fn test_sgd_momentum_converges() {
let (client, device) = cpu_setup();
let target = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device);
let w_init = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[2, 2], &device);
let w_id = w_init.id();
let mut params = HashMap::new();
params.insert(w_id, w_init);
let config = SgdConfig {
lr: 0.1,
momentum: 0.9,
..Default::default()
};
let mut opt = Sgd::<CpuRuntime>::new(config);
let mut first_loss = 0.0f64;
let mut last_loss = 0.0f64;
for i in 0..50 {
let w_tensor = params.get(&w_id).unwrap().clone();
let w = Var::with_id(w_tensor, w_id, true);
let t = Var::new(target.clone(), false);
let diff = var_sub(&w, &t, &client).unwrap();
let sq = var_mul(&diff, &diff, &client).unwrap();
let loss = var_mean(&sq, &[0, 1], false, &client).unwrap();
let loss_val = loss.tensor().to_vec::<f32>()[0] as f64;
if i == 0 {
first_loss = loss_val;
}
last_loss = loss_val;
let grads = backward(&loss, &client).unwrap();
opt.step(&client, &mut params, &grads).unwrap();
}
assert!(
last_loss < first_loss * 0.01,
"loss should decrease: first={first_loss} last={last_loss}"
);
}
#[test]
fn test_sgd_nesterov() {
let (client, device) = cpu_setup();
let target = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device);
let w_init = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[2, 2], &device);
let w_id = w_init.id();
let mut params = HashMap::new();
params.insert(w_id, w_init);
let config = SgdConfig {
lr: 0.1,
momentum: 0.9,
nesterov: true,
..Default::default()
};
let mut opt = Sgd::<CpuRuntime>::new(config);
let mut first_loss = 0.0f64;
let mut last_loss = 0.0f64;
for i in 0..50 {
let w_tensor = params.get(&w_id).unwrap().clone();
let w = Var::with_id(w_tensor, w_id, true);
let t = Var::new(target.clone(), false);
let diff = var_sub(&w, &t, &client).unwrap();
let sq = var_mul(&diff, &diff, &client).unwrap();
let loss = var_mean(&sq, &[0, 1], false, &client).unwrap();
let loss_val = loss.tensor().to_vec::<f32>()[0] as f64;
if i == 0 {
first_loss = loss_val;
}
last_loss = loss_val;
let grads = backward(&loss, &client).unwrap();
opt.step(&client, &mut params, &grads).unwrap();
}
assert!(
last_loss < first_loss * 0.01,
"nesterov should converge: first={first_loss} last={last_loss}"
);
}
#[test]
fn test_sgd_weight_decay() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[5.0f32, 5.0], &[2], &device);
let w_id = w_tensor.id();
let zero_grad = Tensor::<CpuRuntime>::zeros(&[2], DType::F32, &device);
let mut grads = GradStore::new();
grads.insert(w_id, zero_grad);
let mut params = HashMap::new();
params.insert(w_id, w_tensor);
let config = SgdConfig {
lr: 0.1,
weight_decay: 0.1,
..Default::default()
};
let mut opt = Sgd::<CpuRuntime>::new(config);
opt.step(&client, &mut params, &grads).unwrap();
let updated = params.get(&w_id).unwrap().to_vec::<f32>();
assert!(
(updated[0] - 4.95).abs() < 1e-5,
"weight decay: got {}",
updated[0]
);
}
#[test]
fn test_sgd_skips_missing_grads() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let w_id = w_tensor.id();
let mut params = HashMap::new();
params.insert(w_id, w_tensor);
let grads = GradStore::new();
let mut opt = Sgd::<CpuRuntime>::new(SgdConfig::default());
opt.step(&client, &mut params, &grads).unwrap();
let updated = params.get(&w_id).unwrap().to_vec::<f32>();
assert_eq!(updated, vec![1.0, 2.0]);
}
#[test]
fn test_sgd_reset() {
let mut opt = Sgd::<CpuRuntime>::new(SgdConfig {
momentum: 0.9,
..Default::default()
});
opt.reset();
assert!(opt.velocity.is_empty());
}
#[test]
fn test_sgd_set_lr() {
let mut opt = Sgd::<CpuRuntime>::new(SgdConfig::default());
opt.set_lr(0.05);
assert_eq!(opt.lr(), 0.05);
}
}