use super::{Optimizer, Param, Penalty};
use ndarray::{ArrayD, ArrayViewMutD, Zip};
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use std::cell::{Cell, RefCell};
#[allow(clippy::upper_case_acronyms)]
pub struct AMSGrad<'a, T: Penalty> {
params: RefCell<Vec<AMSGradParam<'a>>>,
lr: Cell<f32>,
penalty: T,
betas: Cell<(f32, f32)>,
eps: Cell<f32>,
}
impl<'a, T: Penalty> AMSGrad<'a, T> {
pub fn new(params: Vec<Param<'a>>, lr: f32, betas: (f32, f32), penalty: T, eps: f32) -> Self {
let params = RefCell::new(Self::build_params(params));
let lr = Cell::new(lr);
Self {
params,
lr,
penalty,
betas: Cell::new(betas),
eps: Cell::new(eps),
}
}
pub fn get_lr(&self) -> f32 {
Optimizer::get_lr(self)
}
pub fn set_lr(&self, lr: f32) {
Optimizer::set_lr(self, lr);
}
pub fn get_betas(&self) -> (f32, f32) {
self.betas.get()
}
pub fn set_betas(&self, betas: (f32, f32)) {
self.betas.set(betas)
}
pub fn get_eps(&self) -> f32 {
self.eps.get()
}
pub fn set_eps(&self, eps: f32) {
self.eps.set(eps)
}
pub fn step(&self) {
Optimizer::step(self);
}
pub fn zero_grad(&self) {
Optimizer::zero_grad(self);
}
}
#[allow(clippy::upper_case_acronyms)]
pub struct AMSGradParam<'a> {
data: ArrayViewMutD<'a, f32>,
grad: ArrayViewMutD<'a, f32>,
step: usize,
exp_avg: ArrayD<f32>,
exp_avg_sq: ArrayD<f32>,
max_exp_avg_sq: ArrayD<f32>,
}
impl<'a> From<Param<'a>> for AMSGradParam<'a> {
fn from(param: Param<'a>) -> Self {
let Param { data, grad } = param;
let step = 0;
let (exp_avg, exp_avg_sq, max_exp_avg_sq) = {
(
ArrayD::zeros(grad.raw_dim()),
ArrayD::zeros(grad.raw_dim()),
ArrayD::zeros(grad.raw_dim()),
)
};
Self {
data,
grad,
step,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
}
}
}
impl<'a, T: Penalty> Optimizer<'a> for AMSGrad<'a, T> {
type ParamRepr = AMSGradParam<'a>;
fn step(&self) {
let (lr, penalty, mut params, (beta1, beta2), eps) = (
self.lr.get(),
&self.penalty,
self.params.borrow_mut(),
&self.betas.get(),
&self.eps.get(),
);
params.par_iter_mut().for_each(|param| {
let (step, exp_avg, exp_avg_sq, max_exp_avg_sq) = (
&mut param.step,
&mut param.exp_avg,
&mut param.exp_avg_sq,
&mut param.max_exp_avg_sq,
);
*step += 1;
let bias_correction1 = 1. - beta1.powi(*step as i32);
let bias_correction2 = 1. - beta2.powi(*step as i32);
let mut p_grad = param.grad.to_owned();
Zip::from(&mut p_grad)
.and(¶m.data)
.for_each(|p_grad_el, data_el| *p_grad_el += penalty.penalize(data_el));
Zip::from(exp_avg)
.and(&p_grad)
.for_each(|exp_avg_el, p_grad_el| {
*exp_avg_el = *exp_avg_el * beta1 + p_grad_el * (1. - beta1)
});
Zip::from(exp_avg_sq)
.and(&p_grad)
.for_each(|exp_avg_sq_el, p_grad_el| {
*exp_avg_sq_el = *exp_avg_sq_el * beta2 + p_grad_el * p_grad_el * (1. - beta2)
});
Zip::from(max_exp_avg_sq).and(¶m.exp_avg_sq).for_each(
|max_exp_avg_sq_el, exp_avg_sq_el| {
*max_exp_avg_sq_el = max_exp_avg_sq_el.max(*exp_avg_sq_el)
},
);
Zip::from(&mut param.data)
.and(¶m.exp_avg)
.and(¶m.max_exp_avg_sq)
.for_each(|data_el, exp_avg_el, max_exp_avg_sq_el| {
*data_el += exp_avg_el
/ ((max_exp_avg_sq_el.sqrt() / bias_correction2.sqrt()) + *eps)
* (-lr / bias_correction1)
})
});
}
fn zero_grad(&self) {
self.params.borrow_mut().par_iter_mut().for_each(|param| {
let grad = &mut param.grad;
Zip::from(grad).for_each(|grad_el| *grad_el = 0.);
});
}
fn get_lr(&self) -> f32 {
self.lr.get()
}
fn set_lr(&self, lr: f32) {
self.lr.set(lr)
}
}
#[cfg(test)]
mod test;