use burn::nn::{
pool::{AdaptiveAvgPool1d, AdaptiveAvgPool1dConfig},
BatchNorm, BatchNormConfig, Dropout, DropoutConfig, Linear, LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLPConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub hidden_sizes: Vec<usize>,
pub dropout: f64,
pub use_bn: bool,
pub pool: String,
}
impl Default for MLPConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
hidden_sizes: vec![500, 500, 500],
dropout: 0.1,
use_bn: false,
pool: "flatten".to_string(),
}
}
}
impl MLPConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_hidden_sizes(mut self, sizes: Vec<usize>) -> Self {
self.hidden_sizes = sizes;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
#[must_use]
pub fn with_bn(mut self, use_bn: bool) -> Self {
self.use_bn = use_bn;
self
}
#[must_use]
pub fn with_pool(mut self, pool: &str) -> Self {
self.pool = pool.to_string();
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
MLP::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
struct MLPBlock<B: Backend> {
linear: Linear<B>,
bn: Option<BatchNorm<B, 1>>,
dropout: Dropout,
}
impl<B: Backend> MLPBlock<B> {
fn new(in_features: usize, out_features: usize, dropout: f64, use_bn: bool, device: &B::Device) -> Self {
let linear = LinearConfig::new(in_features, out_features).init(device);
let bn = if use_bn {
Some(BatchNormConfig::new(out_features).init(device))
} else {
None
};
let dropout_layer = DropoutConfig::new(dropout).init();
Self { linear, bn, dropout: dropout_layer }
}
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let out = self.linear.forward(x);
let out = Relu::new().forward(out);
let out = if let Some(ref bn) = self.bn {
let [batch, features] = out.dims();
let out = out.reshape([batch, features, 1]);
let out = bn.forward(out);
out.reshape([batch, features])
} else {
out
};
self.dropout.forward(out)
}
}
#[derive(Module, Debug)]
pub struct MLP<B: Backend> {
gap: Option<AdaptiveAvgPool1d>,
blocks: Vec<MLPBlock<B>>,
head: Linear<B>,
#[module(skip)]
in_features: usize,
#[module(skip)]
use_gap: bool,
}
impl<B: Backend> MLP<B> {
pub fn new(config: MLPConfig, device: &B::Device) -> Self {
let use_gap = config.pool == "gap";
let (gap, in_features) = if use_gap {
(Some(AdaptiveAvgPool1dConfig::new(1).init()), config.n_vars)
} else {
(None, config.n_vars * config.seq_len)
};
let mut blocks = Vec::new();
let mut prev_size = in_features;
for &hidden_size in &config.hidden_sizes {
blocks.push(MLPBlock::new(prev_size, hidden_size, config.dropout, config.use_bn, device));
prev_size = hidden_size;
}
let head = LinearConfig::new(prev_size, config.n_classes).init(device);
Self {
gap,
blocks,
head,
in_features,
use_gap,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, n_vars, seq_len] = x.dims();
let out = if self.use_gap {
let out = self.gap.as_ref().unwrap().forward(x);
out.reshape([batch_size, n_vars])
} else {
x.reshape([batch_size, n_vars * seq_len])
};
let mut out = out;
for block in &self.blocks {
out = block.forward(out);
}
self.head.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mlp_config_default() {
let config = MLPConfig::default();
assert_eq!(config.hidden_sizes, vec![500, 500, 500]);
assert_eq!(config.dropout, 0.1);
assert!(!config.use_bn);
}
#[test]
fn test_mlp_config_new() {
let config = MLPConfig::new(3, 200, 10);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
}
#[test]
fn test_mlp_config_builder() {
let config = MLPConfig::new(3, 100, 5)
.with_hidden_sizes(vec![256, 128])
.with_dropout(0.3)
.with_bn(true)
.with_pool("gap");
assert_eq!(config.hidden_sizes, vec![256, 128]);
assert_eq!(config.dropout, 0.3);
assert!(config.use_bn);
assert_eq!(config.pool, "gap");
}
}