#[derive(Debug, Clone, serde::Deserialize)]
pub struct ModelConfig {
#[serde(default = "default_embed_dim")]
pub embed_dim: usize,
#[serde(default = "default_depth")]
pub depth: usize,
#[serde(default = "default_heads")]
pub heads: usize,
#[serde(default = "default_head_dim")]
pub head_dim: usize,
#[serde(default = "default_mlp_dim_ratio")]
pub mlp_dim_ratio: f64,
#[serde(default = "default_use_geglu")]
pub use_geglu: bool,
#[serde(default = "default_freqs")]
pub freqs: usize,
#[serde(default = "default_patch_size")]
pub patch_size: usize,
#[serde(default = "default_patch_overlap")]
pub patch_overlap: usize,
#[serde(default)]
pub attention_pooling: bool,
#[serde(default)]
pub n_outputs: usize,
#[serde(default)]
pub n_chans: usize,
#[serde(default)]
pub n_times: usize,
}
fn default_embed_dim() -> usize { 512 }
fn default_depth() -> usize { 22 }
fn default_heads() -> usize { 8 }
fn default_head_dim() -> usize { 64 }
fn default_mlp_dim_ratio() -> f64 { 2.66 }
fn default_use_geglu() -> bool { true }
fn default_freqs() -> usize { 4 }
fn default_patch_size() -> usize { 200 }
fn default_patch_overlap() -> usize { 20 }
impl Default for ModelConfig {
fn default() -> Self {
Self {
embed_dim: default_embed_dim(),
depth: default_depth(),
heads: default_heads(),
head_dim: default_head_dim(),
mlp_dim_ratio: default_mlp_dim_ratio(),
use_geglu: default_use_geglu(),
freqs: default_freqs(),
patch_size: default_patch_size(),
patch_overlap: default_patch_overlap(),
attention_pooling: false,
n_outputs: 4,
n_chans: 22,
n_times: 1000,
}
}
}
impl ModelConfig {
pub fn inner_dim(&self) -> usize {
self.head_dim * self.heads
}
pub fn mlp_dim(&self) -> usize {
(self.embed_dim as f64 * self.mlp_dim_ratio) as usize
}
pub fn ffn_in_features(&self) -> usize {
let mlp = self.mlp_dim();
if self.use_geglu { mlp * 2 } else { mlp }
}
}