use std::cell::RefCell;
use crate::autograd::Variable;
use crate::tensor::{Device, DType, Result, RnnParams, Tensor, TensorOptions};
use super::lstmcell::LSTMCell;
use super::parameter::Parameter;
use super::Module;
pub struct LSTM {
cells: Vec<LSTMCell>,
hidden_size: i64,
num_layers: usize,
batch_first: bool,
rnn_params: RefCell<Option<RnnParams>>,
}
impl LSTM {
pub fn new(input_size: i64, hidden_size: i64, num_layers: usize) -> Result<Self> {
Self::on_device(input_size, hidden_size, num_layers, false, Device::CPU)
}
pub fn on_device(
input_size: i64,
hidden_size: i64,
num_layers: usize,
batch_first: bool,
device: Device,
) -> Result<Self> {
assert!(num_layers >= 1, "LSTM requires at least 1 layer");
let mut cells = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let in_size = if layer == 0 { input_size } else { hidden_size };
cells.push(LSTMCell::on_device(in_size, hidden_size, device)?);
}
Ok(LSTM {
cells,
hidden_size,
num_layers,
batch_first,
rnn_params: RefCell::new(None),
})
}
pub fn batch_first(mut self, batch_first: bool) -> Self {
self.batch_first = batch_first;
self
}
pub fn forward_seq(
&self,
input: &Variable,
state_0: Option<(&Variable, &Variable)>,
) -> Result<(Variable, (Variable, Variable))> {
let shape = input.shape();
let batch = if self.batch_first { shape[0] } else { shape[1] };
let nl = self.num_layers as i64;
let hs = self.hidden_size;
let opts = TensorOptions {
dtype: DType::Float32,
device: self.cells[0].parameters()[0].variable.device(),
};
let (h_0, c_0) = match state_0 {
Some((h, c)) => (h.data(), c.data()),
None => (
Tensor::zeros(&[nl, batch, hs], opts)?,
Tensor::zeros(&[nl, batch, hs], opts)?,
),
};
{
let mut cache = self.rnn_params.borrow_mut();
if cache.is_none() {
let params: Vec<Tensor> = self.cells.iter()
.flat_map(|cell| cell.parameters().into_iter().map(|p| p.variable.data()))
.collect();
*cache = Some(RnnParams::new(¶ms, 2, nl, self.batch_first, true)?);
}
}
let cache = self.rnn_params.borrow();
let (output, h_n, c_n) = input.data().lstm_seq_cached(
&h_0, &c_0, cache.as_ref().unwrap(), nl, self.batch_first,
)?;
Ok((Variable::wrap(output), (Variable::wrap(h_n), Variable::wrap(c_n))))
}
}
impl Module for LSTM {
fn name(&self) -> &str { "lstm" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let (output, _) = self.forward_seq(input, None)?;
Ok(output)
}
fn parameters(&self) -> Vec<Parameter> {
self.cells.iter().flat_map(|c| c.parameters()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lstm_shapes() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let lstm = LSTM::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let (output, (h_n, c_n)) = lstm.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![5, 3, 8]); assert_eq!(h_n.shape(), vec![2, 3, 8]); assert_eq!(c_n.shape(), vec![2, 3, 8]); }
#[test]
fn test_lstm_batch_first() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let lstm = LSTM::on_device(4, 8, 2, true, dev).unwrap();
let x = Variable::new(Tensor::randn(&[3, 5, 4], opts).unwrap(), false);
let (output, (h_n, c_n)) = lstm.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![3, 5, 8]); assert_eq!(h_n.shape(), vec![2, 3, 8]);
assert_eq!(c_n.shape(), vec![2, 3, 8]);
}
#[test]
fn test_lstm_with_initial_state() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let lstm = LSTM::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let h0 = Variable::new(Tensor::randn(&[2, 3, 8], opts).unwrap(), false);
let c0 = Variable::new(Tensor::randn(&[2, 3, 8], opts).unwrap(), false);
let (output, (h_n, c_n)) = lstm.forward_seq(&x, Some((&h0, &c0))).unwrap();
assert_eq!(output.shape(), vec![5, 3, 8]);
assert_eq!(h_n.shape(), vec![2, 3, 8]);
assert_eq!(c_n.shape(), vec![2, 3, 8]);
}
#[test]
fn test_lstm_single_layer() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let lstm = LSTM::on_device(4, 8, 1, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let (output, (h_n, c_n)) = lstm.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![5, 3, 8]);
assert_eq!(h_n.shape(), vec![1, 3, 8]);
assert_eq!(c_n.shape(), vec![1, 3, 8]);
}
#[test]
fn test_lstm_gradient() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let lstm = LSTM::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), true);
let (output, _) = lstm.forward_seq(&x, None).unwrap();
let loss = output.sum().unwrap();
loss.backward().unwrap();
for p in lstm.parameters() {
assert!(p.variable.grad().is_some(), "missing grad for {}", p.name);
}
assert!(x.grad().is_some());
}
#[test]
fn test_lstm_module_forward() {
let dev = crate::tensor::test_device();
let opts = crate::tensor::test_opts();
let lstm = LSTM::on_device(4, 8, 2, false, dev).unwrap();
let x = Variable::new(Tensor::randn(&[5, 3, 4], opts).unwrap(), false);
let y = lstm.forward(&x).unwrap();
assert_eq!(y.shape(), vec![5, 3, 8]);
}
#[test]
fn test_lstm_parameters_count() {
let dev = crate::tensor::test_device();
let lstm = LSTM::on_device(4, 8, 2, false, dev).unwrap();
assert_eq!(lstm.parameters().len(), 8);
}
#[test]
fn test_lstm_builder_pattern() {
let dev = crate::tensor::test_device();
let lstm = LSTM::on_device(4, 8, 1, false, dev).unwrap().batch_first(true);
let opts = crate::tensor::test_opts();
let x = Variable::new(Tensor::randn(&[3, 5, 4], opts).unwrap(), false);
let (output, _) = lstm.forward_seq(&x, None).unwrap();
assert_eq!(output.shape(), vec![3, 5, 8]); }
}