use burn::module::Module;
use burn::nn::{
conv::{Conv1d, Conv1dConfig},
BatchNorm, BatchNormConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig, PaddingConfig1d,
};
use burn::prelude::*;
use crate::config::SeizureTransformerConfig;
#[derive(Module, Debug)]
pub struct FusedMultiheadAttention<B: Backend> {
pub in_proj: Linear<B>,
pub out_proj: Linear<B>,
pub num_heads: usize,
pub head_dim: usize,
}
impl<B: Backend> FusedMultiheadAttention<B> {
pub fn new(d_model: usize, num_heads: usize, device: &B::Device) -> Self {
Self {
in_proj: LinearConfig::new(d_model, 3 * d_model)
.with_bias(true)
.init(device),
out_proj: LinearConfig::new(d_model, d_model)
.with_bias(true)
.init(device),
num_heads,
head_dim: d_model / num_heads,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, s, d] = x.dims();
let h = self.num_heads;
let hd = self.head_dim;
let qkv = self.in_proj.forward(x);
let q = qkv.clone().slice([0..b, 0..s, 0..d]);
let k = qkv.clone().slice([0..b, 0..s, d..(2 * d)]);
let v = qkv.slice([0..b, 0..s, (2 * d)..(3 * d)]);
let q = q.reshape([b, s, h, hd]).swap_dims(1, 2);
let k = k.reshape([b, s, h, hd]).swap_dims(1, 2);
let v = v.reshape([b, s, h, hd]).swap_dims(1, 2);
let scores = q.matmul(k.swap_dims(2, 3)) / (hd as f64).sqrt();
let attn = burn::tensor::activation::softmax(scores, 3);
let out = attn.matmul(v);
let out = out.swap_dims(1, 2).reshape([b, s, d]);
self.out_proj.forward(out)
}
}
#[derive(Module, Debug)]
pub struct TransformerEncoderLayer<B: Backend> {
pub self_attn: FusedMultiheadAttention<B>,
pub linear1: Linear<B>,
pub linear2: Linear<B>,
pub norm1: LayerNorm<B>,
pub norm2: LayerNorm<B>,
}
impl<B: Backend> TransformerEncoderLayer<B> {
pub fn new(d_model: usize, num_heads: usize, ffn_dim: usize, device: &B::Device) -> Self {
Self {
self_attn: FusedMultiheadAttention::new(d_model, num_heads, device),
linear1: LinearConfig::new(d_model, ffn_dim)
.with_bias(true)
.init(device),
linear2: LinearConfig::new(ffn_dim, d_model)
.with_bias(true)
.init(device),
norm1: LayerNormConfig::new(d_model).init(device),
norm2: LayerNormConfig::new(d_model).init(device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x2 = self.self_attn.forward(x.clone());
let x = self.norm1.forward(x + x2);
let x2 = self.linear1.forward(x.clone());
let x2 = burn::tensor::activation::relu(x2);
let x2 = self.linear2.forward(x2);
self.norm2.forward(x + x2)
}
}
#[derive(Module, Debug)]
pub struct Encoder<B: Backend> {
pub convs: Vec<Conv1d<B>>,
pub paddings: Vec<usize>,
}
impl<B: Backend> Encoder<B> {
pub fn new(
input_channels: usize,
filters: &[usize],
kernel_sizes: &[usize],
in_samples: usize,
device: &B::Device,
) -> Self {
let mut convs = Vec::new();
let mut paddings = Vec::new();
let mut n = in_samples;
for (i, (&out_ch, &k)) in filters.iter().zip(kernel_sizes.iter()).enumerate() {
let in_ch = if i == 0 {
input_channels
} else {
filters[i - 1]
};
let conv = Conv1dConfig::new(in_ch, out_ch, k)
.with_padding(PaddingConfig1d::Explicit(k / 2))
.with_bias(true)
.init(device);
convs.push(conv);
let pad = n % 2;
paddings.push(pad);
n = (n + pad) / 2;
}
Self { convs, paddings }
}
pub fn forward(&self, mut x: Tensor<B, 3>) -> (Tensor<B, 3>, Vec<Tensor<B, 3>>) {
let mut skips = Vec::with_capacity(self.convs.len());
for (conv, pad) in self.convs.iter().zip(self.paddings.iter()) {
x = elu(conv.forward(x));
skips.push(x.clone());
if *pad != 0 {
x = pad_right_const(x, *pad, -1e10);
}
x = max_pool1d_stride2(x);
}
(x, skips)
}
}
#[derive(Module, Debug)]
pub struct ResCNNBlock<B: Backend> {
pub norm1: BatchNorm<B>,
pub conv1: Conv1d<B>,
pub norm2: BatchNorm<B>,
pub conv2: Conv1d<B>,
pub manual_padding: bool,
}
impl<B: Backend> ResCNNBlock<B> {
pub fn new(filters: usize, ker: usize, device: &B::Device) -> Self {
let (manual_padding, padding) = if ker == 3 { (false, 1) } else { (true, 0) };
Self {
norm1: BatchNormConfig::new(filters)
.with_epsilon(1e-3)
.init(device),
conv1: Conv1dConfig::new(filters, filters, ker)
.with_padding(PaddingConfig1d::Explicit(padding))
.with_bias(true)
.init(device),
norm2: BatchNormConfig::new(filters)
.with_epsilon(1e-3)
.init(device),
conv2: Conv1dConfig::new(filters, filters, ker)
.with_padding(PaddingConfig1d::Explicit(padding))
.with_bias(true)
.init(device),
manual_padding,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let mut y = burn::tensor::activation::relu(self.norm1.forward(x.clone()));
if self.manual_padding {
y = pad_right_const(y, 1, 0.0);
}
y = self.conv1.forward(y);
y = burn::tensor::activation::relu(self.norm2.forward(y));
if self.manual_padding {
y = pad_right_const(y, 1, 0.0);
}
y = self.conv2.forward(y);
x + y
}
}
#[derive(Module, Debug)]
pub struct ResCNNStack<B: Backend> {
pub blocks: Vec<ResCNNBlock<B>>,
}
impl<B: Backend> ResCNNStack<B> {
pub fn new(filters: usize, kernels: &[usize], device: &B::Device) -> Self {
Self {
blocks: kernels
.iter()
.map(|&k| ResCNNBlock::new(filters, k, device))
.collect(),
}
}
pub fn forward(&self, mut x: Tensor<B, 3>) -> Tensor<B, 3> {
for block in self.blocks.iter() {
x = block.forward(x);
}
x
}
}
#[derive(Module, Debug)]
pub struct Decoder<B: Backend> {
pub convs: Vec<Conv1d<B>>,
pub crops: Vec<usize>,
}
impl<B: Backend> Decoder<B> {
pub fn new(
input_channels: usize,
filters: &[usize],
kernel_sizes: &[usize],
out_samples: usize,
device: &B::Device,
) -> Self {
let mut crops = Vec::new();
let mut current_samples = out_samples;
for (i, _) in filters.iter().enumerate() {
let pad = current_samples % 2;
current_samples = (current_samples + pad) / 2;
if pad == 1 {
crops.push(filters.len() - 1 - i);
}
}
let mut convs = Vec::new();
for (i, (&out_ch, &k)) in filters.iter().zip(kernel_sizes.iter()).enumerate() {
let in_ch = if i == 0 {
input_channels
} else {
filters[i - 1]
};
let conv = Conv1dConfig::new(in_ch, out_ch, k)
.with_padding(PaddingConfig1d::Explicit(k / 2))
.with_bias(true)
.init(device);
convs.push(conv);
}
Self { convs, crops }
}
pub fn forward(&self, mut x: Tensor<B, 3>, skips: &[Tensor<B, 3>]) -> Tensor<B, 3> {
for (i, conv) in self.convs.iter().enumerate() {
x = upsample_nearest_2x(x);
if self.crops.contains(&i) {
x = crop_right_one(x);
}
x = elu(conv.forward(x));
if i < skips.len() {
let skip = skips[skips.len() - 1 - i].clone();
x = x + skip;
}
}
x
}
}
#[derive(Module, Debug)]
pub struct PositionalEncoding<B: Backend> {
pub pe: Tensor<B, 3>, }
impl<B: Backend> PositionalEncoding<B> {
pub fn new(d_model: usize, max_len: usize, device: &B::Device) -> Self {
let mut data = vec![0f32; max_len * d_model];
for pos in 0..max_len {
for i in (0..d_model).step_by(2) {
let div = (10000.0f32).powf(-(i as f32) / (d_model as f32));
data[pos * d_model + i] = (pos as f32 * div).sin();
if i + 1 < d_model {
data[pos * d_model + i + 1] = (pos as f32 * div).cos();
}
}
}
let pe = Tensor::<B, 3>::from_data(TensorData::new(data, [1, max_len, d_model]), device);
Self { pe }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, t, d] = x.dims();
let pe = self.pe.clone().slice([0..1, 0..t, 0..d]).expand([b, t, d]);
x + pe
}
}
#[derive(Module, Debug)]
pub struct SeizureTransformer<B: Backend> {
pub in_channels: usize,
pub in_samples: usize,
pub encoder: Encoder<B>,
pub res_cnn_stack: ResCNNStack<B>,
pub position_encoding: PositionalEncoding<B>,
pub transformer_encoder: Vec<TransformerEncoderLayer<B>>,
pub decoder_d: Decoder<B>,
pub conv_d: Conv1d<B>,
}
impl<B: Backend> SeizureTransformer<B> {
pub fn new(config: &SeizureTransformerConfig, device: &B::Device) -> Self {
let filters = [32, 64, 128, 256, 512];
let kernel_sizes = [11, 9, 7, 7, 5, 5, 3];
let res_cnn_kernels = [3, 3, 3, 3, 2, 3, 2];
let encoder = Encoder::new(
config.in_channels,
&filters,
&kernel_sizes,
config.in_samples,
device,
);
let res_cnn_stack = ResCNNStack::new(*filters.last().unwrap(), &res_cnn_kernels, device);
let position_encoding =
PositionalEncoding::new(*filters.last().unwrap(), config.max_pos_len, device);
let transformer_encoder = (0..config.num_layers)
.map(|_| {
TransformerEncoderLayer::new(
*filters.last().unwrap(),
config.num_heads,
config.dim_feedforward,
device,
)
})
.collect();
let rev_filters = [512, 256, 128, 64, 32];
let rev_kernels = [3, 5, 5, 7, 7, 9, 11];
let decoder_d = Decoder::new(
*filters.last().unwrap(),
&rev_filters,
&rev_kernels,
config.in_samples,
device,
);
let conv_d = Conv1dConfig::new(*rev_filters.last().unwrap(), 1, 11)
.with_padding(PaddingConfig1d::Explicit(5))
.with_bias(true)
.init(device);
Self {
in_channels: config.in_channels,
in_samples: config.in_samples,
encoder,
res_cnn_stack,
position_encoding,
transformer_encoder,
decoder_d,
conv_d,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [_b, c, s] = x.dims();
assert_eq!(c, self.in_channels);
assert_eq!(s, self.in_samples);
let (x, skips) = self.encoder.forward(x);
let res_x = self.res_cnn_stack.forward(x);
let mut x = res_x.clone().swap_dims(1, 2);
x = self.position_encoding.forward(x);
for layer in self.transformer_encoder.iter() {
x = layer.forward(x);
}
x = x.swap_dims(1, 2) + res_x;
let detection = self.decoder_d.forward(x, &skips);
let detection = burn::tensor::activation::sigmoid(self.conv_d.forward(detection));
let [bb, _, tt] = detection.dims();
detection.reshape([bb, tt])
}
}
fn elu<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
let zeros = Tensor::zeros_like(&x);
let positive = x.clone().max_pair(zeros.clone());
let negative = x.min_pair(zeros).exp().sub_scalar(1.0);
positive + negative
}
fn pad_right_const<B: Backend>(x: Tensor<B, 3>, pad: usize, value: f32) -> Tensor<B, 3> {
if pad == 0 {
return x;
}
let [b, c, _t] = x.dims();
let rhs = Tensor::<B, 3>::ones([b, c, pad], &x.device()).mul_scalar(value);
Tensor::cat(vec![x, rhs], 2)
}
fn crop_right_one<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, c, t] = x.dims();
x.slice([0..b, 0..c, 0..(t - 1)])
}
fn upsample_nearest_2x<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, c, t] = x.dims();
x.unsqueeze_dim::<4>(3)
.expand([b, c, t, 2])
.reshape([b, c, t * 2])
}
fn max_pool1d_stride2<B: Backend>(x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, c, t] = x.dims();
let out_t = t / 2;
let x = x.reshape([b, c, out_t, 2]);
let (x, _idx) = x.max_dim_with_indices(3);
x.reshape([b, c, out_t])
}