eegpt-rs 0.0.1

EEGPT EEG Foundation Model — inference in Rust with Burn ML
Documentation
/// Multi-head self-attention for EEGPT.
///
/// Python: _Attention with qkv Linear, scaled_dot_product_attention, proj Linear.
/// No RoPE at inference (use_rope=False by default).

use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::softmax;

#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
    pub qkv: Linear<B>,
    pub proj: Linear<B>,
    pub n_heads: usize,
    pub head_dim: usize,
}

impl<B: Backend> Attention<B> {
    pub fn new(dim: usize, n_heads: usize, qkv_bias: bool, device: &B::Device) -> Self {
        let head_dim = dim / n_heads;
        Self {
            qkv: LinearConfig::new(dim, dim * 3).with_bias(qkv_bias).init(device),
            proj: LinearConfig::new(dim, dim).with_bias(true).init(device),
            n_heads,
            head_dim,
        }
    }

    /// x: [B, S, D] → [B, S, D]
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        let [b, s, _] = x.dims();
        let (h, dh) = (self.n_heads, self.head_dim);
        let dim = h * dh;

        let qkv = self.qkv.forward(x);
        let q = qkv.clone().narrow(2, 0, dim).reshape([b, s, h, dh]).swap_dims(1, 2);
        let k = qkv.clone().narrow(2, dim, dim).reshape([b, s, h, dh]).swap_dims(1, 2);
        let v = qkv.narrow(2, dim * 2, dim).reshape([b, s, h, dh]).swap_dims(1, 2);

        let scale = (dh as f64).powf(-0.5) as f32;
        let attn = softmax(q.matmul(k.transpose()).mul_scalar(scale), 3);
        let out = attn.matmul(v);
        let out = out.swap_dims(1, 2).reshape([b, s, dim]);
        self.proj.forward(out)
    }
}