use burn::nn::{
conv::{Conv1d, Conv1dConfig},
pool::{AdaptiveAvgPool1d, AdaptiveAvgPool1dConfig, MaxPool1d, MaxPool1dConfig},
BatchNorm, BatchNormConfig, Linear, LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct XceptionTimeConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_filters: usize,
pub kernel_size: usize,
pub n_blocks: usize,
pub dropout: f64,
}
impl Default for XceptionTimeConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_filters: 128,
kernel_size: 39,
n_blocks: 4,
dropout: 0.0,
}
}
}
impl XceptionTimeConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_filters(mut self, n_filters: usize) -> Self {
self.n_filters = n_filters;
self
}
#[must_use]
pub fn with_kernel_size(mut self, kernel_size: usize) -> Self {
self.kernel_size = kernel_size;
self
}
#[must_use]
pub fn with_n_blocks(mut self, n_blocks: usize) -> Self {
self.n_blocks = n_blocks;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> XceptionTime<B> {
XceptionTime::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct SeparableConv1d<B: Backend> {
depthwise: Conv1d<B>,
pointwise: Conv1d<B>,
bn: BatchNorm<B, 1>,
}
impl<B: Backend> SeparableConv1d<B> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
device: &B::Device,
) -> Self {
let depthwise = Conv1dConfig::new(in_channels, in_channels, kernel_size)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_groups(in_channels)
.with_bias(false)
.init(device);
let pointwise = Conv1dConfig::new(in_channels, out_channels, 1)
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(out_channels).init(device);
Self {
depthwise,
pointwise,
bn,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let out = self.depthwise.forward(x);
let out = self.pointwise.forward(out);
let out = self.bn.forward(out);
Relu::new().forward(out)
}
}
#[derive(Module, Debug)]
pub struct XceptionBlock<B: Backend> {
sep_conv1: SeparableConv1d<B>,
sep_conv2: SeparableConv1d<B>,
residual_conv: Conv1d<B>,
residual_bn: BatchNorm<B, 1>,
maxpool: MaxPool1d,
}
impl<B: Backend> XceptionBlock<B> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
device: &B::Device,
) -> Self {
let sep_conv1 = SeparableConv1d::new(in_channels, out_channels, kernel_size, device);
let sep_conv2 = SeparableConv1d::new(out_channels, out_channels, kernel_size, device);
let residual_conv = Conv1dConfig::new(in_channels, out_channels, 1)
.with_bias(false)
.init(device);
let residual_bn = BatchNormConfig::new(out_channels).init(device);
let maxpool = MaxPool1dConfig::new(3)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_stride(1)
.init();
Self {
sep_conv1,
sep_conv2,
residual_conv,
residual_bn,
maxpool,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let out = self.sep_conv1.forward(x.clone());
let out = self.sep_conv2.forward(out);
let out = self.maxpool.forward(out);
let residual = self.residual_conv.forward(x);
let residual = self.residual_bn.forward(residual);
Relu::new().forward(out + residual)
}
}
#[derive(Module, Debug)]
pub struct XceptionTime<B: Backend> {
entry_conv: Conv1d<B>,
entry_bn: BatchNorm<B, 1>,
blocks: Vec<XceptionBlock<B>>,
gap: AdaptiveAvgPool1d,
fc: Linear<B>,
}
impl<B: Backend> XceptionTime<B> {
pub fn new(config: XceptionTimeConfig, device: &B::Device) -> Self {
let entry_conv = Conv1dConfig::new(config.n_vars, config.n_filters, config.kernel_size)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let entry_bn = BatchNormConfig::new(config.n_filters).init(device);
let mut blocks = Vec::new();
for i in 0..config.n_blocks {
let in_channels = if i == 0 {
config.n_filters
} else {
config.n_filters * 2_usize.pow(i as u32 - 1).min(4)
};
let out_channels = config.n_filters * 2_usize.pow(i as u32).min(4);
blocks.push(XceptionBlock::new(
in_channels,
out_channels,
config.kernel_size,
device,
));
}
let final_channels = config.n_filters * 2_usize.pow((config.n_blocks - 1) as u32).min(4);
let gap = AdaptiveAvgPool1dConfig::new(1).init();
let fc = LinearConfig::new(final_channels, config.n_classes).init(device);
Self {
entry_conv,
entry_bn,
blocks,
gap,
fc,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let mut out = self.entry_conv.forward(x);
out = self.entry_bn.forward(out);
out = Relu::new().forward(out);
for block in &self.blocks {
out = block.forward(out);
}
let out = self.gap.forward(out);
let [batch, channels, _] = out.dims();
let out = out.reshape([batch, channels]);
self.fc.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xception_config_default() {
let config = XceptionTimeConfig::default();
assert_eq!(config.n_vars, 1);
assert_eq!(config.n_filters, 128);
assert_eq!(config.kernel_size, 39);
assert_eq!(config.n_blocks, 4);
}
#[test]
fn test_xception_config_builder() {
let config = XceptionTimeConfig::new(3, 200, 10)
.with_filters(64)
.with_kernel_size(15)
.with_n_blocks(3);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
assert_eq!(config.n_filters, 64);
assert_eq!(config.kernel_size, 15);
assert_eq!(config.n_blocks, 3);
}
}