1#[derive(Debug, Clone, serde::Deserialize)]
4pub struct ModelConfig {
5 #[serde(default = "d200")] pub patch_size: usize,
6 #[serde(default = "d200")] pub embed_dim: usize,
7 #[serde(default = "d12")] pub num_layers: usize,
8 #[serde(default = "d10")] pub num_heads: usize,
9 #[serde(default = "d4f")] pub mlp_ratio: f64,
10 #[serde(default = "d8")] pub conv_out_channels: usize,
11 #[serde(default = "d128")] pub n_pos_embeddings: usize,
12 #[serde(default)] pub n_outputs: usize,
13 #[serde(default)] pub n_chans: usize,
14 #[serde(default)] pub n_times: usize,
15}
16
17fn d200() -> usize { 200 }
18fn d12() -> usize { 12 }
19fn d10() -> usize { 10 }
20fn d4f() -> f64 { 4.0 }
21fn d8() -> usize { 8 }
22fn d128() -> usize { 128 }
23
24impl Default for ModelConfig {
25 fn default() -> Self {
26 Self {
27 patch_size: 200, embed_dim: 200, num_layers: 12, num_heads: 10,
28 mlp_ratio: 4.0, conv_out_channels: 8, n_pos_embeddings: 128,
29 n_outputs: 4, n_chans: 64, n_times: 1600,
30 }
31 }
32}