use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
#[serde(default = "default_model_name")]
pub model_name: String,
#[serde(default = "default_embed_dim")]
pub embed_dim: usize,
#[serde(default = "default_depth")]
pub depth: usize,
#[serde(default = "default_num_heads")]
pub num_heads: usize,
#[serde(default = "default_mlp_ratio")]
pub mlp_ratio: f64,
#[serde(default = "default_pred_depth")]
pub pred_depth: usize,
#[serde(default = "default_pred_emb_dim")]
pub pred_emb_dim: usize,
#[serde(default = "default_patch_size")]
pub patch_size: usize,
#[serde(default = "default_norm_eps")]
pub norm_eps: f64,
#[serde(default = "default_pos_mode")]
pub pos_mode: String,
#[serde(default = "default_num_latent_tokens")]
pub num_latent_tokens: usize,
#[serde(default = "default_use_cls_token")]
pub use_cls_token: bool,
#[serde(default)]
pub add_pre_mapping: bool,
#[serde(default = "default_grad_dim")]
pub grad_dim: usize,
#[serde(default = "default_geoh_dim")]
pub geoh_dim: usize,
}
fn default_model_name() -> String { "vit_base".into() }
fn default_embed_dim() -> usize { 768 }
fn default_depth() -> usize { 12 }
fn default_num_heads() -> usize { 12 }
fn default_mlp_ratio() -> f64 { 4.0 }
fn default_pred_depth() -> usize { 6 }
fn default_pred_emb_dim() -> usize { 384 }
fn default_patch_size() -> usize { 48 }
fn default_norm_eps() -> f64 { 1e-6 }
fn default_pos_mode() -> String { "gradient_geoh".into() }
fn default_num_latent_tokens() -> usize { 128 }
fn default_use_cls_token() -> bool { false }
fn default_grad_dim() -> usize { 30 }
fn default_geoh_dim() -> usize { 200 }
impl Default for ModelConfig {
fn default() -> Self {
Self {
model_name: default_model_name(),
embed_dim: default_embed_dim(),
depth: default_depth(),
num_heads: default_num_heads(),
mlp_ratio: default_mlp_ratio(),
pred_depth: default_pred_depth(),
pred_emb_dim: default_pred_emb_dim(),
patch_size: default_patch_size(),
norm_eps: default_norm_eps(),
pos_mode: default_pos_mode(),
num_latent_tokens: default_num_latent_tokens(),
use_cls_token: default_use_cls_token(),
add_pre_mapping: false,
grad_dim: default_grad_dim(),
geoh_dim: default_geoh_dim(),
}
}
}
impl ModelConfig {
pub fn from_variant(name: &str) -> crate::error::Result<Self> {
match name {
"vit_small" => Ok(Self {
model_name: "vit_small".into(),
embed_dim: 384,
depth: 12,
num_heads: 6,
add_pre_mapping: true,
..Default::default()
}),
"vit_base" => Ok(Self::default()),
"vit_large" => Ok(Self {
model_name: "vit_large".into(),
embed_dim: 1024,
depth: 24,
num_heads: 16,
add_pre_mapping: true,
..Default::default()
}),
_ => Err(crate::error::BrainHarmonyError::UnknownVariant {
name: name.to_string(),
}),
}
}
pub fn head_dim(&self) -> usize {
self.embed_dim / self.num_heads
}
pub fn mlp_hidden_dim(&self) -> usize {
(self.embed_dim as f64 * self.mlp_ratio) as usize
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct DataConfig {
#[serde(default = "default_n_cortical_rois")]
pub n_cortical_rois: usize,
#[serde(default = "default_n_time_patches")]
pub n_time_patches: usize,
#[serde(default = "default_n_structural_tokens")]
pub n_structural_tokens: usize,
#[serde(default = "default_total_tokens")]
pub total_tokens: usize,
#[serde(default = "default_input_dim")]
pub input_dim: usize,
#[serde(default = "default_signal_size")]
pub signal_size: (usize, usize),
}
fn default_n_cortical_rois() -> usize { 400 }
fn default_n_time_patches() -> usize { 18 }
fn default_n_structural_tokens() -> usize { 1200 }
fn default_total_tokens() -> usize { 8400 } fn default_input_dim() -> usize { 768 }
fn default_signal_size() -> (usize, usize) { (400, 864) }
impl Default for DataConfig {
fn default() -> Self {
Self {
n_cortical_rois: default_n_cortical_rois(),
n_time_patches: default_n_time_patches(),
n_structural_tokens: default_n_structural_tokens(),
total_tokens: default_total_tokens(),
input_dim: default_input_dim(),
signal_size: default_signal_size(),
}
}
}
impl DataConfig {
pub fn n_cortical_tokens(&self) -> usize {
self.n_cortical_rois * self.n_time_patches
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct YamlConfig {
pub data: Option<YamlDataSection>,
pub mask: Option<YamlMaskSection>,
pub meta: Option<YamlMetaSection>,
pub optimization: Option<YamlOptSection>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct YamlDataSection {
pub batch_size: Option<usize>,
pub signal_size: Option<Vec<usize>>,
pub num_workers: Option<usize>,
pub pin_mem: Option<bool>,
pub gradient_csv_path: Option<String>,
pub geo_harm_csv_path: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(non_snake_case)]
pub struct YamlMaskSection {
pub patch_size: Option<usize>,
pub min_keep: Option<usize>,
pub enc_mask_scale: Option<Vec<f64>>,
pub pred_mask_R_scale: Option<Vec<f64>>,
pub pred_mask_T_scale: Option<Vec<f64>>,
pub allow_overlap: Option<bool>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct YamlMetaSection {
pub model_name: Option<String>,
pub pred_depth: Option<usize>,
pub pred_emb_dim: Option<usize>,
pub use_bfloat16: Option<bool>,
pub attn_mode: Option<String>,
pub pos_mode: Option<String>,
pub num_latent_tokens: Option<usize>,
pub add_pre_mapping: Option<bool>,
pub grad_dim: Option<usize>,
pub geoh_dim: Option<usize>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct YamlOptSection {
pub lr: Option<f64>,
pub start_lr: Option<f64>,
pub final_lr: Option<f64>,
pub warmup: Option<usize>,
pub weight_decay: Option<f64>,
pub epochs: Option<usize>,
pub ema: Option<Vec<f64>>,
}
impl YamlConfig {
pub fn from_file(path: &str) -> anyhow::Result<Self> {
let s = std::fs::read_to_string(path)?;
Ok(serde_yaml::from_str(&s)?)
}
pub fn to_model_config(&self) -> crate::error::Result<ModelConfig> {
let mut cfg = if let Some(meta) = &self.meta {
let name = meta.model_name.as_deref().unwrap_or("vit_base");
let mut c = ModelConfig::from_variant(name)?;
if let Some(d) = meta.pred_depth { c.pred_depth = d; }
if let Some(d) = meta.pred_emb_dim { c.pred_emb_dim = d; }
if let Some(ref m) = meta.pos_mode { c.pos_mode = m.clone(); }
if let Some(n) = meta.num_latent_tokens { c.num_latent_tokens = n; }
if let Some(b) = meta.add_pre_mapping { c.add_pre_mapping = b; }
if let Some(d) = meta.grad_dim { c.grad_dim = d; }
if let Some(d) = meta.geoh_dim { c.geoh_dim = d; }
c
} else {
ModelConfig::default()
};
if let Some(mask) = &self.mask {
if let Some(ps) = mask.patch_size { cfg.patch_size = ps; }
}
Ok(cfg)
}
pub fn to_data_config(&self) -> DataConfig {
let mut cfg = DataConfig::default();
if let Some(data) = &self.data {
if let Some(ref ss) = data.signal_size {
if ss.len() == 2 {
cfg.signal_size = (ss[0], ss[1]);
}
}
}
cfg
}
}