use std::cell::RefCell;
use crate::autograd::Variable;
use crate::tensor::{Device, DType, Result, RnnParams, Tensor, TensorOptions};
use super::grucell::GRUCell;
use super::parameter::Parameter;
use super::Module;
pub struct GRU {
cells: Vec<GRUCell>,
hidden_size: i64,
num_layers: usize,
batch_first: bool,
rnn_params: RefCell<Option<RnnParams>>,
}
impl GRU {
pub fn new(input_size: i64, hidden_size: i64, num_layers: usize) -> Result<Self> {
Self::on_device(input_size, hidden_size, num_layers, false, Device::CPU)
}
pub fn on_device(
input_size: i64,
hidden_size: i64,
num_layers: usize,
batch_first: bool,
device: Device,
) -> Result<Self> {
assert!(num_layers >= 1, "GRU requires at least 1 layer");
let mut cells = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let in_size = if layer == 0 { input_size } else { hidden_size };
cells.push(GRUCell::on_device(in_size, hidden_size, device)?);
}
Ok(GRU {
cells,
hidden_size,
num_layers,
batch_first,
rnn_params: RefCell::new(None),
})
}
pub fn batch_first(mut self, batch_first: bool) -> Self {
self.batch_first = batch_first;
self
}
pub fn forward_seq(
&self,
input: &Variable,
h_0: Option<&Variable>,
) -> Result<(Variable, Variable)> {
let shape = input.shape();
let batch = if self.batch_first { shape[0] } else { shape[1] };
let nl = self.num_layers as i64;
let hs = self.hidden_size;
let opts = TensorOptions {
dtype: DType::Float32,
device: self.cells[0].parameters()[0].variable.device(),
};
let h0 = match h_0 {
Some(h) => h.data(),
None => Tensor::zeros(&[nl, batch, hs], opts)?,
};
{
let mut cache = self.rnn_params.borrow_mut();
if cache.is_none() {
let params: Vec<Tensor> = self.cells.iter()
.flat_map(|cell| cell.parameters().into_iter().map(|p| p.variable.data()))
.collect();
*cache = Some(RnnParams::new(¶ms, 3, nl, self.batch_first, true)?);
}
}
let cache = self.rnn_params.borrow();
let (output, h_n) = input.data().gru_seq_cached(
&h0, cache.as_ref().unwrap(), nl, self.batch_first,
)?;
Ok((Variable::wrap(output), Variable::wrap(h_n)))
}
}
impl Module for GRU {
fn name(&self) -> &str { "gru" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let (output, _h_n) = self.forward_seq(input, None)?;
Ok(output)
}
fn parameters(&self) -> Vec<Parameter> {
self.cells.iter().flat_map(|c| c.parameters()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gru_shapes() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let gru = GRU::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let (output, h_n) = gru.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![5, 3, 8]); assert_eq!(h_n.shape(), vec![2, 3, 8]); }
#[test]
fn test_gru_batch_first() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let gru = GRU::on_device(4, 8, 2, true, dev).unwrap();
let x = Variable::new(Tensor::randn(&[3, 5, 4], opts).unwrap(), false);
let (output, h_n) = gru.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![3, 5, 8]); assert_eq!(h_n.shape(), vec![2, 3, 8]); }
#[test]
fn test_gru_with_initial_hidden() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let gru = GRU::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let h0 = Variable::new(Tensor::randn(&[2, 3, 8], opts).unwrap(), false);
let (output, h_n) = gru.forward_seq(&x, Some(&h0)).unwrap();
assert_eq!(output.shape(), vec![5, 3, 8]);
assert_eq!(h_n.shape(), vec![2, 3, 8]);
}
#[test]
fn test_gru_single_layer() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let gru = GRU::on_device(4, 8, 1, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let (output, h_n) = gru.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![5, 3, 8]);
assert_eq!(h_n.shape(), vec![1, 3, 8]);
}
#[test]
fn test_gru_gradient() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let gru = GRU::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), true);
let (output, _h_n) = gru.forward_seq(&x, None).unwrap();
let loss = output.sum().unwrap();
loss.backward().unwrap();
for p in gru.parameters() {
assert!(p.variable.grad().is_some(), "missing grad for {}", p.name);
}
assert!(x.grad().is_some());
}
#[test]
fn test_gru_module_forward() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let gru = GRU::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let y = gru.forward(&x).unwrap();
assert_eq!(y.shape(), vec![5, 3, 8]);
}
#[test]
fn test_gru_parameters_count() {
let dev = crate::tensor::test_device();
let gru = GRU::on_device(4, 8, 2, false, dev).unwrap();
assert_eq!(gru.parameters().len(), 8);
}
#[test]
fn test_gru_builder_pattern() {
let dev = crate::tensor::test_device();
let gru = GRU::on_device(4, 8, 1, false, dev).unwrap().batch_first(true);
let opts = crate::tensor::test_opts();
let x = Variable::new(Tensor::randn(&[3, 5, 4], opts).unwrap(), false);
let (output, _) = gru.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![3, 5, 8]); }
}