use burn::nn::{
conv::{Conv1d, Conv1dConfig},
pool::{AdaptiveAvgPool1d, AdaptiveAvgPool1dConfig},
BatchNorm, BatchNormConfig, Dropout, DropoutConfig, Embedding, EmbeddingConfig, Linear,
LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum BackboneType {
CNN,
ResNet,
FCN,
}
impl Default for BackboneType {
fn default() -> Self {
Self::ResNet
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum FusionType {
Concat,
Add,
Gated,
}
impl Default for FusionType {
fn default() -> Self {
Self::Concat
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiInputNetConfig {
pub n_ts_vars: usize,
pub ts_seq_len: usize,
pub n_continuous: usize,
pub n_categorical: usize,
pub cat_cardinalities: Vec<usize>,
pub cat_embed_dim: usize,
pub n_classes: usize,
pub backbone: BackboneType,
pub fusion: FusionType,
pub backbone_filters: Vec<usize>,
pub tab_hidden_dim: usize,
pub final_hidden_dim: usize,
pub dropout: f64,
}
impl Default for MultiInputNetConfig {
fn default() -> Self {
Self {
n_ts_vars: 1,
ts_seq_len: 100,
n_continuous: 10,
n_categorical: 5,
cat_cardinalities: vec![10, 20, 30, 40, 50],
cat_embed_dim: 8,
n_classes: 2,
backbone: BackboneType::default(),
fusion: FusionType::default(),
backbone_filters: vec![64, 128, 256],
tab_hidden_dim: 128,
final_hidden_dim: 256,
dropout: 0.1,
}
}
}
impl MultiInputNetConfig {
pub fn new(n_ts_vars: usize, ts_seq_len: usize, n_classes: usize) -> Self {
Self {
n_ts_vars,
ts_seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_n_continuous(mut self, n: usize) -> Self {
self.n_continuous = n;
self
}
#[must_use]
pub fn with_categorical(mut self, n_categorical: usize, cardinalities: Vec<usize>) -> Self {
self.n_categorical = n_categorical;
self.cat_cardinalities = cardinalities;
self
}
#[must_use]
pub fn with_cat_embed_dim(mut self, dim: usize) -> Self {
self.cat_embed_dim = dim;
self
}
#[must_use]
pub fn with_backbone(mut self, backbone: BackboneType) -> Self {
self.backbone = backbone;
self
}
#[must_use]
pub fn with_fusion(mut self, fusion: FusionType) -> Self {
self.fusion = fusion;
self
}
#[must_use]
pub fn with_backbone_filters(mut self, filters: Vec<usize>) -> Self {
self.backbone_filters = filters;
self
}
#[must_use]
pub fn with_tab_hidden_dim(mut self, dim: usize) -> Self {
self.tab_hidden_dim = dim;
self
}
#[must_use]
pub fn with_final_hidden_dim(mut self, dim: usize) -> Self {
self.final_hidden_dim = dim;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> MultiInputNet<B> {
MultiInputNet::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct ResBlock<B: Backend> {
conv1: Conv1d<B>,
bn1: BatchNorm<B, 1>,
conv2: Conv1d<B>,
bn2: BatchNorm<B, 1>,
shortcut: Option<Conv1d<B>>,
}
impl<B: Backend> ResBlock<B> {
fn new(in_channels: usize, out_channels: usize, device: &B::Device) -> Self {
let conv1 = Conv1dConfig::new(in_channels, out_channels, 3)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let bn1 = BatchNormConfig::new(out_channels).init(device);
let conv2 = Conv1dConfig::new(out_channels, out_channels, 3)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let bn2 = BatchNormConfig::new(out_channels).init(device);
let shortcut = if in_channels != out_channels {
Some(
Conv1dConfig::new(in_channels, out_channels, 1)
.with_bias(false)
.init(device),
)
} else {
None
};
Self {
conv1,
bn1,
conv2,
bn2,
shortcut,
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let relu = Relu::new();
let out = relu.forward(self.bn1.forward(self.conv1.forward(x.clone())));
let out = self.bn2.forward(self.conv2.forward(out));
let shortcut = match &self.shortcut {
Some(sc) => sc.forward(x),
None => x,
};
relu.forward(out + shortcut)
}
}
#[derive(Module, Debug)]
struct ConvBlock<B: Backend> {
conv: Conv1d<B>,
bn: BatchNorm<B, 1>,
}
impl<B: Backend> ConvBlock<B> {
fn new(in_channels: usize, out_channels: usize, kernel_size: usize, device: &B::Device) -> Self {
let conv = Conv1dConfig::new(in_channels, out_channels, kernel_size)
.with_padding(burn::nn::PaddingConfig1d::Same)
.init(device);
let bn = BatchNormConfig::new(out_channels).init(device);
Self { conv, bn }
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
Relu::new().forward(self.bn.forward(self.conv.forward(x)))
}
}
#[derive(Module, Debug)]
pub struct MultiInputNet<B: Backend> {
ts_res_blocks: Vec<ResBlock<B>>,
ts_conv_blocks: Vec<ConvBlock<B>>,
gap: AdaptiveAvgPool1d,
cat_embeddings: Vec<Embedding<B>>,
tab_fc1: Linear<B>,
tab_fc2: Linear<B>,
fusion_gate: Option<Linear<B>>,
ts_proj: Option<Linear<B>>,
tab_proj: Option<Linear<B>>,
final_fc: Linear<B>,
head: Linear<B>,
dropout: Dropout,
#[module(skip)]
use_resnet: bool,
#[module(skip)]
use_gated_fusion: bool,
#[module(skip)]
use_add_fusion: bool,
}
impl<B: Backend> MultiInputNet<B> {
pub fn new(config: MultiInputNetConfig, device: &B::Device) -> Self {
let mut ts_res_blocks = Vec::new();
let mut ts_conv_blocks = Vec::new();
let mut in_channels = config.n_ts_vars;
for &out_channels in &config.backbone_filters {
match config.backbone {
BackboneType::ResNet => {
ts_res_blocks.push(ResBlock::new(in_channels, out_channels, device));
}
BackboneType::CNN | BackboneType::FCN => {
let kernel = if config.backbone == BackboneType::FCN { 8 } else { 3 };
ts_conv_blocks.push(ConvBlock::new(in_channels, out_channels, kernel, device));
}
}
in_channels = out_channels;
}
let ts_out_dim = *config.backbone_filters.last().unwrap_or(&64);
let gap = AdaptiveAvgPool1dConfig::new(1).init();
let cat_embeddings: Vec<_> = config
.cat_cardinalities
.iter()
.map(|&card| EmbeddingConfig::new(card, config.cat_embed_dim).init(device))
.collect();
let tab_in_dim =
config.n_continuous + config.n_categorical * config.cat_embed_dim;
let has_tabular = tab_in_dim > 0;
let tab_out_dim = if has_tabular { config.tab_hidden_dim } else { 0 };
let tab_fc_in = if tab_in_dim > 0 { tab_in_dim } else { 1 };
let tab_fc_out = if has_tabular { tab_out_dim } else { 1 };
let tab_fc1 = LinearConfig::new(tab_fc_in, config.tab_hidden_dim).init(device);
let tab_fc2 = LinearConfig::new(config.tab_hidden_dim, tab_fc_out).init(device);
let (fusion_gate, ts_proj, tab_proj, final_in_dim) = match config.fusion {
FusionType::Concat => (None, None, None, ts_out_dim + tab_out_dim),
FusionType::Add => {
let proj_dim = config.final_hidden_dim;
let ts_proj = Some(LinearConfig::new(ts_out_dim, proj_dim).init(device));
let tab_proj = if has_tabular {
Some(LinearConfig::new(tab_out_dim, proj_dim).init(device))
} else {
Some(LinearConfig::new(1, proj_dim).init(device)) };
(None, ts_proj, tab_proj, proj_dim)
}
FusionType::Gated => {
let combined_dim = ts_out_dim + tab_out_dim;
let gate_dim = if combined_dim > 0 { combined_dim } else { 1 };
let gate = Some(LinearConfig::new(gate_dim, gate_dim).init(device));
(gate, None, None, combined_dim.max(ts_out_dim))
}
};
let final_fc = LinearConfig::new(final_in_dim, config.final_hidden_dim).init(device);
let head = LinearConfig::new(config.final_hidden_dim, config.n_classes).init(device);
let dropout = DropoutConfig::new(config.dropout).init();
let use_resnet = config.backbone == BackboneType::ResNet;
let use_gated_fusion = config.fusion == FusionType::Gated;
let use_add_fusion = config.fusion == FusionType::Add;
Self {
ts_res_blocks,
ts_conv_blocks,
gap,
cat_embeddings,
tab_fc1,
tab_fc2,
fusion_gate,
ts_proj,
tab_proj,
final_fc,
head,
dropout,
use_resnet,
use_gated_fusion,
use_add_fusion,
}
}
pub fn forward_ts_only(&self, ts: Tensor<B, 3>) -> Tensor<B, 2> {
let ts_features = self.extract_ts_features(ts);
let out = self.dropout.forward(Relu::new().forward(self.final_fc.forward(ts_features)));
self.head.forward(out)
}
fn extract_ts_features(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let mut out = x;
if self.use_resnet {
for block in &self.ts_res_blocks {
out = block.forward(out);
}
} else {
for block in &self.ts_conv_blocks {
out = block.forward(out);
}
}
let out = self.gap.forward(out);
let [batch, channels, _] = out.dims();
out.reshape([batch, channels])
}
fn extract_tab_features(
&self,
continuous: Tensor<B, 2>,
categorical: Vec<Tensor<B, 2, burn::tensor::Int>>,
) -> Tensor<B, 2> {
let batch_size = continuous.dims()[0];
let mut tab_features = vec![continuous];
for (i, cat_tensor) in categorical.iter().enumerate() {
if i < self.cat_embeddings.len() {
let embedded = self.cat_embeddings[i].forward(cat_tensor.clone());
let [_b, seq, embed_dim] = embedded.dims();
let embedded = embedded.reshape([batch_size, seq * embed_dim]);
tab_features.push(embedded);
}
}
let combined = if tab_features.len() > 1 {
Tensor::cat(tab_features, 1)
} else {
tab_features.into_iter().next().unwrap()
};
let out = Relu::new().forward(self.tab_fc1.forward(combined));
self.dropout.forward(Relu::new().forward(self.tab_fc2.forward(out)))
}
pub fn forward(
&self,
ts: Tensor<B, 3>,
continuous: Tensor<B, 2>,
categorical: Vec<Tensor<B, 2, burn::tensor::Int>>,
) -> Tensor<B, 2> {
let ts_features = self.extract_ts_features(ts);
let tab_features = self.extract_tab_features(continuous, categorical);
let fused = if self.use_add_fusion {
let ts_proj = self.ts_proj.as_ref().unwrap();
let tab_proj = self.tab_proj.as_ref().unwrap();
ts_proj.forward(ts_features) + tab_proj.forward(tab_features)
} else if self.use_gated_fusion {
let gate = self.fusion_gate.as_ref().unwrap();
let combined = Tensor::cat(vec![ts_features.clone(), tab_features.clone()], 1);
let gate_values = burn::tensor::activation::sigmoid(gate.forward(combined.clone()));
combined * gate_values
} else {
Tensor::cat(vec![ts_features, tab_features], 1)
};
let out = self.dropout.forward(Relu::new().forward(self.final_fc.forward(fused)));
self.head.forward(out)
}
pub fn forward_probs(
&self,
ts: Tensor<B, 3>,
continuous: Tensor<B, 2>,
categorical: Vec<Tensor<B, 2, burn::tensor::Int>>,
) -> Tensor<B, 2> {
let logits = self.forward(ts, continuous, categorical);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_ndarray::NdArray;
type TestBackend = NdArray<f32>;
#[test]
fn test_multi_input_config() {
let config = MultiInputNetConfig::default();
assert_eq!(config.backbone, BackboneType::ResNet);
assert_eq!(config.fusion, FusionType::Concat);
assert_eq!(config.n_ts_vars, 1);
}
#[test]
fn test_multi_input_forward_ts_only() {
let device = Default::default();
let config = MultiInputNetConfig::new(3, 100, 5)
.with_n_continuous(0)
.with_categorical(0, vec![]);
let model: MultiInputNet<TestBackend> = config.init(&device);
let ts = Tensor::<TestBackend, 3>::zeros([4, 3, 100], &device);
let out = model.forward_ts_only(ts);
assert_eq!(out.dims(), [4, 5]);
}
#[test]
fn test_multi_input_fusion_types() {
let device = Default::default();
let config = MultiInputNetConfig::new(3, 100, 5)
.with_n_continuous(10)
.with_categorical(2, vec![5, 10])
.with_fusion(FusionType::Concat);
let model: MultiInputNet<TestBackend> = config.init(&device);
let ts = Tensor::<TestBackend, 3>::zeros([4, 3, 100], &device);
let cont = Tensor::<TestBackend, 2>::zeros([4, 10], &device);
let cat = vec![
Tensor::<TestBackend, 2, burn::tensor::Int>::zeros([4, 1], &device),
Tensor::<TestBackend, 2, burn::tensor::Int>::zeros([4, 1], &device),
];
let out = model.forward(ts.clone(), cont.clone(), cat.clone());
assert_eq!(out.dims(), [4, 5]);
let config = MultiInputNetConfig::new(3, 100, 5)
.with_n_continuous(10)
.with_categorical(2, vec![5, 10])
.with_fusion(FusionType::Add);
let model: MultiInputNet<TestBackend> = config.init(&device);
let out = model.forward(ts.clone(), cont.clone(), cat.clone());
assert_eq!(out.dims(), [4, 5]);
let config = MultiInputNetConfig::new(3, 100, 5)
.with_n_continuous(10)
.with_categorical(2, vec![5, 10])
.with_fusion(FusionType::Gated);
let model: MultiInputNet<TestBackend> = config.init(&device);
let out = model.forward(ts, cont, cat);
assert_eq!(out.dims(), [4, 5]);
}
}