tch 0.0.1

PyTorch wrappers for rust
use super::c_optimizer::COptimizer;
use super::var_store::VarStore;
use crate::tensor::Tensor;

pub struct Optimizer {
    opt: COptimizer,
}

pub struct Sgd {
    momentum: f64,
    dampening: f64,
    wd: f64,
    nesterov: bool,
}

impl Default for Sgd {
    fn default() -> Self {
        Sgd {
            momentum: 0.,
            dampening: 0.,
            wd: 0.,
            nesterov: false,
        }
    }
}

pub struct Adam {
    beta1: f64,
    beta2: f64,
    wd: f64,
}

impl Default for Adam {
    fn default() -> Self {
        Adam {
            beta1: 0.9,
            beta2: 0.999,
            wd: 0.,
        }
    }
}

pub struct RmsProp {
    alpha: f64,
    eps: f64,
    wd: f64,
    momentum: f64,
    centered: bool,
}

impl Default for RmsProp {
    fn default() -> Self {
        RmsProp {
            alpha: 0.99,
            eps: 1e-8,
            wd: 0.,
            momentum: 0.,
            centered: false,
        }
    }
}

impl Optimizer {
    pub fn sgd(vs: &VarStore, lr: f64, s: Sgd) -> Optimizer {
        let mut opt = COptimizer::sgd(lr, s.momentum, s.dampening, s.wd, s.nesterov);
        opt.add_parameters(&vs.trainable_variables());
        Optimizer { opt }
    }

    pub fn adam(vs: &VarStore, lr: f64, a: Adam) -> Optimizer {
        let mut opt = COptimizer::adam(lr, a.beta1, a.beta2, a.wd);
        opt.add_parameters(&vs.trainable_variables());
        Optimizer { opt }
    }

    pub fn rms_prop(vs: &VarStore, lr: f64, r: RmsProp) -> Optimizer {
        let mut opt = COptimizer::rms_prop(lr, r.alpha, r.eps, r.wd, r.momentum, r.centered);
        opt.add_parameters(&vs.trainable_variables());
        Optimizer { opt }
    }

    pub fn zero_grad(&self) {
        self.opt.zero_grad()
    }

    pub fn step(&self) {
        self.opt.step()
    }

    pub fn backward_step(&self, loss: &Tensor) {
        self.zero_grad();
        loss.backward();
        self.step();
    }

    pub fn set_lr(&mut self, lr: f64) {
        self.opt.set_learning_rate(lr)
    }

    pub fn set_momentum(&mut self, m: f64) {
        self.opt.set_momentum(m)
    }
}