LogSoftmax

Trait LogSoftmax 

Source
pub trait LogSoftmax<F>: NN<F> {
    // Required methods
    fn log_softmax(
        &self,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>,
    ) -> Result<(), Error>;
    fn log_softmax_grad(
        &self,
        x: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>,
    ) -> Result<(), Error>;
}
Expand description

Provides the functionality for a Backend to support LogSoftmax operations.

Required Methods§

Source

fn log_softmax( &self, x: &SharedTensor<F>, result: &mut SharedTensor<F>, ) -> Result<(), Error>

Computes a logarithmic softmax over the input Tensor x.

Saves the result to result.

Source

fn log_softmax_grad( &self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result_diff: &mut SharedTensor<F>, ) -> Result<(), Error>

Computes the gradient of a logarithmic softmax over the input Tensor x.

Saves the result to result_diff.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl LogSoftmax<f32> for Backend<Native>

Source§

fn log_softmax( &self, x: &SharedTensor<f32>, result: &mut SharedTensor<f32>, ) -> Result<(), Error>

Source§

fn log_softmax_grad( &self, x: &SharedTensor<f32>, x_diff: &SharedTensor<f32>, result_diff: &mut SharedTensor<f32>, ) -> Result<(), Error>

Source§

impl LogSoftmax<f64> for Backend<Native>

Source§

fn log_softmax( &self, x: &SharedTensor<f64>, result: &mut SharedTensor<f64>, ) -> Result<(), Error>

Source§

fn log_softmax_grad( &self, x: &SharedTensor<f64>, x_diff: &SharedTensor<f64>, result_diff: &mut SharedTensor<f64>, ) -> Result<(), Error>

Source§

impl<T> LogSoftmax<T> for Backend<Cuda>
where T: Float + Default + DataTypeInfo,

Source§

fn log_softmax( &self, x: &SharedTensor<T>, result: &mut SharedTensor<T>, ) -> Result<(), Error>

Source§

fn log_softmax_grad( &self, x: &SharedTensor<T>, x_diff: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, ) -> Result<(), Error>

Implementors§