use burn::nn::{
BatchNorm, BatchNormConfig, Dropout, DropoutConfig, Embedding, EmbeddingConfig,
Linear, LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TabModelConfig {
pub n_continuous: usize,
pub n_categorical: usize,
pub cat_cardinalities: Vec<usize>,
pub embed_dim: usize,
pub n_classes: usize,
pub hidden_sizes: Vec<usize>,
pub dropout: f64,
pub use_bn: bool,
}
impl Default for TabModelConfig {
fn default() -> Self {
Self {
n_continuous: 10,
n_categorical: 5,
cat_cardinalities: vec![10, 20, 30, 40, 50],
embed_dim: 8,
n_classes: 2,
hidden_sizes: vec![200, 100],
dropout: 0.1,
use_bn: true,
}
}
}
impl TabModelConfig {
pub fn new(n_continuous: usize, n_categorical: usize, n_classes: usize) -> Self {
Self {
n_continuous,
n_categorical,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_cardinalities(mut self, cardinalities: Vec<usize>) -> Self {
self.cat_cardinalities = cardinalities;
self
}
#[must_use]
pub fn with_embed_dim(mut self, embed_dim: usize) -> Self {
self.embed_dim = embed_dim;
self
}
#[must_use]
pub fn with_hidden_sizes(mut self, sizes: Vec<usize>) -> Self {
self.hidden_sizes = sizes;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
#[must_use]
pub fn with_bn(mut self, use_bn: bool) -> Self {
self.use_bn = use_bn;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> TabModel<B> {
TabModel::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct TabMLPBlock<B: Backend> {
linear: Linear<B>,
bn: Option<BatchNorm<B, 1>>,
dropout: Dropout,
}
impl<B: Backend> TabMLPBlock<B> {
fn new(in_features: usize, out_features: usize, dropout: f64, use_bn: bool, device: &B::Device) -> Self {
let linear = LinearConfig::new(in_features, out_features).init(device);
let bn = if use_bn {
Some(BatchNormConfig::new(out_features).init(device))
} else {
None
};
let dropout_layer = DropoutConfig::new(dropout).init();
Self { linear, bn, dropout: dropout_layer }
}
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let out = self.linear.forward(x);
let out = Relu::new().forward(out);
let out = if let Some(ref bn) = self.bn {
let [batch, features] = out.dims();
let out = out.reshape([batch, features, 1]);
let out = bn.forward(out);
out.reshape([batch, features])
} else {
out
};
self.dropout.forward(out)
}
}
#[derive(Module, Debug)]
pub struct TabModel<B: Backend> {
cat_embeddings: Vec<Embedding<B>>,
cont_bn: Option<BatchNorm<B, 1>>,
blocks: Vec<TabMLPBlock<B>>,
head: Linear<B>,
#[module(skip)]
embed_dim: usize,
#[module(skip)]
n_categorical: usize,
#[module(skip)]
n_continuous: usize,
}
impl<B: Backend> TabModel<B> {
pub fn new(config: TabModelConfig, device: &B::Device) -> Self {
let cat_embeddings: Vec<_> = config
.cat_cardinalities
.iter()
.take(config.n_categorical)
.map(|&card| EmbeddingConfig::new(card, config.embed_dim).init(device))
.collect();
let cont_bn = if config.use_bn && config.n_continuous > 0 {
Some(BatchNormConfig::new(config.n_continuous).init(device))
} else {
None
};
let input_size = config.n_continuous + config.n_categorical * config.embed_dim;
let mut blocks = Vec::new();
let mut prev_size = input_size;
for &hidden_size in &config.hidden_sizes {
blocks.push(TabMLPBlock::new(prev_size, hidden_size, config.dropout, config.use_bn, device));
prev_size = hidden_size;
}
let head = LinearConfig::new(prev_size, config.n_classes).init(device);
Self {
cat_embeddings,
cont_bn,
blocks,
head,
embed_dim: config.embed_dim,
n_categorical: config.n_categorical,
n_continuous: config.n_continuous,
}
}
pub fn forward(
&self,
x_continuous: Tensor<B, 2>,
x_categorical: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
let [batch, _] = x_continuous.dims();
let cont = if let Some(ref bn) = self.cont_bn {
let reshaped = x_continuous.reshape([batch, self.n_continuous, 1]);
let normed = bn.forward(reshaped);
normed.reshape([batch, self.n_continuous])
} else {
x_continuous
};
let mut cat_embeds = Vec::new();
for (i, embedding) in self.cat_embeddings.iter().enumerate() {
if i < self.n_categorical {
let cat_col = x_categorical.clone().slice([0..batch, i..(i + 1)]);
let embedded = embedding.forward(cat_col); let embedded = embedded.reshape([batch, self.embed_dim]);
cat_embeds.push(embedded);
}
}
let mut features = vec![cont];
features.extend(cat_embeds);
let combined = Tensor::cat(features, 1);
let mut out = combined;
for block in &self.blocks {
out = block.forward(out);
}
self.head.forward(out)
}
pub fn forward_probs(
&self,
x_continuous: Tensor<B, 2>,
x_categorical: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
let logits = self.forward(x_continuous, x_categorical);
softmax(logits, 1)
}
pub fn forward_continuous(&self, x_continuous: Tensor<B, 2>) -> Tensor<B, 2> {
let [batch, _] = x_continuous.dims();
let cont = if let Some(ref bn) = self.cont_bn {
let reshaped = x_continuous.reshape([batch, self.n_continuous, 1]);
let normed = bn.forward(reshaped);
normed.reshape([batch, self.n_continuous])
} else {
x_continuous
};
let mut out = cont;
for block in &self.blocks {
out = block.forward(out);
}
self.head.forward(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tab_model_config_default() {
let config = TabModelConfig::default();
assert_eq!(config.n_continuous, 10);
assert_eq!(config.n_categorical, 5);
assert_eq!(config.embed_dim, 8);
assert!(config.use_bn);
}
#[test]
fn test_tab_model_config_new() {
let config = TabModelConfig::new(20, 8, 10);
assert_eq!(config.n_continuous, 20);
assert_eq!(config.n_categorical, 8);
assert_eq!(config.n_classes, 10);
}
#[test]
fn test_tab_model_config_builder() {
let config = TabModelConfig::new(10, 5, 3)
.with_cardinalities(vec![5, 10, 15, 20, 25])
.with_embed_dim(16)
.with_hidden_sizes(vec![128, 64])
.with_dropout(0.2)
.with_bn(false);
assert_eq!(config.cat_cardinalities, vec![5, 10, 15, 20, 25]);
assert_eq!(config.embed_dim, 16);
assert_eq!(config.hidden_sizes, vec![128, 64]);
assert_eq!(config.dropout, 0.2);
assert!(!config.use_bn);
}
}