Skip to main content

eegdino_rs/model/
transformer.rs

1/// Transformer encoder layer with pre-norm residual connections.
2///
3/// Matches the Python `TransformerEncoderLayer` in `models/transformer.py`.
4///
5/// Architecture (inference path, gamma=None):
6/// ```text
7/// x = x + Attn(LayerNorm(x))
8/// x = x + MLP(LayerNorm(x))
9/// ```
10///
11/// DropPath and gamma scaling are training-only and omitted here.
12use burn::prelude::*;
13use burn::nn::{LayerNorm, LayerNormConfig};
14
15use crate::config::ModelConfig;
16use super::attention::Attention;
17use super::mlp::Mlp;
18
19#[derive(Module, Debug)]
20pub struct TransformerEncoderLayer<B: Backend> {
21    pub norm1: LayerNorm<B>,
22    pub attn: Attention<B>,
23    pub norm2: LayerNorm<B>,
24    pub mlp: Mlp<B>,
25}
26
27impl<B: Backend> TransformerEncoderLayer<B> {
28    pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
29        let d = cfg.feature_size;
30        Self {
31            norm1: LayerNormConfig::new(d)
32                .with_epsilon(cfg.layer_norm_eps)
33                .init(device),
34            attn: Attention::new(d, cfg.num_heads, device),
35            norm2: LayerNormConfig::new(d)
36                .with_epsilon(cfg.layer_norm_eps)
37                .init(device),
38            mlp: Mlp::new(d, cfg.dim_feedforward, device),
39        }
40    }
41
42    /// `x`: `[B, N, D]` -> `[B, N, D]`
43    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
44        // Attention residual (clone is O(1) — Arc increment, not a data copy)
45        let h = x.clone() + self.attn.forward(self.norm1.forward(x));
46        // MLP residual
47        h.clone() + self.mlp.forward(self.norm2.forward(h))
48    }
49}