Struct neuronika::nn::LSTMCell [−][src]
pub struct LSTMCell {
pub weight_ih: Learnable<Ix2>,
pub weight_hh: Learnable<Ix2>,
pub bias_ih: Learnable<Ix1>,
pub bias_hh: Learnable<Ix1>,
}
Expand description
A long short-term memory (LSTM) cell.
Fields
weight_ih: Learnable<Ix2>
weight_hh: Learnable<Ix2>
bias_ih: Learnable<Ix1>
bias_hh: Learnable<Ix1>
Implementations
Creates a new LSTMCell.
Arguments
-
input_size
- number of expected features in the input. -
hidden_size
- number of features in the hidden state.
All the weight and biases are initialized from U(-k, k) where
k = (1. / hidden_size as f32).sqrt()
.
pub fn forward<Cf, Cb, Hf, Hb, I, T, U>(
&self,
state: (VarDiff<Cf, Cb>, VarDiff<Hf, Hb>),
input: I
) -> (VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>, VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>) where
Cf: Data<Dim = Ix2>,
Cb: Gradient<Dim = Ix2> + Overwrite,
Hf: Data<Dim = Ix2>,
Hb: Gradient<Dim = Ix2> + Overwrite,
I: MatMatMulT<Learnable<Ix2>>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix2>,
U: Gradient<Dim = Ix2> + Overwrite,
pub fn forward<Cf, Cb, Hf, Hb, I, T, U>(
&self,
state: (VarDiff<Cf, Cb>, VarDiff<Hf, Hb>),
input: I
) -> (VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>, VarDiff<impl Data<Dim = Ix2> + Forward, impl Gradient<Dim = Ix2> + Overwrite + Backward>) where
Cf: Data<Dim = Ix2>,
Cb: Gradient<Dim = Ix2> + Overwrite,
Hf: Data<Dim = Ix2>,
Hb: Gradient<Dim = Ix2> + Overwrite,
I: MatMatMulT<Learnable<Ix2>>,
I::Output: Into<VarDiff<T, U>>,
T: Data<Dim = Ix2>,
U: Gradient<Dim = Ix2> + Overwrite,
Computes a single LSTM step.
Arguments
-
state
- a tuple of tensors, both of shape (batch, hidden_size), containing the initial hidden state for each element in the batch and the initial cell’s state for each element in the batch. -
input
- a variable containing the input features of shape (batch, input_size).
The output is a tuple of tensors made of the next hidden state for each element in the batch, of shape (batch, hidden_size) and the next cell’s state for each element in the batch, of shape (batch, hidden_size).