use burn::nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig, Linear,
LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatchTSTConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_outputs: usize,
pub patch_len: usize,
pub stride: usize,
pub d_model: usize,
pub n_heads: usize,
pub n_layers: usize,
pub d_ff: usize,
pub dropout: f64,
pub learnable_pe: bool,
pub is_classification: bool,
}
impl Default for PatchTSTConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 512,
n_outputs: 96,
patch_len: 16,
stride: 8,
d_model: 128,
n_heads: 8,
n_layers: 3,
d_ff: 256,
dropout: 0.1,
learnable_pe: true,
is_classification: false,
}
}
}
impl PatchTSTConfig {
pub fn for_classification(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_outputs: n_classes,
is_classification: true,
..Default::default()
}
}
pub fn for_forecasting(n_vars: usize, seq_len: usize, horizon: usize) -> Self {
Self {
n_vars,
seq_len,
n_outputs: horizon,
is_classification: false,
..Default::default()
}
}
pub fn n_patches(&self) -> usize {
(self.seq_len - self.patch_len) / self.stride + 1
}
pub fn init<B: Backend>(&self, device: &B::Device) -> PatchTST<B> {
PatchTST::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct TransformerEncoderLayer<B: Backend> {
attention: MultiHeadAttention<B>,
norm1: LayerNorm<B>,
ff_linear1: Linear<B>,
ff_linear2: Linear<B>,
norm2: LayerNorm<B>,
dropout: Dropout,
}
impl<B: Backend> TransformerEncoderLayer<B> {
fn new(d_model: usize, n_heads: usize, d_ff: usize, dropout: f64, device: &B::Device) -> Self {
let attention = MultiHeadAttentionConfig::new(d_model, n_heads)
.with_dropout(dropout)
.init(device);
let norm1 = LayerNormConfig::new(d_model).init(device);
let ff_linear1 = LinearConfig::new(d_model, d_ff).init(device);
let ff_linear2 = LinearConfig::new(d_ff, d_model).init(device);
let norm2 = LayerNormConfig::new(d_model).init(device);
let dropout_layer = DropoutConfig::new(dropout).init();
Self {
attention,
norm1,
ff_linear1,
ff_linear2,
norm2,
dropout: dropout_layer,
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let attn_input = MhaInput::self_attn(x.clone());
let attn_out = self.attention.forward(attn_input).context;
let x = self.norm1.forward(x + self.dropout.forward(attn_out));
let ff_out = self.ff_linear1.forward(x.clone());
let ff_out = Relu::new().forward(ff_out);
let ff_out = self.dropout.forward(ff_out);
let ff_out = self.ff_linear2.forward(ff_out);
self.norm2.forward(x + self.dropout.forward(ff_out))
}
}
#[derive(Module, Debug)]
pub struct PatchTST<B: Backend> {
patch_embed: Linear<B>,
encoder_layers: Vec<TransformerEncoderLayer<B>>,
head: Linear<B>,
dropout: Dropout,
}
impl<B: Backend> PatchTST<B> {
pub fn new(config: PatchTSTConfig, device: &B::Device) -> Self {
let n_patches = config.n_patches();
let patch_embed = LinearConfig::new(config.patch_len, config.d_model).init(device);
let encoder_layers: Vec<_> = (0..config.n_layers)
.map(|_| {
TransformerEncoderLayer::new(
config.d_model,
config.n_heads,
config.d_ff,
config.dropout,
device,
)
})
.collect();
let head_in = config.d_model * n_patches * config.n_vars;
let head = LinearConfig::new(head_in, config.n_outputs).init(device);
let dropout = DropoutConfig::new(config.dropout).init();
Self {
patch_embed,
encoder_layers,
head,
dropout,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, n_vars, _seq_len] = x.dims();
let embedded = self.patch_embed.forward(x.clone());
let [_, _, _d_model] = embedded.dims();
let mut out = embedded;
for layer in &self.encoder_layers {
out = layer.forward(out);
}
let [_, out_seq, out_dim] = out.dims();
let out = out.reshape([batch, n_vars * out_seq * out_dim]);
self.head.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_patchtst_config() {
let config = PatchTSTConfig::default();
assert_eq!(config.patch_len, 16);
assert_eq!(config.stride, 8);
assert_eq!(config.n_patches(), 63); }
#[test]
fn test_classification_config() {
let config = PatchTSTConfig::for_classification(3, 100, 5);
assert!(config.is_classification);
assert_eq!(config.n_outputs, 5);
}
}