use serde::Deserialize;
use std::collections::BTreeMap;
#[derive(Debug, Clone, Deserialize)]
pub struct TribeV2Config {
pub brain_model_config: BrainModelConfig,
pub data: DataConfig,
#[serde(default)]
pub average_subjects: bool,
#[serde(default)]
pub seed: Option<u64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DataConfig {
#[serde(default = "default_features_to_use")]
pub features_to_use: Vec<String>,
#[serde(default)]
pub features_to_mask: Vec<String>,
#[serde(default = "default_duration_trs")]
pub duration_trs: usize,
#[serde(default)]
pub overlap_trs_val: usize,
#[serde(default)]
pub stride_drop_incomplete: bool,
#[serde(default)]
pub frequency: Option<f64>,
pub text_feature: Option<TextFeatureConfig>,
pub audio_feature: Option<AudioFeatureConfig>,
pub video_feature: Option<VideoFeatureConfig>,
pub subject_id: Option<SubjectIdConfig>,
}
fn default_features_to_use() -> Vec<String> {
vec!["text".into(), "audio".into(), "video".into()]
}
fn default_duration_trs() -> usize { 100 }
#[derive(Debug, Clone, Deserialize)]
pub struct TextFeatureConfig {
pub model_name: Option<String>,
#[serde(default)]
pub layers: Vec<f64>,
#[serde(default)]
pub layer_aggregation: Option<String>,
#[serde(default = "default_frequency")]
pub frequency: f64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AudioFeatureConfig {
pub model_name: Option<String>,
#[serde(default)]
pub layers: Vec<f64>,
#[serde(default)]
pub layer_aggregation: Option<String>,
#[serde(default = "default_frequency")]
pub frequency: f64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct VideoFeatureConfig {
pub image: Option<VideoImageConfig>,
#[serde(default)]
pub layers: Vec<f64>,
#[serde(default)]
pub layer_aggregation: Option<String>,
#[serde(default = "default_frequency")]
pub frequency: f64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct VideoImageConfig {
pub model_name: Option<String>,
#[serde(default)]
pub layers: Vec<f64>,
#[serde(default)]
pub layer_aggregation: Option<String>,
}
fn default_frequency() -> f64 { 2.0 }
#[derive(Debug, Clone, Deserialize)]
pub struct SubjectIdConfig {
#[serde(default)]
pub predefined_mapping: Option<BTreeMap<String, usize>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BrainModelConfig {
#[serde(default)]
pub projector: MlpConfig,
#[serde(default)]
pub combiner: Option<MlpConfig>,
#[serde(default)]
pub encoder: Option<EncoderConfig>,
#[serde(default = "default_true")]
pub time_pos_embedding: bool,
#[serde(default)]
pub subject_embedding: bool,
#[serde(default)]
pub subject_layers: Option<SubjectLayersConfig>,
#[serde(default = "default_hidden")]
pub hidden: usize,
#[serde(default = "default_max_seq_len")]
pub max_seq_len: usize,
#[serde(default)]
pub dropout: f64,
#[serde(default = "default_cat")]
pub extractor_aggregation: String,
#[serde(default = "default_cat")]
pub layer_aggregation: String,
#[serde(default)]
pub linear_baseline: bool,
#[serde(default)]
pub modality_dropout: f64,
#[serde(default)]
pub temporal_dropout: f64,
#[serde(default)]
pub low_rank_head: Option<usize>,
#[serde(default)]
pub temporal_smoothing: Option<TemporalSmoothingConfig>,
}
fn default_true() -> bool { true }
fn default_hidden() -> usize { 1152 }
fn default_max_seq_len() -> usize { 1024 }
fn default_cat() -> String { "cat".into() }
#[derive(Debug, Clone, Deserialize, Default)]
pub struct MlpConfig {
pub input_size: Option<usize>,
pub hidden_sizes: Option<Vec<usize>>,
pub norm_layer: Option<String>,
pub activation_layer: Option<String>,
#[serde(default = "default_true")]
pub bias: bool,
#[serde(default)]
pub dropout: f64,
#[serde(default)]
pub name: Option<String>,
}
impl MlpConfig {
pub fn is_identity(&self, output_size: Option<usize>) -> bool {
self.hidden_sizes.as_ref().map_or(true, |h| h.is_empty()) && output_size.is_none()
}
pub fn is_single_linear(&self, output_size: Option<usize>) -> bool {
self.hidden_sizes.as_ref().map_or(true, |h| h.is_empty()) && output_size.is_some()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct EncoderConfig {
#[serde(default = "default_heads")]
pub heads: usize,
#[serde(default = "default_depth")]
pub depth: usize,
#[serde(default)]
pub cross_attend: bool,
#[serde(default)]
pub causal: bool,
#[serde(default)]
pub attn_flash: bool,
#[serde(default)]
pub attn_dropout: f64,
#[serde(default = "default_ff_mult")]
pub ff_mult: usize,
#[serde(default)]
pub ff_dropout: f64,
#[serde(default = "default_true")]
pub use_scalenorm: bool,
#[serde(default)]
pub use_rmsnorm: bool,
#[serde(default)]
pub rel_pos_bias: bool,
#[serde(default)]
pub alibi_pos_bias: bool,
#[serde(default = "default_true")]
pub rotary_pos_emb: bool,
#[serde(default)]
pub rotary_xpos: bool,
#[serde(default)]
pub residual_attn: bool,
#[serde(default = "default_true")]
pub scale_residual: bool,
#[serde(default)]
pub layer_dropout: f64,
#[serde(default)]
pub name: Option<String>,
}
fn default_heads() -> usize { 8 }
fn default_depth() -> usize { 8 }
fn default_ff_mult() -> usize { 4 }
impl Default for EncoderConfig {
fn default() -> Self {
Self {
heads: 8,
depth: 8,
cross_attend: false,
causal: false,
attn_flash: false,
attn_dropout: 0.0,
ff_mult: 4,
ff_dropout: 0.0,
use_scalenorm: true,
use_rmsnorm: false,
rel_pos_bias: false,
alibi_pos_bias: false,
rotary_pos_emb: true,
rotary_xpos: false,
residual_attn: false,
scale_residual: true,
layer_dropout: 0.0,
name: None,
}
}
}
impl EncoderConfig {
pub fn dim_head(&self, dim: usize) -> usize {
dim / self.heads
}
pub fn rotary_emb_dim(&self, dim: usize) -> usize {
(self.dim_head(dim) / 2).max(32)
}
pub fn ff_inner_dim(&self, dim: usize) -> usize {
dim * self.ff_mult
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct SubjectLayersConfig {
#[serde(default = "default_n_subjects")]
pub n_subjects: usize,
#[serde(default = "default_true")]
pub bias: bool,
#[serde(default)]
pub init_id: bool,
#[serde(default = "default_gather")]
pub mode: String,
#[serde(default)]
pub subject_dropout: Option<f64>,
#[serde(default)]
pub average_subjects: bool,
#[serde(default)]
pub name: Option<String>,
}
fn default_n_subjects() -> usize { 25 }
fn default_gather() -> String { "gather".into() }
impl Default for SubjectLayersConfig {
fn default() -> Self {
Self {
n_subjects: 25,
bias: true,
init_id: false,
mode: "gather".into(),
subject_dropout: Some(0.1),
average_subjects: false,
name: None,
}
}
}
impl SubjectLayersConfig {
pub fn num_weight_subjects(&self) -> usize {
if self.subject_dropout.is_some() {
self.n_subjects + 1
} else {
self.n_subjects
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TemporalSmoothingConfig {
#[serde(default = "default_kernel_size")]
pub kernel_size: usize,
#[serde(default)]
pub sigma: Option<f64>,
#[serde(default)]
pub name: Option<String>,
}
fn default_kernel_size() -> usize { 9 }
#[derive(Debug, Clone)]
pub struct ModalityDims {
pub name: String,
pub dims: Option<(usize, usize)>,
}
impl ModalityDims {
pub fn new(name: &str, num_layers: usize, feature_dim: usize) -> Self {
Self { name: name.to_string(), dims: Some((num_layers, feature_dim)) }
}
pub fn none(name: &str) -> Self {
Self { name: name.to_string(), dims: None }
}
pub fn num_layers(&self) -> usize {
self.dims.map_or(0, |(l, _)| l)
}
pub fn feature_dim(&self) -> usize {
self.dims.map_or(0, |(_, d)| d)
}
pub fn pretrained() -> Vec<Self> {
vec![
Self::new("text", 3, 3072),
Self::new("audio", 3, 1024),
Self::new("video", 3, 1408),
]
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelBuildArgs {
pub feature_dims: BTreeMap<String, Option<Vec<usize>>>,
pub n_outputs: usize,
pub n_output_timesteps: usize,
}
impl ModelBuildArgs {
pub fn from_json(path: &str) -> anyhow::Result<Self> {
let json = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&json)?)
}
pub fn to_modality_dims(&self) -> Vec<ModalityDims> {
self.feature_dims.iter().map(|(name, dims)| {
match dims {
Some(v) if v.len() == 2 => ModalityDims::new(name, v[0], v[1]),
_ => ModalityDims::none(name),
}
}).collect()
}
}