1#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
9pub struct ModelConfig {
10 #[serde(default = "default_encoder_name")]
12 pub encoder_name: String,
13
14 #[serde(default = "default_num_leads")]
16 pub num_leads: usize,
17
18 #[serde(default = "default_patch_size_time")]
20 pub patch_size_time: usize,
21
22 #[serde(default = "default_patch_size_ch")]
24 pub patch_size_ch: usize,
25
26 #[serde(default = "default_lead_wise")]
28 pub lead_wise: usize,
29
30 #[serde(default = "default_sample_rate")]
32 pub sample_rate: usize,
33
34 #[serde(default = "default_window_size_sec")]
36 pub window_size_sec: usize,
37
38 #[serde(default = "default_seq_len")]
40 pub seq_len: usize,
41
42 #[serde(default = "default_width")]
44 pub width: usize,
45
46 #[serde(default = "default_depth")]
48 pub depth: usize,
49
50 #[serde(default = "default_heads")]
52 pub heads: usize,
53
54 #[serde(default = "default_mlp_dim")]
56 pub mlp_dim: usize,
57
58 #[serde(default = "default_dim_head")]
60 pub dim_head: usize,
61}
62
63fn default_encoder_name() -> String { "vit_base".to_string() }
64fn default_num_leads() -> usize { 12 }
65fn default_patch_size_time() -> usize { 64 }
66fn default_patch_size_ch() -> usize { 4 }
67fn default_lead_wise() -> usize { 1 }
68fn default_sample_rate() -> usize { 64 }
69fn default_window_size_sec() -> usize { 30 }
70fn default_seq_len() -> usize { 1920 }
71fn default_width() -> usize { 768 }
72fn default_depth() -> usize { 12 }
73fn default_heads() -> usize { 12 }
74fn default_mlp_dim() -> usize { 3072 }
75fn default_dim_head() -> usize { 64 }
76
77impl Default for ModelConfig {
78 fn default() -> Self {
79 Self {
80 encoder_name: default_encoder_name(),
81 num_leads: default_num_leads(),
82 patch_size_time: default_patch_size_time(),
83 patch_size_ch: default_patch_size_ch(),
84 lead_wise: default_lead_wise(),
85 sample_rate: default_sample_rate(),
86 window_size_sec: default_window_size_sec(),
87 seq_len: default_seq_len(),
88 width: default_width(),
89 depth: default_depth(),
90 heads: default_heads(),
91 mlp_dim: default_mlp_dim(),
92 dim_head: default_dim_head(),
93 }
94 }
95}
96
97impl ModelConfig {
98 pub fn num_patches_time(&self) -> usize {
100 self.seq_len / self.patch_size_time
101 }
102
103 pub fn num_lead_rows(&self) -> usize {
105 if self.lead_wise == 0 { 1 } else { self.num_leads / self.patch_size_ch }
106 }
107
108 pub fn num_patches(&self) -> usize {
110 self.num_lead_rows() * self.num_patches_time()
111 }
112
113 pub fn for_variant(name: &str) -> Self {
115 match name {
116 "vit_nano" => Self {
117 encoder_name: "vit_nano".into(),
118 width: 128, depth: 6, heads: 4, mlp_dim: 512, dim_head: 32,
119 ..Default::default()
120 },
121 "vit_tiny" => Self {
122 encoder_name: "vit_tiny".into(),
123 width: 192, depth: 12, heads: 3, mlp_dim: 768, dim_head: 64,
124 ..Default::default()
125 },
126 "vit_small" => Self {
127 encoder_name: "vit_small".into(),
128 width: 384, depth: 12, heads: 6, mlp_dim: 1536, dim_head: 64,
129 ..Default::default()
130 },
131 "vit_middle" => Self {
132 encoder_name: "vit_middle".into(),
133 width: 512, depth: 12, heads: 8, mlp_dim: 2048, dim_head: 64,
134 ..Default::default()
135 },
136 "vit_base" | _ => Self::default(),
137 }
138 }
139}
140
141pub const PSG_CHANNELS: &[&str] = &[
145 "ECG",
146 "EMG_Chin",
147 "EMG_LLeg",
148 "EMG_RLeg",
149 "ABD",
150 "THX",
151 "NP",
152 "SN",
153 "EOG_E1_A2",
154 "EOG_E2_A1",
155 "EEG_C3_A2",
156 "EEG_C4_A1",
157];
158
159pub const NUM_PSG_CHANNELS: usize = 12;
161
162pub const SAMPLE_RATE: usize = 64;
164
165pub const EPOCH_SEC: usize = 30;
167
168pub const EPOCH_SAMPLES: usize = SAMPLE_RATE * EPOCH_SEC;