use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, Lstm, LstmConfig};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
use super::RNNType;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AttentionType {
Additive,
DotProduct,
ScaledDotProduct,
}
impl Default for AttentionType {
fn default() -> Self {
Self::ScaledDotProduct
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RNNAttentionConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub hidden_size: usize,
pub n_layers: usize,
pub rnn_type: RNNType,
pub bidirectional: bool,
pub attention_type: AttentionType,
pub attention_dim: usize,
pub dropout: f64,
}
impl Default for RNNAttentionConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
hidden_size: 128,
n_layers: 1,
rnn_type: RNNType::LSTM,
bidirectional: true,
attention_type: AttentionType::ScaledDotProduct,
attention_dim: 64,
dropout: 0.1,
}
}
}
impl RNNAttentionConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_hidden_size(mut self, hidden_size: usize) -> Self {
self.hidden_size = hidden_size;
self
}
#[must_use]
pub fn with_bidirectional(mut self, bidirectional: bool) -> Self {
self.bidirectional = bidirectional;
self
}
#[must_use]
pub fn with_attention_type(mut self, attention_type: AttentionType) -> Self {
self.attention_type = attention_type;
self
}
fn output_dim(&self) -> usize {
if self.bidirectional {
self.hidden_size * 2
} else {
self.hidden_size
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> RNNAttention<B> {
RNNAttention::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct AdditiveAttention<B: Backend> {
query_proj: Linear<B>,
key_proj: Linear<B>,
score_proj: Linear<B>,
}
impl<B: Backend> AdditiveAttention<B> {
pub fn new(hidden_dim: usize, attention_dim: usize, device: &B::Device) -> Self {
let query_proj = LinearConfig::new(hidden_dim, attention_dim)
.with_bias(false)
.init(device);
let key_proj = LinearConfig::new(hidden_dim, attention_dim)
.with_bias(false)
.init(device);
let score_proj = LinearConfig::new(attention_dim, 1)
.with_bias(false)
.init(device);
Self {
query_proj,
key_proj,
score_proj,
}
}
pub fn forward(&self, query: Tensor<B, 2>, keys: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, seq_len, _] = keys.dims();
let query_proj = self.query_proj.forward(query);
let query_proj = query_proj.unsqueeze_dim(1);
let [batch_k, seq_len_k, hidden_k] = keys.dims();
let keys_flat = keys.reshape([batch_k * seq_len_k, hidden_k]);
let keys_proj = self.key_proj.forward(keys_flat);
let attention_dim = query_proj.dims()[2];
let keys_proj = keys_proj.reshape([batch, seq_len, attention_dim]);
let combined = query_proj + keys_proj;
let combined = combined.tanh();
let [batch_c, seq_len_c, att_dim] = combined.dims();
let combined_flat = combined.reshape([batch_c * seq_len_c, att_dim]);
let scores = self.score_proj.forward(combined_flat);
let scores = scores.reshape([batch, seq_len]);
softmax(scores, 1)
}
}
#[derive(Module, Debug)]
pub struct RNNAttention<B: Backend> {
lstm: Lstm<B>,
additive_attention: Option<AdditiveAttention<B>>,
query_proj: Option<Linear<B>>,
dropout: Dropout,
fc: Linear<B>,
#[module(skip)]
attention_mode: u8,
#[module(skip)]
hidden_dim: usize,
}
impl<B: Backend> RNNAttention<B> {
pub fn new(config: RNNAttentionConfig, device: &B::Device) -> Self {
let lstm = LstmConfig::new(config.n_vars, config.hidden_size, config.bidirectional)
.init(device);
let hidden_dim = config.output_dim();
let attention_mode = match config.attention_type {
AttentionType::Additive => 0,
AttentionType::DotProduct => 1,
AttentionType::ScaledDotProduct => 2,
};
let (additive_attention, query_proj) = match config.attention_type {
AttentionType::Additive => {
let attn = AdditiveAttention::new(hidden_dim, config.attention_dim, device);
(Some(attn), None)
}
AttentionType::DotProduct | AttentionType::ScaledDotProduct => {
let query = LinearConfig::new(hidden_dim, hidden_dim)
.with_bias(false)
.init(device);
(None, Some(query))
}
};
let dropout = DropoutConfig::new(config.dropout).init();
let fc = LinearConfig::new(hidden_dim, config.n_classes).init(device);
Self {
lstm,
additive_attention,
query_proj,
dropout,
fc,
attention_mode,
hidden_dim,
}
}
fn dot_product_attention(&self, outputs: Tensor<B, 3>, scale: bool) -> Tensor<B, 2> {
let [batch, seq_len, hidden_dim] = outputs.dims();
let query = outputs.clone().mean_dim(1);
let query = if let Some(ref proj) = self.query_proj {
proj.forward(query)
} else {
query
};
let query = query.unsqueeze_dim(1); let keys = outputs.swap_dims(1, 2); let scores = query.matmul(keys).reshape([batch, seq_len]);
let scores = if scale {
scores / (hidden_dim as f32).sqrt()
} else {
scores
};
softmax(scores, 1)
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, _n_vars, _seq_len] = x.dims();
let x = x.swap_dims(1, 2);
let (outputs, _) = self.lstm.forward(x, None);
let [_, seq_len, hidden_dim] = outputs.dims();
let attention_weights = match self.attention_mode {
0 => {
let last_hidden =
outputs
.clone()
.slice([0..batch, (seq_len - 1)..seq_len, 0..hidden_dim]);
let last_hidden = last_hidden.reshape([batch, hidden_dim]);
self.additive_attention
.as_ref()
.unwrap()
.forward(last_hidden, outputs.clone())
}
1 => self.dot_product_attention(outputs.clone(), false),
_ => self.dot_product_attention(outputs.clone(), true),
};
let weights = attention_weights.unsqueeze_dim(2); let context = outputs * weights; let context = context.sum_dim(1).squeeze(1);
let context = self.dropout.forward(context);
self.fc.forward(context)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
pub fn get_attention_weights(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, _n_vars, _seq_len] = x.dims();
let x = x.swap_dims(1, 2);
let (outputs, _) = self.lstm.forward(x, None);
let [_, seq_len, hidden_dim] = outputs.dims();
match self.attention_mode {
0 => {
let last_hidden =
outputs
.clone()
.slice([0..batch, (seq_len - 1)..seq_len, 0..hidden_dim]);
let last_hidden = last_hidden.reshape([batch, hidden_dim]);
self.additive_attention
.as_ref()
.unwrap()
.forward(last_hidden, outputs)
}
1 => self.dot_product_attention(outputs, false),
_ => self.dot_product_attention(outputs, true),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rnn_attention_config_default() {
let config = RNNAttentionConfig::default();
assert_eq!(config.hidden_size, 128);
assert_eq!(config.attention_type, AttentionType::ScaledDotProduct);
assert!(config.bidirectional);
}
#[test]
fn test_rnn_attention_config_builder() {
let config = RNNAttentionConfig::new(3, 200, 10)
.with_hidden_size(256)
.with_bidirectional(false)
.with_attention_type(AttentionType::Additive);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
assert_eq!(config.hidden_size, 256);
assert!(!config.bidirectional);
assert_eq!(config.attention_type, AttentionType::Additive);
}
#[test]
fn test_output_dim() {
let config = RNNAttentionConfig::default();
assert_eq!(config.output_dim(), 256);
let config_uni = RNNAttentionConfig {
bidirectional: false,
..Default::default()
};
assert_eq!(config_uni.output_dim(), 128);
}
}