#[cfg(test)]
#[allow(clippy::module_inception)]
mod tests {
use crate::mixtral::config::MixtralConfig;
use crate::mixtral::model::{
compute_load_balancing_loss, MixtralForCausalLM, MixtralModel, MixtralSparseMoeBlock,
};
use crate::mixtral::tasks::MixtralCausalLMTask;
use trustformers_core::{tensor::Tensor, traits::Config, traits::Layer, traits::Model};
fn tiny_config() -> MixtralConfig {
MixtralConfig {
hidden_size: 64,
intermediate_size: 256,
num_hidden_layers: 2,
num_attention_heads: 8,
num_key_value_heads: 2,
num_local_experts: 4,
num_experts_per_tok: 2,
vocab_size: 512,
max_position_embeddings: 128,
sliding_window: None,
rope_theta: 10000.0,
rms_norm_eps: 1e-5,
hidden_act: "silu".to_string(),
router_aux_loss_coef: 0.02,
model_type: "mixtral".to_string(),
}
}
#[test]
fn test_mixtral_config() {
let config = MixtralConfig::mixtral_8x7b();
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.num_local_experts, 8);
assert_eq!(config.num_experts_per_tok, 2);
assert_eq!(config.vocab_size, 32000);
config.validate().expect("8x7b config should be valid");
let config22 = MixtralConfig::mixtral_8x22b();
assert_eq!(config22.hidden_size, 6144);
assert_eq!(config22.num_local_experts, 8);
config22.validate().expect("8x22b config should be valid");
}
#[test]
fn test_moe_router_output_shape() {
let config = tiny_config();
let moe = MixtralSparseMoeBlock::new(&config).expect("moe creation");
let batch = 2usize;
let seq = 3usize;
let h = config.hidden_size;
let hidden =
Tensor::from_vec(vec![0.1f32; batch * seq * h], &[batch, seq, h]).expect("tensor");
let router_logits = moe.router_logits(&hidden).expect("router forward");
let shape = router_logits.shape().to_vec();
assert_eq!(shape[0], batch * seq);
assert_eq!(shape[1], config.num_local_experts);
}
#[test]
fn test_moe_top_k_selection() {
let config = tiny_config();
let moe = MixtralSparseMoeBlock::new(&config).expect("moe creation");
let h = config.hidden_size;
let hidden = Tensor::from_vec(vec![0.5f32; h], &[1, 1, h]).expect("tensor");
let (output, router_logits) = moe.forward_with_router_logits(hidden).expect("forward");
let shape = output.shape().to_vec();
assert_eq!(shape, vec![1, 1, h]);
let r_shape = router_logits.shape().to_vec();
assert_eq!(r_shape[1], config.num_local_experts);
}
#[test]
fn test_moe_routing_weights_sum() {
let config = tiny_config();
let moe = MixtralSparseMoeBlock::new(&config).expect("moe creation");
let h = config.hidden_size;
let num_tokens = 8usize;
let mut data = Vec::with_capacity(num_tokens * h);
for i in 0..num_tokens * h {
data.push((i as f32) * 0.01);
}
let hidden = Tensor::from_vec(data, &[1, num_tokens, h]).expect("tensor");
let (output, _) = moe.forward_with_router_logits(hidden).expect("forward");
match &output {
Tensor::F32(arr) => {
for &v in arr.iter() {
assert!(v.is_finite(), "output must be finite");
}
},
_ => panic!("expected F32"),
}
}
#[test]
fn test_mixtral_decoder_layer_shape() {
use crate::mixtral::model::MixtralDecoderLayer;
let config = tiny_config();
let layer = MixtralDecoderLayer::new(&config).expect("layer creation");
let batch = 1usize;
let seq = 4usize;
let h = config.hidden_size;
let input =
Tensor::from_vec(vec![0.1f32; batch * seq * h], &[batch, seq, h]).expect("tensor");
match layer.forward(input) {
Ok(output) => {
let shape = output.shape().to_vec();
assert_eq!(shape, vec![batch, seq, h]);
},
Err(_) => { },
}
}
#[test]
fn test_mixtral_model_shapes() {
let config = tiny_config();
let model = MixtralModel::new(config.clone()).expect("model creation");
let input_ids: Vec<u32> = vec![1, 2, 3, 4];
match model.forward(input_ids) {
Ok(output) => {
let shape = output.shape().to_vec();
assert_eq!(shape[1], 4); assert_eq!(shape[2], config.hidden_size);
},
Err(_) => { },
}
}
#[test]
fn test_mixtral_weight_map() {
let map = MixtralForCausalLM::weight_map();
assert!(!map.is_empty());
let hf_keys: Vec<&str> = map.iter().map(|(hf, _)| *hf).collect();
assert!(hf_keys.contains(&"model.embed_tokens.weight"));
assert!(hf_keys.contains(&"lm_head.weight"));
assert!(hf_keys.iter().any(|k| k.contains("block_sparse_moe")));
}
#[test]
fn test_moe_load_balancing_loss() {
let num_experts = 4usize;
let num_tokens = 8usize;
let logits_data = vec![1.0f32; num_tokens * num_experts];
let logits =
Tensor::from_vec(logits_data, &[num_tokens, num_experts]).expect("logits tensor");
let loss =
compute_load_balancing_loss(&logits, num_experts, 2, 0.02).expect("loss computation");
assert!(loss.is_finite(), "loss must be finite: {}", loss);
assert!(loss >= 0.0, "loss must be non-negative: {}", loss);
}
#[test]
fn test_mixtral_for_causal_lm() {
let config = tiny_config();
let model = MixtralForCausalLM::new(config.clone()).expect("causal lm");
let ids: Vec<u32> = vec![10, 20, 30];
match model.forward(ids) {
Ok(logits) => {
let shape = logits.shape().to_vec();
assert_eq!(shape[1], 3);
assert_eq!(shape[2], config.vocab_size);
},
Err(_) => { },
}
}
#[test]
fn test_mixtral_task_wrapper() {
let config = tiny_config();
let task = MixtralCausalLMTask::new(config.clone()).expect("task creation");
let ids: Vec<u32> = vec![1, 2];
match task.forward(ids) {
Ok(logits) => {
let shape = logits.shape().to_vec();
assert_eq!(shape[2], config.vocab_size);
},
Err(_) => { },
}
}
#[test]
fn test_mixtral_8x7b_experts_per_tok_invariant() {
let cfg = MixtralConfig::mixtral_8x7b();
assert_eq!(
cfg.num_experts_per_tok, 2,
"Mixtral 8x7B must route to exactly 2 experts"
);
}
#[test]
fn test_mixtral_8x7b_num_local_experts_invariant() {
let cfg = MixtralConfig::mixtral_8x7b();
assert_eq!(
cfg.num_local_experts, 8,
"Mixtral 8x7B has 8 experts per layer"
);
}
#[test]
fn test_mixtral_8x22b_architecture_label() {
let cfg = MixtralConfig::mixtral_8x22b();
assert_eq!(cfg.architecture(), "Mixtral");
cfg.validate().expect("8x22b must be valid");
}
#[test]
fn test_mixtral_8x22b_experts_per_tok_invariant() {
let cfg = MixtralConfig::mixtral_8x22b();
assert_eq!(
cfg.num_experts_per_tok, 2,
"Mixtral 8x22B must also route to 2 experts"
);
}
#[test]
fn test_mixtral_head_dim_formula() {
let cfg = MixtralConfig::mixtral_8x7b();
let expected = cfg.hidden_size / cfg.num_attention_heads; assert_eq!(
cfg.head_dim(),
expected,
"head_dim must equal hidden_size / num_heads"
);
}
#[test]
fn test_mixtral_8x22b_head_dim_formula() {
let cfg = MixtralConfig::mixtral_8x22b();
let expected = cfg.hidden_size / cfg.num_attention_heads;
assert_eq!(cfg.head_dim(), expected);
}
#[test]
fn test_mixtral_8x7b_gqa_groups() {
let cfg = MixtralConfig::mixtral_8x7b();
assert_eq!(cfg.num_query_groups(), 4);
assert!(cfg.num_attention_heads > cfg.num_key_value_heads);
}
#[test]
fn test_mixtral_sliding_window_none() {
let cfg = MixtralConfig::mixtral_8x7b();
assert!(
cfg.sliding_window.is_none(),
"Mixtral uses full attention, sliding_window must be None"
);
}
#[test]
fn test_mixtral_8x7b_rope_theta() {
let cfg = MixtralConfig::mixtral_8x7b();
assert!(
(cfg.rope_theta - 1_000_000.0).abs() < 1.0,
"Mixtral 8x7B rope_theta should be 1e6, got {}",
cfg.rope_theta
);
}
#[test]
fn test_mixtral_moe_activation_silu() {
let cfg = MixtralConfig::mixtral_8x7b();
assert_eq!(
cfg.hidden_act, "silu",
"Mixtral experts use SwiGLU (SiLU gate)"
);
}
#[test]
fn test_mixtral_config_invalid_experts_per_tok_exceeds_num_experts() {
let cfg = MixtralConfig {
num_local_experts: 4,
num_experts_per_tok: 5, ..tiny_config()
};
assert!(
cfg.validate().is_err(),
"num_experts_per_tok > num_local_experts must be invalid"
);
}
#[test]
fn test_mixtral_config_invalid_zero_experts_per_tok() {
let cfg = MixtralConfig {
num_experts_per_tok: 0,
..tiny_config()
};
assert!(
cfg.validate().is_err(),
"num_experts_per_tok=0 must be invalid"
);
}
#[test]
fn test_mixtral_config_invalid_hidden_not_divisible_by_heads() {
let cfg = MixtralConfig {
hidden_size: 65, num_attention_heads: 8,
..tiny_config()
};
assert!(
cfg.validate().is_err(),
"hidden_size % num_attention_heads != 0 must fail"
);
}
#[test]
fn test_mixtral_config_zero_local_experts_invalid() {
let cfg = MixtralConfig {
num_local_experts: 0,
num_experts_per_tok: 1,
..tiny_config()
};
assert!(
cfg.validate().is_err(),
"num_local_experts=0 must be invalid"
);
}
#[test]
fn test_moe_load_balancing_loss_uniform_balanced() {
let ne = 4usize;
let nt = 16usize;
let logits_data = vec![0.0f32; nt * ne];
let logits = Tensor::from_vec(logits_data, &[nt, ne]).expect("logits");
let loss = compute_load_balancing_loss(&logits, ne, 2, 0.02).expect("lb loss");
assert!(
loss.is_finite() && loss >= 0.0,
"uniform loss must be finite non-neg: {loss}"
);
}
#[test]
fn test_moe_load_balancing_loss_wrong_shape_error() {
let logits = Tensor::from_vec(vec![1.0f32; 8 * 3], &[8, 3]).expect("logits");
let result = compute_load_balancing_loss(&logits, 4, 2, 0.02);
assert!(
result.is_err(),
"mismatched num_experts must return an error"
);
}
#[test]
fn test_mixtral_causal_lm_parameter_count_positive() {
use trustformers_core::traits::Model;
let cfg = tiny_config();
let model = MixtralForCausalLM::new(cfg).expect("model");
assert!(model.num_parameters() > 0, "model must have > 0 parameters");
}
#[test]
fn test_moe_router_logits_for_single_token() {
let config = tiny_config();
let moe = MixtralSparseMoeBlock::new(&config).expect("moe");
let h = config.hidden_size;
let hidden = Tensor::from_vec(vec![0.3f32; h], &[1, h]).expect("tensor");
let logits = moe.router_logits(&hidden).expect("router logits");
let shape = logits.shape().to_vec();
assert_eq!(
shape[1], config.num_local_experts,
"router width must be num_local_experts"
);
}
#[test]
fn test_mixtral_task_config_accessor() {
let config = tiny_config();
let task = MixtralCausalLMTask::new(config.clone()).expect("task");
assert_eq!(task.config().hidden_size, config.hidden_size);
assert_eq!(task.config().num_local_experts, config.num_local_experts);
assert_eq!(
task.config().num_experts_per_tok,
config.num_experts_per_tok
);
}
#[test]
fn test_mixtral_router_aux_loss_coef() {
let cfg = MixtralConfig::mixtral_8x7b();
assert!(
(cfg.router_aux_loss_coef - 0.02).abs() < 1e-6,
"default router_aux_loss_coef should be 0.02"
);
}
}