use burn::nn::{
Dropout, DropoutConfig, Linear, LinearConfig, Lstm, LstmConfig,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RNNType {
LSTM,
GRU,
}
impl Default for RNNType {
fn default() -> Self {
Self::LSTM
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RNNPlusConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub hidden_size: usize,
pub n_layers: usize,
pub rnn_type: RNNType,
pub bidirectional: bool,
pub dropout: f64,
}
impl Default for RNNPlusConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
hidden_size: 128,
n_layers: 2,
rnn_type: RNNType::LSTM,
bidirectional: false,
dropout: 0.1,
}
}
}
impl RNNPlusConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
fn output_dim(&self) -> usize {
if self.bidirectional {
self.hidden_size * 2
} else {
self.hidden_size
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> RNNPlus<B> {
RNNPlus::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct RNNPlus<B: Backend> {
lstm: Lstm<B>,
dropout: Dropout,
fc: Linear<B>,
}
impl<B: Backend> RNNPlus<B> {
pub fn new(config: RNNPlusConfig, device: &B::Device) -> Self {
let lstm = LstmConfig::new(config.n_vars, config.hidden_size, config.bidirectional)
.init(device);
let output_dim = config.output_dim();
let dropout = DropoutConfig::new(config.dropout).init();
let fc = LinearConfig::new(output_dim, config.n_classes).init(device);
Self { lstm, dropout, fc }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, _n_vars, seq_len] = x.dims();
let x = x.swap_dims(1, 2);
let (output, _) = self.lstm.forward(x, None);
let [_, _, hidden_dim] = output.dims();
let last_output = output.slice([0..batch, (seq_len - 1)..seq_len, 0..hidden_dim]);
let last_output = last_output.reshape([batch, hidden_dim]);
let output = self.dropout.forward(last_output);
self.fc.forward(output)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rnn_config() {
let config = RNNPlusConfig::default();
assert_eq!(config.hidden_size, 128);
assert_eq!(config.rnn_type, RNNType::LSTM);
}
#[test]
fn test_output_dim() {
let config = RNNPlusConfig::default();
assert_eq!(config.output_dim(), 128);
let config_bi = RNNPlusConfig {
bidirectional: true,
..Default::default()
};
assert_eq!(config_bi.output_dim(), 256);
}
}