use burn::module::Param;
use burn::nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, Linear, LinearConfig, LayerNorm, LayerNormConfig,
};
use burn::prelude::*;
use burn::tensor::activation::{gelu, softmax};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformerModelConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_outputs: usize,
pub d_model: usize,
pub n_heads: usize,
pub n_layers: usize,
pub d_ff: usize,
pub dropout: f64,
pub attn_dropout: f64,
pub use_pos_encoding: bool,
pub pos_encoding_type: PositionalEncodingType,
pub pre_norm: bool,
pub activation: ActivationType,
pub aggregation: AggregationType,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum PositionalEncodingType {
#[default]
Sinusoidal,
Learnable,
None,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum ActivationType {
ReLU,
#[default]
GELU,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum AggregationType {
#[default]
Mean,
First,
Last,
Flatten,
}
impl Default for TransformerModelConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_outputs: 2,
d_model: 128,
n_heads: 8,
n_layers: 3,
d_ff: 512,
dropout: 0.1,
attn_dropout: 0.0,
use_pos_encoding: true,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
pre_norm: false,
activation: ActivationType::GELU,
aggregation: AggregationType::Mean,
}
}
}
impl TransformerModelConfig {
pub fn new(n_vars: usize, seq_len: usize, n_outputs: usize) -> Self {
Self {
n_vars,
seq_len,
n_outputs,
..Default::default()
}
}
pub fn with_d_model(mut self, d_model: usize) -> Self {
self.d_model = d_model;
self
}
pub fn with_n_heads(mut self, n_heads: usize) -> Self {
self.n_heads = n_heads;
self
}
pub fn with_n_layers(mut self, n_layers: usize) -> Self {
self.n_layers = n_layers;
self
}
pub fn with_d_ff(mut self, d_ff: usize) -> Self {
self.d_ff = d_ff;
self
}
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
pub fn with_pos_encoding(mut self, pos_type: PositionalEncodingType) -> Self {
self.pos_encoding_type = pos_type;
self.use_pos_encoding = !matches!(pos_type, PositionalEncodingType::None);
self
}
pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
self.pre_norm = pre_norm;
self
}
pub fn with_activation(mut self, activation: ActivationType) -> Self {
self.activation = activation;
self
}
pub fn with_aggregation(mut self, aggregation: AggregationType) -> Self {
self.aggregation = aggregation;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerModel<B> {
TransformerModel::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct TransformerEncoderLayer<B: Backend> {
attention: MultiHeadAttention<B>,
norm1: LayerNorm<B>,
ff_linear1: Linear<B>,
ff_linear2: Linear<B>,
norm2: LayerNorm<B>,
dropout1: Dropout,
dropout2: Dropout,
#[module(skip)]
pre_norm: bool,
#[module(skip)]
use_gelu: bool,
}
impl<B: Backend> TransformerEncoderLayer<B> {
fn new(config: &TransformerModelConfig, device: &B::Device) -> Self {
let attention = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_dropout(config.attn_dropout)
.init(device);
let norm1 = LayerNormConfig::new(config.d_model).init(device);
let ff_linear1 = LinearConfig::new(config.d_model, config.d_ff).init(device);
let ff_linear2 = LinearConfig::new(config.d_ff, config.d_model).init(device);
let norm2 = LayerNormConfig::new(config.d_model).init(device);
let dropout1 = DropoutConfig::new(config.dropout).init();
let dropout2 = DropoutConfig::new(config.dropout).init();
Self {
attention,
norm1,
ff_linear1,
ff_linear2,
norm2,
dropout1,
dropout2,
pre_norm: config.pre_norm,
use_gelu: matches!(config.activation, ActivationType::GELU),
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
if self.pre_norm {
let x_norm = self.norm1.forward(x.clone());
let attn_input = MhaInput::self_attn(x_norm);
let attn_out = self.attention.forward(attn_input).context;
let x = x + self.dropout1.forward(attn_out);
let x_norm = self.norm2.forward(x.clone());
let ff_out = self.ff_linear1.forward(x_norm);
let ff_out = if self.use_gelu {
gelu(ff_out)
} else {
ff_out.clamp_min(0.0) };
let ff_out = self.dropout2.forward(ff_out);
let ff_out = self.ff_linear2.forward(ff_out);
x + self.dropout2.forward(ff_out)
} else {
let attn_input = MhaInput::self_attn(x.clone());
let attn_out = self.attention.forward(attn_input).context;
let x = self.norm1.forward(x + self.dropout1.forward(attn_out));
let ff_out = self.ff_linear1.forward(x.clone());
let ff_out = if self.use_gelu {
gelu(ff_out)
} else {
ff_out.clamp_min(0.0) };
let ff_out = self.dropout2.forward(ff_out);
let ff_out = self.ff_linear2.forward(ff_out);
self.norm2.forward(x + self.dropout2.forward(ff_out))
}
}
}
#[derive(Module, Debug)]
pub struct TransformerModel<B: Backend> {
input_embedding: Linear<B>,
pos_embedding: Option<Param<Tensor<B, 2>>>,
input_dropout: Dropout,
encoder_layers: Vec<TransformerEncoderLayer<B>>,
final_norm: Option<LayerNorm<B>>,
head: Linear<B>,
#[module(skip)]
d_model: usize,
#[module(skip)]
seq_len: usize,
#[module(skip)]
use_pos_encoding: bool,
#[module(skip)]
aggregation: u8,
}
impl<B: Backend> TransformerModel<B> {
pub fn new(config: TransformerModelConfig, device: &B::Device) -> Self {
let input_embedding = LinearConfig::new(config.n_vars, config.d_model).init(device);
let pos_embedding = if config.use_pos_encoding
&& matches!(config.pos_encoding_type, PositionalEncodingType::Learnable)
{
let pe = Tensor::random(
[config.seq_len, config.d_model],
burn::tensor::Distribution::Normal(0.0, 0.02),
device,
);
Some(Param::from_tensor(pe))
} else {
None
};
let input_dropout = DropoutConfig::new(config.dropout).init();
let encoder_layers: Vec<_> = (0..config.n_layers)
.map(|_| TransformerEncoderLayer::new(&config, device))
.collect();
let final_norm = if config.pre_norm {
Some(LayerNormConfig::new(config.d_model).init(device))
} else {
None
};
let head_input_size = match config.aggregation {
AggregationType::Flatten => config.d_model * config.seq_len,
_ => config.d_model,
};
let head = LinearConfig::new(head_input_size, config.n_outputs).init(device);
Self {
input_embedding,
pos_embedding,
input_dropout,
encoder_layers,
final_norm,
head,
d_model: config.d_model,
seq_len: config.seq_len,
use_pos_encoding: config.use_pos_encoding,
aggregation: match config.aggregation {
AggregationType::Mean => 0,
AggregationType::First => 1,
AggregationType::Last => 2,
AggregationType::Flatten => 3,
},
}
}
fn sinusoidal_encoding(&self, seq_len: usize, d_model: usize, device: &B::Device) -> Tensor<B, 2> {
let mut pe = vec![0.0f32; seq_len * d_model];
for pos in 0..seq_len {
for i in 0..d_model {
let angle = pos as f32 / (10000.0f32).powf((2 * (i / 2)) as f32 / d_model as f32);
pe[pos * d_model + i] = if i % 2 == 0 { angle.sin() } else { angle.cos() };
}
}
Tensor::<B, 1>::from_floats(pe.as_slice(), device).reshape([seq_len, d_model])
}
fn aggregate(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, seq_len, d_model] = x.dims();
match self.aggregation {
0 => x.mean_dim(1).reshape([batch, d_model]), 1 => x.slice([0..batch, 0..1, 0..d_model]).reshape([batch, d_model]), 2 => x.slice([0..batch, seq_len - 1..seq_len, 0..d_model]).reshape([batch, d_model]), _ => x.reshape([batch, seq_len * d_model]), }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [_batch, _n_vars, seq_len] = x.dims();
let device = x.device();
let x = x.swap_dims(1, 2);
let mut x = self.input_embedding.forward(x);
let [_, _, d_model] = x.dims();
if self.use_pos_encoding {
let pos_enc = match &self.pos_embedding {
Some(pe) => pe.val().clone(),
None => self.sinusoidal_encoding(seq_len, d_model, &device),
};
x = x + pos_enc.unsqueeze::<3>();
}
x = self.input_dropout.forward(x);
for layer in &self.encoder_layers {
x = layer.forward(x);
}
if let Some(ref norm) = self.final_norm {
x = norm.forward(x);
}
let x = self.aggregate(x);
self.head.forward(x)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
pub fn encode(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [_batch, _n_vars, seq_len] = x.dims();
let device = x.device();
let x = x.swap_dims(1, 2);
let mut x = self.input_embedding.forward(x);
let [_, _, d_model] = x.dims();
if self.use_pos_encoding {
let pos_enc = match &self.pos_embedding {
Some(pe) => pe.val().clone(),
None => self.sinusoidal_encoding(seq_len, d_model, &device),
};
x = x + pos_enc.unsqueeze::<3>();
}
x = self.input_dropout.forward(x);
for layer in &self.encoder_layers {
x = layer.forward(x);
}
if let Some(ref norm) = self.final_norm {
x = norm.forward(x);
}
x
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn seq_len(&self) -> usize {
self.seq_len
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_ndarray::NdArray;
type TestBackend = NdArray;
#[test]
fn test_transformer_model_default() {
let device = Default::default();
let config = TransformerModelConfig::new(3, 50, 5);
let model: TransformerModel<TestBackend> = config.init(&device);
let input = Tensor::random([4, 3, 50], burn::tensor::Distribution::Normal(0.0, 1.0), &device);
let output = model.forward(input);
assert_eq!(output.dims(), [4, 5]);
}
#[test]
fn test_transformer_model_with_pre_norm() {
let device = Default::default();
let config = TransformerModelConfig::new(2, 100, 10)
.with_d_model(64)
.with_n_layers(4)
.with_n_heads(4)
.with_pre_norm(true);
let model: TransformerModel<TestBackend> = config.init(&device);
let input = Tensor::random([2, 2, 100], burn::tensor::Distribution::Normal(0.0, 1.0), &device);
let output = model.forward(input);
assert_eq!(output.dims(), [2, 10]);
}
#[test]
fn test_transformer_model_learnable_pos() {
let device = Default::default();
let config = TransformerModelConfig::new(1, 32, 3)
.with_pos_encoding(PositionalEncodingType::Learnable);
let model: TransformerModel<TestBackend> = config.init(&device);
let input = Tensor::random([8, 1, 32], burn::tensor::Distribution::Normal(0.0, 1.0), &device);
let output = model.forward(input);
assert_eq!(output.dims(), [8, 3]);
}
#[test]
fn test_transformer_encode() {
let device = Default::default();
let config = TransformerModelConfig::new(3, 50, 5)
.with_d_model(32);
let model: TransformerModel<TestBackend> = config.init(&device);
let input = Tensor::random([4, 3, 50], burn::tensor::Distribution::Normal(0.0, 1.0), &device);
let encoded = model.encode(input);
assert_eq!(encoded.dims(), [4, 50, 32]);
}
#[test]
fn test_aggregation_types() {
let device = Default::default();
let config = TransformerModelConfig::new(2, 10, 4)
.with_d_model(16)
.with_aggregation(AggregationType::Flatten);
let model: TransformerModel<TestBackend> = config.init(&device);
let input = Tensor::random([2, 2, 10], burn::tensor::Distribution::Normal(0.0, 1.0), &device);
let output = model.forward(input);
assert_eq!(output.dims(), [2, 4]);
}
#[test]
fn test_forward_probs() {
let device = Default::default();
let config = TransformerModelConfig::new(3, 50, 5);
let model: TransformerModel<TestBackend> = config.init(&device);
let input = Tensor::random([4, 3, 50], burn::tensor::Distribution::Normal(0.0, 1.0), &device);
let probs = model.forward_probs(input);
assert_eq!(probs.dims(), [4, 5]);
let sum = probs.clone().sum_dim(1);
let expected = Tensor::ones([4, 1], &device);
let diff = (sum - expected).abs().max().into_scalar();
assert!(diff < 1e-5, "Probabilities should sum to 1");
}
}