use burn::nn::{
conv::{Conv1d, Conv1dConfig},
pool::{AdaptiveAvgPool1d, AdaptiveAvgPool1dConfig},
BatchNorm, BatchNormConfig, Linear, LinearConfig, Relu,
};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FCNConfig {
pub n_vars: usize,
pub seq_len: usize,
pub n_classes: usize,
pub n_filters_1: usize,
pub n_filters_2: usize,
pub n_filters_3: usize,
pub kernel_size_1: usize,
pub kernel_size_2: usize,
pub kernel_size_3: usize,
}
impl Default for FCNConfig {
fn default() -> Self {
Self {
n_vars: 1,
seq_len: 100,
n_classes: 2,
n_filters_1: 128,
n_filters_2: 256,
n_filters_3: 128,
kernel_size_1: 8,
kernel_size_2: 5,
kernel_size_3: 3,
}
}
}
impl FCNConfig {
pub fn new(n_vars: usize, seq_len: usize, n_classes: usize) -> Self {
Self {
n_vars,
seq_len,
n_classes,
..Default::default()
}
}
#[must_use]
pub fn with_filters(mut self, n_filters_1: usize, n_filters_2: usize, n_filters_3: usize) -> Self {
self.n_filters_1 = n_filters_1;
self.n_filters_2 = n_filters_2;
self.n_filters_3 = n_filters_3;
self
}
#[must_use]
pub fn with_kernel_sizes(mut self, k1: usize, k2: usize, k3: usize) -> Self {
self.kernel_size_1 = k1;
self.kernel_size_2 = k2;
self.kernel_size_3 = k3;
self
}
pub fn init<B: Backend>(&self, device: &B::Device) -> FCN<B> {
FCN::new(self.clone(), device)
}
}
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Conv1d<B>,
bn: BatchNorm<B, 1>,
}
impl<B: Backend> ConvBlock<B> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
device: &B::Device,
) -> Self {
let conv = Conv1dConfig::new(in_channels, out_channels, kernel_size)
.with_padding(burn::nn::PaddingConfig1d::Same)
.with_bias(false)
.init(device);
let bn = BatchNormConfig::new(out_channels).init(device);
Self { conv, bn }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let out = self.conv.forward(x);
let out = self.bn.forward(out);
Relu::new().forward(out)
}
}
#[derive(Module, Debug)]
pub struct FCN<B: Backend> {
block1: ConvBlock<B>,
block2: ConvBlock<B>,
block3: ConvBlock<B>,
gap: AdaptiveAvgPool1d,
fc: Linear<B>,
}
impl<B: Backend> FCN<B> {
pub fn new(config: FCNConfig, device: &B::Device) -> Self {
let block1 = ConvBlock::new(
config.n_vars,
config.n_filters_1,
config.kernel_size_1,
device,
);
let block2 = ConvBlock::new(
config.n_filters_1,
config.n_filters_2,
config.kernel_size_2,
device,
);
let block3 = ConvBlock::new(
config.n_filters_2,
config.n_filters_3,
config.kernel_size_3,
device,
);
let gap = AdaptiveAvgPool1dConfig::new(1).init();
let fc = LinearConfig::new(config.n_filters_3, config.n_classes).init(device);
Self {
block1,
block2,
block3,
gap,
fc,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let out = self.block1.forward(x);
let out = self.block2.forward(out);
let out = self.block3.forward(out);
let out = self.gap.forward(out);
let [batch, channels, _] = out.dims();
let out = out.reshape([batch, channels]);
self.fc.forward(out)
}
pub fn forward_probs(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let logits = self.forward(x);
softmax(logits, 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fcn_config_default() {
let config = FCNConfig::default();
assert_eq!(config.n_vars, 1);
assert_eq!(config.n_filters_1, 128);
assert_eq!(config.n_filters_2, 256);
assert_eq!(config.n_filters_3, 128);
assert_eq!(config.kernel_size_1, 8);
assert_eq!(config.kernel_size_2, 5);
assert_eq!(config.kernel_size_3, 3);
}
#[test]
fn test_fcn_config_new() {
let config = FCNConfig::new(3, 200, 10);
assert_eq!(config.n_vars, 3);
assert_eq!(config.seq_len, 200);
assert_eq!(config.n_classes, 10);
}
#[test]
fn test_fcn_config_builder() {
let config = FCNConfig::new(3, 100, 5)
.with_filters(64, 128, 64)
.with_kernel_sizes(7, 5, 3);
assert_eq!(config.n_filters_1, 64);
assert_eq!(config.n_filters_2, 128);
assert_eq!(config.n_filters_3, 64);
assert_eq!(config.kernel_size_1, 7);
}
}