use burn::nn::{
pool::{AdaptiveAvgPool1d, AdaptiveAvgPool1dConfig},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig,
Lstm, LstmConfig,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TSSequencerPlusConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub d_model: usize,
pub n_layers: usize,
pub lstm_hidden: usize,
pub ff_mult: usize,
pub dropout: f64,
pub bidirectional: bool,
}
impl Default for TSSequencerPlusConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
d_model: 128,
n_layers: 4,
lstm_hidden: 64,
ff_mult: 4,
dropout: 0.1,
bidirectional: true,
}
}
}
impl TSSequencerPlusConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_d_model(mut self, d_model: usize) -> Self {
self.d_model = d_model;
self
}
#[must_use]
pub fn with_n_layers(mut self, n_layers: usize) -> Self {
self.n_layers = n_layers;
self
}
#[must_use]
pub fn with_lstm_hidden(mut self, lstm_hidden: usize) -> Self {
self.lstm_hidden = lstm_hidden;
self
}
#[must_use]
pub fn with_ff_mult(mut self, ff_mult: usize) -> Self {
self.ff_mult = ff_mult;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
#[must_use]
pub fn with_bidirectional(mut self, bidirectional: bool) -> Self {
self.bidirectional = bidirectional;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> TSSequencerPlus<B> {
TSSequencerPlus::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct SequencerBlock<B: Backend> {
norm1: LayerNorm<B>,
lstm_fwd: Lstm<B>,
lstm_bwd: Option<Lstm<B>>,
lstm_proj: Linear<B>,
norm2: LayerNorm<B>,
ff_linear1: Linear<B>,
ff_linear2: Linear<B>,
dropout: Dropout,
#[module(skip)]
bidirectional: bool,
#[module(skip)]
lstm_hidden: usize,
}
impl<B: Backend> SequencerBlock<B> {
fn new(
d_model: usize,
lstm_hidden: usize,
ff_mult: usize,
dropout: f64,
bidirectional: bool,
device: &B::Device,
) -> Self {
let norm1 = LayerNormConfig::new(d_model).init(device);
let lstm_fwd = LstmConfig::new(d_model, lstm_hidden, true).init(device);
let lstm_bwd = if bidirectional {
Some(LstmConfig::new(d_model, lstm_hidden, true).init(device))
} else {
None
};
let lstm_out_size = if bidirectional {
lstm_hidden * 2
} else {
lstm_hidden
};
let lstm_proj = LinearConfig::new(lstm_out_size, d_model).init(device);
let norm2 = LayerNormConfig::new(d_model).init(device);
let d_ff = d_model * ff_mult;
let ff_linear1 = LinearConfig::new(d_model, d_ff).init(device);
let ff_linear2 = LinearConfig::new(d_ff, d_model).init(device);
let dropout_layer = DropoutConfig::new(dropout).init();
Self {
norm1,
lstm_fwd,
lstm_bwd,
lstm_proj,
norm2,
ff_linear1,
ff_linear2,
dropout: dropout_layer,
bidirectional,
lstm_hidden,
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [_batch, _seq_len, _d_model] = x.dims();
let residual = x.clone();
let normed = self.norm1.forward(x);
let (fwd_out, _) = self.lstm_fwd.forward(normed.clone(), None);
let lstm_out = if let Some(ref lstm_bwd) = self.lstm_bwd {
let reversed = normed.flip([1]);
let (bwd_out, _) = lstm_bwd.forward(reversed, None);
let bwd_out = bwd_out.flip([1]);
Tensor::cat(vec![fwd_out, bwd_out], 2)
} else {
fwd_out
};
let lstm_out = self.lstm_proj.forward(lstm_out);
let x = residual + self.dropout.forward(lstm_out);
let residual = x.clone();
let normed = self.norm2.forward(x);
let ff_out = self.ff_linear1.forward(normed);
let ff_out = burn::tensor::activation::gelu(ff_out);
let ff_out = self.dropout.forward(ff_out);
let ff_out = self.ff_linear2.forward(ff_out);
residual + self.dropout.forward(ff_out)
}
}
#[derive(Module, Debug)]
pub struct TSSequencerPlus<B: Backend> {
input_proj: Linear<B>,
blocks: Vec<SequencerBlock<B>>,
final_norm: LayerNorm<B>,
gap: AdaptiveAvgPool1d,
head: Linear<B>,
#[module(skip)]
d_model: usize,
}
impl<B: Backend> TSSequencerPlus<B> {
pub fn new(config: TSSequencerPlusConfig, device: &B::Device) -> Self {
let input_proj = LinearConfig::new(config.n_vars, config.d_model).init(device);
let blocks: Vec<_> = (0..config.n_layers)
.map(|_| {
SequencerBlock::new(
config.d_model,
config.lstm_hidden,
config.ff_mult,
config.dropout,
config.bidirectional,
device,
)
})
.collect();
let final_norm = LayerNormConfig::new(config.d_model).init(device);
let gap = AdaptiveAvgPool1dConfig::new(1).init();
let head = LinearConfig::new(config.d_model, config.n_classes).init(device);
Self {
input_proj,
blocks,
final_norm,
gap,
head,
d_model: config.d_model,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, _n_vars, _seq_len] = x.dims();
let out = x.swap_dims(1, 2);
let out = self.input_proj.forward(out);
let mut out = out;
for block in &self.blocks {
out = block.forward(out);
}
let out = self.final_norm.forward(out);
let out = out.swap_dims(1, 2);
let out = self.gap.forward(out);
let out = out.reshape([batch_size, self.d_model]);
self.head.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ts_sequencer_config_default() {
let config = TSSequencerPlusConfig::default();
assert_eq!(config.d_model, 128);
assert_eq!(config.n_layers, 4);
assert_eq!(config.lstm_hidden, 64);
assert!(config.bidirectional);
}
#[test]
fn test_ts_sequencer_config_new() {
let config = TSSequencerPlusConfig::new(3, 200, 10);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
}
#[test]
fn test_ts_sequencer_config_builder() {
let config = TSSequencerPlusConfig::new(3, 100, 5)
.with_d_model(256)
.with_n_layers(6)
.with_lstm_hidden(128)
.with_ff_mult(2)
.with_dropout(0.2)
.with_bidirectional(false);
assert_eq!(config.d_model, 256);
assert_eq!(config.n_layers, 6);
assert_eq!(config.lstm_hidden, 128);
assert_eq!(config.ff_mult, 2);
assert_eq!(config.dropout, 0.2);
assert!(!config.bidirectional);
}
}