use super::super::super::{SeqPacked, SeqSerial};
use super::super::seq_serial_map;
use super::{RnnBase, RnnBaseConfig, RnnImpl, RnnLayerWeights};
use crate::torch::packed::PackedTensor;
use serde::{Deserialize, Serialize};
use tch::{Kind, Tensor};
pub type LstmConfig = RnnBaseConfig<LstmImpl>;
pub type Lstm = RnnBase<LstmImpl>;
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LstmImpl;
impl RnnImpl for LstmImpl {
type CellState = (Tensor, Tensor);
const CUDNN_MODE: u32 = 2;
const GATES_MULTIPLE: usize = 4;
fn initial_cell_state(rnn: &RnnBase<Self>, batch_size: usize) -> Self::CellState {
let hidden_state = Tensor::zeros(
&[batch_size as i64, rnn.hidden_size as i64],
(Kind::Float, rnn.device),
);
let cell_state = hidden_state.shallow_clone();
(hidden_state, cell_state)
}
fn cell_batch_step(
_: &RnnBase<Self>,
state: &mut Self::CellState,
w: &RnnLayerWeights,
batch_input: &Tensor,
) -> Tensor {
let (ref hidden_state, ref cell_state) = state;
let (new_hidden_state, new_cell_state) = batch_input.lstm_cell(
&[hidden_state, cell_state],
w.w_ih(),
w.w_hh(),
w.b_ih(),
w.b_hh(),
);
*state = (new_hidden_state, new_cell_state);
state.0.shallow_clone()
}
}
impl SeqSerial for Lstm {
fn seq_serial(&self, inputs: &Tensor, seq_lengths: &[usize]) -> Tensor {
let shape = inputs.size();
assert_eq!(
shape.len(),
3,
"Input must have 3 dimensions: [BATCH_SIZE, SEQ_LEN, NUM_FEATURES]"
);
let num_layers = self.weights.num_layers() as i64;
let batch_size = shape[0] as i64;
let zeros = Tensor::zeros(
&[num_layers, batch_size, self.hidden_size as i64],
(inputs.kind(), inputs.device()),
);
let initial_state = [zeros.shallow_clone(), zeros];
seq_serial_map(inputs, seq_lengths, |seq_input| {
let (seq_output, _, _) = seq_input.lstm(
&initial_state,
self.weights.flat_weights(),
self.weights.has_biases,
num_layers,
self.dropout,
true, false, true, );
seq_output
})
}
}
impl SeqPacked for Lstm {
fn seq_packed(&self, inputs: &PackedTensor) -> PackedTensor {
let initial_batch_size = match inputs.first_batch_size() {
Some(size) => size,
None => {
return inputs.clone();
}
};
let num_layers: i64 = self.weights.num_layers() as i64;
let zeros = Tensor::zeros(
&[num_layers, initial_batch_size, self.hidden_size as i64],
(inputs.kind(), inputs.device()),
);
let initial_state = [zeros.shallow_clone(), zeros];
let batch_sizes = inputs.batch_sizes_tensor();
let lstm = |input_tensor| {
let (outputs, _, _) = Tensor::lstm_data(
input_tensor,
&batch_sizes,
&initial_state,
self.weights.flat_weights(),
self.weights.has_biases,
num_layers,
self.dropout,
true, false, );
outputs
};
inputs.batch_map_ref(lstm)
}
}
#[cfg(test)]
#[allow(
clippy::needless_pass_by_value,
clippy::used_underscore_binding,
clippy::no_effect_underscore_binding
)]
mod tests {
use super::super::super::super::testing::{
self, RunIterStep, RunModule, RunSeqPacked, RunSeqSerial,
};
use super::super::super::super::Module;
use super::*;
use rstest::{fixture, rstest};
use tch::Device;
#[fixture]
fn lstm() -> (Lstm, usize, usize) {
let in_dim: usize = 3;
let out_dim: usize = 2;
let lstm = Lstm::new(in_dim, out_dim, Device::Cpu, &LstmConfig::default());
(lstm, in_dim, out_dim)
}
#[rstest]
fn seq_serial(lstm: (Lstm, usize, usize)) {
let (lstm, in_dim, out_dim) = lstm;
testing::check_seq_serial(&lstm, in_dim, out_dim);
}
#[rstest]
fn seq_packed(lstm: (Lstm, usize, usize)) {
let (lstm, in_dim, out_dim) = lstm;
testing::check_seq_packed(&lstm, in_dim, out_dim);
}
#[rstest]
fn seq_step(lstm: (Lstm, usize, usize)) {
let (lstm, in_dim, out_dim) = lstm;
testing::check_step(&lstm, in_dim, out_dim);
}
#[rstest]
fn seq_packed_matches_iter_steps(lstm: (Lstm, usize, usize)) {
let (lstm, in_dim, out_dim) = lstm;
testing::check_seq_packed_matches_iter_steps(&lstm, in_dim, out_dim);
}
#[test]
fn seq_packed_matches_iter_steps_2layers() {
let in_dim: usize = 3;
let out_dim: usize = 2;
let config = LstmConfig {
num_layers: 2,
..LstmConfig::default()
};
let lstm = Lstm::new(in_dim, out_dim, Device::Cpu, &config);
testing::check_seq_packed_matches_iter_steps(&lstm, in_dim, out_dim);
}
#[test]
fn seq_packed_matches_iter_steps_nobias() {
let in_dim: usize = 3;
let out_dim: usize = 2;
let config = LstmConfig {
bias_init: None,
..LstmConfig::default()
};
let lstm = Lstm::new(in_dim, out_dim, Device::Cpu, &config);
testing::check_seq_packed_matches_iter_steps(&lstm, in_dim, out_dim);
}
#[rstest]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn gradient_descent<R: RunModule<Lstm>>(#[case] _runner: R) {
testing::check_config_gradient_descent::<R, _>(&LstmConfig::default());
}
#[rstest]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn clone_to_new_device<R: RunModule<Lstm>>(#[case] _runner: R) {
testing::check_config_clone_to_new_device::<R, _>(&LstmConfig::default());
}
#[test]
fn clone_to_same_device() {
testing::check_config_clone_to_same_device::<RunSeqPacked, _>(&LstmConfig::default());
}
#[rstest]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn ser_de_matches<R: RunModule<Lstm>>(#[case] _runner: R, lstm: (Lstm, usize, usize)) {
let (module, in_dim, _) = lstm;
testing::check_ser_de_matches::<R, _>(&module, in_dim);
}
#[rstest]
fn variables_count(lstm: (Lstm, usize, usize)) {
let (lstm, _, _) = lstm;
assert_eq!(lstm.variables().count(), 4);
}
#[rstest]
fn trainable_variables_count(lstm: (Lstm, usize, usize)) {
let (lstm, _, _) = lstm;
assert_eq!(lstm.trainable_variables().count(), 4);
}
}