eegpt-rs 0.0.1

EEGPT EEG Foundation Model — inference in Rust with Burn ML
Documentation
/// EEG Transformer backbone for EEGPT.
///
/// Python: _EEGTransformer
///   1. patch_embed: Conv2d → [B, n_patches, C, embed_dim]
///   2. Add channel embeddings
///   3. For each patch group: concat summary_token, run through blocks
///   4. Extract summary tokens, apply norm
///   5. Reshape to [B, n_patches, embed_num, embed_dim]

use burn::prelude::*;
use burn::module::{Param, ParamId};
use burn::nn::{Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig};

use crate::model::patch_embed::PatchEmbed;
use crate::model::block::TransformerBlock;

#[derive(Module, Debug)]
pub struct EEGTransformer<B: Backend> {
    pub patch_embed: PatchEmbed<B>,
    pub chan_embed: Embedding<B>,
    pub blocks: Vec<TransformerBlock<B>>,
    pub norm: LayerNorm<B>,
    pub summary_token: Param<Tensor<B, 3>>,
    pub embed_dim: usize,
    pub embed_num: usize,
    pub n_patches: usize,
}

impl<B: Backend> EEGTransformer<B> {
    pub fn new(
        n_chans: usize, n_times: usize,
        patch_size: usize, patch_stride: usize,
        embed_num: usize, embed_dim: usize,
        depth: usize, n_heads: usize, mlp_ratio: f64,
        qkv_bias: bool, n_chan_embeddings: usize,
        eps: f64, device: &B::Device,
    ) -> Self {
        let patch_embed = PatchEmbed::new(n_chans, n_times, patch_size, patch_stride, embed_dim, device);
        let n_patches = patch_embed.n_patches;
        let chan_embed = EmbeddingConfig::new(n_chan_embeddings, embed_dim).init(device);
        let blocks = (0..depth)
            .map(|_| TransformerBlock::new(embed_dim, n_heads, mlp_ratio, qkv_bias, eps, device))
            .collect();
        let norm = LayerNormConfig::new(embed_dim).with_epsilon(eps).init(device);
        let summary_token = Param::initialized(
            ParamId::new(), Tensor::zeros([1, embed_num, embed_dim], device),
        );

        Self { patch_embed, chan_embed, blocks, norm, summary_token, embed_dim, embed_num, n_patches }
    }

    /// x: [B, C, T], chan_ids: [1, C] → [B, n_patches, embed_num, embed_dim]
    pub fn forward(&self, x: Tensor<B, 3>, chan_ids: Tensor<B, 2, Int>) -> Tensor<B, 4> {
        let [batch, n_chans, _] = x.dims();

        // 1. Patch embedding: [B, C, T] → [B, n_patches, C, embed_dim]
        let x = self.patch_embed.forward(x);
        let n_patches = x.dims()[1];

        // 2. Add channel embedding: [1, C] → [1, 1, C, embed_dim]
        let chan_emb = self.chan_embed.forward(chan_ids).unsqueeze_dim::<4>(0);
        let x = x + chan_emb;

        // 3. Process each patch group through transformer
        // Flatten: [B, n_patches, C, D] → [B*n_patches, C, D]
        let x = x.reshape([batch * n_patches, n_chans, self.embed_dim]);

        // Concat summary tokens: [B*n_patches, C+embed_num, D]
        let summary = self.summary_token.val()
            .expand([batch * n_patches, self.embed_num, self.embed_dim]);
        let x = Tensor::cat(vec![x, summary], 1);

        // Run through transformer blocks
        let mut x = x;
        for block in &self.blocks {
            x = block.forward(x);
        }

        // 4. Extract summary tokens (last embed_num tokens)
        let seq_len = x.dims()[1];
        let x = x.narrow(1, seq_len - self.embed_num, self.embed_num);
        // [B*n_patches, embed_num, embed_dim]

        // Apply norm
        let x = self.norm.forward(x);

        // 5. Reshape: [B*n_patches, embed_num, D] → [B, n_patches, embed_num, D]
        x.reshape([batch, n_patches, self.embed_num, self.embed_dim])
    }
}