Skip to main content

eegpt_rs/
config.rs

1/// Model configuration for EEGPT.
2
3#[derive(Debug, Clone, serde::Deserialize)]
4pub struct ModelConfig {
5    #[serde(default = "d64")]  pub patch_size: usize,
6    #[serde(default = "d32")]  pub patch_stride: usize,
7    #[serde(default = "d4")]   pub embed_num: usize,
8    #[serde(default = "d512")] pub embed_dim: usize,
9    #[serde(default = "d8")]   pub depth: usize,
10    #[serde(default = "d8")]   pub num_heads: usize,
11    #[serde(default = "d4f")]  pub mlp_ratio: f64,
12    #[serde(default = "dtrue")] pub qkv_bias: bool,
13    #[serde(default = "d62")]  pub n_chan_embeddings: usize,
14    #[serde(default = "d16")]  pub probe_hidden_dim: usize,
15    #[serde(default)]          pub n_outputs: usize,
16    #[serde(default)]          pub n_chans: usize,
17    #[serde(default)]          pub n_times: usize,
18}
19
20fn d64() -> usize { 64 }
21fn d32() -> usize { 32 }
22fn d4() -> usize { 4 }
23fn d512() -> usize { 512 }
24fn d8() -> usize { 8 }
25fn d4f() -> f64 { 4.0 }
26fn dtrue() -> bool { true }
27fn d62() -> usize { 62 }
28fn d16() -> usize { 16 }
29
30impl Default for ModelConfig {
31    fn default() -> Self {
32        Self {
33            patch_size: 64, patch_stride: 32, embed_num: 4, embed_dim: 512,
34            depth: 8, num_heads: 8, mlp_ratio: 4.0, qkv_bias: true,
35            n_chan_embeddings: 62, probe_hidden_dim: 16,
36            n_outputs: 4, n_chans: 22, n_times: 1000,
37        }
38    }
39}