pub trait Rnn<F>: NN<F> {
// Required methods
fn new_rnn_config(
&self,
src: &SharedTensor<F>,
dropout_probability: Option<f32>,
dropout_seed: Option<u64>,
sequence_length: i32,
network_mode: RnnNetworkMode,
input_mode: RnnInputMode,
direction_mode: DirectionMode,
algorithm: RnnAlgorithm,
hidden_size: i32,
num_layers: i32,
batch_size: i32,
) -> Result<Self::CRNN, Error>;
fn generate_rnn_weight_description(
&self,
rnn_config: &Self::CRNN,
batch_size: i32,
input_size: i32,
) -> Result<Vec<usize>, Error>;
fn rnn_forward(
&self,
src: &SharedTensor<F>,
output: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
) -> Result<(), Error>;
fn rnn_backward_data(
&self,
src: &SharedTensor<F>,
src_gradient: &mut SharedTensor<F>,
output: &SharedTensor<F>,
output_gradient: &SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
) -> Result<(), Error>;
fn rnn_backward_weights(
&self,
src: &SharedTensor<F>,
output: &SharedTensor<F>,
filter: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
workspace: &mut SharedTensor<u8>,
) -> Result<(), Error>;
}Expand description
Provide the functionality for a Backend to support RNN operations
Required Methods§
Sourcefn new_rnn_config(
&self,
src: &SharedTensor<F>,
dropout_probability: Option<f32>,
dropout_seed: Option<u64>,
sequence_length: i32,
network_mode: RnnNetworkMode,
input_mode: RnnInputMode,
direction_mode: DirectionMode,
algorithm: RnnAlgorithm,
hidden_size: i32,
num_layers: i32,
batch_size: i32,
) -> Result<Self::CRNN, Error>
fn new_rnn_config( &self, src: &SharedTensor<F>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32, ) -> Result<Self::CRNN, Error>
Create a RnnConfig
Sourcefn generate_rnn_weight_description(
&self,
rnn_config: &Self::CRNN,
batch_size: i32,
input_size: i32,
) -> Result<Vec<usize>, Error>
fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, batch_size: i32, input_size: i32, ) -> Result<Vec<usize>, Error>
Generate Weights for RNN
Sourcefn rnn_forward(
&self,
src: &SharedTensor<F>,
output: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
) -> Result<(), Error>
fn rnn_forward( &self, src: &SharedTensor<F>, output: &mut SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>
Train a LSTM Network and Return Results
§Arguments
weight_descPreviously initialised FilterDescriptor for Weights
Sourcefn rnn_backward_data(
&self,
src: &SharedTensor<F>,
src_gradient: &mut SharedTensor<F>,
output: &SharedTensor<F>,
output_gradient: &SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
) -> Result<(), Error>
fn rnn_backward_data( &self, src: &SharedTensor<F>, src_gradient: &mut SharedTensor<F>, output: &SharedTensor<F>, output_gradient: &SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>
Calculates RNN Gradients for Input/Hidden/Cell
Sourcefn rnn_backward_weights(
&self,
src: &SharedTensor<F>,
output: &SharedTensor<F>,
filter: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
workspace: &mut SharedTensor<u8>,
) -> Result<(), Error>
fn rnn_backward_weights( &self, src: &SharedTensor<F>, output: &SharedTensor<F>, filter: &mut SharedTensor<F>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>
Calculates RNN Gradients for Weights
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.
Implementations on Foreign Types§
Source§impl<T> Rnn<T> for Backend<Cuda>where
T: Float + DataTypeInfo,
impl<T> Rnn<T> for Backend<Cuda>where
T: Float + DataTypeInfo,
Source§fn rnn_forward(
&self,
src: &SharedTensor<T>,
output: &mut SharedTensor<T>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<T>,
workspace: &mut SharedTensor<u8>,
) -> Result<(), Error>
fn rnn_forward( &self, src: &SharedTensor<T>, output: &mut SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8>, ) -> Result<(), Error>
Train and Output a RNN Network