Skip to main content

eegdino_rs/model/
classifier.rs

1/// Classification model: encoder + two-stage pooling + 3-layer MLP head.
2///
3/// Matches the Python `ClassificationModel` from `run_finetuning.py`.
4///
5/// Pipeline:
6///   encoder(x) → strip global tokens → full_linear → GELU
7///   → reshape → channel pool (mean over C) → channel_linear → GELU
8///   → time pool (mean over P) → classifier MLP → logits
9use burn::prelude::*;
10use burn::nn::Linear;
11
12use crate::config::ModelConfig;
13use super::linear_zeros;
14use super::encoder::EEGEncoder;
15
16#[derive(Module, Debug)]
17pub struct ClassificationModel<B: Backend> {
18    pub encoder: EEGEncoder<B>,
19    pub full_linear: Linear<B>,
20    pub channel_linear: Linear<B>,
21    /// 3-layer classifier: D → D/2 → D/4 → num_classes
22    pub cls_fc1: Linear<B>,
23    pub cls_fc2: Linear<B>,
24    pub cls_fc3: Linear<B>,
25    pub feature_size: usize,
26    pub num_global_tokens: usize,
27}
28
29impl<B: Backend> ClassificationModel<B> {
30    pub fn new(cfg: &ModelConfig, num_classes: usize, device: &B::Device) -> Self {
31        let d = cfg.feature_size;
32        Self {
33            encoder: EEGEncoder::new(cfg, device),
34            full_linear: linear_zeros::<B>(d, d, true, device),
35            channel_linear: linear_zeros::<B>(d, d, true, device),
36            cls_fc1: linear_zeros::<B>(d, d / 2, true, device),
37            cls_fc2: linear_zeros::<B>(d / 2, d / 4, true, device),
38            cls_fc3: linear_zeros::<B>(d / 4, num_classes, true, device),
39            feature_size: d,
40            num_global_tokens: cfg.num_global_tokens,
41        }
42    }
43
44    /// Forward pass.
45    ///
46    /// x: [B, C, P, L] → [B, num_classes]
47    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
48        let [bs, ch, seq_len, _feat] = x.dims();
49        let d = self.feature_size;
50
51        // Encoder: [B, C, P, L] → [B, num_global + C*P, D]
52        let features = self.encoder.forward(x);
53
54        // Strip global tokens: [B, C*P, D]
55        let total_seq = features.dims()[1];
56        let tokens = features.slice([0..bs, self.num_global_tokens..total_seq]);
57
58        // full_linear + GELU on flattened tokens
59        let flat = tokens.reshape([bs * ch * seq_len, d]);
60        let processed = burn::tensor::activation::gelu(self.full_linear.forward(flat));
61
62        // Reshape back: [B, C, P, D]
63        let reshaped = processed.reshape([bs, ch, seq_len, d]);
64
65        // Channel pool: mean over dim=1 (channels) → [B, P, D]
66        let channel_pooled = reshaped.mean_dim(1);
67
68        // channel_linear + GELU on flattened time steps
69        let flat = channel_pooled.reshape([bs * seq_len, d]);
70        let processed = burn::tensor::activation::gelu(self.channel_linear.forward(flat));
71        let processed = processed.reshape([bs, seq_len, d]);
72
73        // Time pool: mean over dim=1 (patches) → [B, D]
74        let time_pooled = processed.mean_dim(1).reshape([bs, d]);
75
76        // Classifier MLP: D → D/2 → D/4 → num_classes
77        // (Dropout is omitted for inference)
78        let h = burn::tensor::activation::gelu(self.cls_fc1.forward(time_pooled));
79        let h = burn::tensor::activation::gelu(self.cls_fc2.forward(h));
80        self.cls_fc3.forward(h)
81    }
82}