eegpt-rs 0.0.1

EEGPT EEG Foundation Model — inference in Rust with Burn ML
Documentation
/// Transformer Block for EEGPT.
///
/// Python: _Block = norm1 → attention → residual → norm2 → MLP → residual
/// MLP = Linear(D, 4D) → GELU → Linear(4D, D)
/// At inference: DropPath is identity.

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, LayerNorm, LayerNormConfig};
use burn::tensor::activation::gelu;
use crate::model::attention::Attention;

#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
    pub norm1: LayerNorm<B>,
    pub attn: Attention<B>,
    pub norm2: LayerNorm<B>,
    pub mlp_fc1: Linear<B>,
    pub mlp_fc2: Linear<B>,
}

impl<B: Backend> TransformerBlock<B> {
    pub fn new(
        dim: usize, n_heads: usize, mlp_ratio: f64,
        qkv_bias: bool, eps: f64, device: &B::Device,
    ) -> Self {
        let mlp_hidden = (dim as f64 * mlp_ratio) as usize;
        Self {
            norm1: LayerNormConfig::new(dim).with_epsilon(eps).init(device),
            attn: Attention::new(dim, n_heads, qkv_bias, device),
            norm2: LayerNormConfig::new(dim).with_epsilon(eps).init(device),
            mlp_fc1: LinearConfig::new(dim, mlp_hidden).with_bias(true).init(device),
            mlp_fc2: LinearConfig::new(mlp_hidden, dim).with_bias(true).init(device),
        }
    }

    /// x: [B, S, D] → [B, S, D]
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        let h = self.attn.forward(self.norm1.forward(x.clone()));
        let x = x + h;
        let h = self.mlp_fc2.forward(gelu(self.mlp_fc1.forward(self.norm2.forward(x.clone()))));
        x + h
    }
}