eegdino_rs/model/
classifier.rs1use burn::prelude::*;
10use burn::nn::Linear;
11
12use crate::config::ModelConfig;
13use super::linear_zeros;
14use super::encoder::EEGEncoder;
15
16#[derive(Module, Debug)]
17pub struct ClassificationModel<B: Backend> {
18 pub encoder: EEGEncoder<B>,
19 pub full_linear: Linear<B>,
20 pub channel_linear: Linear<B>,
21 pub cls_fc1: Linear<B>,
23 pub cls_fc2: Linear<B>,
24 pub cls_fc3: Linear<B>,
25 pub feature_size: usize,
26 pub num_global_tokens: usize,
27}
28
29impl<B: Backend> ClassificationModel<B> {
30 pub fn new(cfg: &ModelConfig, num_classes: usize, device: &B::Device) -> Self {
31 let d = cfg.feature_size;
32 Self {
33 encoder: EEGEncoder::new(cfg, device),
34 full_linear: linear_zeros::<B>(d, d, true, device),
35 channel_linear: linear_zeros::<B>(d, d, true, device),
36 cls_fc1: linear_zeros::<B>(d, d / 2, true, device),
37 cls_fc2: linear_zeros::<B>(d / 2, d / 4, true, device),
38 cls_fc3: linear_zeros::<B>(d / 4, num_classes, true, device),
39 feature_size: d,
40 num_global_tokens: cfg.num_global_tokens,
41 }
42 }
43
44 pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
48 let [bs, ch, seq_len, _feat] = x.dims();
49 let d = self.feature_size;
50
51 let features = self.encoder.forward(x);
53
54 let total_seq = features.dims()[1];
56 let tokens = features.slice([0..bs, self.num_global_tokens..total_seq]);
57
58 let flat = tokens.reshape([bs * ch * seq_len, d]);
60 let processed = burn::tensor::activation::gelu(self.full_linear.forward(flat));
61
62 let reshaped = processed.reshape([bs, ch, seq_len, d]);
64
65 let channel_pooled = reshaped.mean_dim(1);
67
68 let flat = channel_pooled.reshape([bs * seq_len, d]);
70 let processed = burn::tensor::activation::gelu(self.channel_linear.forward(flat));
71 let processed = processed.reshape([bs, seq_len, d]);
72
73 let time_pooled = processed.mean_dim(1).reshape([bs, d]);
75
76 let h = burn::tensor::activation::gelu(self.cls_fc1.forward(time_pooled));
79 let h = burn::tensor::activation::gelu(self.cls_fc2.forward(h));
80 self.cls_fc3.forward(h)
81 }
82}