use scivex_core::{Float, Tensor};
use crate::variable::Variable;
use super::Optimizer;
pub struct SGD<T: Float> {
params: Vec<Variable<T>>,
lr: T,
momentum: T,
weight_decay: T,
velocities: Vec<Option<Tensor<T>>>,
}
impl<T: Float> SGD<T> {
pub fn new(params: Vec<Variable<T>>, lr: T) -> Self {
let n = params.len();
Self {
params,
lr,
momentum: T::zero(),
weight_decay: T::zero(),
velocities: vec![None; n],
}
}
pub fn with_momentum(mut self, momentum: T) -> Self {
self.momentum = momentum;
self
}
pub fn with_weight_decay(mut self, wd: T) -> Self {
self.weight_decay = wd;
self
}
}
impl<T: Float> Optimizer<T> for SGD<T> {
fn step(&mut self) {
let lr = self.lr;
let momentum = self.momentum;
let wd = self.weight_decay;
for (i, param) in self.params.iter().enumerate() {
let Some(grad) = param.grad() else {
continue;
};
let data = param.data();
let mut grad_with_wd = if wd > T::zero() {
grad.zip_map(&data, |g, p| g + wd * p)
.expect("grad and param shapes match")
} else {
grad
};
if momentum > T::zero() {
let v = match self.velocities[i].take() {
Some(prev_v) => {
prev_v
.zip_map(&grad_with_wd, |vi, gi| momentum * vi + gi)
.expect("velocity and grad shapes match")
}
None => grad_with_wd.clone(),
};
grad_with_wd = v.clone();
self.velocities[i] = Some(v);
}
let new_data = data
.zip_map(&grad_with_wd, |p, g| p - lr * g)
.expect("param and grad shapes match");
update_param_data(param, new_data);
}
}
fn zero_grad(&mut self) {
for param in &self.params {
param.zero_grad();
}
}
}
pub(crate) fn update_param_data<T: Float>(param: &Variable<T>, new_data: Tensor<T>) {
param.set_data(new_data);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loss;
use scivex_core::Tensor;
#[test]
fn test_sgd_reduces_loss() {
let x = Variable::new(Tensor::from_vec(vec![0.0_f64], vec![1]).unwrap(), true);
let target = Variable::new(Tensor::from_vec(vec![3.0], vec![1]).unwrap(), false);
let mut optim = SGD::new(vec![x.clone()], 0.1);
let initial_loss;
{
let l = loss::mse_loss(&x, &target).unwrap();
initial_loss = l.data().as_slice()[0];
}
for _ in 0..10 {
optim.zero_grad();
let l = loss::mse_loss(&x, &target).unwrap();
l.backward();
optim.step();
}
let final_loss = loss::mse_loss(&x, &target).unwrap().data().as_slice()[0];
assert!(final_loss < initial_loss, "loss did not decrease");
}
}