Softmax

Trait Softmax 

Source
pub trait Softmax<F>: NN<F> {
    // Required methods
    fn softmax(
        &self,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>,
    ) -> Result<(), Error>;
    fn softmax_grad(
        &self,
        x: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>,
    ) -> Result<(), Error>;
}
Expand description

Provides the functionality for a Backend to support Softmax operations.

Required Methods§

Source

fn softmax( &self, x: &SharedTensor<F>, result: &mut SharedTensor<F>, ) -> Result<(), Error>

Computes a [Softmax][softmax] over the input Tensor x. [softmax]: https://en.wikipedia.org/wiki/Softmax_function

Saves the result to result.

Source

fn softmax_grad( &self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result_diff: &mut SharedTensor<F>, ) -> Result<(), Error>

Computes the gradient of a [Softmax][softmax] over the input Tensor x. [softmax]: https://en.wikipedia.org/wiki/Softmax_function

Saves the result to result_diff.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl Softmax<f32> for Backend<Native>

Source§

fn softmax( &self, x: &SharedTensor<f32>, result: &mut SharedTensor<f32>, ) -> Result<(), Error>

Source§

fn softmax_grad( &self, x: &SharedTensor<f32>, x_diff: &SharedTensor<f32>, result_diff: &mut SharedTensor<f32>, ) -> Result<(), Error>

Source§

impl Softmax<f64> for Backend<Native>

Source§

fn softmax( &self, x: &SharedTensor<f64>, result: &mut SharedTensor<f64>, ) -> Result<(), Error>

Source§

fn softmax_grad( &self, x: &SharedTensor<f64>, x_diff: &SharedTensor<f64>, result_diff: &mut SharedTensor<f64>, ) -> Result<(), Error>

Source§

impl<T> Softmax<T> for Backend<Cuda>
where T: Float + Default + DataTypeInfo,

Source§

fn softmax( &self, x: &SharedTensor<T>, result: &mut SharedTensor<T>, ) -> Result<(), Error>

Source§

fn softmax_grad( &self, x: &SharedTensor<T>, x_diff: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, ) -> Result<(), Error>

Implementors§