use burn::prelude::*;
use burn::module::{Param, ParamId};
use burn::nn::{Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig};
use crate::model::patch_embed::PatchEmbed;
use crate::model::block::TransformerBlock;
#[derive(Module, Debug)]
pub struct EEGTransformer<B: Backend> {
pub patch_embed: PatchEmbed<B>,
pub chan_embed: Embedding<B>,
pub blocks: Vec<TransformerBlock<B>>,
pub norm: LayerNorm<B>,
pub summary_token: Param<Tensor<B, 3>>,
pub embed_dim: usize,
pub embed_num: usize,
pub n_patches: usize,
}
impl<B: Backend> EEGTransformer<B> {
pub fn new(
n_chans: usize, n_times: usize,
patch_size: usize, patch_stride: usize,
embed_num: usize, embed_dim: usize,
depth: usize, n_heads: usize, mlp_ratio: f64,
qkv_bias: bool, n_chan_embeddings: usize,
eps: f64, device: &B::Device,
) -> Self {
let patch_embed = PatchEmbed::new(n_chans, n_times, patch_size, patch_stride, embed_dim, device);
let n_patches = patch_embed.n_patches;
let chan_embed = EmbeddingConfig::new(n_chan_embeddings, embed_dim).init(device);
let blocks = (0..depth)
.map(|_| TransformerBlock::new(embed_dim, n_heads, mlp_ratio, qkv_bias, eps, device))
.collect();
let norm = LayerNormConfig::new(embed_dim).with_epsilon(eps).init(device);
let summary_token = Param::initialized(
ParamId::new(), Tensor::zeros([1, embed_num, embed_dim], device),
);
Self { patch_embed, chan_embed, blocks, norm, summary_token, embed_dim, embed_num, n_patches }
}
pub fn forward(&self, x: Tensor<B, 3>, chan_ids: Tensor<B, 2, Int>) -> Tensor<B, 4> {
let [batch, n_chans, _] = x.dims();
let x = self.patch_embed.forward(x);
let n_patches = x.dims()[1];
let chan_emb = self.chan_embed.forward(chan_ids).unsqueeze_dim::<4>(0);
let x = x + chan_emb;
let x = x.reshape([batch * n_patches, n_chans, self.embed_dim]);
let summary = self.summary_token.val()
.expand([batch * n_patches, self.embed_num, self.embed_dim]);
let x = Tensor::cat(vec![x, summary], 1);
let mut x = x;
for block in &self.blocks {
x = block.forward(x);
}
let seq_len = x.dims()[1];
let x = x.narrow(1, seq_len - self.embed_num, self.embed_num);
let x = self.norm.forward(x);
x.reshape([batch, n_patches, self.embed_num, self.embed_dim])
}
}