use super::super::super::{SeqPacked, SeqSerial};
use super::super::seq_serial_map;
use super::{RnnBase, RnnBaseConfig, RnnImpl, RnnLayerWeights};
use crate::torch::packed::PackedTensor;
use serde::{Deserialize, Serialize};
use tch::{Kind, Tensor};
pub type GruConfig = RnnBaseConfig<GruImpl>;
pub type Gru = RnnBase<GruImpl>;
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct GruImpl;
impl RnnImpl for GruImpl {
type CellState = Tensor;
const CUDNN_MODE: u32 = 3;
const GATES_MULTIPLE: usize = 3;
fn initial_cell_state(rnn: &RnnBase<Self>, batch_size: usize) -> Self::CellState {
Tensor::zeros(
&[batch_size as i64, rnn.hidden_size as i64],
(Kind::Float, rnn.device),
)
}
fn cell_batch_step(
_: &RnnBase<Self>,
state: &mut Self::CellState,
w: &RnnLayerWeights,
batch_input: &Tensor,
) -> Tensor {
*state = batch_input.gru_cell(state, w.w_ih(), w.w_hh(), w.b_ih(), w.b_hh());
state.shallow_clone()
}
}
impl SeqSerial for Gru {
fn seq_serial(&self, inputs: &Tensor, seq_lengths: &[usize]) -> Tensor {
let shape = inputs.size();
assert_eq!(
shape.len(),
3,
"Input must have 3 dimensions: [BATCH_SIZE, SEQ_LEN, NUM_FEATURES]"
);
let batch_size: i64 = shape[0] as i64;
let num_layers: i64 = self.weights.num_layers() as i64;
let initial_state = Tensor::zeros(
&[num_layers, batch_size, self.hidden_size as i64],
(inputs.kind(), inputs.device()),
);
seq_serial_map(inputs, seq_lengths, |seq_input| {
let (seq_output, _) = seq_input.gru(
&initial_state,
self.weights.flat_weights(),
self.weights.has_biases,
num_layers,
self.dropout,
true, false, true, );
seq_output
})
}
}
impl SeqPacked for Gru {
fn seq_packed(&self, inputs: &PackedTensor) -> PackedTensor {
let initial_batch_size = match inputs.first_batch_size() {
Some(size) => size,
None => {
return inputs.clone();
}
};
let num_layers: i64 = self.weights.num_layers() as i64;
let initial_state = Tensor::zeros(
&[num_layers, initial_batch_size, self.hidden_size as i64],
(inputs.kind(), inputs.device()),
);
let batch_sizes = inputs.batch_sizes_tensor();
let gru = |input_tensor| {
let (outputs, _) = Tensor::gru_data(
input_tensor,
&batch_sizes,
&initial_state,
self.weights.flat_weights(),
self.weights.has_biases,
num_layers,
self.dropout,
true, false, );
outputs
};
inputs.batch_map_ref(gru)
}
}
#[cfg(test)]
#[allow(
clippy::needless_pass_by_value,
clippy::used_underscore_binding,
clippy::no_effect_underscore_binding
)]
mod tests {
use super::super::super::super::testing::{
self, RunIterStep, RunModule, RunSeqPacked, RunSeqSerial,
};
use super::super::super::super::Module;
use super::*;
use rstest::{fixture, rstest};
use tch::Device;
#[fixture]
fn gru() -> (Gru, usize, usize) {
let in_dim: usize = 3;
let out_dim: usize = 2;
let gru = Gru::new(in_dim, out_dim, Device::Cpu, &GruConfig::default());
(gru, in_dim, out_dim)
}
#[rstest]
fn seq_serial(gru: (Gru, usize, usize)) {
let (gru, in_dim, out_dim) = gru;
testing::check_seq_serial(&gru, in_dim, out_dim);
}
#[rstest]
fn seq_packed(gru: (Gru, usize, usize)) {
let (gru, in_dim, out_dim) = gru;
testing::check_seq_packed(&gru, in_dim, out_dim);
}
#[rstest]
fn seq_step(gru: (Gru, usize, usize)) {
let (gru, in_dim, out_dim) = gru;
testing::check_step(&gru, in_dim, out_dim);
}
#[rstest]
fn seq_packed_matches_iter_steps(gru: (Gru, usize, usize)) {
let (gru, in_dim, out_dim) = gru;
testing::check_seq_packed_matches_iter_steps(&gru, in_dim, out_dim);
}
#[test]
fn seq_packed_matches_iter_steps_2layers() {
let in_dim: usize = 3;
let out_dim: usize = 2;
let config = GruConfig {
num_layers: 2,
..GruConfig::default()
};
let gru = Gru::new(in_dim, out_dim, Device::Cpu, &config);
testing::check_seq_packed_matches_iter_steps(&gru, in_dim, out_dim);
}
#[test]
fn seq_packed_matches_iter_steps_nobias() {
let in_dim: usize = 3;
let out_dim: usize = 2;
let config = GruConfig {
bias_init: None,
..GruConfig::default()
};
let gru = Gru::new(in_dim, out_dim, Device::Cpu, &config);
testing::check_seq_packed_matches_iter_steps(&gru, in_dim, out_dim);
}
#[rstest]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn gradient_descent<R: RunModule<Gru>>(#[case] _runner: R) {
testing::check_config_gradient_descent::<R, _>(&GruConfig::default());
}
#[rstest]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn clone_to_new_device<R: RunModule<Gru>>(#[case] _runner: R) {
testing::check_config_clone_to_new_device::<R, _>(&GruConfig::default());
}
#[test]
fn clone_to_same_device() {
testing::check_config_clone_to_same_device::<RunSeqPacked, _>(&GruConfig::default());
}
#[rstest]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn ser_de_matches<R: RunModule<Gru>>(#[case] _runner: R, gru: (Gru, usize, usize)) {
let (module, in_dim, _) = gru;
testing::check_ser_de_matches::<R, _>(&module, in_dim);
}
#[rstest]
fn variables_count(gru: (Gru, usize, usize)) {
let (gru, _, _) = gru;
assert_eq!(gru.variables().count(), 4);
}
#[rstest]
fn trainable_variables_count(gru: (Gru, usize, usize)) {
let (gru, _, _) = gru;
assert_eq!(gru.trainable_variables().count(), 4);
}
}