1#[derive(Debug, Clone, serde::Deserialize)]
4pub struct ModelConfig {
5 #[serde(default = "d64")] pub patch_size: usize,
6 #[serde(default = "d32")] pub patch_stride: usize,
7 #[serde(default = "d4")] pub embed_num: usize,
8 #[serde(default = "d512")] pub embed_dim: usize,
9 #[serde(default = "d8")] pub depth: usize,
10 #[serde(default = "d8")] pub num_heads: usize,
11 #[serde(default = "d4f")] pub mlp_ratio: f64,
12 #[serde(default = "dtrue")] pub qkv_bias: bool,
13 #[serde(default = "d62")] pub n_chan_embeddings: usize,
14 #[serde(default = "d16")] pub probe_hidden_dim: usize,
15 #[serde(default)] pub n_outputs: usize,
16 #[serde(default)] pub n_chans: usize,
17 #[serde(default)] pub n_times: usize,
18}
19
20fn d64() -> usize { 64 }
21fn d32() -> usize { 32 }
22fn d4() -> usize { 4 }
23fn d512() -> usize { 512 }
24fn d8() -> usize { 8 }
25fn d4f() -> f64 { 4.0 }
26fn dtrue() -> bool { true }
27fn d62() -> usize { 62 }
28fn d16() -> usize { 16 }
29
30impl Default for ModelConfig {
31 fn default() -> Self {
32 Self {
33 patch_size: 64, patch_stride: 32, embed_num: 4, embed_dim: 512,
34 depth: 8, num_heads: 8, mlp_ratio: 4.0, qkv_bias: true,
35 n_chan_embeddings: 62, probe_hidden_dim: 16,
36 n_outputs: 4, n_chans: 22, n_times: 1000,
37 }
38 }
39}