use burn::nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, Gelu, LayerNorm, LayerNormConfig, Linear, LinearConfig,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TSPerceiverConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub d_latent: usize,
pub n_latents: usize,
pub n_heads: usize,
pub n_cross_layers: usize,
pub n_self_layers: usize,
pub ff_mult: usize,
pub dropout: f64,
pub share_weights: bool,
}
impl Default for TSPerceiverConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
d_latent: 128,
n_latents: 32,
n_heads: 8,
n_cross_layers: 2,
n_self_layers: 4,
ff_mult: 4,
dropout: 0.1,
share_weights: true,
}
}
}
impl TSPerceiverConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_d_latent(mut self, d_latent: usize) -> Self {
self.d_latent = d_latent;
self
}
#[must_use]
pub fn with_n_latents(mut self, n_latents: usize) -> Self {
self.n_latents = n_latents;
self
}
#[must_use]
pub fn with_n_heads(mut self, n_heads: usize) -> Self {
self.n_heads = n_heads;
self
}
#[must_use]
pub fn with_n_cross_layers(mut self, n_cross_layers: usize) -> Self {
self.n_cross_layers = n_cross_layers;
self
}
#[must_use]
pub fn with_n_self_layers(mut self, n_self_layers: usize) -> Self {
self.n_self_layers = n_self_layers;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
#[must_use]
pub fn with_share_weights(mut self, share_weights: bool) -> Self {
self.share_weights = share_weights;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> TSPerceiver<B> {
TSPerceiver::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct CrossAttentionBlock<B: Backend> {
latent_norm: LayerNorm<B>,
input_norm: LayerNorm<B>,
attention: MultiHeadAttention<B>,
ff_norm: LayerNorm<B>,
ff_linear1: Linear<B>,
ff_linear2: Linear<B>,
dropout: Dropout,
}
impl<B: Backend> CrossAttentionBlock<B> {
fn new(d_latent: usize, d_input: usize, n_heads: usize, d_ff: usize, dropout: f64, device: &B::Device) -> Self {
let latent_norm = LayerNormConfig::new(d_latent).init(device);
let input_norm = LayerNormConfig::new(d_input).init(device);
let attention = MultiHeadAttentionConfig::new(d_latent, n_heads)
.with_dropout(dropout)
.init(device);
let ff_norm = LayerNormConfig::new(d_latent).init(device);
let ff_linear1 = LinearConfig::new(d_latent, d_ff).init(device);
let ff_linear2 = LinearConfig::new(d_ff, d_latent).init(device);
let dropout_layer = DropoutConfig::new(dropout).init();
Self {
latent_norm,
input_norm,
attention,
ff_norm,
ff_linear1,
ff_linear2,
dropout: dropout_layer,
}
}
fn forward(&self, latent: Tensor<B, 3>, input: Tensor<B, 3>, input_proj: &Linear<B>) -> Tensor<B, 3> {
let latent_normed = self.latent_norm.forward(latent.clone());
let input_normed = self.input_norm.forward(input);
let input_projected = input_proj.forward(input_normed);
let attn_input = MhaInput::new(latent_normed, input_projected.clone(), input_projected);
let attn_out = self.attention.forward(attn_input).context;
let latent = latent + self.dropout.forward(attn_out);
let normed = self.ff_norm.forward(latent.clone());
let ff_out = self.ff_linear1.forward(normed);
let ff_out = Gelu::new().forward(ff_out);
let ff_out = self.dropout.forward(ff_out);
let ff_out = self.ff_linear2.forward(ff_out);
latent + self.dropout.forward(ff_out)
}
}
#[derive(Module, Debug)]
struct SelfAttentionBlock<B: Backend> {
attn_norm: LayerNorm<B>,
attention: MultiHeadAttention<B>,
ff_norm: LayerNorm<B>,
ff_linear1: Linear<B>,
ff_linear2: Linear<B>,
dropout: Dropout,
}
impl<B: Backend> SelfAttentionBlock<B> {
fn new(d_latent: usize, n_heads: usize, d_ff: usize, dropout: f64, device: &B::Device) -> Self {
let attn_norm = LayerNormConfig::new(d_latent).init(device);
let attention = MultiHeadAttentionConfig::new(d_latent, n_heads)
.with_dropout(dropout)
.init(device);
let ff_norm = LayerNormConfig::new(d_latent).init(device);
let ff_linear1 = LinearConfig::new(d_latent, d_ff).init(device);
let ff_linear2 = LinearConfig::new(d_ff, d_latent).init(device);
let dropout_layer = DropoutConfig::new(dropout).init();
Self {
attn_norm,
attention,
ff_norm,
ff_linear1,
ff_linear2,
dropout: dropout_layer,
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let normed = self.attn_norm.forward(x.clone());
let attn_input = MhaInput::self_attn(normed);
let attn_out = self.attention.forward(attn_input).context;
let x = x + self.dropout.forward(attn_out);
let normed = self.ff_norm.forward(x.clone());
let ff_out = self.ff_linear1.forward(normed);
let ff_out = Gelu::new().forward(ff_out);
let ff_out = self.dropout.forward(ff_out);
let ff_out = self.ff_linear2.forward(ff_out);
x + self.dropout.forward(ff_out)
}
}
#[derive(Module, Debug)]
pub struct TSPerceiver<B: Backend> {
input_proj: Linear<B>,
cross_attn_blocks: Vec<CrossAttentionBlock<B>>,
self_attn_blocks: Vec<SelfAttentionBlock<B>>,
final_norm: LayerNorm<B>,
head: Linear<B>,
head_dropout: Dropout,
n_cross_layers: usize,
n_self_layers: usize,
d_latent: usize,
n_latents: usize,
}
impl<B: Backend> TSPerceiver<B> {
pub fn new(config: TSPerceiverConfig, device: &B::Device) -> Self {
let d_ff = config.d_latent * config.ff_mult;
let input_proj = LinearConfig::new(config.n_vars, config.d_latent).init(device);
let n_cross = if config.share_weights { 1 } else { config.n_cross_layers };
let cross_attn_blocks: Vec<_> = (0..n_cross)
.map(|_| {
CrossAttentionBlock::new(
config.d_latent,
config.d_latent,
config.n_heads,
d_ff,
config.dropout,
device,
)
})
.collect();
let self_attn_blocks: Vec<_> = (0..config.n_self_layers)
.map(|_| {
SelfAttentionBlock::new(
config.d_latent,
config.n_heads,
d_ff,
config.dropout,
device,
)
})
.collect();
let final_norm = LayerNormConfig::new(config.d_latent).init(device);
let head = LinearConfig::new(config.d_latent, config.n_classes).init(device);
let head_dropout = DropoutConfig::new(config.dropout).init();
Self {
input_proj,
cross_attn_blocks,
self_attn_blocks,
final_norm,
head,
head_dropout,
n_cross_layers: config.n_cross_layers,
n_self_layers: config.n_self_layers,
d_latent: config.d_latent,
n_latents: config.n_latents,
}
}
fn init_latents(&self, batch_size: usize, device: &B::Device) -> Tensor<B, 3> {
Tensor::random(
[batch_size, self.n_latents, self.d_latent],
burn::tensor::Distribution::Normal(0.0, 0.02),
device,
)
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, _n_vars, _seq_len] = x.dims();
let device = x.device();
let x = x.swap_dims(1, 2);
let input = self.input_proj.forward(x);
let mut latent = self.init_latents(batch, &device);
for i in 0..self.n_cross_layers {
let cross_idx = if self.cross_attn_blocks.len() == 1 { 0 } else { i };
latent = self.cross_attn_blocks[cross_idx].forward(
latent,
input.clone(),
&self.input_proj,
);
for self_attn in &self.self_attn_blocks {
latent = self_attn.forward(latent);
}
}
let latent = self.final_norm.forward(latent);
let pooled = latent.mean_dim(1).reshape([batch, self.d_latent]);
let out = self.head_dropout.forward(pooled);
self.head.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
pub fn n_latents(&self) -> usize {
self.n_latents
}
pub fn d_latent(&self) -> usize {
self.d_latent
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_perceiver_config_default() {
let config = TSPerceiverConfig::default();
assert_eq!(config.d_latent, 128);
assert_eq!(config.n_latents, 32);
assert_eq!(config.n_heads, 8);
assert_eq!(config.n_cross_layers, 2);
assert_eq!(config.n_self_layers, 4);
assert!(config.share_weights);
}
#[test]
fn test_perceiver_config_new() {
let config = TSPerceiverConfig::new(3, 200, 10);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
}
#[test]
fn test_perceiver_config_builder() {
let config = TSPerceiverConfig::new(3, 100, 5)
.with_d_latent(64)
.with_n_latents(16)
.with_n_heads(4)
.with_n_cross_layers(3)
.with_n_self_layers(2)
.with_dropout(0.2)
.with_share_weights(false);
assert_eq!(config.d_latent, 64);
assert_eq!(config.n_latents, 16);
assert_eq!(config.n_heads, 4);
assert_eq!(config.n_cross_layers, 3);
assert_eq!(config.n_self_layers, 2);
assert_eq!(config.dropout, 0.2);
assert!(!config.share_weights);
}
}