brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Model and runtime configuration for Brain-Harmony inference.
///
/// Configuration can be loaded from YAML files matching the Brain-Harmony Python format,
/// or constructed programmatically.
use serde::Deserialize;

// -- ModelConfig ------------------------------------------------------------------

/// Architecture hyperparameters for the Vision Transformer.
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
    /// Model variant name: "vit_small", "vit_base", or "vit_large"
    #[serde(default = "default_model_name")]
    pub model_name: String,

    /// Embedding dimension (384 / 768 / 1024)
    #[serde(default = "default_embed_dim")]
    pub embed_dim: usize,

    /// Number of transformer layers in the encoder
    #[serde(default = "default_depth")]
    pub depth: usize,

    /// Number of attention heads
    #[serde(default = "default_num_heads")]
    pub num_heads: usize,

    /// MLP hidden dim = embed_dim * mlp_ratio
    #[serde(default = "default_mlp_ratio")]
    pub mlp_ratio: f64,

    /// Predictor depth (transformer layers)
    #[serde(default = "default_pred_depth")]
    pub pred_depth: usize,

    /// Predictor embedding dimension
    #[serde(default = "default_pred_emb_dim")]
    pub pred_emb_dim: usize,

    /// Temporal patch size
    #[serde(default = "default_patch_size")]
    pub patch_size: usize,

    /// LayerNorm epsilon
    #[serde(default = "default_norm_eps")]
    pub norm_eps: f64,

    /// Positional embedding mode: "gradient_geoh" or "sincos"
    #[serde(default = "default_pos_mode")]
    pub pos_mode: String,

    /// Number of latent tokens appended to the input
    #[serde(default = "default_num_latent_tokens")]
    pub num_latent_tokens: usize,

    /// Whether to use a CLS token
    #[serde(default = "default_use_cls_token")]
    pub use_cls_token: bool,

    /// Whether to add a pre-mapping linear projection (768 -> embed_dim)
    #[serde(default)]
    pub add_pre_mapping: bool,

    /// Gradient dimension from CSV
    #[serde(default = "default_grad_dim")]
    pub grad_dim: usize,

    /// Geometric harmonics dimension from CSV
    #[serde(default = "default_geoh_dim")]
    pub geoh_dim: usize,
}

fn default_model_name() -> String { "vit_base".into() }
fn default_embed_dim() -> usize { 768 }
fn default_depth() -> usize { 12 }
fn default_num_heads() -> usize { 12 }
fn default_mlp_ratio() -> f64 { 4.0 }
fn default_pred_depth() -> usize { 6 }
fn default_pred_emb_dim() -> usize { 384 }
fn default_patch_size() -> usize { 48 }
fn default_norm_eps() -> f64 { 1e-6 }
fn default_pos_mode() -> String { "gradient_geoh".into() }
fn default_num_latent_tokens() -> usize { 128 }
fn default_use_cls_token() -> bool { false }
fn default_grad_dim() -> usize { 30 }
fn default_geoh_dim() -> usize { 200 }

impl Default for ModelConfig {
    fn default() -> Self {
        Self {
            model_name: default_model_name(),
            embed_dim: default_embed_dim(),
            depth: default_depth(),
            num_heads: default_num_heads(),
            mlp_ratio: default_mlp_ratio(),
            pred_depth: default_pred_depth(),
            pred_emb_dim: default_pred_emb_dim(),
            patch_size: default_patch_size(),
            norm_eps: default_norm_eps(),
            pos_mode: default_pos_mode(),
            num_latent_tokens: default_num_latent_tokens(),
            use_cls_token: default_use_cls_token(),
            add_pre_mapping: false,
            grad_dim: default_grad_dim(),
            geoh_dim: default_geoh_dim(),
        }
    }
}

impl ModelConfig {
    /// Construct a config for one of the standard ViT variants.
    pub fn from_variant(name: &str) -> crate::error::Result<Self> {
        match name {
            "vit_small" => Ok(Self {
                model_name: "vit_small".into(),
                embed_dim: 384,
                depth: 12,
                num_heads: 6,
                add_pre_mapping: true,
                ..Default::default()
            }),
            "vit_base" => Ok(Self::default()),
            "vit_large" => Ok(Self {
                model_name: "vit_large".into(),
                embed_dim: 1024,
                depth: 24,
                num_heads: 16,
                add_pre_mapping: true,
                ..Default::default()
            }),
            _ => Err(crate::error::BrainHarmonyError::UnknownVariant {
                name: name.to_string(),
            }),
        }
    }

    /// Head dimension (embed_dim / num_heads).
    pub fn head_dim(&self) -> usize {
        self.embed_dim / self.num_heads
    }

    /// MLP hidden dimension.
    pub fn mlp_hidden_dim(&self) -> usize {
        (self.embed_dim as f64 * self.mlp_ratio) as usize
    }
}

// -- DataConfig -------------------------------------------------------------------

/// Brain data parameters.
#[derive(Debug, Clone, Deserialize)]
pub struct DataConfig {
    /// Number of cortical ROIs
    #[serde(default = "default_n_cortical_rois")]
    pub n_cortical_rois: usize,

    /// Number of temporal patches per ROI
    #[serde(default = "default_n_time_patches")]
    pub n_time_patches: usize,

    /// Number of structural (T1) tokens
    #[serde(default = "default_n_structural_tokens")]
    pub n_structural_tokens: usize,

    /// Total number of tokens (cortical + structural)
    #[serde(default = "default_total_tokens")]
    pub total_tokens: usize,

    /// Input token embedding dimension (before any pre-mapping)
    #[serde(default = "default_input_dim")]
    pub input_dim: usize,

    /// Input signal size (n_rois, signal_length) for raw signal mode
    #[serde(default = "default_signal_size")]
    pub signal_size: (usize, usize),
}

fn default_n_cortical_rois() -> usize { 400 }
fn default_n_time_patches() -> usize { 18 }
fn default_n_structural_tokens() -> usize { 1200 }
fn default_total_tokens() -> usize { 8400 } // 400*18 + 1200
fn default_input_dim() -> usize { 768 }
fn default_signal_size() -> (usize, usize) { (400, 864) } // (400, 18*48)

impl Default for DataConfig {
    fn default() -> Self {
        Self {
            n_cortical_rois: default_n_cortical_rois(),
            n_time_patches: default_n_time_patches(),
            n_structural_tokens: default_n_structural_tokens(),
            total_tokens: default_total_tokens(),
            input_dim: default_input_dim(),
            signal_size: default_signal_size(),
        }
    }
}

impl DataConfig {
    /// Number of cortical fMRI tokens (n_cortical_rois * n_time_patches).
    pub fn n_cortical_tokens(&self) -> usize {
        self.n_cortical_rois * self.n_time_patches
    }
}

// -- Full YAML config -------------------------------------------------------------

/// Top-level YAML config matching Brain-Harmony's format.
#[derive(Debug, Clone, Deserialize)]
pub struct YamlConfig {
    pub data: Option<YamlDataSection>,
    pub mask: Option<YamlMaskSection>,
    pub meta: Option<YamlMetaSection>,
    pub optimization: Option<YamlOptSection>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct YamlDataSection {
    pub batch_size: Option<usize>,
    pub signal_size: Option<Vec<usize>>,
    pub num_workers: Option<usize>,
    pub pin_mem: Option<bool>,
    pub gradient_csv_path: Option<String>,
    pub geo_harm_csv_path: Option<String>,
}

#[derive(Debug, Clone, Deserialize)]
#[allow(non_snake_case)]
pub struct YamlMaskSection {
    pub patch_size: Option<usize>,
    pub min_keep: Option<usize>,
    pub enc_mask_scale: Option<Vec<f64>>,
    pub pred_mask_R_scale: Option<Vec<f64>>,
    pub pred_mask_T_scale: Option<Vec<f64>>,
    pub allow_overlap: Option<bool>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct YamlMetaSection {
    pub model_name: Option<String>,
    pub pred_depth: Option<usize>,
    pub pred_emb_dim: Option<usize>,
    pub use_bfloat16: Option<bool>,
    pub attn_mode: Option<String>,
    pub pos_mode: Option<String>,
    pub num_latent_tokens: Option<usize>,
    pub add_pre_mapping: Option<bool>,
    pub grad_dim: Option<usize>,
    pub geoh_dim: Option<usize>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct YamlOptSection {
    pub lr: Option<f64>,
    pub start_lr: Option<f64>,
    pub final_lr: Option<f64>,
    pub warmup: Option<usize>,
    pub weight_decay: Option<f64>,
    pub epochs: Option<usize>,
    pub ema: Option<Vec<f64>>,
}

impl YamlConfig {
    /// Parse from a YAML file.
    pub fn from_file(path: &str) -> anyhow::Result<Self> {
        let s = std::fs::read_to_string(path)?;
        Ok(serde_yaml::from_str(&s)?)
    }

    /// Extract ModelConfig from YAML sections.
    pub fn to_model_config(&self) -> crate::error::Result<ModelConfig> {
        let mut cfg = if let Some(meta) = &self.meta {
            let name = meta.model_name.as_deref().unwrap_or("vit_base");
            let mut c = ModelConfig::from_variant(name)?;
            if let Some(d) = meta.pred_depth { c.pred_depth = d; }
            if let Some(d) = meta.pred_emb_dim { c.pred_emb_dim = d; }
            if let Some(ref m) = meta.pos_mode { c.pos_mode = m.clone(); }
            if let Some(n) = meta.num_latent_tokens { c.num_latent_tokens = n; }
            if let Some(b) = meta.add_pre_mapping { c.add_pre_mapping = b; }
            if let Some(d) = meta.grad_dim { c.grad_dim = d; }
            if let Some(d) = meta.geoh_dim { c.geoh_dim = d; }
            c
        } else {
            ModelConfig::default()
        };
        if let Some(mask) = &self.mask {
            if let Some(ps) = mask.patch_size { cfg.patch_size = ps; }
        }
        Ok(cfg)
    }

    /// Extract DataConfig from YAML sections.
    pub fn to_data_config(&self) -> DataConfig {
        let mut cfg = DataConfig::default();
        if let Some(data) = &self.data {
            if let Some(ref ss) = data.signal_size {
                if ss.len() == 2 {
                    cfg.signal_size = (ss[0], ss[1]);
                }
            }
        }
        cfg
    }
}