use crate::{Device, Kind, Tensor};
pub trait RNN {
type State;
fn zero_state(&self, batch_dim: i64) -> Self::State;
fn step(&self, input: &Tensor, state: &Self::State) -> Self::State;
fn seq(&self, input: &Tensor) -> (Tensor, Self::State) {
let batch_dim = input.size()[0];
let state = self.zero_state(batch_dim);
self.seq_init(input, &state)
}
fn seq_init(&self, input: &Tensor, state: &Self::State) -> (Tensor, Self::State);
}
#[derive(Debug)]
pub struct LSTMState(pub (Tensor, Tensor));
impl LSTMState {
pub fn h(&self) -> Tensor {
(self.0).0.shallow_clone()
}
pub fn c(&self) -> Tensor {
(self.0).1.shallow_clone()
}
}
#[derive(Debug, Clone, Copy)]
pub struct RNNConfig {
pub has_biases: bool,
pub num_layers: i64,
pub dropout: f64,
pub train: bool,
pub bidirectional: bool,
pub batch_first: bool,
}
impl Default for RNNConfig {
fn default() -> Self {
RNNConfig {
has_biases: true,
num_layers: 1,
dropout: 0.,
train: true,
bidirectional: false,
batch_first: true,
}
}
}
#[derive(Debug)]
pub struct LSTM {
flat_weights: Vec<Tensor>,
hidden_dim: i64,
config: RNNConfig,
device: Device,
}
pub fn lstm(vs: &super::var_store::Path, in_dim: i64, hidden_dim: i64, c: RNNConfig) -> LSTM {
let num_directions = if c.bidirectional { 2 } else { 1 };
let gate_dim = 4 * hidden_dim;
let mut flat_weights = vec![];
for layer_idx in 0..c.num_layers {
for direction_idx in 0..num_directions {
let in_dim = if layer_idx == 0 {
in_dim
} else {
hidden_dim * num_directions
};
let suffix = if direction_idx == 1 { "_reverse" } else { "" };
let w_ih = vs.kaiming_uniform(
&format!("weight_ih_l{}{}", layer_idx, suffix),
&[gate_dim, in_dim],
);
let w_hh = vs.kaiming_uniform(
&format!("weight_hh_l{}{}", layer_idx, suffix),
&[gate_dim, hidden_dim],
);
flat_weights.push(w_ih);
flat_weights.push(w_hh);
if c.has_biases {
let b_ih = vs.zeros(&format!("bias_ih_l{}{}", layer_idx, suffix), &[gate_dim]);
let b_hh = vs.zeros(&format!("bias_hh_l{}{}", layer_idx, suffix), &[gate_dim]);
flat_weights.push(b_ih);
flat_weights.push(b_hh);
}
}
}
if vs.device().is_cuda() && crate::Cuda::cudnn_is_available() {
let _ = Tensor::internal_cudnn_rnn_flatten_weight(
&flat_weights,
4,
in_dim,
2,
hidden_dim,
0,
c.num_layers,
c.batch_first,
c.bidirectional,
);
}
LSTM {
flat_weights,
hidden_dim,
config: c,
device: vs.device(),
}
}
impl RNN for LSTM {
type State = LSTMState;
fn zero_state(&self, batch_dim: i64) -> LSTMState {
let num_directions = if self.config.bidirectional { 2 } else { 1 };
let layer_dim = self.config.num_layers * num_directions;
let shape = [layer_dim, batch_dim, self.hidden_dim];
let zeros = Tensor::zeros(&shape, (Kind::Float, self.device));
LSTMState((zeros.shallow_clone(), zeros.shallow_clone()))
}
fn step(&self, input: &Tensor, in_state: &LSTMState) -> LSTMState {
let input = input.unsqueeze(1);
let (_output, state) = self.seq_init(&input, in_state);
state
}
fn seq_init(&self, input: &Tensor, in_state: &LSTMState) -> (Tensor, LSTMState) {
let LSTMState((h, c)) = in_state;
let flat_weights = self.flat_weights.iter().collect::<Vec<_>>();
let (output, h, c) = input.lstm(
&[h, c],
&flat_weights,
self.config.has_biases,
self.config.num_layers,
self.config.dropout,
self.config.train,
self.config.bidirectional,
self.config.batch_first,
);
(output, LSTMState((h, c)))
}
}
#[derive(Debug)]
pub struct GRUState(pub Tensor);
impl GRUState {
pub fn value(&self) -> Tensor {
self.0.shallow_clone()
}
}
#[derive(Debug)]
pub struct GRU {
flat_weights: Vec<Tensor>,
hidden_dim: i64,
config: RNNConfig,
device: Device,
}
pub fn gru(vs: &super::var_store::Path, in_dim: i64, hidden_dim: i64, c: RNNConfig) -> GRU {
let num_directions = if c.bidirectional { 2 } else { 1 };
let gate_dim = 3 * hidden_dim;
let mut flat_weights = vec![];
for layer_idx in 0..c.num_layers {
for direction_idx in 0..num_directions {
let in_dim = if layer_idx == 0 {
in_dim
} else {
hidden_dim * num_directions
};
let suffix = if direction_idx == 1 { "_reverse" } else { "" };
let w_ih = vs.kaiming_uniform(
&format!("weight_ih_l{}{}", layer_idx, suffix),
&[gate_dim, in_dim],
);
let w_hh = vs.kaiming_uniform(
&format!("weight_hh_l{}{}", layer_idx, suffix),
&[gate_dim, hidden_dim],
);
flat_weights.push(w_ih);
flat_weights.push(w_hh);
if c.has_biases {
let b_ih = vs.zeros(&format!("bias_ih_l{}{}", layer_idx, suffix), &[gate_dim]);
let b_hh = vs.zeros(&format!("bias_hh_l{}{}", layer_idx, suffix), &[gate_dim]);
flat_weights.push(b_ih);
flat_weights.push(b_hh);
}
}
}
if vs.device().is_cuda() && crate::Cuda::cudnn_is_available() {
let _ = Tensor::internal_cudnn_rnn_flatten_weight(
&flat_weights,
4,
in_dim,
3,
hidden_dim,
0,
c.num_layers,
c.batch_first,
c.bidirectional,
);
}
GRU {
flat_weights,
hidden_dim,
config: c,
device: vs.device(),
}
}
impl RNN for GRU {
type State = GRUState;
fn zero_state(&self, batch_dim: i64) -> GRUState {
let num_directions = if self.config.bidirectional { 2 } else { 1 };
let layer_dim = self.config.num_layers * num_directions;
let shape = [layer_dim, batch_dim, self.hidden_dim];
GRUState(Tensor::zeros(&shape, (Kind::Float, self.device)))
}
fn step(&self, input: &Tensor, in_state: &GRUState) -> GRUState {
let input = input.unsqueeze(1);
let (_output, state) = self.seq_init(&input, in_state);
state
}
fn seq_init(&self, input: &Tensor, in_state: &GRUState) -> (Tensor, GRUState) {
let GRUState(h) = in_state;
let (output, h) = input.gru(
h,
&self.flat_weights,
self.config.has_biases,
self.config.num_layers,
self.config.dropout,
self.config.train,
self.config.bidirectional,
self.config.batch_first,
);
(output, GRUState(h))
}
}