eegpt-rs 0.0.1

EEGPT EEG Foundation Model — inference in Rust with Burn ML
Documentation
/// EEGPT — full model.
///
/// Architecture:
///   1. Optional channel projection (Conv1d or Linear)
///   2. EEGTransformer: patch_embed → chan_embed → summary_token → transformer → norm
///   3. Final layer: LinearConstraintProbe or flatten+linear

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use crate::model::transformer::EEGTransformer;

#[derive(Module, Debug)]
pub struct EEGPT<B: Backend> {
    pub target_encoder: EEGTransformer<B>,
    pub probe1: Linear<B>,
    pub probe2: Linear<B>,
    pub n_outputs: usize,
    pub embed_num: usize,
    pub embed_dim: usize,
    pub n_patches: usize,
    pub probe_hidden_dim: usize,
}

impl<B: Backend> EEGPT<B> {
    pub fn new(
        n_outputs: usize, n_chans: usize, n_times: usize,
        patch_size: usize, patch_stride: usize,
        embed_num: usize, embed_dim: usize,
        depth: usize, n_heads: usize, mlp_ratio: f64,
        qkv_bias: bool, n_chan_embeddings: usize,
        probe_hidden_dim: usize, eps: f64,
        device: &B::Device,
    ) -> Self {
        let target_encoder = EEGTransformer::new(
            n_chans, n_times, patch_size, patch_stride,
            embed_num, embed_dim, depth, n_heads, mlp_ratio,
            qkv_bias, n_chan_embeddings, eps, device,
        );
        let n_patches = target_encoder.n_patches;

        // Probe: LinearWithConstraint(embed_num*embed_dim → hidden) then LinearWithConstraint(n_patches*hidden → n_outputs)
        let probe1 = LinearConfig::new(embed_num * embed_dim, probe_hidden_dim).with_bias(true).init(device);
        let probe2 = LinearConfig::new(n_patches * probe_hidden_dim, n_outputs).with_bias(true).init(device);

        Self { target_encoder, probe1, probe2, n_outputs, embed_num, embed_dim, n_patches, probe_hidden_dim }
    }

    /// x: [B, C, T], chan_ids: [1, C] → [B, n_outputs]
    pub fn forward(&self, x: Tensor<B, 3>, chan_ids: Tensor<B, 2, Int>) -> Tensor<B, 2> {
        let [batch, _, _] = x.dims();

        // Encoder: [B, C, T] → [B, n_patches, embed_num, embed_dim]
        let z = self.target_encoder.forward(x, chan_ids);

        // Probe
        // Flatten last two dims: [B, n_patches, embed_num*embed_dim]
        let z = z.reshape([batch, self.n_patches, self.embed_num * self.embed_dim]);
        // probe1: [B, n_patches, hidden]
        let h = self.probe1.forward(z);
        // Flatten: [B, n_patches*hidden]
        let h = h.reshape([batch, self.n_patches * self.probe_hidden_dim]);
        // probe2: [B, n_outputs]
        self.probe2.forward(h)
    }
}