use super::{Optimizer, Param, Penalty};
use ndarray::{ArrayD, ArrayViewMutD, Zip};
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use std::cell::{Cell, RefCell};
#[allow(clippy::upper_case_acronyms)]
pub struct SGD<'a, T> {
params: RefCell<Vec<SGDParam<'a>>>,
lr: Cell<f32>,
penalty: T,
}
#[allow(clippy::upper_case_acronyms)]
pub struct SGDParam<'a> {
data: ArrayViewMutD<'a, f32>,
grad: ArrayViewMutD<'a, f32>,
}
impl<'a> From<Param<'a>> for SGDParam<'a> {
fn from(param: Param<'a>) -> Self {
let Param { data, grad } = param;
Self { data, grad }
}
}
impl<'a, T: Penalty> Optimizer<'a> for SGD<'a, T> {
type ParamRepr = SGDParam<'a>;
fn step(&self) {
let (lr, penalty, mut params) = (self.lr.get(), &self.penalty, self.params.borrow_mut());
params.par_iter_mut().for_each(|param| {
let (data, grad) = (&mut param.data, ¶m.grad);
Zip::from(data).and(grad).for_each(|data_el, grad_el| {
*data_el += -(grad_el + penalty.penalize(data_el)) * lr
});
});
}
fn zero_grad(&self) {
self.params.borrow_mut().par_iter_mut().for_each(|param| {
let grad = &mut param.grad;
Zip::from(grad).for_each(|grad_el| *grad_el = 0.);
});
}
fn get_lr(&self) -> f32 {
self.lr.get()
}
fn set_lr(&self, lr: f32) {
self.lr.set(lr)
}
}
impl<'a, T: Penalty> SGD<'a, T> {
pub fn new(parameters: Vec<Param<'a>>, lr: f32, penalty: T) -> Self {
let params = RefCell::new(Self::build_params(parameters));
let lr = Cell::new(lr);
Self {
params,
lr,
penalty,
}
}
pub fn get_lr(&self) -> f32 {
Optimizer::get_lr(self)
}
pub fn set_lr(&self, lr: f32) {
Optimizer::set_lr(self, lr);
}
pub fn step(&self) {
Optimizer::step(self);
}
pub fn zero_grad(&self) {
Optimizer::zero_grad(self);
}
pub fn with_momentum(
self,
momentum: f32,
dampening: f32,
nesterov: bool,
) -> SGDWithMomentum<'a, T> {
let params: RefCell<Vec<SGDWithMomentumParam>> =
RefCell::new(Self::build_params(self.params.into_inner()));
SGDWithMomentum {
params,
lr: self.lr,
penalty: self.penalty,
momentum: Cell::new(momentum),
dampening: Cell::new(dampening),
nesterov: Cell::new(nesterov),
}
}
}
#[allow(clippy::upper_case_acronyms)]
pub struct SGDWithMomentum<'a, T> {
params: RefCell<Vec<SGDWithMomentumParam<'a>>>,
lr: Cell<f32>,
penalty: T,
momentum: Cell<f32>,
dampening: Cell<f32>,
nesterov: Cell<bool>,
}
#[allow(clippy::upper_case_acronyms)]
pub struct SGDWithMomentumParam<'a> {
data: ArrayViewMutD<'a, f32>,
grad: ArrayViewMutD<'a, f32>,
buffer: ArrayD<f32>,
}
impl<'a> From<Param<'a>> for SGDWithMomentumParam<'a> {
fn from(param: Param<'a>) -> Self {
let Param { data, grad } = param;
let buffer = ArrayD::zeros(grad.raw_dim());
Self { data, grad, buffer }
}
}
impl<'a> From<SGDParam<'a>> for SGDWithMomentumParam<'a> {
fn from(param: SGDParam<'a>) -> Self {
let (data, grad) = (param.data, param.grad);
let buffer = ArrayD::zeros(grad.raw_dim());
Self { data, grad, buffer }
}
}
impl<'a, T: Penalty> Optimizer<'a> for SGDWithMomentum<'a, T> {
type ParamRepr = SGDWithMomentumParam<'a>;
fn step(&self) {
let (lr, penalty, momentum, dampening, nesterov, mut params) = (
self.lr.get(),
&self.penalty,
&self.momentum.get(),
&self.dampening.get(),
&self.nesterov.get(),
self.params.borrow_mut(),
);
params.par_iter_mut().for_each(|param| {
let mut p_grad = param.grad.to_owned();
Zip::from(&mut p_grad)
.and(¶m.data)
.for_each(|p_grad_el, data_el| *p_grad_el += penalty.penalize(data_el));
Zip::from(&mut param.buffer)
.and(&p_grad)
.for_each(|buffer_el, p_grad_el| {
*buffer_el = *buffer_el * *momentum + p_grad_el * (1. - dampening)
});
let zip = Zip::from(&mut param.data).and(¶m.buffer);
if *nesterov {
zip.and(&p_grad).for_each(|data_el, buffer_el, p_grad_el| {
*data_el += -(p_grad_el + *buffer_el * *momentum) * lr
});
} else {
zip.for_each(|data_el, buffer_el| *data_el += -*buffer_el * lr);
}
});
}
fn zero_grad(&self) {
self.params.borrow_mut().par_iter_mut().for_each(|param| {
let grad = &mut param.grad;
Zip::from(grad).for_each(|grad_el| *grad_el = 0.);
});
}
fn get_lr(&self) -> f32 {
self.lr.get()
}
fn set_lr(&self, lr: f32) {
self.lr.set(lr)
}
}
impl<'a, T: Penalty> SGDWithMomentum<'a, T> {
pub fn get_lr(&self) -> f32 {
Optimizer::get_lr(self)
}
pub fn set_lr(&self, lr: f32) {
Optimizer::set_lr(self, lr);
}
pub fn get_momentum(&self) -> f32 {
self.momentum.get()
}
pub fn set_momentum(&self, momentum: f32) {
self.momentum.set(momentum);
}
pub fn get_dampening(&self) -> f32 {
self.dampening.get()
}
pub fn set_dampening(&self, dampening: f32) {
self.dampening.set(dampening);
}
pub fn get_nesterov(&self) -> bool {
self.nesterov.get()
}
pub fn set_nesterov(&self, nesterov: bool) {
self.nesterov.set(nesterov);
}
pub fn step(&self) {
Optimizer::step(self);
}
pub fn zero_grad(&self) {
Optimizer::zero_grad(self);
}
}
#[cfg(test)]
mod test;