pub struct Lstm<B: Backend> { /* private fields */ }Expand description
The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
Implementations§
source§impl<B: Backend> Lstm<B>
impl<B: Backend> Lstm<B>
sourcepub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>
) -> (Tensor<B, 3>, Tensor<B, 3>)
pub fn forward( &self, batched_input: Tensor<B, 3>, state: Option<(Tensor<B, 2>, Tensor<B, 2>)> ) -> (Tensor<B, 3>, Tensor<B, 3>)
Applies the forward pass on the input tensor. This LSTM implementation
returns the cell state and hidden state for each element in a sequence (i.e., across seq_length),
producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size].
Parameters: batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. state: An optional tuple of tensors representing the initial cell state and hidden state. Each state tensor has shape [batch_size, hidden_size]. If no initial state is provided, these tensors are initialized to zeros.
Returns: A tuple of tensors, where the first tensor represents the cell states and the second tensor represents the hidden states for each sequence element. Both output tensors have the shape [batch_size, sequence_length, hidden_size].