eegdino_rs/model/
transformer.rs1use burn::prelude::*;
13use burn::nn::{LayerNorm, LayerNormConfig};
14
15use crate::config::ModelConfig;
16use super::attention::Attention;
17use super::mlp::Mlp;
18
19#[derive(Module, Debug)]
20pub struct TransformerEncoderLayer<B: Backend> {
21 pub norm1: LayerNorm<B>,
22 pub attn: Attention<B>,
23 pub norm2: LayerNorm<B>,
24 pub mlp: Mlp<B>,
25}
26
27impl<B: Backend> TransformerEncoderLayer<B> {
28 pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
29 let d = cfg.feature_size;
30 Self {
31 norm1: LayerNormConfig::new(d)
32 .with_epsilon(cfg.layer_norm_eps)
33 .init(device),
34 attn: Attention::new(d, cfg.num_heads, device),
35 norm2: LayerNormConfig::new(d)
36 .with_epsilon(cfg.layer_norm_eps)
37 .init(device),
38 mlp: Mlp::new(d, cfg.dim_feedforward, device),
39 }
40 }
41
42 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
44 let h = x.clone() + self.attn.forward(self.norm1.forward(x));
46 h.clone() + self.mlp.forward(self.norm2.forward(h))
48 }
49}