#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct ModelConfig {
#[serde(default = "default_encoder_name")]
pub encoder_name: String,
#[serde(default = "default_num_leads")]
pub num_leads: usize,
#[serde(default = "default_patch_size_time")]
pub patch_size_time: usize,
#[serde(default = "default_patch_size_ch")]
pub patch_size_ch: usize,
#[serde(default = "default_lead_wise")]
pub lead_wise: usize,
#[serde(default = "default_sample_rate")]
pub sample_rate: usize,
#[serde(default = "default_window_size_sec")]
pub window_size_sec: usize,
#[serde(default = "default_seq_len")]
pub seq_len: usize,
#[serde(default = "default_width")]
pub width: usize,
#[serde(default = "default_depth")]
pub depth: usize,
#[serde(default = "default_heads")]
pub heads: usize,
#[serde(default = "default_mlp_dim")]
pub mlp_dim: usize,
#[serde(default = "default_dim_head")]
pub dim_head: usize,
}
fn default_encoder_name() -> String { "vit_base".to_string() }
fn default_num_leads() -> usize { 12 }
fn default_patch_size_time() -> usize { 64 }
fn default_patch_size_ch() -> usize { 4 }
fn default_lead_wise() -> usize { 1 }
fn default_sample_rate() -> usize { 64 }
fn default_window_size_sec() -> usize { 30 }
fn default_seq_len() -> usize { 1920 }
fn default_width() -> usize { 768 }
fn default_depth() -> usize { 12 }
fn default_heads() -> usize { 12 }
fn default_mlp_dim() -> usize { 3072 }
fn default_dim_head() -> usize { 64 }
impl Default for ModelConfig {
fn default() -> Self {
Self {
encoder_name: default_encoder_name(),
num_leads: default_num_leads(),
patch_size_time: default_patch_size_time(),
patch_size_ch: default_patch_size_ch(),
lead_wise: default_lead_wise(),
sample_rate: default_sample_rate(),
window_size_sec: default_window_size_sec(),
seq_len: default_seq_len(),
width: default_width(),
depth: default_depth(),
heads: default_heads(),
mlp_dim: default_mlp_dim(),
dim_head: default_dim_head(),
}
}
}
impl ModelConfig {
pub fn num_patches_time(&self) -> usize {
self.seq_len / self.patch_size_time
}
pub fn num_lead_rows(&self) -> usize {
if self.lead_wise == 0 { 1 } else { self.num_leads / self.patch_size_ch }
}
pub fn num_patches(&self) -> usize {
self.num_lead_rows() * self.num_patches_time()
}
pub fn for_variant(name: &str) -> Self {
match name {
"vit_nano" => Self {
encoder_name: "vit_nano".into(),
width: 128, depth: 6, heads: 4, mlp_dim: 512, dim_head: 32,
..Default::default()
},
"vit_tiny" => Self {
encoder_name: "vit_tiny".into(),
width: 192, depth: 12, heads: 3, mlp_dim: 768, dim_head: 64,
..Default::default()
},
"vit_small" => Self {
encoder_name: "vit_small".into(),
width: 384, depth: 12, heads: 6, mlp_dim: 1536, dim_head: 64,
..Default::default()
},
"vit_middle" => Self {
encoder_name: "vit_middle".into(),
width: 512, depth: 12, heads: 8, mlp_dim: 2048, dim_head: 64,
..Default::default()
},
"vit_base" | _ => Self::default(),
}
}
}
pub const PSG_CHANNELS: &[&str] = &[
"ECG",
"EMG_Chin",
"EMG_LLeg",
"EMG_RLeg",
"ABD",
"THX",
"NP",
"SN",
"EOG_E1_A2",
"EOG_E2_A1",
"EEG_C3_A2",
"EEG_C4_A1",
];
pub const NUM_PSG_CHANNELS: usize = 12;
pub const SAMPLE_RATE: usize = 64;
pub const EPOCH_SEC: usize = 30;
pub const EPOCH_SAMPLES: usize = SAMPLE_RATE * EPOCH_SEC;