use burn::nn::{
conv::{Conv1d, Conv1dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TCNBlockConfig {
pub in_channels: usize,
pub out_channels: usize,
pub kernel_size: usize,
pub dilation: usize,
pub dropout: f64,
}
impl TCNBlockConfig {
pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize, dilation: usize) -> Self {
Self {
in_channels,
out_channels,
kernel_size,
dilation,
dropout: 0.1,
}
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> TCNBlock<B> {
TCNBlock::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct TCNBlock<B: Backend> {
conv1: Conv1d<B>,
conv2: Conv1d<B>,
dropout: Dropout,
residual: Option<Conv1d<B>>,
#[module(skip)]
padding: usize,
}
impl<B: Backend> TCNBlock<B> {
pub fn new(config: TCNBlockConfig, device: &B::Device) -> Self {
let padding = (config.kernel_size - 1) * config.dilation;
let conv1 = Conv1dConfig::new(config.in_channels, config.out_channels, config.kernel_size)
.with_dilation(config.dilation)
.with_padding(burn::nn::PaddingConfig1d::Explicit(padding))
.init(device);
let conv2 = Conv1dConfig::new(config.out_channels, config.out_channels, config.kernel_size)
.with_dilation(config.dilation)
.with_padding(burn::nn::PaddingConfig1d::Explicit(padding))
.init(device);
let residual = if config.in_channels != config.out_channels {
Some(
Conv1dConfig::new(config.in_channels, config.out_channels, 1)
.init(device),
)
} else {
None
};
let dropout = DropoutConfig::new(config.dropout).init();
Self {
conv1,
conv2,
dropout,
residual,
padding,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [_batch, _, seq_len] = x.dims();
let out = self.conv1.forward(x.clone());
let out_dims = out.dims();
let out = out.slice([0..out_dims[0], 0..out_dims[1], 0..seq_len]);
let out = Relu::new().forward(out);
let out = self.dropout.forward(out);
let out = self.conv2.forward(out);
let out_dims = out.dims();
let out = out.slice([0..out_dims[0], 0..out_dims[1], 0..seq_len]);
let out = Relu::new().forward(out);
let out = self.dropout.forward(out);
let residual = match &self.residual {
Some(conv) => conv.forward(x),
None => x,
};
Relu::new().forward(out + residual)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TCNConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_channels: Vec<usize>,
pub kernel_size: usize,
pub dropout: f64,
}
impl Default for TCNConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_channels: vec![64, 64, 64, 64],
kernel_size: 3,
dropout: 0.1,
}
}
}
impl TCNConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_channels(mut self, channels: Vec<usize>) -> Self {
self.n_channels = channels;
self
}
#[must_use]
pub fn with_kernel_size(mut self, kernel_size: usize) -> Self {
self.kernel_size = kernel_size;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
pub fn receptive_field(&self) -> usize {
let n_layers = self.n_channels.len();
let dilation_sum: usize = (0..n_layers).map(|i| 1 << i).sum();
1 + 2 * (self.kernel_size - 1) * dilation_sum
}
pub fn init<B: Backend>(&self, device: &B::Device) -> TCN<B> {
TCN::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct TCN<B: Backend> {
blocks: Vec<TCNBlock<B>>,
classifier: Linear<B>,
}
impl<B: Backend> TCN<B> {
pub fn new(config: TCNConfig, device: &B::Device) -> Self {
let mut blocks = Vec::new();
let n_layers = config.n_channels.len();
for i in 0..n_layers {
let in_channels = if i == 0 {
config.n_vars
} else {
config.n_channels[i - 1]
};
let out_channels = config.n_channels[i];
let dilation = 1 << i;
let block_config = TCNBlockConfig::new(in_channels, out_channels, config.kernel_size, dilation)
.with_dropout(config.dropout);
blocks.push(block_config.init(device));
}
let final_channels = *config.n_channels.last().unwrap_or(&64);
let classifier = LinearConfig::new(final_channels, config.n_classes).init(device);
Self { blocks, classifier }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let mut out = x;
for block in &self.blocks {
out = block.forward(out);
}
let [batch, channels, _] = out.dims();
let out: Tensor<B, 2> = out.mean_dim(2).reshape([batch, channels]);
self.classifier.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
pub fn num_layers(&self) -> usize {
self.blocks.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcn_config_default() {
let config = TCNConfig::default();
assert_eq!(config.n_vars, 1);
assert_eq!(config.kernel_size, 3);
assert_eq!(config.n_channels.len(), 4);
}
#[test]
fn test_tcn_config_new() {
let config = TCNConfig::new(3, 200, 10);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
}
#[test]
fn test_tcn_config_builder() {
let config = TCNConfig::new(3, 100, 5)
.with_channels(vec![32, 64, 128])
.with_kernel_size(5)
.with_dropout(0.2);
assert_eq!(config.n_channels, vec![32, 64, 128]);
assert_eq!(config.kernel_size, 5);
assert_eq!(config.dropout, 0.2);
}
#[test]
fn test_receptive_field() {
let config = TCNConfig::default();
assert_eq!(config.receptive_field(), 61);
let config = TCNConfig::new(1, 100, 2)
.with_channels(vec![64, 64, 64])
.with_kernel_size(5);
assert_eq!(config.receptive_field(), 57);
}
#[test]
fn test_tcn_block_config() {
let config = TCNBlockConfig::new(32, 64, 3, 4)
.with_dropout(0.2);
assert_eq!(config.in_channels, 32);
assert_eq!(config.out_channels, 64);
assert_eq!(config.kernel_size, 3);
assert_eq!(config.dilation, 4);
assert_eq!(config.dropout, 0.2);
}
}