Skip to main content

eegdino_rs/
config.rs

1use serde::{Deserialize, Serialize};
2
3/// Model size variants matching the Python EEG-DINO codebase.
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "lowercase")]
6pub enum ModelSize {
7    Small,
8    Medium,
9    Large,
10}
11
12/// Full model configuration derived from model size.
13///
14/// All values match the Python EEG-DINO defaults exactly.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelConfig {
17    pub model_size: ModelSize,
18    /// Embedding dimension (d_model): 200 / 512 / 1024
19    pub feature_size: usize,
20    /// Number of attention heads: 8 / 16 / 24
21    pub num_heads: usize,
22    /// Number of transformer encoder layers: 12 / 12 / 24
23    pub num_layers: usize,
24    /// Feed-forward hidden dimension: 512 / 1024 / 2048
25    pub dim_feedforward: usize,
26    /// Number of learnable global tokens (default: 1)
27    pub num_global_tokens: usize,
28    /// Layer index (1-based) at which global tokens are injected (default: 1)
29    pub global_token_layer: usize,
30    /// Number of EEG channels (default: 19)
31    pub num_channels: usize,
32    /// Samples per patch (default: 200)
33    pub patch_size: usize,
34    /// Conv channel widths for the 3 conv layers in proj_in: [c1, c2, c3]
35    pub conv_channels: [usize; 3],
36    /// GroupNorm group counts for the 3 norm layers in proj_in
37    pub norm_groups: [usize; 3],
38    /// LayerNorm epsilon
39    pub layer_norm_eps: f64,
40}
41
42impl ModelConfig {
43    pub fn from_size(size: ModelSize) -> Self {
44        match size {
45            ModelSize::Small => Self {
46                model_size: size,
47                feature_size: 200,
48                num_heads: 8,
49                num_layers: 12,
50                dim_feedforward: 512,
51                num_global_tokens: 1,
52                global_token_layer: 1,
53                num_channels: 19,
54                patch_size: 200,
55                conv_channels: [25, 25, 25],
56                norm_groups: [5, 5, 5],
57                layer_norm_eps: 1e-5,
58            },
59            ModelSize::Medium => Self {
60                model_size: size,
61                feature_size: 512,
62                num_heads: 16,
63                num_layers: 16,
64                dim_feedforward: 1024,
65                num_global_tokens: 1,
66                global_token_layer: 1,
67                num_channels: 19,
68                patch_size: 200,
69                conv_channels: [64, 128, 64],
70                norm_groups: [8, 8, 8],
71                layer_norm_eps: 1e-5,
72            },
73            ModelSize::Large => Self {
74                model_size: size,
75                feature_size: 1024,
76                // NOTE: README claims 24 heads, but 1024/24 = 42.67 is non-integer.
77                // Weights confirm all_head_dim=1024, so num_heads must divide 1024.
78                // 16 heads (head_dim=64) matches the weight dimensions.
79                num_heads: 16,
80                num_layers: 24,
81                dim_feedforward: 2048,
82                num_global_tokens: 1,
83                global_token_layer: 1,
84                num_channels: 19,
85                patch_size: 200,
86                conv_channels: [128, 256, 128],
87                norm_groups: [16, 16, 16],
88                layer_norm_eps: 1e-5,
89            },
90        }
91    }
92
93    /// Load config from a JSON file.
94    pub fn from_file(path: &std::path::Path) -> crate::error::Result<Self> {
95        let data = std::fs::read_to_string(path)?;
96        serde_json::from_str(&data).map_err(|e| {
97            crate::error::EegDinoError::WeightLoad(format!("config parse error: {e}"))
98        })
99    }
100
101    /// Head dimension (feature_size / num_heads).
102    pub fn head_dim(&self) -> usize {
103        self.feature_size / self.num_heads
104    }
105
106    /// Number of spectral bins from rfft of patch_size samples: patch_size/2 + 1.
107    pub fn spectral_bins(&self) -> usize {
108        self.patch_size / 2 + 1
109    }
110
111    /// Temporal output dimension of the conv stack: floor((patch_size - 49 + 2*24) / 25) + 1.
112    pub fn temporal_conv_out(&self) -> usize {
113        (self.patch_size - 49 + 2 * 24) / 25 + 1
114    }
115}