use crate::mixtral::{config::MixtralConfig, model::MixtralForCausalLM};
use trustformers_core::{errors::Result, tensor::Tensor, traits::Model};
pub struct MixtralCausalLMTask {
model: MixtralForCausalLM,
}
impl MixtralCausalLMTask {
pub fn new(config: MixtralConfig) -> Result<Self> {
let model = MixtralForCausalLM::new(config)?;
Ok(Self { model })
}
pub fn forward(&self, input_ids: Vec<u32>) -> Result<Tensor> {
self.model.forward(input_ids)
}
pub fn model(&self) -> &MixtralForCausalLM {
&self.model
}
pub fn config(&self) -> &MixtralConfig {
self.model.get_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mixtral::config::MixtralConfig;
use trustformers_core::traits::Model;
fn small_config() -> MixtralConfig {
MixtralConfig {
hidden_size: 64,
intermediate_size: 128,
num_hidden_layers: 2,
num_attention_heads: 4,
num_key_value_heads: 2,
num_local_experts: 4,
num_experts_per_tok: 2,
sliding_window: None,
vocab_size: 256,
max_position_embeddings: 64,
rope_theta: 1_000_000.0,
rms_norm_eps: 1e-5,
hidden_act: "silu".to_string(),
router_aux_loss_coef: 0.02,
model_type: "mixtral".to_string(),
}
}
#[test]
fn test_task_construction() {
let result = MixtralCausalLMTask::new(small_config());
assert!(result.is_ok(), "MixtralCausalLMTask must construct");
}
#[test]
fn test_config_accessor() {
let cfg = small_config();
let vocab = cfg.vocab_size;
let task = MixtralCausalLMTask::new(cfg).unwrap_or_else(|_| panic!("init failed"));
assert_eq!(task.config().vocab_size, vocab);
}
#[test]
fn test_model_accessor_num_parameters() {
let task =
MixtralCausalLMTask::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
assert!(
task.model().num_parameters() > 0,
"model must have parameters"
);
}
#[test]
fn test_forward_nonempty_output() {
let task =
MixtralCausalLMTask::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let result = task.forward(vec![1u32, 2, 3]);
assert!(result.is_ok(), "forward must succeed");
}
#[test]
fn test_config_num_experts() {
let cfg = small_config();
assert_eq!(cfg.num_local_experts, 4);
assert_eq!(cfg.num_experts_per_tok, 2);
}
#[test]
fn test_config_model_type() {
let cfg = small_config();
assert_eq!(cfg.model_type, "mixtral");
}
#[test]
fn test_config_router_aux_loss_coef() {
let cfg = small_config();
assert!((cfg.router_aux_loss_coef - 0.02).abs() < 1e-6);
}
#[test]
fn test_forward_tensor_non_empty() {
let task =
MixtralCausalLMTask::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
if let Ok(Tensor::F32(arr)) = task.forward(vec![0u32, 1]).as_ref() {
assert!(!arr.is_empty(), "output tensor must be non-empty");
}
}
#[test]
fn test_forward_output_finite() {
let task =
MixtralCausalLMTask::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
if let Ok(Tensor::F32(arr)) = task.forward(vec![1u32]).as_ref() {
for &v in arr.iter() {
assert!(v.is_finite(), "logit {v} must be finite");
}
}
}
#[test]
fn test_forward_deterministic() {
let task =
MixtralCausalLMTask::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let ids = vec![1u32, 2, 3];
let r1 = task.forward(ids.clone());
let r2 = task.forward(ids);
if let (Ok(a), Ok(b)) = (r1, r2) {
if let (Tensor::F32(arr_a), Tensor::F32(arr_b)) = (&a, &b) {
let v1: Vec<f32> = arr_a.iter().copied().collect();
let v2: Vec<f32> = arr_b.iter().copied().collect();
assert_eq!(v1, v2, "forward must be deterministic");
}
}
}
#[test]
fn test_num_parameters_via_model() {
let task =
MixtralCausalLMTask::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let params = task.model().num_parameters();
assert!(params > 0, "model must have non-zero parameters");
}
#[test]
fn test_default_config_experts() {
let cfg = MixtralConfig::default();
assert_eq!(cfg.num_local_experts, 8);
assert_eq!(cfg.num_experts_per_tok, 2);
}
}