1#[derive(Debug, Clone, serde::Deserialize)]
4pub struct ModelConfig {
5 #[serde(default = "default_patch_size")]
6 pub patch_size: usize,
7 #[serde(default = "default_dim_feedforward")]
8 pub dim_feedforward: usize,
9 #[serde(default = "default_n_layer")]
10 pub n_layer: usize,
11 #[serde(default = "default_nhead")]
12 pub nhead: usize,
13 #[serde(default = "default_emb_dim")]
14 pub emb_dim: usize,
15 #[serde(default)]
16 pub n_outputs: usize,
17 #[serde(default)]
18 pub n_chans: usize,
19 #[serde(default)]
20 pub n_times: usize,
21}
22
23fn default_patch_size() -> usize { 200 }
24fn default_dim_feedforward() -> usize { 800 }
25fn default_n_layer() -> usize { 12 }
26fn default_nhead() -> usize { 8 }
27fn default_emb_dim() -> usize { 200 }
28
29impl Default for ModelConfig {
30 fn default() -> Self {
31 Self {
32 patch_size: 200, dim_feedforward: 800, n_layer: 12, nhead: 8,
33 emb_dim: 200, n_outputs: 4, n_chans: 22, n_times: 1000,
34 }
35 }
36}