use super::var_store::{VarStore, Variables};
use crate::wrappers::optimizer::COptimizer;
use crate::Tensor;
use failure::Fallible;
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct Optimizer<T> {
opt: COptimizer,
variables: Arc<Mutex<Variables>>,
variables_in_optimizer: usize,
config: T,
}
pub trait OptimizerConfig
where
Self: std::marker::Sized,
{
fn build_copt(&self, lr: f64) -> Fallible<COptimizer>;
fn build(self, vs: &VarStore, lr: f64) -> Fallible<Optimizer<Self>> {
let mut opt = self.build_copt(lr)?;
let v = vs.variables_.lock().unwrap();
opt.add_parameters(&v.trainable_variables)?;
Ok(Optimizer {
opt,
variables: vs.variables_.clone(),
variables_in_optimizer: v.trainable_variables.len(),
config: self,
})
}
}
#[derive(Debug, Copy, Clone)]
pub struct Sgd {
pub momentum: f64,
pub dampening: f64,
pub wd: f64,
pub nesterov: bool,
}
impl Default for Sgd {
fn default() -> Self {
Sgd {
momentum: 0.,
dampening: 0.,
wd: 0.,
nesterov: false,
}
}
}
pub fn sgd(momentum: f64, dampening: f64, wd: f64, nesterov: bool) -> Sgd {
Sgd {
momentum,
dampening,
wd,
nesterov,
}
}
impl OptimizerConfig for Sgd {
fn build_copt(&self, lr: f64) -> Fallible<COptimizer> {
COptimizer::sgd(lr, self.momentum, self.dampening, self.wd, self.nesterov)
}
}
#[derive(Debug, Copy, Clone)]
pub struct Adam {
pub beta1: f64,
pub beta2: f64,
pub wd: f64,
}
impl Default for Adam {
fn default() -> Self {
Adam {
beta1: 0.9,
beta2: 0.999,
wd: 0.,
}
}
}
pub fn adam(beta1: f64, beta2: f64, wd: f64) -> Adam {
Adam { beta1, beta2, wd }
}
impl OptimizerConfig for Adam {
fn build_copt(&self, lr: f64) -> Fallible<COptimizer> {
COptimizer::adam(lr, self.beta1, self.beta2, self.wd)
}
}
#[derive(Debug, Copy, Clone)]
pub struct RmsProp {
pub alpha: f64,
pub eps: f64,
pub wd: f64,
pub momentum: f64,
pub centered: bool,
}
impl Default for RmsProp {
fn default() -> Self {
RmsProp {
alpha: 0.99,
eps: 1e-8,
wd: 0.,
momentum: 0.,
centered: false,
}
}
}
pub fn rms_prop(alpha: f64, eps: f64, wd: f64, momentum: f64, centered: bool) -> RmsProp {
RmsProp {
alpha,
eps,
wd,
momentum,
centered,
}
}
impl OptimizerConfig for RmsProp {
fn build_copt(&self, lr: f64) -> Fallible<COptimizer> {
COptimizer::rms_prop(
lr,
self.alpha,
self.eps,
self.wd,
self.momentum,
self.centered,
)
}
}
impl<T> Optimizer<T> {
fn add_missing_variables(&mut self) {
let v = self.variables.lock().unwrap();
let missing_variables = v.trainable_variables.len() - self.variables_in_optimizer;
if missing_variables > 0 {
self.opt
.add_parameters(&v.trainable_variables[self.variables_in_optimizer..])
.unwrap();
self.variables_in_optimizer = v.trainable_variables.len();
}
}
pub fn zero_grad(&mut self) {
self.add_missing_variables();
self.opt.zero_grad().unwrap()
}
pub fn clip_grad_value(&self, max: f64) {
let v = self.variables.lock().unwrap();
for tensor in v.trainable_variables.iter() {
let _t = tensor.grad().clamp_(-max, max);
}
}
pub fn step(&mut self) {
self.add_missing_variables();
self.opt.step().unwrap()
}
pub fn backward_step(&mut self, loss: &Tensor) {
self.add_missing_variables();
self.opt.zero_grad().unwrap();
loss.backward();
self.opt.step().unwrap()
}
pub fn backward_step_clip(&mut self, loss: &Tensor, max: f64) {
self.add_missing_variables();
self.opt.zero_grad().unwrap();
loss.backward();
self.clip_grad_value(max);
self.opt.step().unwrap()
}
pub fn set_lr(&mut self, lr: f64) {
self.opt.set_learning_rate(lr).unwrap()
}
pub fn set_momentum(&mut self, m: f64) {
self.opt.set_momentum(m).unwrap()
}
}