use std::collections::HashMap;
use smol_str::format_smolstr;
use crate::{
Array, Result,
error::{Error, InvariantViolationPayload, NonFiniteScalarPayload, OutOfRangePayload},
lm::{
load::Weights,
tuner::optimizers::base::{LearningRate, Optimizer, zeros_like, zeros_like_map},
},
ops::arithmetic,
};
pub struct SGD {
learning_rate: LearningRate,
momentum: f32,
weight_decay: f32,
dampening: f32,
nesterov: bool,
step_count: usize,
current_lr: f32,
lr_resolved_for_step: Option<usize>,
state: HashMap<String, Array>,
}
impl SGD {
pub fn new(
learning_rate: impl Into<LearningRate>,
momentum: f32,
weight_decay: f32,
dampening: f32,
nesterov: bool,
) -> Result<Self> {
Self::validate_momentum_finite(momentum)?;
Self::validate_weight_decay(weight_decay)?;
Self::validate_dampening(dampening)?;
Self::validate_nesterov(momentum, dampening, nesterov)?;
let lr = learning_rate.into();
let current_lr = lr.try_current(0)?;
Ok(Self {
learning_rate: lr,
momentum,
weight_decay,
dampening,
nesterov,
step_count: 0,
current_lr,
lr_resolved_for_step: Some(0),
state: HashMap::new(),
})
}
pub fn vanilla(learning_rate: impl Into<LearningRate>) -> Result<Self> {
Self::new(learning_rate, 0.0, 0.0, 0.0, false)
}
fn validate_nesterov(momentum: f32, dampening: f32, nesterov: bool) -> Result<()> {
if nesterov && (!momentum.is_finite() || momentum <= 0.0 || dampening != 0.0) {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"SGD: Nesterov momentum",
"requires momentum > 0 (finite) and dampening == 0",
)));
}
Ok(())
}
fn validate_momentum_finite(momentum: f32) -> Result<()> {
if !momentum.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"SGD: momentum",
momentum as f64,
)));
}
Ok(())
}
fn validate_weight_decay(weight_decay: f32) -> Result<()> {
if !weight_decay.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"SGD: weight_decay",
weight_decay as f64,
)));
}
if weight_decay < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"SGD: weight_decay",
"must be >= 0.0",
format_smolstr!("{weight_decay}"),
)));
}
Ok(())
}
fn validate_dampening(dampening: f32) -> Result<()> {
if !dampening.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"SGD: dampening",
dampening as f64,
)));
}
if dampening < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"SGD: dampening",
"must be >= 0.0",
format_smolstr!("{dampening}"),
)));
}
Ok(())
}
#[inline(always)]
pub fn learning_rate_ref(&self) -> &LearningRate {
&self.learning_rate
}
#[inline(always)]
pub fn momentum(&self) -> f32 {
self.momentum
}
#[inline(always)]
pub fn weight_decay(&self) -> f32 {
self.weight_decay
}
#[inline(always)]
pub fn dampening(&self) -> f32 {
self.dampening
}
#[inline(always)]
pub fn nesterov(&self) -> bool {
self.nesterov
}
pub fn with_learning_rate(mut self, learning_rate: impl Into<LearningRate>) -> Result<Self> {
let lr = learning_rate.into();
let current_lr = lr.try_current(self.step_count)?;
self.learning_rate = lr;
self.current_lr = current_lr;
self.lr_resolved_for_step = Some(self.step_count);
Ok(self)
}
pub fn with_momentum(mut self, momentum: f32) -> Result<Self> {
Self::validate_momentum_finite(momentum)?;
Self::validate_nesterov(momentum, self.dampening, self.nesterov)?;
self.momentum = momentum;
Ok(self)
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self> {
Self::validate_weight_decay(weight_decay)?;
self.weight_decay = weight_decay;
Ok(self)
}
pub fn with_dampening(mut self, dampening: f32) -> Result<Self> {
Self::validate_dampening(dampening)?;
Self::validate_nesterov(self.momentum, dampening, self.nesterov)?;
self.dampening = dampening;
Ok(self)
}
pub fn with_nesterov(mut self, nesterov: bool) -> Result<Self> {
Self::validate_nesterov(self.momentum, self.dampening, nesterov)?;
self.nesterov = nesterov;
Ok(self)
}
}
impl Optimizer for SGD {
fn init(&mut self, params: &Weights) -> Result<()> {
self.state = zeros_like_map(params)?;
Ok(())
}
fn preflight(&mut self) -> Result<()> {
if self.lr_resolved_for_step == Some(self.step_count) {
return Ok(()); }
self.current_lr = self.learning_rate.try_current(self.step_count)?;
self.lr_resolved_for_step = Some(self.step_count);
Ok(())
}
fn apply_gradients(&mut self, gradients: &Weights, params: &mut Weights) -> Result<()> {
if self.state.is_empty() {
self.init(gradients)?;
}
self.preflight()?;
self.step_count += 1;
let lr = self.current_lr;
for (key, grad) in gradients {
let Some(param) = params.get(key) else {
continue;
};
let effective_grad = if self.weight_decay != 0.0 {
let wd = Array::full::<f32>(&[0i32; 0], self.weight_decay)?;
let decay_term = arithmetic::multiply(&wd, param)?;
arithmetic::add(grad, &decay_term)?
} else {
grad.try_clone()?
};
if self.momentum <= 0.0 {
let lr_scalar = Array::full::<f32>(&[0i32; 0], lr)?;
let step = arithmetic::multiply(&lr_scalar, &effective_grad)?;
let new_w = arithmetic::subtract(param, &step)?;
params.insert(key.clone(), new_w);
continue;
}
let prev_v = match self.state.get(key) {
Some(v) => v.try_clone()?,
None => zeros_like(param)?,
};
let mu_scalar = Array::full::<f32>(&[0i32; 0], self.momentum)?;
let v_scaled = arithmetic::multiply(&mu_scalar, &prev_v)?;
let v_new = if self.dampening > 0.0 {
let one_minus_damp = Array::full::<f32>(&[0i32; 0], 1.0 - self.dampening)?;
let g_damped = arithmetic::multiply(&one_minus_damp, &effective_grad)?;
arithmetic::add(&v_scaled, &g_damped)?
} else {
arithmetic::add(&v_scaled, &effective_grad)?
};
let update = if self.nesterov {
let mu_v = arithmetic::multiply(&mu_scalar, &v_new)?;
arithmetic::add(&effective_grad, &mu_v)?
} else {
v_new.try_clone()?
};
let lr_scalar = Array::full::<f32>(&[0i32; 0], lr)?;
let step = arithmetic::multiply(&lr_scalar, &update)?;
let new_w = arithmetic::subtract(param, &step)?;
params.insert(key.clone(), new_w);
self.state.insert(key.clone(), v_new);
}
Ok(())
}
fn step(&self) -> usize {
self.step_count
}
fn learning_rate(&self) -> f32 {
self.current_lr
}
}
#[cfg(test)]
mod tests;