#[derive(Debug, Clone, serde::Deserialize)]
pub struct ModelConfig {
#[serde(default = "d64")] pub patch_size: usize,
#[serde(default = "d32")] pub patch_stride: usize,
#[serde(default = "d4")] pub embed_num: usize,
#[serde(default = "d512")] pub embed_dim: usize,
#[serde(default = "d8")] pub depth: usize,
#[serde(default = "d8")] pub num_heads: usize,
#[serde(default = "d4f")] pub mlp_ratio: f64,
#[serde(default = "dtrue")] pub qkv_bias: bool,
#[serde(default = "d62")] pub n_chan_embeddings: usize,
#[serde(default = "d16")] pub probe_hidden_dim: usize,
#[serde(default)] pub n_outputs: usize,
#[serde(default)] pub n_chans: usize,
#[serde(default)] pub n_times: usize,
}
fn d64() -> usize { 64 }
fn d32() -> usize { 32 }
fn d4() -> usize { 4 }
fn d512() -> usize { 512 }
fn d8() -> usize { 8 }
fn d4f() -> f64 { 4.0 }
fn dtrue() -> bool { true }
fn d62() -> usize { 62 }
fn d16() -> usize { 16 }
impl Default for ModelConfig {
fn default() -> Self {
Self {
patch_size: 64, patch_stride: 32, embed_num: 4, embed_dim: 512,
depth: 8, num_heads: 8, mlp_ratio: 4.0, qkv_bias: true,
n_chan_embeddings: 62, probe_hidden_dim: 16,
n_outputs: 4, n_chans: 22, n_times: 1000,
}
}
}