eegdino_rs/model/
encoder.rs1use burn::prelude::*;
10use burn::module::{Param, ParamId};
11
12use crate::config::ModelConfig;
13use super::embedding::{EmbeddingCache, PatchEmbedding};
14use super::transformer::TransformerEncoderLayer;
15
16#[derive(Module, Debug)]
17pub struct EEGEncoder<B: Backend> {
18 pub patch_embedding: PatchEmbedding<B>,
19 pub encoder_layers: Vec<TransformerEncoderLayer<B>>,
20 pub global_tokens: Param<Tensor<B, 3>>,
22 pub global_token_layer: usize,
23 pub num_global_tokens: usize,
24}
25
26impl<B: Backend> EEGEncoder<B> {
27 pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
28 let layers: Vec<_> = (0..cfg.num_layers)
29 .map(|_| TransformerEncoderLayer::new(cfg, device))
30 .collect();
31
32 let global_tokens = Param::initialized(
33 ParamId::new(),
34 Tensor::zeros([1, cfg.num_global_tokens, cfg.feature_size], device),
35 );
36
37 Self {
38 patch_embedding: PatchEmbedding::new(cfg, device),
39 encoder_layers: layers,
40 global_tokens,
41 global_token_layer: cfg.global_token_layer,
42 num_global_tokens: cfg.num_global_tokens,
43 }
44 }
45
46 pub fn forward_cached(&self, x_in: Tensor<B, 4>, cache: &EmbeddingCache<B>) -> Tensor<B, 3> {
50 let [b, _c, _p, _l] = x_in.dims();
51 let x = self.patch_embedding.forward_cached(x_in, cache);
52 self.run_transformer(x, b)
53 }
54
55 pub fn forward(&self, x_in: Tensor<B, 4>) -> Tensor<B, 3> {
59 let [b, _c, _p, _l] = x_in.dims();
60 let x = self.patch_embedding.forward(x_in);
61 self.run_transformer(x, b)
62 }
63
64 fn run_transformer(&self, emb: Tensor<B, 4>, b: usize) -> Tensor<B, 3> {
65 let d = emb.dims()[3];
66 let seq_len = emb.dims()[1] * emb.dims()[2];
67 let mut x = emb.reshape([b, seq_len, d]);
68
69 let global = self.global_tokens.val().expand([b, self.num_global_tokens, d]);
70
71 for (i, layer) in self.encoder_layers.iter().enumerate() {
72 x = layer.forward(x);
73 if i + 1 == self.global_token_layer {
74 x = Tensor::cat(vec![global.clone(), x], 1);
75 }
76 }
77 x
78 }
79}