eegpt-rs 0.0.1

EEGPT EEG Foundation Model — inference in Rust with Burn ML
Documentation
/// Patch Embedding for EEGPT.
///
/// Python: _PatchEmbed uses Conv2d(1, embed_dim, (1, patch_size), stride=(1, patch_stride))
/// Input: [B, C, T] → pad → [B, 1, C, T_padded] → Conv2d → [B, embed_dim, C, n_patches]
/// → rearrange to [B, n_patches, C, embed_dim]

use burn::prelude::*;
use burn::nn::conv::{Conv2d, Conv2dConfig};

#[derive(Module, Debug)]
pub struct PatchEmbed<B: Backend> {
    pub proj: Conv2d<B>,
    pub n_chans: usize,
    pub n_times: usize,
    pub patch_size: usize,
    pub patch_stride: usize,
    pub padding_size: usize,
    pub n_patches: usize,
}

impl<B: Backend> PatchEmbed<B> {
    pub fn new(
        n_chans: usize, n_times: usize,
        patch_size: usize, patch_stride: usize,
        embed_dim: usize, device: &B::Device,
    ) -> Self {
        // Calculate padding
        let remainder = (n_times - patch_size) % patch_stride;
        let padding_size = if remainder != 0 { patch_stride - remainder } else { 0 };
        let n_times_padded = n_times + padding_size;
        let n_patches = (n_times_padded - patch_size) / patch_stride + 1;

        let proj = Conv2dConfig::new([1, embed_dim], [1, patch_size])
            .with_stride([1, patch_stride])
            .with_padding(burn::nn::PaddingConfig2d::Valid)
            .with_bias(true)
            .init(device);

        Self { proj, n_chans, n_times, patch_size, patch_stride, padding_size, n_patches }
    }

    /// x: [B, C, T] → [B, n_patches, C, embed_dim]
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 4> {
        let [batch, n_chans, _n_times] = x.dims();
        let device = x.device();

        // Pad if needed
        let x = if self.padding_size > 0 {
            let pad = Tensor::zeros([batch, n_chans, self.padding_size], &device);
            Tensor::cat(vec![x, pad], 2)
        } else {
            x
        };

        // [B, C, T] → [B, 1, C, T]
        let x = x.unsqueeze_dim::<4>(1);

        // Conv2d: [B, 1, C, T] → [B, embed_dim, C, n_patches]
        let x = self.proj.forward(x);
        let [_b, _embed_dim, _c, _np] = x.dims();

        // Rearrange: [B, embed_dim, C, n_patches] → [B, n_patches, C, embed_dim]
        x.swap_dims(1, 3) // [B, n_patches, C, embed_dim]
    }
}