use burn::nn::{
conv::{Conv1d, Conv1dConfig},
pool::{AdaptiveAvgPool1d, AdaptiveAvgPool1dConfig, MaxPool1d, MaxPool1dConfig},
BatchNorm, BatchNormConfig, Linear, LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InceptionTimePlusConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_blocks: usize,
pub n_filters: usize,
pub kernel_sizes: [usize; 3],
pub bottleneck_dim: usize,
pub dropout: f64,
}
impl Default for InceptionTimePlusConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_blocks: 6,
n_filters: 32,
kernel_sizes: [9, 19, 39],
bottleneck_dim: 32,
dropout: 0.0,
}
}
}
impl InceptionTimePlusConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> InceptionTimePlus<B> {
InceptionTimePlus::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct InceptionBlock<B: Backend> {
bottleneck: Option<Conv1d<B>>,
conv1: Conv1d<B>,
conv2: Conv1d<B>,
conv3: Conv1d<B>,
maxpool: MaxPool1d,
conv_maxpool: Conv1d<B>,
bn: BatchNorm<B, 1>,
}
impl<B: Backend> InceptionBlock<B> {
pub fn new(
in_channels: usize,
n_filters: usize,
kernel_sizes: [usize; 3],
bottleneck_dim: usize,
device: &B::Device,
) -> Self {
let (conv_in, bottleneck) = if bottleneck_dim > 0 {
let bn_conv = Conv1dConfig::new(in_channels, bottleneck_dim, 1)
.with_bias(false)
.init(device);
(bottleneck_dim, Some(bn_conv))
} else {
(in_channels, None)
};
let conv1 = Conv1dConfig::new(conv_in, n_filters, kernel_sizes[0])
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let conv2 = Conv1dConfig::new(conv_in, n_filters, kernel_sizes[1])
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let conv3 = Conv1dConfig::new(conv_in, n_filters, kernel_sizes[2])
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let maxpool = MaxPool1dConfig::new(3)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_stride(1)
.init();
let conv_maxpool = Conv1dConfig::new(in_channels, n_filters, 1)
.with_bias(false)
.init(device);
let out_channels = n_filters * 4;
let bn = BatchNormConfig::new(out_channels).init(device);
Self {
bottleneck,
conv1,
conv2,
conv3,
maxpool,
conv_maxpool,
bn,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x_bn = if let Some(ref bottleneck) = self.bottleneck {
bottleneck.forward(x.clone())
} else {
x.clone()
};
let out1 = self.conv1.forward(x_bn.clone());
let out2 = self.conv2.forward(x_bn.clone());
let out3 = self.conv3.forward(x_bn);
let out_pool = self.maxpool.forward(x);
let out_pool = self.conv_maxpool.forward(out_pool);
let out = Tensor::cat(vec![out1, out2, out3, out_pool], 1);
let out = self.bn.forward(out);
Relu::new().forward(out)
}
}
#[derive(Module, Debug)]
pub struct InceptionTimePlus<B: Backend> {
blocks: Vec<InceptionBlock<B>>,
residual_convs: Vec<Conv1d<B>>,
residual_bns: Vec<BatchNorm<B, 1>>,
gap: AdaptiveAvgPool1d,
fc: Linear<B>,
}
impl<B: Backend> InceptionTimePlus<B> {
pub fn new(config: InceptionTimePlusConfig, device: &B::Device) -> Self {
let mut blocks = Vec::new();
let mut residual_convs = Vec::new();
let mut residual_bns = Vec::new();
let n_filters = config.n_filters;
let out_channels = n_filters * 4;
for i in 0..config.n_blocks {
let in_channels = if i == 0 {
config.n_vars
} else {
out_channels
};
let block = InceptionBlock::new(
in_channels,
n_filters,
config.kernel_sizes,
config.bottleneck_dim,
device,
);
blocks.push(block);
if (i + 1) % 3 == 0 {
let res_in = if i < 3 { config.n_vars } else { out_channels };
let res_conv = Conv1dConfig::new(res_in, out_channels, 1)
.with_bias(false)
.init(device);
let res_bn = BatchNormConfig::new(out_channels).init(device);
residual_convs.push(res_conv);
residual_bns.push(res_bn);
}
}
let gap = AdaptiveAvgPool1dConfig::new(1).init();
let fc = LinearConfig::new(out_channels, config.n_classes).init(device);
Self {
blocks,
residual_convs,
residual_bns,
gap,
fc,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let mut out = x.clone();
let mut residual = x;
let mut res_idx = 0;
for (i, block) in self.blocks.iter().enumerate() {
out = block.forward(out);
if (i + 1) % 3 == 0 && res_idx < self.residual_convs.len() {
let res = self.residual_convs[res_idx].forward(residual.clone());
let res = self.residual_bns[res_idx].forward(res);
out = out + res;
out = Relu::new().forward(out);
residual = out.clone();
res_idx += 1;
}
}
let out = self.gap.forward(out);
let [batch, channels, _] = out.dims();
let out = out.reshape([batch, channels]);
self.fc.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = InceptionTimePlusConfig::default();
assert_eq!(config.n_blocks, 6);
assert_eq!(config.n_filters, 32);
}
}