Skip to main content

eegdino_rs/model/
encoder.rs

1/// EEG-DINO Encoder: patch embedding → transformer layers with global tokens.
2///
3/// Matches the Python `EEGEncoder` class from `models/eeg_encoder.py`.
4///
5/// Input:  `[B, C, P, L]`  (batch, channels, patches, patch_length)
6/// Output: `[B, num_global + C*P, D]`
7///
8/// Global tokens are injected after layer `global_token_layer` (1-indexed).
9use burn::prelude::*;
10use burn::module::{Param, ParamId};
11
12use crate::config::ModelConfig;
13use super::embedding::{EmbeddingCache, PatchEmbedding};
14use super::transformer::TransformerEncoderLayer;
15
16#[derive(Module, Debug)]
17pub struct EEGEncoder<B: Backend> {
18    pub patch_embedding: PatchEmbedding<B>,
19    pub encoder_layers: Vec<TransformerEncoderLayer<B>>,
20    /// Learnable global tokens: `[1, num_global_tokens, feature_size]`
21    pub global_tokens: Param<Tensor<B, 3>>,
22    pub global_token_layer: usize,
23    pub num_global_tokens: usize,
24}
25
26impl<B: Backend> EEGEncoder<B> {
27    pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
28        let layers: Vec<_> = (0..cfg.num_layers)
29            .map(|_| TransformerEncoderLayer::new(cfg, device))
30            .collect();
31
32        let global_tokens = Param::initialized(
33            ParamId::new(),
34            Tensor::zeros([1, cfg.num_global_tokens, cfg.feature_size], device),
35        );
36
37        Self {
38            patch_embedding: PatchEmbedding::new(cfg, device),
39            encoder_layers: layers,
40            global_tokens,
41            global_token_layer: cfg.global_token_layer,
42            num_global_tokens: cfg.num_global_tokens,
43        }
44    }
45
46    /// Forward pass using a pre-built embedding cache (fast path).
47    ///
48    /// `x_in`: `[B, C, P, L]` → `[B, num_global + C*P, D]`
49    pub fn forward_cached(&self, x_in: Tensor<B, 4>, cache: &EmbeddingCache<B>) -> Tensor<B, 3> {
50        let [b, _c, _p, _l] = x_in.dims();
51        let x = self.patch_embedding.forward_cached(x_in, cache);
52        self.run_transformer(x, b)
53    }
54
55    /// Forward pass without cache (rebuilds constants each call).
56    ///
57    /// `x_in`: `[B, C, P, L]` → `[B, num_global + C*P, D]`
58    pub fn forward(&self, x_in: Tensor<B, 4>) -> Tensor<B, 3> {
59        let [b, _c, _p, _l] = x_in.dims();
60        let x = self.patch_embedding.forward(x_in);
61        self.run_transformer(x, b)
62    }
63
64    fn run_transformer(&self, emb: Tensor<B, 4>, b: usize) -> Tensor<B, 3> {
65        let d = emb.dims()[3];
66        let seq_len = emb.dims()[1] * emb.dims()[2];
67        let mut x = emb.reshape([b, seq_len, d]);
68
69        let global = self.global_tokens.val().expand([b, self.num_global_tokens, d]);
70
71        for (i, layer) in self.encoder_layers.iter().enumerate() {
72            x = layer.forward(x);
73            if i + 1 == self.global_token_layer {
74                x = Tensor::cat(vec![global.clone(), x], 1);
75            }
76        }
77        x
78    }
79}