use std::collections::BTreeMap;
use crate::tensor::Tensor;
use crate::config::{BrainModelConfig, ModalityDims, ModelBuildArgs, TribeV2Config};
use crate::weights::{WeightMap, load_weights};
use super::projector::Projector;
use super::encoder::XTransformerEncoder;
use super::subject_layers::SubjectLayers;
use super::temporal_smoothing::TemporalSmoothing;
#[derive(Debug, Clone)]
pub struct NamedProjector {
pub name: String,
pub projector: Projector,
}
#[derive(Debug, Clone)]
pub struct TribeV2 {
pub projectors: Vec<NamedProjector>,
pub combiner: Option<Projector>,
pub time_pos_embed: Option<Tensor>,
pub encoder: Option<XTransformerEncoder>,
pub subject_embed: Option<Tensor>,
pub low_rank_head: Option<Tensor>,
pub predictor: SubjectLayers,
pub temporal_smoothing: Option<TemporalSmoothing>,
pub feature_dims: Vec<ModalityDims>,
pub n_outputs: usize,
pub n_output_timesteps: usize,
pub config: BrainModelConfig,
}
impl TribeV2 {
pub fn new(
feature_dims: Vec<ModalityDims>,
n_outputs: usize,
n_output_timesteps: usize,
config: &BrainModelConfig,
) -> Self {
let hidden = config.hidden;
let n_modalities = feature_dims.len();
let mut projectors = Vec::new();
for md in &feature_dims {
if let Some((num_layers, feature_dim)) = md.dims {
let input_dim = if config.layer_aggregation == "cat" {
feature_dim * num_layers
} else {
feature_dim
};
let output_dim = if config.extractor_aggregation == "cat" {
hidden / n_modalities
} else {
hidden
};
let proj = if config.projector.name.as_deref() == Some("SubjectLayers") {
let sl_cfg = config.subject_layers.clone().unwrap_or_default();
Projector::new_subject_layers(input_dim, output_dim, &sl_cfg)
} else if let Some(ref hs) = config.projector.hidden_sizes {
if !hs.is_empty() {
let has_norm = config.projector.norm_layer.as_deref() == Some("layer");
Projector::new_mlp(input_dim, output_dim, hs, has_norm)
} else {
Projector::new_linear(input_dim, output_dim)
}
} else {
Projector::new_linear(input_dim, output_dim)
};
projectors.push(NamedProjector {
name: md.name.clone(),
projector: proj,
});
}
}
let combiner_input_dim = if config.extractor_aggregation == "cat" {
(hidden / n_modalities) * n_modalities
} else {
hidden
};
let combiner = if config.combiner.is_some() {
Some(Projector::new_linear(combiner_input_dim, hidden))
} else {
None
};
let time_pos_embed = if config.time_pos_embedding && !config.linear_baseline {
Some(Tensor::zeros(&[1, config.max_seq_len, hidden]))
} else {
None
};
let encoder = if !config.linear_baseline {
config.encoder.as_ref().map(|enc_config| {
XTransformerEncoder::new(hidden, enc_config)
})
} else {
None
};
let subject_embed = if config.subject_embedding && !config.linear_baseline {
let n_subjects = config.subject_layers.as_ref().map_or(200, |sl| sl.n_subjects);
Some(Tensor::zeros(&[n_subjects, hidden]))
} else {
None
};
let low_rank_head = config.low_rank_head.map(|lr| {
Tensor::zeros(&[hidden, lr])
});
let bottleneck = config.low_rank_head.unwrap_or(hidden);
let sl_config = config.subject_layers.clone().unwrap_or_default();
let predictor = SubjectLayers::new(bottleneck, n_outputs, &sl_config);
let temporal_smoothing = config.temporal_smoothing.as_ref().map(|ts| {
if let Some(sigma) = ts.sigma {
TemporalSmoothing::new_gaussian(hidden, ts.kernel_size, sigma)
} else {
TemporalSmoothing::new_learnable(hidden, ts.kernel_size)
}
});
Self {
projectors,
combiner,
time_pos_embed,
subject_embed,
encoder,
low_rank_head,
predictor,
temporal_smoothing,
feature_dims,
n_outputs,
n_output_timesteps,
config: config.clone(),
}
}
pub fn from_pretrained(
config_path: &str,
weights_path: &str,
build_args_path: Option<&str>,
) -> anyhow::Result<Self> {
let yaml = std::fs::read_to_string(config_path)?;
let mut config: TribeV2Config = serde_yaml::from_str(&yaml)?;
if let Some(ref mut sl) = config.brain_model_config.subject_layers {
sl.average_subjects = true;
sl.n_subjects = 0;
}
let (feature_dims, n_outputs, n_output_timesteps) = if let Some(ba_path) = build_args_path {
let ba = ModelBuildArgs::from_json(ba_path)?;
(ba.to_modality_dims(), ba.n_outputs, ba.n_output_timesteps)
} else {
(ModalityDims::pretrained(), 20484, config.data.duration_trs)
};
let mut model = Self::new(
feature_dims,
n_outputs,
n_output_timesteps,
&config.brain_model_config,
);
let mut wm = WeightMap::from_safetensors(weights_path)?;
load_weights(&mut wm, &mut model)?;
Ok(model)
}
pub fn aggregate_features(
&self,
features: &BTreeMap<String, Tensor>,
) -> Tensor {
self.aggregate_features_with_subjects(features, None)
}
pub fn aggregate_features_with_subjects(
&self,
features: &BTreeMap<String, Tensor>,
subject_ids: Option<&[usize]>,
) -> Tensor {
let n_modalities = self.feature_dims.len();
let hidden = self.config.hidden;
let first = features.values().next().expect("no features provided");
let b = first.shape[0];
let t = *first.shape.last().unwrap();
let mut tensors = Vec::new();
for md in &self.feature_dims {
let projector = self.projectors.iter().find(|np| np.name == md.name);
let has_projector = projector.is_some();
if has_projector && features.contains_key(&md.name) {
let projector = &projector.unwrap().projector;
let data = features.get(&md.name).unwrap();
let mut data = data.clone();
if data.ndim() == 3 {
data = data.reshape(&[b, 1, data.shape[1], t]);
}
let data = if self.config.layer_aggregation == "mean" {
let l = data.shape[1];
let d = data.shape[2];
let mut mean_data = vec![0.0f32; b * d * t];
for bi in 0..b {
for di in 0..d {
for ti in 0..t {
let mut sum = 0.0f32;
for li in 0..l {
sum += data.data[bi * l * d * t + li * d * t + di * t + ti];
}
mean_data[bi * d * t + di * t + ti] = sum / l as f32;
}
}
}
Tensor::from_vec(mean_data, vec![b, d, t])
} else {
let l = data.shape[1];
let d = data.shape[2];
data.reshape(&[b, l * d, t])
};
let data = data.permute(&[0, 2, 1]);
let data = projector.forward_with_subjects(&data, subject_ids);
tensors.push(data);
} else {
let out_dim = if self.config.extractor_aggregation == "cat" {
hidden / n_modalities
} else {
hidden
};
tensors.push(Tensor::zeros(&[b, t, out_dim]));
}
}
if self.config.extractor_aggregation == "cat" {
let refs: Vec<&Tensor> = tensors.iter().collect();
Tensor::cat_last(&refs)
} else if self.config.extractor_aggregation == "sum" {
let refs: Vec<&Tensor> = tensors.iter().collect();
Tensor::sum_tensors(&refs)
} else {
let refs: Vec<&Tensor> = tensors.iter().collect();
Tensor::cat_dim1(&refs)
}
}
fn transformer_forward(&self, x: &Tensor, subject_ids: Option<&[usize]>) -> Tensor {
let mut x = if let Some(ref combiner) = self.combiner {
combiner.forward(x)
} else {
x.clone()
};
if let Some(ref tpe) = self.time_pos_embed {
let t = x.shape[1];
let tpe_slice = tpe.slice_dim1(0, t); let tpe_expanded = tpe_slice.reshape(&[1, t, self.config.hidden]);
let b = x.shape[0];
let h = self.config.hidden;
for bi in 0..b {
for ti in 0..t {
for hi in 0..h {
let idx = bi * t * h + ti * h + hi;
let tpe_idx = ti * h + hi;
x.data[idx] += tpe_expanded.data[tpe_idx];
}
}
}
}
if let Some(ref se) = self.subject_embed {
if let Some(sids) = subject_ids {
let (b, t, h) = (x.shape[0], x.shape[1], x.shape[2]);
for bi in 0..b {
let sid = if bi < sids.len() { sids[bi] } else { 0 };
let emb_offset = sid * h;
for ti in 0..t {
for hi in 0..h {
x.data[bi * t * h + ti * h + hi] += se.data[emb_offset + hi];
}
}
}
}
}
if let Some(ref encoder) = self.encoder {
x = encoder.forward(&x);
}
x
}
pub fn forward(
&self,
features: &BTreeMap<String, Tensor>,
subject_ids: Option<&[usize]>,
pool_outputs: bool,
) -> Tensor {
let mut x = self.aggregate_features_with_subjects(features, subject_ids);
if let Some(ref ts) = self.temporal_smoothing {
x = x.permute(&[0, 2, 1]); x = ts.forward(&x);
x = x.permute(&[0, 2, 1]); }
if !self.config.linear_baseline {
x = self.transformer_forward(&x, subject_ids);
}
x = x.permute(&[0, 2, 1]);
if let Some(ref lr_weight) = self.low_rank_head {
let (b, h, t) = (x.shape[0], x.shape[1], x.shape[2]);
x = x.permute(&[0, 2, 1]); x = x.reshape(&[b * t, h]).matmul(lr_weight).reshape(&[b, t, lr_weight.shape[1]]);
x = x.permute(&[0, 2, 1]); }
x = self.predictor.forward(&x, subject_ids);
if pool_outputs {
let (b, d, _t) = (x.shape[0], x.shape[1], x.shape[2]);
let mut pooled_data = Vec::with_capacity(b * d * self.n_output_timesteps);
let t_in = x.shape[2];
for bi in 0..b {
for di in 0..d {
let base = bi * d * t_in + di * t_in;
for i in 0..self.n_output_timesteps {
let start = (i * t_in) / self.n_output_timesteps;
let end = ((i + 1) * t_in + self.n_output_timesteps - 1) / self.n_output_timesteps;
let len = (end - start) as f32;
let sum: f32 = x.data[base + start..base + end].iter().sum();
pooled_data.push(sum / len);
}
}
}
x = Tensor::from_vec(pooled_data, vec![b, d, self.n_output_timesteps]);
}
x
}
pub fn predict_from_text_features(
&self,
text_features: &[Vec<f32>],
n_timesteps: usize,
) -> anyhow::Result<Vec<Vec<f32>>> {
let t = text_features.len().min(n_timesteps);
if t == 0 {
anyhow::bail!("no timesteps provided");
}
let d = text_features[0].len();
let mut features = BTreeMap::new();
let mut text_data = vec![0.0f32; d * t];
for ti in 0..t {
for di in 0..d {
text_data[di * t + ti] = text_features[ti][di];
}
}
features.insert("text".to_string(), Tensor::from_vec(text_data, vec![1, d, t]));
let out = self.forward(&features, None, true);
let n_out = out.shape[1];
let t_out = out.shape[2];
let mut result = Vec::with_capacity(t_out);
for ti in 0..t_out {
let mut row = Vec::with_capacity(n_out);
for di in 0..n_out {
row.push(out.data[di * t_out + ti]);
}
result.push(row);
}
Ok(result)
}
pub fn modality_ablation(
&self,
features: &BTreeMap<String, Tensor>,
subject_ids: Option<&[usize]>,
) -> BTreeMap<String, Vec<f32>> {
let full_output = self.forward(features, subject_ids, true);
let n_out = full_output.shape[1];
let n_t = full_output.shape[2];
let full_avg: Vec<f32> = (0..n_out)
.map(|di| {
let base = di * n_t;
full_output.data[base..base + n_t].iter().sum::<f32>() / n_t as f32
})
.collect();
let mut contributions = BTreeMap::new();
for md in &self.feature_dims {
if !features.contains_key(&md.name) {
continue;
}
let mut ablated_features = features.clone();
if let Some(tensor) = ablated_features.get(&md.name) {
let zeros = Tensor::zeros(&tensor.shape);
ablated_features.insert(md.name.clone(), zeros);
}
let ablated_output = self.forward(&ablated_features, subject_ids, true);
let contribution: Vec<f32> = (0..n_out)
.map(|di| {
let base = di * n_t;
let ablated_avg: f32 = ablated_output.data[base..base + n_t].iter().sum::<f32>() / n_t as f32;
(full_avg[di] - ablated_avg).abs()
})
.collect();
contributions.insert(md.name.clone(), contribution);
}
contributions
}
}