gradients 0.3.4

An OpenCL, CUDA and CPU based Deep Learning Library
Documentation
use crate::{GetParam, WithDevice};
use custos::{number::Float, CDatatype, GenericBlas};
use custos_math::Matrix;
use gradients_derive::NoParams;

#[derive(NoParams)]
pub struct ReLU<'a, T> {
    inputs: Option<Matrix<'a, T>>,
}

impl<'a, T: Float + CDatatype> ReLU<'a, T> {
    pub fn new() -> ReLU<'a, T> {
        ReLU { inputs: None }
    }
    pub fn forward(&mut self, inputs: &Matrix<'a, T>) -> Matrix<'a, T> {
        self.inputs = Some(inputs.shallow_or_clone());
        inputs.relu()
    }
    pub fn backward(&self, grad: &Matrix<'a, T>) -> Matrix<'a, T> {
       self.inputs.as_ref().unwrap().relu_grad() * grad
    }
}

impl<'a, T> Default for ReLU<'a, T> {
    fn default() -> Self {
        Self {
            inputs: Default::default(),
        }
    }
}

#[derive(NoParams)]
pub struct Softmax<'a, T> {
    activated: Option<Matrix<'a, T>>,
}

impl<'a, T: CDatatype + GenericBlas> Softmax<'a, T> {
    pub fn new() -> Self {
        Softmax { activated: None }
    }

    pub fn forward(&mut self, x: &Matrix<'a, T>) -> Matrix<'a, T> {
        let activated = x.softmax();
        self.activated = Some(activated.shallow_or_clone());
        activated
    }

    pub fn backward(&self, grad: &Matrix<'a, T>) -> Matrix<'a, T> {
        grad.softmax_grad(&self.activated.as_ref().unwrap())
    }
}

impl<'a, T> Default for Softmax<'a, T> {
    fn default() -> Self {
        Self {
            activated: Default::default(),
        }
    }
}

#[derive(NoParams)]
pub struct Tanh<'a, T> {
    inputs: Option<Matrix<'a, T>>,
}

impl<'a, T> Tanh<'a, T> {
    pub fn new() -> Tanh<'a, T> {
        Tanh { inputs: None }
    }
}

impl<'a, T: Float + CDatatype> Tanh<'a, T> {
    pub fn forward(&mut self, inputs: &Matrix<'a, T>) -> Matrix<'a, T> {
        self.inputs = Some(inputs.shallow_or_clone());
        inputs.tanh()
    }
    pub fn backward(&self, grad: &Matrix<'a, T>) -> Matrix<'a, T> {
        self.inputs.as_ref().unwrap().tanh_grad() * grad
    }
}

impl<'a, T> Default for Tanh<'a, T> {
    fn default() -> Self {
        Self {
            inputs: Default::default(),
        }
    }
}