Relu

Trait Relu 

Source
pub trait Relu<F>: NN<F> {
    // Required methods
    fn relu(
        &self,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>,
    ) -> Result<(), Error>;
    fn relu_grad(
        &self,
        x: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>,
    ) -> Result<(), Error>;
}
Expand description

Provides the functionality for a Backend to support ReLU operations.

Required Methods§

Source

fn relu( &self, x: &SharedTensor<F>, result: &mut SharedTensor<F>, ) -> Result<(), Error>

Computes the [Rectified linear units][relu] over the input Tensor x. [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)

Saves the result to result.

Source

fn relu_grad( &self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result: &SharedTensor<F>, result_diff: &mut SharedTensor<F>, ) -> Result<(), Error>

Computes the gradient of [ReLU][relu] over the input Tensor x. [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)

Saves the result to result_diff.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl Relu<f32> for Backend<Native>

Source§

fn relu( &self, x: &SharedTensor<f32>, result: &mut SharedTensor<f32>, ) -> Result<(), Error>

Source§

fn relu_grad( &self, x: &SharedTensor<f32>, x_diff: &SharedTensor<f32>, result: &SharedTensor<f32>, result_diff: &mut SharedTensor<f32>, ) -> Result<(), Error>

Source§

impl Relu<f64> for Backend<Native>

Source§

fn relu( &self, x: &SharedTensor<f64>, result: &mut SharedTensor<f64>, ) -> Result<(), Error>

Source§

fn relu_grad( &self, x: &SharedTensor<f64>, x_diff: &SharedTensor<f64>, result: &SharedTensor<f64>, result_diff: &mut SharedTensor<f64>, ) -> Result<(), Error>

Source§

impl<T> Relu<T> for Backend<Cuda>
where T: Float + Default + DataTypeInfo,

Source§

fn relu( &self, x: &SharedTensor<T>, result: &mut SharedTensor<T>, ) -> Result<(), Error>

Source§

fn relu_grad( &self, x: &SharedTensor<T>, x_diff: &SharedTensor<T>, result: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, ) -> Result<(), Error>

Implementors§