1#[derive(Debug, Clone, serde::Deserialize)]
6pub struct ModelConfig {
7 #[serde(default = "default_embed_dim")]
9 pub embed_dim: usize,
10
11 #[serde(default = "default_depth")]
13 pub depth: usize,
14
15 #[serde(default = "default_heads")]
17 pub heads: usize,
18
19 #[serde(default = "default_head_dim")]
21 pub head_dim: usize,
22
23 #[serde(default = "default_mlp_dim_ratio")]
25 pub mlp_dim_ratio: f64,
26
27 #[serde(default = "default_use_geglu")]
29 pub use_geglu: bool,
30
31 #[serde(default = "default_freqs")]
33 pub freqs: usize,
34
35 #[serde(default = "default_patch_size")]
37 pub patch_size: usize,
38
39 #[serde(default = "default_patch_overlap")]
41 pub patch_overlap: usize,
42
43 #[serde(default)]
45 pub attention_pooling: bool,
46
47 #[serde(default)]
49 pub n_outputs: usize,
50
51 #[serde(default)]
53 pub n_chans: usize,
54
55 #[serde(default)]
57 pub n_times: usize,
58}
59
60fn default_embed_dim() -> usize { 512 }
61fn default_depth() -> usize { 22 }
62fn default_heads() -> usize { 8 }
63fn default_head_dim() -> usize { 64 }
64fn default_mlp_dim_ratio() -> f64 { 2.66 }
65fn default_use_geglu() -> bool { true }
66fn default_freqs() -> usize { 4 }
67fn default_patch_size() -> usize { 200 }
68fn default_patch_overlap() -> usize { 20 }
69
70impl Default for ModelConfig {
71 fn default() -> Self {
72 Self {
73 embed_dim: default_embed_dim(),
74 depth: default_depth(),
75 heads: default_heads(),
76 head_dim: default_head_dim(),
77 mlp_dim_ratio: default_mlp_dim_ratio(),
78 use_geglu: default_use_geglu(),
79 freqs: default_freqs(),
80 patch_size: default_patch_size(),
81 patch_overlap: default_patch_overlap(),
82 attention_pooling: false,
83 n_outputs: 4,
84 n_chans: 22,
85 n_times: 1000,
86 }
87 }
88}
89
90impl ModelConfig {
91 pub fn inner_dim(&self) -> usize {
93 self.head_dim * self.heads
94 }
95
96 pub fn mlp_dim(&self) -> usize {
98 (self.embed_dim as f64 * self.mlp_dim_ratio) as usize
99 }
100
101 pub fn ffn_in_features(&self) -> usize {
103 let mlp = self.mlp_dim();
104 if self.use_geglu { mlp * 2 } else { mlp }
105 }
106}