Rnn

Trait Rnn 

Source
pub trait Rnn<F>: NN<F> {
    // Required methods
    fn new_rnn_config(
        &self,
        src: &SharedTensor<F>,
        dropout_probability: Option<f32>,
        dropout_seed: Option<u64>,
        sequence_length: i32,
        network_mode: RnnNetworkMode,
        input_mode: RnnInputMode,
        direction_mode: DirectionMode,
        algorithm: RnnAlgorithm,
        hidden_size: i32,
        num_layers: i32,
        batch_size: i32,
    ) -> Result<Self::CRNN, Error>;
    fn generate_rnn_weight_description(
        &self,
        rnn_config: &Self::CRNN,
        batch_size: i32,
        input_size: i32,
    ) -> Result<Vec<usize>, Error>;
    fn rnn_forward(
        &self,
        src: &SharedTensor<F>,
        output: &mut SharedTensor<F>,
        rnn_config: &Self::CRNN,
        weight: &SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
    ) -> Result<(), Error>;
    fn rnn_backward_data(
        &self,
        src: &SharedTensor<F>,
        src_gradient: &mut SharedTensor<F>,
        output: &SharedTensor<F>,
        output_gradient: &SharedTensor<F>,
        rnn_config: &Self::CRNN,
        weight: &SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
    ) -> Result<(), Error>;
    fn rnn_backward_weights(
        &self,
        src: &SharedTensor<F>,
        output: &SharedTensor<F>,
        filter: &mut SharedTensor<F>,
        rnn_config: &Self::CRNN,
        workspace: &mut SharedTensor<u8>,
    ) -> Result<(), Error>;
}
Expand description

Provide the functionality for a Backend to support RNN operations

Required Methods§

Source

fn new_rnn_config( &self, src: &SharedTensor<F>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32, ) -> Result<Self::CRNN, Error>

Create a RnnConfig

Source

fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, batch_size: i32, input_size: i32, ) -> Result<Vec<usize>, Error>

Generate Weights for RNN

Source

fn rnn_forward( &self, src: &SharedTensor<F>, output: &mut SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Train a LSTM Network and Return Results

§Arguments
  • weight_desc Previously initialised FilterDescriptor for Weights
Source

fn rnn_backward_data( &self, src: &SharedTensor<F>, src_gradient: &mut SharedTensor<F>, output: &SharedTensor<F>, output_gradient: &SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Calculates RNN Gradients for Input/Hidden/Cell

Source

fn rnn_backward_weights( &self, src: &SharedTensor<F>, output: &SharedTensor<F>, filter: &mut SharedTensor<F>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Calculates RNN Gradients for Weights

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl<T> Rnn<T> for Backend<Cuda>
where T: Float + DataTypeInfo,

Source§

fn rnn_forward( &self, src: &SharedTensor<T>, output: &mut SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Train and Output a RNN Network

Source§

fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, batch_size: i32, input_size: i32, ) -> Result<Vec<usize>, Error>

Source§

fn new_rnn_config( &self, src: &SharedTensor<T>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32, ) -> Result<Self::CRNN, Error>

Source§

fn rnn_backward_data( &self, src: &SharedTensor<T>, src_gradient: &mut SharedTensor<T>, output: &SharedTensor<T>, output_gradient: &SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Source§

fn rnn_backward_weights( &self, src: &SharedTensor<T>, output: &SharedTensor<T>, filter: &mut SharedTensor<T>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Source§

impl<T> Rnn<T> for Backend<Native>
where T: Float + Default + Copy + PartialOrd + Bounded,

Source§

fn new_rnn_config( &self, src: &SharedTensor<T>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32, ) -> Result<Self::CRNN, Error>

Source§

fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, batch_size: i32, input_size: i32, ) -> Result<Vec<usize>, Error>

Source§

fn rnn_forward( &self, src: &SharedTensor<T>, output: &mut SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Source§

fn rnn_backward_data( &self, src: &SharedTensor<T>, src_gradient: &mut SharedTensor<T>, output: &SharedTensor<T>, output_gradient: &SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Source§

fn rnn_backward_weights( &self, src: &SharedTensor<T>, output: &SharedTensor<T>, filter: &mut SharedTensor<T>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>

Implementors§