use std::collections::BTreeMap;
use tribev2::config::*;
use tribev2::model::feedforward::FeedForward;
use tribev2::model::projector::Projector;
use tribev2::model::residual::Residual;
use tribev2::model::rotary::RotaryEmbedding;
use tribev2::model::scalenorm::ScaleNorm;
use tribev2::model::subject_layers::SubjectLayers;
use tribev2::model::tribe::TribeV2;
use tribev2::tensor::Tensor;
fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_scalenorm_matches_python() {
let sn = ScaleNorm::new(2);
let x = Tensor::from_vec(vec![3.0, 4.0], vec![1, 2]);
let out = sn.forward(&x);
let sqrt2 = (2.0f32).sqrt();
assert!(approx_eq(out.data[0], 0.6 * sqrt2, 1e-5));
assert!(approx_eq(out.data[1], 0.8 * sqrt2, 1e-5));
}
#[test]
fn test_scalenorm_with_custom_g() {
let mut sn = ScaleNorm::new(4);
sn.g = 2.0;
let x = Tensor::from_vec(vec![1.0, 0.0, 0.0, 0.0], vec![1, 4]);
let out = sn.forward(&x);
assert!(approx_eq(out.data[0], 4.0, 1e-5));
assert!(approx_eq(out.data[1], 0.0, 1e-5));
}
#[test]
fn test_rotary_embedding_shape() {
let rot = RotaryEmbedding::new(8);
let freqs = rot.forward(10);
assert_eq!(freqs.shape, vec![10, 8]);
for i in 0..8 {
assert!(approx_eq(freqs.data[i], 0.0, 1e-8));
}
assert!(approx_eq(freqs.data[8 + 0], freqs.data[8 + 4], 1e-8)); }
#[test]
fn test_rotary_inv_freq() {
let rot = RotaryEmbedding::new(4);
assert_eq!(rot.inv_freq.len(), 2);
assert!(approx_eq(rot.inv_freq[0], 1.0, 1e-5)); assert!(approx_eq(rot.inv_freq[1], 0.01, 1e-4)); }
#[test]
fn test_residual_no_scale() {
let res = Residual::new(3, false);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
let r = Tensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
let out = res.forward(&x, &r);
assert!(approx_eq(out.data[0], 1.1, 1e-6));
assert!(approx_eq(out.data[1], 2.2, 1e-6));
assert!(approx_eq(out.data[2], 3.3, 1e-6));
}
#[test]
fn test_residual_with_scale() {
let mut res = Residual::new(2, true);
res.residual_scale = Some(Tensor::from_vec(vec![2.0, 3.0], vec![2]));
let x = Tensor::from_vec(vec![1.0, 1.0], vec![1, 2]);
let r = Tensor::from_vec(vec![0.5, 0.5], vec![1, 2]);
let out = res.forward(&x, &r);
assert!(approx_eq(out.data[0], 2.0, 1e-6));
assert!(approx_eq(out.data[1], 2.5, 1e-6));
}
#[test]
fn test_feedforward_identity_weights() {
let ff = FeedForward::new(4, 2); let x = Tensor::zeros(&[1, 3, 4]); let out = ff.forward(&x);
assert_eq!(out.shape, vec![1, 3, 4]);
}
#[test]
fn test_subject_layers_average_mode() {
let config = SubjectLayersConfig {
n_subjects: 2,
bias: true,
subject_dropout: Some(0.1),
average_subjects: true,
..Default::default()
};
let mut sl = SubjectLayers::new(3, 2, &config);
assert_eq!(sl.weights.shape, vec![3, 3, 2]);
sl.weights.data[2 * 3 * 2 + 0] = 1.0; sl.weights.data[2 * 3 * 2 + 3] = 1.0;
if let Some(ref mut b) = sl.bias {
b.data[2 * 2 + 0] = 0.5;
b.data[2 * 2 + 1] = 0.5;
}
let x = Tensor::from_vec(vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ], vec![1, 3, 2]);
let out = sl.forward(&x, None);
assert_eq!(out.shape, vec![1, 2, 2]);
assert!(approx_eq(out.data[0], 1.5, 1e-5)); assert!(approx_eq(out.data[1], 2.5, 1e-5)); assert!(approx_eq(out.data[2], 3.5, 1e-5)); assert!(approx_eq(out.data[3], 4.5, 1e-5)); }
#[test]
fn test_subject_layers_per_subject_gather() {
let config = SubjectLayersConfig {
n_subjects: 2,
bias: false,
subject_dropout: None,
average_subjects: false,
..Default::default()
};
let mut sl = SubjectLayers::new(2, 2, &config);
sl.weights.data = vec![
1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, ];
let x = Tensor::from_vec(vec![
3.0, 4.0, 5.0, 6.0, ], vec![2, 2, 1]);
let out = sl.forward(&x, Some(&[0, 1]));
assert_eq!(out.shape, vec![2, 2, 1]);
assert!(approx_eq(out.data[0], 3.0, 1e-5));
assert!(approx_eq(out.data[1], 4.0, 1e-5));
assert!(approx_eq(out.data[2], 6.0, 1e-5));
assert!(approx_eq(out.data[3], 5.0, 1e-5));
}
#[test]
fn test_linear_projector() {
let mut proj = Projector::new_linear(3, 2);
proj.layers[0].weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0], vec![3, 2]);
proj.layers[0].bias = Tensor::from_vec(vec![0.1, 0.2], vec![2]);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
let out = proj.forward(&x);
assert_eq!(out.shape, vec![1, 2]);
assert!(approx_eq(out.data[0], 1.1, 1e-5)); assert!(approx_eq(out.data[1], 2.2, 1e-5)); }
#[test]
fn test_adaptive_avg_pool_identity() {
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
let out = x.adaptive_avg_pool1d(4);
assert_eq!(out.data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_adaptive_avg_pool_downsample() {
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![1, 6]);
let out = x.adaptive_avg_pool1d(2);
assert_eq!(out.shape, vec![1, 2]);
assert!(approx_eq(out.data[0], 2.0, 1e-5));
assert!(approx_eq(out.data[1], 5.0, 1e-5));
}
#[test]
fn test_einsum_bct_cd_bdt() {
let x = Tensor::from_vec(vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ], vec![1, 2, 3]);
let w = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
let out = x.einsum_bct_cd_bdt(&w);
assert_eq!(out.shape, vec![1, 2, 3]);
assert_eq!(out.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_model_forward_shape() {
let feature_dims = vec![
ModalityDims::new("text", 1, 128),
ModalityDims::new("audio", 1, 64),
];
let config = BrainModelConfig {
hidden: 128,
max_seq_len: 128,
extractor_aggregation: "cat".into(),
layer_aggregation: "cat".into(),
linear_baseline: true, time_pos_embedding: false,
subject_embedding: false,
dropout: 0.0,
modality_dropout: 0.0,
temporal_dropout: 0.0,
low_rank_head: None,
combiner: None,
temporal_smoothing: None,
projector: Default::default(),
encoder: None,
subject_layers: Some(SubjectLayersConfig {
n_subjects: 1,
bias: true,
subject_dropout: None,
average_subjects: false,
..Default::default()
}),
};
let model = TribeV2::new(feature_dims, 50, 5, &config);
let mut features = BTreeMap::new();
features.insert("text".to_string(), Tensor::zeros(&[2, 128, 10]));
features.insert("audio".to_string(), Tensor::zeros(&[2, 64, 10]));
let out = model.forward(&features, Some(&[0, 0]), true);
assert_eq!(out.shape, vec![2, 50, 5]);
}
#[test]
fn test_model_with_none_modality() {
let feature_dims = vec![
ModalityDims::new("text", 1, 64),
ModalityDims::none("audio"), ];
let config = BrainModelConfig {
hidden: 64,
max_seq_len: 128,
extractor_aggregation: "cat".into(),
layer_aggregation: "cat".into(),
linear_baseline: true,
time_pos_embedding: false,
subject_embedding: false,
dropout: 0.0,
modality_dropout: 0.0,
temporal_dropout: 0.0,
low_rank_head: None,
combiner: None,
temporal_smoothing: None,
projector: Default::default(),
encoder: None,
subject_layers: Some(SubjectLayersConfig {
n_subjects: 1,
bias: true,
subject_dropout: None,
average_subjects: false,
..Default::default()
}),
};
let model = TribeV2::new(feature_dims, 20, 5, &config);
let mut features = BTreeMap::new();
features.insert("text".to_string(), Tensor::zeros(&[1, 64, 10]));
let out = model.forward(&features, Some(&[0]), true);
assert_eq!(out.shape, vec![1, 20, 5]);
}
#[test]
fn test_aggregate_cat() {
let feature_dims = vec![
ModalityDims::new("a", 1, 4),
ModalityDims::new("b", 1, 4),
];
let config = BrainModelConfig {
hidden: 8,
extractor_aggregation: "cat".into(),
layer_aggregation: "cat".into(),
linear_baseline: true,
time_pos_embedding: false,
subject_embedding: false,
combiner: None,
encoder: None,
subject_layers: Some(SubjectLayersConfig { n_subjects: 1, bias: false, subject_dropout: None, average_subjects: false, ..Default::default() }),
..default_brain_config()
};
let model = TribeV2::new(feature_dims, 10, 1, &config);
let mut features = BTreeMap::new();
features.insert("a".to_string(), Tensor::from_vec(vec![1.0; 4 * 3], vec![1, 4, 3]));
features.insert("b".to_string(), Tensor::from_vec(vec![2.0; 4 * 3], vec![1, 4, 3]));
let agg = model.aggregate_features(&features);
assert_eq!(agg.shape, vec![1, 3, 8]);
}
#[test]
fn test_aggregate_sum() {
let feature_dims = vec![
ModalityDims::new("a", 1, 4),
ModalityDims::new("b", 1, 4),
];
let config = BrainModelConfig {
hidden: 4,
extractor_aggregation: "sum".into(),
layer_aggregation: "cat".into(),
linear_baseline: true,
time_pos_embedding: false,
subject_embedding: false,
combiner: None,
encoder: None,
subject_layers: Some(SubjectLayersConfig { n_subjects: 1, bias: false, subject_dropout: None, average_subjects: false, ..Default::default() }),
..default_brain_config()
};
let model = TribeV2::new(feature_dims, 10, 1, &config);
let mut features = BTreeMap::new();
features.insert("a".to_string(), Tensor::from_vec(vec![1.0; 4 * 3], vec![1, 4, 3]));
features.insert("b".to_string(), Tensor::from_vec(vec![2.0; 4 * 3], vec![1, 4, 3]));
let agg = model.aggregate_features(&features);
assert_eq!(agg.shape, vec![1, 3, 4]);
}
#[test]
fn test_attention_numerical() {
use tribev2::model::attention::Attention;
let mut attn = Attention::new(4, 2);
attn.w_q = Tensor::from_vec(vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
], vec![4, 4]);
attn.w_k = attn.w_q.clone();
attn.w_v = attn.w_q.clone();
attn.w_out = attn.w_q.clone();
let x = Tensor::from_vec(vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ], vec![1, 2, 4]);
let out = attn.forward(&x, None);
assert_eq!(out.shape, vec![1, 2, 4]);
let sum: f32 = out.data.iter().sum();
assert!(sum.is_finite());
assert!(out.data[0] > 0.0, "token 0 dim 0 should be > 0, got {}", out.data[0]);
}
#[test]
fn test_encoder_deterministic() {
use tribev2::model::encoder::XTransformerEncoder;
let config = EncoderConfig {
heads: 2,
depth: 1,
ff_mult: 2,
use_scalenorm: true,
rotary_pos_emb: true,
scale_residual: true,
..Default::default()
};
let enc = XTransformerEncoder::new(64, &config);
let x = Tensor::from_vec(vec![0.1f32; 1 * 3 * 64], vec![1, 3, 64]);
let out = enc.forward(&x);
assert_eq!(out.shape, vec![1, 3, 64]);
for v in &out.data {
assert!(approx_eq(*v, 1.0, 1e-4), "expected ~1.0, got {}", v);
}
}
#[test]
fn test_full_forward_numerical() {
let feature_dims = vec![ModalityDims::new("text", 1, 4)];
let config = BrainModelConfig {
hidden: 4,
max_seq_len: 16,
extractor_aggregation: "cat".into(),
layer_aggregation: "cat".into(),
linear_baseline: true, time_pos_embedding: false,
subject_embedding: false,
dropout: 0.0,
modality_dropout: 0.0,
temporal_dropout: 0.0,
low_rank_head: None,
combiner: None,
temporal_smoothing: None,
projector: Default::default(),
encoder: None,
subject_layers: Some(SubjectLayersConfig {
n_subjects: 1,
bias: false,
subject_dropout: None,
average_subjects: false,
..Default::default()
}),
};
let mut model = TribeV2::new(feature_dims, 2, 3, &config);
model.projectors[0].projector.layers[0].weight =
Tensor::from_vec(vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
], vec![4, 4]);
model.projectors[0].projector.layers[0].bias =
Tensor::from_vec(vec![0.0; 4], vec![4]);
model.predictor.weights = Tensor::from_vec(vec![
1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ], vec![1, 4, 2]);
let mut features = BTreeMap::new();
let text_data: Vec<f32> = (0..24).map(|i| (i as f32) * 0.1).collect();
features.insert("text".to_string(), Tensor::from_vec(text_data, vec![1, 4, 6]));
let out = model.forward(&features, Some(&[0]), true);
assert_eq!(out.shape, vec![1, 2, 3]);
assert!(approx_eq(out.data[0], 0.05, 1e-5), "got {}", out.data[0]);
assert!(approx_eq(out.data[1], 0.25, 1e-5), "got {}", out.data[1]);
assert!(approx_eq(out.data[2], 0.45, 1e-5), "got {}", out.data[2]);
assert!(approx_eq(out.data[3], 0.65, 1e-5), "got {}", out.data[3]);
assert!(approx_eq(out.data[4], 0.85, 1e-5), "got {}", out.data[4]);
assert!(approx_eq(out.data[5], 1.05, 1e-5), "got {}", out.data[5]);
}
#[test]
fn test_layer_aggregation_mean() {
let feature_dims = vec![ModalityDims::new("text", 2, 4)];
let config = BrainModelConfig {
hidden: 4,
extractor_aggregation: "cat".into(),
layer_aggregation: "mean".into(),
linear_baseline: true,
time_pos_embedding: false,
subject_embedding: false,
combiner: None,
encoder: None,
subject_layers: Some(SubjectLayersConfig {
n_subjects: 1, bias: false, subject_dropout: None, average_subjects: false, ..Default::default()
}),
..default_brain_config()
};
let mut model = TribeV2::new(feature_dims, 2, 2, &config);
model.projectors[0].projector.layers[0].weight =
Tensor::from_vec(vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0], vec![4, 4]);
model.projectors[0].projector.layers[0].bias = Tensor::from_vec(vec![0.0; 4], vec![4]);
let data: Vec<f32> = vec![
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
];
let mut features = BTreeMap::new();
features.insert("text".to_string(), Tensor::from_vec(data, vec![1, 2, 4, 3]));
let agg = model.aggregate_features(&features);
assert_eq!(agg.shape, vec![1, 3, 4]);
for v in &agg.data {
assert!(approx_eq(*v, 2.0, 1e-5), "expected 2.0, got {}", v);
}
}
fn default_brain_config() -> BrainModelConfig {
BrainModelConfig {
hidden: 64,
max_seq_len: 128,
extractor_aggregation: "cat".into(),
layer_aggregation: "cat".into(),
linear_baseline: false,
time_pos_embedding: false,
subject_embedding: false,
dropout: 0.0,
modality_dropout: 0.0,
temporal_dropout: 0.0,
low_rank_head: None,
combiner: None,
temporal_smoothing: None,
projector: Default::default(),
encoder: None,
subject_layers: Some(SubjectLayersConfig::default()),
}
}
#[test]
fn test_parse_real_config() {
let yaml = std::fs::read_to_string("/tmp/tribev2_config.yaml").unwrap();
let config: Result<TribeV2Config, _> = serde_yaml::from_str(&yaml);
match &config {
Ok(c) => {
assert_eq!(c.brain_model_config.hidden, 1152);
assert_eq!(c.brain_model_config.encoder.as_ref().unwrap().depth, 8);
assert_eq!(c.brain_model_config.encoder.as_ref().unwrap().heads, 8);
assert_eq!(c.brain_model_config.low_rank_head, Some(2048));
assert_eq!(c.data.features_to_use, vec!["text", "audio", "video"]);
assert_eq!(c.data.duration_trs, 100);
}
Err(e) => panic!("Failed to parse real config.yaml: {}", e),
}
}