use std::iter::Rev;
use std::ops::Range;
use rten_gemm::{GemmExecutor, GemmInputA, GemmInputB, GemmOptions};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor, TensorView};
use crate::buffer_pool::{AutoReturn, BufferPool};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext, static_dims,
};
use crate::ops::binary_elementwise::{add_in_place, mul_in_place};
use crate::ops::unary_elementwise::{sigmoid, tanh};
use crate::value::{DataType, ValueType};
#[derive(Copy, Clone, Debug)]
pub enum Direction {
Forward,
Reverse,
Bidirectional,
}
impl Direction {
pub fn num_directions(self) -> usize {
match self {
Self::Forward | Self::Reverse => 1,
Self::Bidirectional => 2,
}
}
}
enum Sequence {
Forward(Range<usize>),
Backward(Rev<Range<usize>>),
}
impl Iterator for Sequence {
type Item = usize;
fn next(&mut self) -> Option<usize> {
match self {
Sequence::Forward(range) => range.next(),
Sequence::Backward(rev_range) => rev_range.next(),
}
}
}
fn sequence_for_dir(op_dirs: Direction, dir: usize, seq_len: usize) -> Sequence {
let reversed = matches!(
(dir, op_dirs),
(0, Direction::Reverse) | (1, Direction::Bidirectional)
);
if reversed {
Sequence::Backward((0..seq_len).rev())
} else {
Sequence::Forward(0..seq_len)
}
}
fn zip3<T1, T2, T3>(
a: impl Iterator<Item = T1>,
b: impl Iterator<Item = T2>,
c: impl Iterator<Item = T3>,
) -> impl Iterator<Item = (T1, T2, T3)> {
a.zip(b.zip(c)).map(|(a, (b, c))| (a, b, c))
}
fn zip4<T1, T2, T3, T4>(
a: impl Iterator<Item = T1>,
b: impl Iterator<Item = T2>,
c: impl Iterator<Item = T3>,
d: impl Iterator<Item = T4>,
) -> impl Iterator<Item = (T1, T2, T3, T4)> {
zip3(a, b, c.zip(d)).map(|(a, b, (c, d))| (a, b, c, d))
}
const PREPACK_MIN_SEQ_LEN: usize = 5;
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
pub struct GRU {
pub direction: Direction,
#[allow(unused)] pub hidden_size: usize,
pub linear_before_reset: bool,
}
pub fn gru(
pool: &BufferPool,
direction: Direction,
input: TensorView,
weights: TensorView,
recurrent_weights: TensorView,
bias: Option<TensorView>,
initial_hidden: Option<TensorView>,
linear_before_reset: bool,
) -> Result<Vec<Tensor>, OpError> {
if !linear_before_reset {
return Err(OpError::UnsupportedValue(
"`linear_before_reset=0` is not supported",
));
}
let input = static_dims!(input, 3, "seq, batch, input")?;
let weights = static_dims!(weights, 3, "dir, hidden x 3, input")?;
let recurrent_weights = static_dims!(recurrent_weights, 3)?;
let bias = bias
.map(|bias| static_dims!(bias, 2, "dir, hidden x 6"))
.transpose()?;
let [seq_len, batch, _input_size] = input.shape();
let [_directions, hidden_x3, _input_size] = weights.shape();
let initial_hidden = initial_hidden
.map(|initial_hidden| static_dims!(initial_hidden, 3))
.transpose()?;
let num_directions = direction.num_directions();
let hidden_size = hidden_x3 / 3;
let mut hidden = initial_hidden
.map(|t| t.to_tensor_in(pool))
.unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size]));
let mut hidden_seq = NdTensor::zeros_in(pool, [seq_len, num_directions, batch, hidden_size]);
const UPDATE_GATE: usize = 0;
const RESET_GATE: usize = 1;
const HIDDEN_GATE: usize = 2;
let n_gates = 3;
let mut gates = NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]).auto_return(pool);
let gate_range = |gate| (gate * hidden_size)..((gate + 1) * hidden_size);
let mut hidden_scratch =
NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]).auto_return(pool);
let gemm = GemmExecutor::new();
for dir in 0..num_directions {
let prepack = seq_len >= PREPACK_MIN_SEQ_LEN;
let input_weights = weights.slice(dir).transposed();
let packed_input_weights =
prepack.then(|| gemm.prepack_b_in(pool, input_weights).auto_return(pool));
let input_weights = packed_input_weights
.as_ref()
.map(|packed| GemmInputB::Packed(packed))
.unwrap_or(GemmInputB::Unpacked(input_weights));
let hidden_weights = recurrent_weights.slice(dir).transposed();
let packed_hidden_weights =
prepack.then(|| gemm.prepack_b_in(pool, hidden_weights).auto_return(pool));
let hidden_weights = packed_hidden_weights
.as_ref()
.map(|packed| GemmInputB::Packed(packed))
.unwrap_or(GemmInputB::Unpacked(hidden_weights));
let input_bias = bias
.as_ref()
.map(|b| b.slice((dir, ..(n_gates * hidden_size))));
let hidden_bias = bias
.as_ref()
.map(|b| b.slice((dir, (n_gates * hidden_size)..)));
for seq in sequence_for_dir(direction, dir, seq_len) {
let in_item = input.slice([seq]);
let hidden_item = hidden.slice([dir]);
gemm.gemm(
gates.data_mut().expect("expected contiguous input"),
GemmInputA::Unpacked(in_item),
input_weights,
GemmOptions::default(),
)
.unwrap();
if let Some(input_bias) = input_bias {
add_in_place(gates.as_dyn_mut(), input_bias.as_dyn());
}
gemm.gemm(
hidden_scratch.data_mut().unwrap(),
GemmInputA::Unpacked(hidden_item),
hidden_weights,
GemmOptions::default(),
)
.unwrap();
if let Some(hidden_bias) = hidden_bias {
add_in_place(hidden_scratch.as_dyn_mut(), hidden_bias.as_dyn());
}
let mut update_reset_gates = gates.slice_mut((
..,
gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end,
));
let hidden_scratch_reset_update_gates = hidden_scratch.slice((
..,
gate_range(UPDATE_GATE).start..gate_range(RESET_GATE).end,
));
add_in_place(
update_reset_gates.as_dyn_mut(),
hidden_scratch_reset_update_gates.as_dyn(),
);
let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool);
let update_reset_gates = update_reset_gates.nd_view::<2>();
let update_gate = update_reset_gates.slice((.., gate_range(UPDATE_GATE)));
let reset_gate = update_reset_gates.slice((.., gate_range(RESET_GATE)));
let mut hidden_gate_recurrent = hidden_scratch.slice_mut((.., gate_range(HIDDEN_GATE)));
mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn());
let mut hidden_gate = gates.slice_mut((.., gate_range(HIDDEN_GATE)));
add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn());
let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool);
let mut hidden_item = hidden.slice_mut([dir]);
for (hidden, update, hidden_gate) in zip3(
hidden_item.iter_mut(),
update_gate.iter(),
hidden_gate.iter(),
) {
*hidden = (1. - update) * hidden_gate + update * (*hidden);
}
hidden_seq.slice_mut([seq, dir]).copy_from(&hidden_item);
}
}
Ok([hidden_seq.into_dyn(), hidden.into_dyn()].into())
}
impl Operator for GRU {
fn name(&self) -> &str {
"GRU"
}
fn max_inputs(&self) -> Option<usize> {
Some(6)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require_as(0)?;
let weights = inputs.require_as(1)?;
let recurrent_weights = inputs.require_as(2)?;
let bias = inputs.get_as(3)?;
let _seq_len = inputs.get_as::<TensorView<i32>>(4)?;
let initial_hidden = inputs.get_as(5)?;
gru(
ctx.pool(),
self.direction,
input,
weights,
recurrent_weights,
bias,
initial_hidden,
self.linear_before_reset,
)
.into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some(OutputTypeList::from_slice(&[
OutputType::Fixed(ValueType::Tensor(DataType::Float)),
OutputType::Fixed(ValueType::Tensor(DataType::Float)),
]))
}
}
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
pub struct LSTM {
pub direction: Direction,
#[allow(unused)]
pub hidden_size: usize, }
pub fn lstm(
pool: &BufferPool,
direction: Direction,
input: TensorView,
weights: TensorView,
recurrent_weights: TensorView,
bias: Option<TensorView>,
initial_hidden: Option<TensorView>,
initial_cell: Option<TensorView>,
) -> Result<Vec<Tensor>, OpError> {
let input = static_dims!(input, 3, "seq, batch, input")?;
let [seq_len, batch, _input_size] = input.shape();
let weights = static_dims!(weights, 3, "dir, hidden x 4, input")?;
let [_directions, hidden_x4, _input_size] = weights.shape();
let recurrent_weights = static_dims!(recurrent_weights, 3, "dir, hidden x 4, hidden")?;
let num_directions = direction.num_directions();
let hidden_size = hidden_x4 / 4;
if weights.size(1) % 4 != 0 {
return Err(OpError::InvalidValue(
"weights dim 1 must be 4 * hidden_size",
));
}
let bias = bias.map(|bias| static_dims!(bias, 2)).transpose()?;
if let Some(bias) = bias.as_ref()
&& bias.size(1) % 8 != 0
{
return Err(OpError::InvalidValue("bias dim 1 must be 8 * hidden_size"));
}
let initial_hidden = initial_hidden
.map(|initial_hidden| static_dims!(initial_hidden, 3))
.transpose()?;
let initial_cell = initial_cell
.map(|initial_cell| static_dims!(initial_cell, 3))
.transpose()?;
let input = input.to_contiguous_in(pool).auto_return(pool);
let bias = bias.map(|t| t.to_contiguous());
const INPUT_GATE: usize = 0;
const OUTPUT_GATE: usize = 1;
const FORGET_GATE: usize = 2;
const CELL_GATE: usize = 3;
let n_gates = 4;
let mut gates = NdTensor::zeros_in(pool, [batch, n_gates * hidden_size]);
let mut cell = initial_cell
.map(|t| t.to_tensor_in(pool))
.unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size]));
let mut hidden = initial_hidden
.map(|t| t.to_tensor_in(pool))
.unwrap_or_else(|| NdTensor::zeros_in(pool, [num_directions, batch, hidden_size]));
let mut hidden_seq =
NdTensor::<f32, 4>::zeros_in(pool, [seq_len, num_directions, batch, hidden_size]);
let gemm = GemmExecutor::new();
let gate_range = |gate| (gate * hidden_size)..((gate + 1) * hidden_size);
for dir in 0..num_directions {
let prepack = seq_len >= PREPACK_MIN_SEQ_LEN;
let input_weights = weights.slice(dir).transposed();
let packed_input_weights =
prepack.then(|| gemm.prepack_b_in(pool, input_weights).auto_return(pool));
let input_weights = packed_input_weights
.as_ref()
.map(|packed| GemmInputB::Packed(packed))
.unwrap_or(GemmInputB::Unpacked(input_weights));
let hidden_weights = recurrent_weights.slice(dir).transposed();
let packed_hidden_weights =
prepack.then(|| gemm.prepack_b_in(pool, hidden_weights).auto_return(pool));
let hidden_weights = packed_hidden_weights
.as_ref()
.map(|packed| GemmInputB::Packed(packed))
.unwrap_or(GemmInputB::Unpacked(hidden_weights));
let input_bias = bias
.as_ref()
.map(|b| b.slice((dir, ..(n_gates * hidden_size))));
let hidden_bias = bias
.as_ref()
.map(|b| b.slice((dir, (n_gates * hidden_size)..)));
for seq in sequence_for_dir(direction, dir, seq_len) {
let in_item = input.slice([seq]);
let hidden_item = hidden.slice([dir]);
gemm.gemm(
gates.data_mut().expect("expected contiguous input"),
GemmInputA::Unpacked(in_item),
input_weights,
GemmOptions::default(),
)
.unwrap();
if let Some(input_bias) = input_bias {
add_in_place(gates.as_dyn_mut(), input_bias.as_dyn());
}
gemm.gemm(
gates.data_mut().expect("expected contiguous input"),
GemmInputA::Unpacked(hidden_item),
hidden_weights,
GemmOptions {
beta: 1.,
..Default::default()
},
)
.unwrap();
if let Some(hidden_bias) = hidden_bias {
add_in_place(gates.as_dyn_mut(), hidden_bias.as_dyn());
}
let iof_gates = gates.slice((
..,
gate_range(INPUT_GATE).start..gate_range(FORGET_GATE).end,
));
let iof_gates = sigmoid(pool, iof_gates.as_dyn()).auto_return(pool);
let iof_gates = iof_gates.nd_view::<2>();
let input_gate = iof_gates.slice((.., gate_range(INPUT_GATE)));
let out_gate = iof_gates.slice((.., gate_range(OUTPUT_GATE)));
let forget_gate = iof_gates.slice((.., gate_range(FORGET_GATE)));
let cell_gate = gates.slice((.., gate_range(CELL_GATE)));
let cell_gate = tanh(pool, cell_gate.as_dyn()).auto_return(pool);
let mut cell_item = cell.slice_mut([dir]);
for (cell, forget_gate, input_gate, cell_gate) in zip4(
cell_item.iter_mut(),
forget_gate.iter(),
input_gate.iter(),
cell_gate.iter(),
) {
*cell = forget_gate * *cell + input_gate * cell_gate;
}
let mut hidden_item = hidden.slice_mut([dir]);
for (hidden, out_gate, cell) in
zip3(hidden_item.iter_mut(), out_gate.iter(), cell_item.iter())
{
*hidden = out_gate * cell.tanh()
}
hidden_seq.slice_mut([seq, dir]).copy_from(&hidden_item);
}
}
Ok([hidden_seq.into_dyn(), hidden.into_dyn(), cell.into_dyn()].into())
}
impl Operator for LSTM {
fn name(&self) -> &str {
"LSTM"
}
fn max_inputs(&self) -> Option<usize> {
Some(7)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require_as(0)?;
let weights = inputs.require_as(1)?;
let recurrent_weights = inputs.require_as(2)?;
let bias = inputs.get_as(3)?;
let _seq_len = inputs.get_as::<TensorView<i32>>(4)?;
let initial_hidden = inputs.get_as(5)?;
let initial_cell = inputs.get_as(6)?;
lstm(
ctx.pool(),
self.direction,
input,
weights,
recurrent_weights,
bias,
initial_hidden,
initial_cell,
)
.into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some(OutputTypeList::from_slice(&[
OutputType::Fixed(ValueType::Tensor(DataType::Float)),
OutputType::Fixed(ValueType::Tensor(DataType::Float)),
OutputType::Fixed(ValueType::Tensor(DataType::Float)),
]))
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::BufReader;
use rten_tensor::prelude::*;
use rten_tensor::rng::XorShiftRng;
use rten_tensor::test_util::expect_equal;
use rten_tensor::{NdTensor, Tensor};
use rten_testing::TestCases;
use serde_json::Value;
use crate::buffer_pool::BufferPool;
use crate::ops::{Direction, concat, gru, lstm, split};
pub fn read_tensor(val: &Value) -> Result<Tensor<f32>, &'static str> {
let vec = match val {
Value::Array(vec) => vec,
_ => return Err("Expected array"),
};
let (shape, data) = match vec.as_slice() {
[Value::Array(shape), Value::Array(data)] => (shape, data),
_ => return Err("Expected [shape, data] array"),
};
let shape = shape
.iter()
.map(|v| v.as_i64().map(|v| v as usize).ok_or("Expected int array"))
.collect::<Result<Vec<usize>, _>>()?;
let data = data
.iter()
.map(|v| v.as_f64().map(|v| v as f32).ok_or("Expected float array"))
.collect::<Result<Vec<f32>, _>>()?;
Ok(Tensor::from_data(&shape, data))
}
pub fn read_json_file(path: &str) -> Value {
let file = File::open(path).unwrap();
let reader = BufReader::new(file);
serde_json::from_reader(reader).unwrap()
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum Op {
Gru,
Lstm,
}
#[test]
fn test_rnn_ops_with_random_input() {
let batch = 2;
let seq_len = 5;
let dir = Direction::Bidirectional;
let hidden_size = 3;
let features = 2;
#[derive(Clone, Debug)]
struct Case {
op: Op,
with_bias: bool,
with_hidden_init: bool,
with_initial_cell: bool,
}
let cases = [
Case {
op: Op::Lstm,
with_bias: true,
with_hidden_init: true,
with_initial_cell: true,
},
Case {
op: Op::Lstm,
with_bias: false,
with_hidden_init: false,
with_initial_cell: false,
},
Case {
op: Op::Gru,
with_bias: true,
with_hidden_init: true,
with_initial_cell: false,
},
Case {
op: Op::Gru,
with_bias: false,
with_hidden_init: false,
with_initial_cell: false,
},
];
cases.test_each_clone(|case| {
let mut rng = XorShiftRng::new(1234);
let pool = BufferPool::new();
let num_gates = match case.op {
Op::Gru => 3,
Op::Lstm => 4,
};
let input =
NdTensor::<f32, 3>::rand([seq_len, batch, features], &mut rng).map(|x| x - 0.5);
let weights = NdTensor::<f32, 3>::rand(
[dir.num_directions(), num_gates * hidden_size, features],
&mut rng,
)
.map(|x| x - 0.5);
let recurrent_weights = NdTensor::<f32, 3>::rand(
[dir.num_directions(), num_gates * hidden_size, hidden_size],
&mut rng,
)
.map(|x| x - 0.5);
let bias = NdTensor::rand(
[dir.num_directions(), 2 * num_gates * hidden_size],
&mut rng,
);
let initial_hidden =
NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng);
let initial_cell = NdTensor::rand([dir.num_directions(), batch, hidden_size], &mut rng);
let result = match case.op {
Op::Lstm => lstm(
&pool,
dir,
input.as_dyn(),
weights.as_dyn(),
recurrent_weights.as_dyn(),
case.with_bias.then_some(bias.as_dyn()),
case.with_hidden_init.then_some(initial_hidden.as_dyn()),
case.with_initial_cell.then_some(initial_cell.as_dyn()),
)
.expect("lstm op failed"),
Op::Gru => gru(
&pool,
dir,
input.as_dyn(),
weights.as_dyn(),
recurrent_weights.as_dyn(),
case.with_bias.then_some(bias.as_dyn()),
case.with_hidden_init.then_some(initial_hidden.as_dyn()),
true,
)
.expect("gru op failed"),
};
assert_eq!(
result.len(),
match case.op {
Op::Gru => 2,
Op::Lstm => 3,
}
);
let hidden_seq = &result[0];
assert_eq!(
hidden_seq.shape(),
&[seq_len, dir.num_directions(), batch, hidden_size]
);
let last_hidden = &result[1];
assert_eq!(
last_hidden.shape(),
&[dir.num_directions(), batch, hidden_size]
);
if case.op == Op::Lstm {
let last_cell = &result[2];
assert_eq!(
last_cell.shape(),
&[dir.num_directions(), batch, hidden_size]
);
}
let hidden_seq_fwd = hidden_seq.slice((
-1, 0, ));
let last_hidden_fwd = last_hidden.slice(0);
assert_eq!(hidden_seq_fwd, last_hidden_fwd);
let hidden_seq_rev = hidden_seq.slice((
0, 1, ));
let last_hidden_rev = last_hidden.slice(1);
assert_eq!(hidden_seq_rev, last_hidden_rev);
})
}
fn reorder_ifco_to_iofc(x: &Tensor, axis: isize) -> Tensor {
let pool = BufferPool::new();
let size = x.size(axis as usize) / 4;
let splits = &[size as i32; 4];
let ifco = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed");
concat(
&pool,
&[
ifco[0].view(),
ifco[3].view(),
ifco[1].view(),
ifco[2].view(),
],
axis,
)
.expect("concat failed")
}
fn reorder_ruh_to_urh(x: &Tensor, axis: isize) -> Tensor {
let pool = BufferPool::new();
let size = x.size(axis as usize) / 3;
let splits = &[size as i32; 3];
let ruh = split(&pool, x.view(), axis, splits.as_slice().into()).expect("split failed");
concat(&pool, &[ruh[1].view(), ruh[0].view(), ruh[2].view()], axis).expect("concat failed")
}
struct RNNRefTest {
input: Tensor,
expected: Tensor,
weights: Tensor,
hidden_weights: Tensor,
bias: Option<Tensor>,
initial_hidden: Option<Tensor>,
initial_cell: Option<Tensor>,
}
fn read_pytorch_ref_test(op: Op, case: &Value) -> RNNRefTest {
let pool = BufferPool::new();
let params = &case["params"];
let is_bidirectional = params.get("weight_ih_l0_reverse").is_some();
let mut input = read_tensor(&case["input"]).expect("failed to read input");
input.insert_axis(1);
let mut expected = read_tensor(&case["output"]).expect("failed to read output");
if is_bidirectional {
let es = expected.shape();
expected.reshape(&[es[0], 2, es[1] / 2]);
} else {
expected.insert_axis(1);
}
expected.insert_axis(2);
let read_param = |name| match op {
Op::Lstm => reorder_ifco_to_iofc(
&read_tensor(¶ms[name]).expect("failed to read weight"),
0,
),
Op::Gru => reorder_ruh_to_urh(
&read_tensor(¶ms[name]).expect("failed to read weight"),
0,
),
};
let mut weights = read_param("weight_ih_l0");
weights.insert_axis(0);
let mut hidden_weights = read_param("weight_hh_l0");
hidden_weights.insert_axis(0);
let input_bias = read_param("bias_ih_l0");
let hidden_bias = read_param("bias_hh_l0");
let mut bias = concat(&pool, &[input_bias.view(), hidden_bias.view()], 0).unwrap();
bias.insert_axis(0);
if is_bidirectional {
let mut rev_weights = read_param("weight_ih_l0_reverse");
rev_weights.insert_axis(0); weights = concat(&pool, &[weights.view(), rev_weights.view()], 0).unwrap();
let mut rev_hidden_weights = read_param("weight_hh_l0_reverse");
rev_hidden_weights.insert_axis(0); hidden_weights = concat(
&pool,
&[hidden_weights.view(), rev_hidden_weights.view()],
0,
)
.unwrap();
let rev_input_bias = read_param("bias_ih_l0_reverse");
let rev_hidden_bias = read_param("bias_hh_l0_reverse");
let mut rev_bias =
concat(&pool, &[rev_input_bias.view(), rev_hidden_bias.view()], 0).unwrap();
rev_bias.insert_axis(0); bias = concat(&pool, &[bias.view(), rev_bias.view()], 0).unwrap();
}
let initial_hidden = case.get("initial_hidden").map(|param| {
let mut init = read_tensor(param).expect("failed to read initial hidden state");
init.insert_axis(1); init
});
let initial_cell = case.get("initial_cell").map(|param| {
let mut init = read_tensor(param).expect("failed to read initial cell state");
init.insert_axis(1); init
});
RNNRefTest {
input,
weights,
hidden_weights,
bias: Some(bias),
expected,
initial_hidden,
initial_cell,
}
}
#[test]
fn test_rnn_pytorch() {
let dict = read_json_file("pytorch-ref-tests/rnn.json");
#[derive(Debug)]
struct Case {
name: &'static str,
dir: Direction,
}
let cases = &[
Case {
name: "lstm_forwards",
dir: Direction::Forward,
},
Case {
name: "lstm_initial",
dir: Direction::Forward,
},
Case {
name: "lstm_bidirectional",
dir: Direction::Bidirectional,
},
Case {
name: "gru_forwards",
dir: Direction::Forward,
},
Case {
name: "gru_initial",
dir: Direction::Forward,
},
Case {
name: "gru_bidirectional",
dir: Direction::Bidirectional,
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let op = if case.name.starts_with("lstm") {
Op::Lstm
} else {
Op::Gru
};
let data = read_pytorch_ref_test(op, &dict[case.name]);
let result = match op {
Op::Lstm => lstm(
&pool,
case.dir,
data.input.view(),
data.weights.view(),
data.hidden_weights.view(),
data.bias.as_ref().map(|b| b.view()),
data.initial_hidden.as_ref().map(|ih| ih.view()),
data.initial_cell.as_ref().map(|ic| ic.view()),
)
.expect("LSTM op failed"),
Op::Gru => gru(
&pool,
case.dir,
data.input.view(),
data.weights.view(),
data.hidden_weights.view(),
data.bias.as_ref().map(|b| b.view()),
data.initial_hidden.as_ref().map(|ih| ih.view()),
true,
)
.expect("GRU op failed"),
};
let output = &result[0];
expect_equal(output, &data.expected).unwrap();
})
}
}