use candle_core::{Result, Var};
use candle_nn::optim::Optimizer;
use crate::{Decay, Momentum, OptimParams};
#[derive(Debug)]
pub struct SGD {
vars: Vec<VarSGD>,
params: ParamsSGD,
}
#[derive(Debug)]
struct VarSGD {
theta: Var,
b: Option<Var>,
}
#[derive(Clone, Debug, PartialEq, PartialOrd)]
pub struct ParamsSGD {
pub lr: f64,
pub weight_decay: Option<Decay>,
pub momentum: Option<Momentum>,
pub dampening: f64,
}
impl Default for ParamsSGD {
fn default() -> Self {
Self {
lr: 0.1,
weight_decay: None,
momentum: None, dampening: 0.0,
}
}
}
impl Optimizer for SGD {
type Config = ParamsSGD;
fn new(vars: Vec<Var>, params: ParamsSGD) -> Result<Self> {
let vars = vars
.into_iter()
.filter(|var| var.dtype().is_float())
.map(|var| VarSGD {
theta: var,
b: None,
})
.collect::<Vec<VarSGD>>();
Ok(Self { vars, params })
}
fn learning_rate(&self) -> f64 {
self.params.lr
}
#[allow(clippy::too_many_lines)]
fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
if let Some(momentum) = self.params.momentum {
match momentum {
Momentum::Classical(momentum) => {
if let Some(decay) = self.params.weight_decay {
match decay {
Decay::WeightDecay(decay) => {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
let grad = &(grad + (decay * theta.as_tensor())?)?;
if let Some(prev_step) = &(var.b) {
let bt = ((prev_step.as_tensor() * momentum)?
+ (1. - self.params.dampening) * (grad))?;
theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
prev_step.set(&bt)?;
} else {
let bt = grad.clone();
theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
var.b = Some(Var::from_tensor(&bt)?);
}
}
}
}
Decay::DecoupledWeightDecay(decay) => {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
theta.set(
&(theta.as_tensor()
* self.params.lr.mul_add(-decay, 1.))?,
)?;
if let Some(prev_step) = &(var.b) {
let bt = ((prev_step.as_tensor() * momentum)?
+ (1. - self.params.dampening) * (grad))?;
theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
prev_step.set(&bt)?;
} else {
let bt = grad.clone();
theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
var.b = Some(Var::from_tensor(&bt)?);
}
}
}
}
}
} else {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
if let Some(prev_step) = &(var.b) {
let bt = ((prev_step.as_tensor() * momentum)?
+ (1. - self.params.dampening) * (grad))?;
theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
prev_step.set(&bt)?;
} else {
let bt = grad.clone();
theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?;
var.b = Some(Var::from_tensor(&bt)?);
}
}
}
}
}
Momentum::Nesterov(momentum) => {
if let Some(decay) = self.params.weight_decay {
match decay {
Decay::WeightDecay(decay) => {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
let grad = &(grad + (decay * theta.as_tensor())?)?;
if let Some(prev_step) = &(var.b) {
let bt = ((prev_step.as_tensor() * momentum)?
+ (1. - self.params.dampening) * (grad))?;
let gt = (grad + (momentum * &bt)?)?;
prev_step.set(&bt)?;
theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
} else {
let bt = grad.clone();
let gt = (grad + (momentum * &bt)?)?;
var.b = Some(Var::from_tensor(&bt)?);
theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
}
}
}
}
Decay::DecoupledWeightDecay(decay) => {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
theta.set(
&(theta.as_tensor()
* self.params.lr.mul_add(-decay, 1.))?,
)?;
if let Some(prev_step) = &(var.b) {
let bt = ((prev_step.as_tensor() * momentum)?
+ (1. - self.params.dampening) * (grad))?;
let gt = (grad + (momentum * &bt)?)?;
prev_step.set(&bt)?;
theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
} else {
let bt = grad.clone();
let gt = (grad + (momentum * &bt)?)?;
var.b = Some(Var::from_tensor(&bt)?);
theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
}
}
}
}
}
} else {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
if let Some(prev_step) = &(var.b) {
let bt = ((prev_step.as_tensor() * momentum)?
+ (1. - self.params.dampening) * (grad))?;
let gt = (grad + (momentum * &bt)?)?;
prev_step.set(&bt)?;
theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
} else {
let bt = grad.clone();
let gt = (grad + (momentum * &bt)?)?;
var.b = Some(Var::from_tensor(&bt)?);
theta.set(&theta.sub(&(gt * self.params.lr)?)?)?;
}
}
}
}
}
}
} else if let Some(decay) = self.params.weight_decay {
match decay {
Decay::WeightDecay(decay) => {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
let grad = &(grad + (decay * theta.as_tensor())?)?; theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; }
}
}
Decay::DecoupledWeightDecay(decay) => {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
theta
.set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; }
}
}
}
} else {
for var in &mut self.vars {
let theta = &var.theta;
if let Some(grad) = grads.get(theta) {
theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; }
}
}
Ok(())
}
fn set_learning_rate(&mut self, lr: f64) {
self.params.lr = lr;
}
}
impl OptimParams for SGD {
fn params(&self) -> &Self::Config {
&self.params
}
fn set_params(&mut self, config: Self::Config) {
self.params = config;
}
}
impl SGD {
#[must_use]
pub fn into_inner(self) -> Vec<Var> {
self.vars.into_iter().map(|v| v.theta).collect()
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use assert_approx_eq::assert_approx_eq;
use candle_core::{Device, Var};
use candle_nn::Optimizer;
use super::*;
#[test]
fn lr_test() -> Result<()> {
let params = ParamsSGD {
lr: 0.004,
..Default::default()
};
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let mut optim = SGD::new(vec![w.clone(), b.clone()], params)?;
assert_approx_eq!(0.004, optim.learning_rate());
optim.set_learning_rate(0.002);
assert_approx_eq!(0.002, optim.learning_rate());
Ok(())
}
#[test]
fn into_inner_test() -> Result<()> {
let params = ParamsSGD::default();
let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
let b = Var::new(-2f32, &Device::Cpu)?;
let optim = SGD::new(vec![w.clone(), b.clone()], params)?;
let inner = optim.into_inner();
assert_eq!(inner[0].as_tensor().to_vec2::<f32>()?, &[[3f32, 1.]]);
assert_approx_eq!(inner[1].as_tensor().to_vec0::<f32>()?, -2_f32);
Ok(())
}
#[test]
fn params_test() -> Result<()> {
let params = ParamsSGD {
lr: 0.004,
..Default::default()
};
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let mut optim = SGD::new(vec![w.clone(), b.clone()], params.clone())?;
assert_eq!(params, optim.params().clone());
let new_params = ParamsSGD {
lr: 0.002,
..Default::default()
};
optim.set_params(new_params.clone());
assert_eq!(new_params, optim.params().clone());
Ok(())
}
}