use brainharmony::{ModelConfig, DataConfig, BrainHarmonyError};
#[test]
fn from_variant_vit_base() {
let cfg = ModelConfig::from_variant("vit_base").unwrap();
assert_eq!(cfg.embed_dim, 768);
assert_eq!(cfg.depth, 12);
assert_eq!(cfg.num_heads, 12);
assert_eq!(cfg.head_dim(), 64);
assert_eq!(cfg.mlp_hidden_dim(), 3072);
}
#[test]
fn from_variant_vit_small() {
let cfg = ModelConfig::from_variant("vit_small").unwrap();
assert_eq!(cfg.embed_dim, 384);
assert_eq!(cfg.depth, 12);
assert_eq!(cfg.num_heads, 6);
assert_eq!(cfg.head_dim(), 64);
assert!(cfg.add_pre_mapping);
}
#[test]
fn from_variant_vit_large() {
let cfg = ModelConfig::from_variant("vit_large").unwrap();
assert_eq!(cfg.embed_dim, 1024);
assert_eq!(cfg.depth, 24);
assert_eq!(cfg.num_heads, 16);
assert!(cfg.add_pre_mapping);
}
#[test]
fn from_variant_unknown_returns_error() {
let err = ModelConfig::from_variant("vit_huge").unwrap_err();
match err {
BrainHarmonyError::UnknownVariant { name } => assert_eq!(name, "vit_huge"),
other => panic!("expected UnknownVariant, got: {other}"),
}
}
#[test]
fn default_data_config() {
let cfg = DataConfig::default();
assert_eq!(cfg.n_cortical_rois, 400);
assert_eq!(cfg.n_time_patches, 18);
assert_eq!(cfg.n_structural_tokens, 1200);
assert_eq!(cfg.total_tokens, 8400);
assert_eq!(cfg.n_cortical_tokens(), 7200);
assert_eq!(cfg.signal_size, (400, 864));
}
#[test]
fn model_config_default_is_vit_base() {
let cfg = ModelConfig::default();
assert_eq!(cfg.model_name, "vit_base");
assert_eq!(cfg.patch_size, 48);
assert_eq!(cfg.pred_depth, 6);
assert_eq!(cfg.pred_emb_dim, 384);
assert_eq!(cfg.num_latent_tokens, 128);
assert_eq!(cfg.grad_dim, 30);
assert_eq!(cfg.geoh_dim, 200);
}