use burn::nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Embedding, EmbeddingConfig, Linear, LinearConfig, LayerNorm, LayerNormConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TabTransformerConfig {
pub n_continuous: usize,
pub n_categorical: usize,
pub cat_cardinalities: Vec<usize>,
pub n_classes: usize,
pub d_model: usize,
pub n_heads: usize,
pub n_layers: usize,
pub dropout: f64,
}
impl Default for TabTransformerConfig {
fn default() -> Self {
Self {
n_continuous: 10,
n_categorical: 5,
cat_cardinalities: vec![10, 20, 30, 40, 50],
n_classes: 2,
d_model: 64,
n_heads: 4,
n_layers: 2,
dropout: 0.1,
}
}
}
impl TabTransformerConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> TabTransformer<B> {
TabTransformer::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct TabEncoderLayer<B: Backend> {
attention: MultiHeadAttention<B>,
norm1: LayerNorm<B>,
ff_linear1: Linear<B>,
ff_linear2: Linear<B>,
norm2: LayerNorm<B>,
}
impl<B: Backend> TabEncoderLayer<B> {
fn new(d_model: usize, n_heads: usize, dropout: f64, device: &B::Device) -> Self {
let attention = MultiHeadAttentionConfig::new(d_model, n_heads)
.with_dropout(dropout)
.init(device);
let norm1 = LayerNormConfig::new(d_model).init(device);
let ff_linear1 = LinearConfig::new(d_model, d_model * 4).init(device);
let ff_linear2 = LinearConfig::new(d_model * 4, d_model).init(device);
let norm2 = LayerNormConfig::new(d_model).init(device);
Self {
attention,
norm1,
ff_linear1,
ff_linear2,
norm2,
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let attn_input = MhaInput::self_attn(x.clone());
let attn_out = self.attention.forward(attn_input).context;
let x = self.norm1.forward(x + attn_out);
let ff_out = self.ff_linear1.forward(x.clone());
let ff_out = Relu::new().forward(ff_out);
let ff_out = self.ff_linear2.forward(ff_out);
self.norm2.forward(x + ff_out)
}
}
#[derive(Module, Debug)]
pub struct TabTransformer<B: Backend> {
cat_embeddings: Vec<Embedding<B>>,
cont_proj: Linear<B>,
encoder_layers: Vec<TabEncoderLayer<B>>,
head: Linear<B>,
}
impl<B: Backend> TabTransformer<B> {
pub fn new(config: TabTransformerConfig, device: &B::Device) -> Self {
let cat_embeddings: Vec<_> = config
.cat_cardinalities
.iter()
.map(|&card| EmbeddingConfig::new(card, config.d_model).init(device))
.collect();
let cont_proj = LinearConfig::new(config.n_continuous, config.d_model).init(device);
let encoder_layers: Vec<_> = (0..config.n_layers)
.map(|_| TabEncoderLayer::new(config.d_model, config.n_heads, config.dropout, device))
.collect();
let total_features = config.n_categorical + 1; let head = LinearConfig::new(config.d_model * total_features, config.n_classes).init(device);
Self {
cat_embeddings,
cont_proj,
encoder_layers,
head,
}
}
pub fn forward(
&self,
x_continuous: Tensor<B, 2>,
x_categorical: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
let [batch, _] = x_continuous.dims();
let cont_embedded = self.cont_proj.forward(x_continuous);
let cont_embedded = cont_embedded.unsqueeze_dim(1);
let mut cat_embeddings = Vec::new();
for (i, embedding) in self.cat_embeddings.iter().enumerate() {
let cat_col = x_categorical.clone().slice([0..batch, i..(i + 1)]);
let embedded = embedding.forward(cat_col); cat_embeddings.push(embedded);
}
let mut all_features = vec![cont_embedded];
all_features.extend(cat_embeddings);
let features = Tensor::cat(all_features, 1);
let mut x = features;
for layer in &self.encoder_layers {
x = layer.forward(x);
}
let [_, n_feats, d_model] = x.dims();
let x = x.reshape([batch, n_feats * d_model]);
self.head.forward(x)
}
pub fn forward_probs(
&self,
x_continuous: Tensor<B, 2>,
x_categorical: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
let logits = self.forward(x_continuous, x_categorical);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tab_transformer_config() {
let config = TabTransformerConfig::default();
assert_eq!(config.n_continuous, 10);
assert_eq!(config.n_categorical, 5);
}
}