use std::collections::HashMap;
use crate::Optimizer;
use crate::common::zeros_entry;
#[derive(Debug, Clone)]
pub struct Sgd {
pub lr: f32,
pub momentum: f32,
pub nesterov: bool,
pub weight_decay: f32,
v: HashMap<String, Vec<f32>>,
}
impl Sgd {
pub fn new(lr: f32) -> Self {
Self {
lr,
momentum: 0.0,
nesterov: false,
weight_decay: 0.0,
v: HashMap::new(),
}
}
pub fn with_momentum(mut self, momentum: f32, nesterov: bool) -> Self {
self.momentum = momentum;
self.nesterov = nesterov;
self
}
pub fn with_weight_decay(mut self, wd: f32) -> Self {
self.weight_decay = wd;
self
}
}
impl Optimizer for Sgd {
fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
debug_assert_eq!(param.len(), grad.len());
let v = zeros_entry(&mut self.v, name, param.len());
let mu = self.momentum;
let wd = self.weight_decay;
let lr = self.lr;
for i in 0..param.len() {
let g = grad[i] + wd * param[i];
if mu == 0.0 {
param[i] -= lr * g;
} else {
v[i] = mu * v[i] + g;
let update = if self.nesterov { g + mu * v[i] } else { v[i] };
param[i] -= lr * update;
}
}
}
}