use burn::nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
Linear, LinearConfig,
};
use burn::prelude::*;
use burn::tensor::activation::{gelu, softmax};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatedTabTransformerConfig {
pub n_continuous: usize,
pub n_categorical: usize,
pub cat_cardinalities: Vec<usize>,
pub n_classes: usize,
pub d_model: usize,
pub n_heads: usize,
pub n_layers: usize,
pub ff_mult: usize,
pub dropout: f64,
pub attn_on_continuous: bool,
}
impl Default for GatedTabTransformerConfig {
fn default() -> Self {
Self {
n_continuous: 10,
n_categorical: 5,
cat_cardinalities: vec![10, 20, 30, 40, 50],
n_classes: 2,
d_model: 64,
n_heads: 4,
n_layers: 2,
ff_mult: 4,
dropout: 0.1,
attn_on_continuous: true,
}
}
}
impl GatedTabTransformerConfig {
pub fn new(n_continuous: usize, n_categorical: usize, n_classes: usize) -> Self {
Self {
n_continuous,
n_categorical,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_cardinalities(mut self, cardinalities: Vec<usize>) -> Self {
self.cat_cardinalities = cardinalities;
self
}
#[must_use]
pub fn with_d_model(mut self, d_model: usize) -> Self {
self.d_model = d_model;
self
}
#[must_use]
pub fn with_n_heads(mut self, n_heads: usize) -> Self {
self.n_heads = n_heads;
self
}
#[must_use]
pub fn with_n_layers(mut self, n_layers: usize) -> Self {
self.n_layers = n_layers;
self
}
#[must_use]
pub fn with_ff_mult(mut self, ff_mult: usize) -> Self {
self.ff_mult = ff_mult;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
#[must_use]
pub fn with_attn_on_continuous(mut self, attn_on_continuous: bool) -> Self {
self.attn_on_continuous = attn_on_continuous;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> GatedTabTransformer<B> {
GatedTabTransformer::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct GEGLU<B: Backend> {
proj: Linear<B>,
}
impl<B: Backend> GEGLU<B> {
fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
let proj = LinearConfig::new(in_features, out_features * 2).init(device);
Self { proj }
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let out = self.proj.forward(x);
let [batch, seq, features] = out.dims();
let half = features / 2;
let x1 = out.clone().slice([0..batch, 0..seq, 0..half]);
let x2 = out.slice([0..batch, 0..seq, half..features]);
x1 * gelu(x2)
}
}
#[derive(Module, Debug)]
struct GatedEncoderLayer<B: Backend> {
attention: MultiHeadAttention<B>,
norm1: LayerNorm<B>,
geglu: GEGLU<B>,
ff_out: Linear<B>,
norm2: LayerNorm<B>,
dropout: Dropout,
}
impl<B: Backend> GatedEncoderLayer<B> {
fn new(d_model: usize, n_heads: usize, ff_mult: usize, dropout: f64, device: &B::Device) -> Self {
let attention = MultiHeadAttentionConfig::new(d_model, n_heads)
.with_dropout(dropout)
.init(device);
let norm1 = LayerNormConfig::new(d_model).init(device);
let d_ff = d_model * ff_mult;
let geglu = GEGLU::new(d_model, d_ff, device);
let ff_out = LinearConfig::new(d_ff, d_model).init(device);
let norm2 = LayerNormConfig::new(d_model).init(device);
let dropout_layer = DropoutConfig::new(dropout).init();
Self {
attention,
norm1,
geglu,
ff_out,
norm2,
dropout: dropout_layer,
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let attn_input = MhaInput::self_attn(x.clone());
let attn_out = self.attention.forward(attn_input).context;
let x = self.norm1.forward(x + self.dropout.forward(attn_out));
let ff_out = self.geglu.forward(x.clone());
let ff_out = self.ff_out.forward(ff_out);
self.norm2.forward(x + self.dropout.forward(ff_out))
}
}
#[derive(Module, Debug)]
pub struct GatedTabTransformer<B: Backend> {
cat_embeddings: Vec<Embedding<B>>,
cont_proj: Linear<B>,
encoder_layers: Vec<GatedEncoderLayer<B>>,
final_norm: LayerNorm<B>,
cont_mlp: Option<Linear<B>>,
head: Linear<B>,
#[module(skip)]
d_model: usize,
#[module(skip)]
n_categorical: usize,
#[module(skip)]
attn_on_continuous: bool,
}
impl<B: Backend> GatedTabTransformer<B> {
pub fn new(config: GatedTabTransformerConfig, device: &B::Device) -> Self {
let cat_embeddings: Vec<_> = config
.cat_cardinalities
.iter()
.map(|&card| EmbeddingConfig::new(card, config.d_model).init(device))
.collect();
let cont_proj = LinearConfig::new(config.n_continuous.max(1), config.d_model).init(device);
let encoder_layers: Vec<_> = (0..config.n_layers)
.map(|_| {
GatedEncoderLayer::new(
config.d_model,
config.n_heads,
config.ff_mult,
config.dropout,
device,
)
})
.collect();
let final_norm = LayerNormConfig::new(config.d_model).init(device);
let cont_mlp = if !config.attn_on_continuous && config.n_continuous > 0 {
Some(LinearConfig::new(config.d_model, config.d_model).init(device))
} else {
None
};
let n_tokens = if config.attn_on_continuous {
config.n_categorical + 1 } else {
config.n_categorical
};
let head_input = if config.attn_on_continuous {
config.d_model * n_tokens
} else {
config.d_model * n_tokens + config.d_model };
let head = LinearConfig::new(head_input, config.n_classes).init(device);
Self {
cat_embeddings,
cont_proj,
encoder_layers,
final_norm,
cont_mlp,
head,
d_model: config.d_model,
n_categorical: config.n_categorical,
attn_on_continuous: config.attn_on_continuous,
}
}
pub fn forward(
&self,
x_continuous: Tensor<B, 2>,
x_categorical: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
let [batch, _] = x_continuous.dims();
let cont_embedded = self.cont_proj.forward(x_continuous);
let cont_token = cont_embedded.clone().reshape([batch, 1, self.d_model]);
let mut cat_tokens = Vec::new();
for (i, embedding) in self.cat_embeddings.iter().enumerate() {
if i < self.n_categorical {
let cat_col = x_categorical.clone().slice([0..batch, i..(i + 1)]);
let embedded = embedding.forward(cat_col); cat_tokens.push(embedded);
}
}
let transformer_input = if self.attn_on_continuous {
let mut all_tokens = vec![cont_token];
all_tokens.extend(cat_tokens);
Tensor::cat(all_tokens, 1) } else {
Tensor::cat(cat_tokens, 1) };
let mut x = transformer_input;
for layer in &self.encoder_layers {
x = layer.forward(x);
}
let x = self.final_norm.forward(x);
let [_, n_tokens, d_model] = x.dims();
let transformer_out = x.reshape([batch, n_tokens * d_model]);
let final_features = if let Some(ref cont_mlp) = self.cont_mlp {
let cont_out = cont_mlp.forward(cont_embedded);
let cont_out = gelu(cont_out);
Tensor::cat(vec![transformer_out, cont_out], 1)
} else {
transformer_out
};
self.head.forward(final_features)
}
pub fn forward_probs(
&self,
x_continuous: Tensor<B, 2>,
x_categorical: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
let logits = self.forward(x_continuous, x_categorical);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gated_tab_transformer_config_default() {
let config = GatedTabTransformerConfig::default();
assert_eq!(config.n_continuous, 10);
assert_eq!(config.n_categorical, 5);
assert_eq!(config.ff_mult, 4);
assert!(config.attn_on_continuous);
}
#[test]
fn test_gated_tab_transformer_config_new() {
let config = GatedTabTransformerConfig::new(20, 8, 10);
assert_eq!(config.n_continuous, 20);
assert_eq!(config.n_categorical, 8);
assert_eq!(config.n_classes, 10);
}
#[test]
fn test_gated_tab_transformer_config_builder() {
let config = GatedTabTransformerConfig::new(10, 5, 3)
.with_d_model(128)
.with_n_heads(8)
.with_n_layers(4)
.with_ff_mult(6)
.with_dropout(0.2)
.with_attn_on_continuous(false);
assert_eq!(config.d_model, 128);
assert_eq!(config.n_heads, 8);
assert_eq!(config.n_layers, 4);
assert_eq!(config.ff_mult, 6);
assert_eq!(config.dropout, 0.2);
assert!(!config.attn_on_continuous);
}
}