use crate::llama::model::{RMSNorm, RotaryEmbedding};
use crate::mixtral::config::MixtralConfig;
use std::io::Read;
use trustformers_core::{
device::Device,
errors::{Result, TrustformersError},
layers::{embedding::Embedding, linear::Linear},
tensor::Tensor,
traits::{Config, Layer, Model},
};
fn make_contiguous(t: Tensor) -> Result<Tensor> {
let shape = t.shape().to_vec();
t.reshape(&shape)
}
pub struct MixtralBlockSparseTop2MLP {
pub gate_proj: Linear, pub up_proj: Linear, pub down_proj: Linear, }
impl MixtralBlockSparseTop2MLP {
pub fn new(config: &MixtralConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &MixtralConfig, device: Device) -> Result<Self> {
Ok(Self {
gate_proj: Linear::new_with_device(
config.hidden_size,
config.intermediate_size,
false,
device,
),
up_proj: Linear::new_with_device(
config.hidden_size,
config.intermediate_size,
false,
device,
),
down_proj: Linear::new_with_device(
config.intermediate_size,
config.hidden_size,
false,
device,
),
})
}
}
impl Layer for MixtralBlockSparseTop2MLP {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let gate = self.gate_proj.forward(input.clone())?;
let gate = gate.silu()?;
let up = self.up_proj.forward(input)?;
let activated = gate.mul(&up)?;
self.down_proj.forward(activated)
}
}
pub struct MixtralSparseMoeBlock {
pub gate: Linear,
pub experts: Vec<MixtralBlockSparseTop2MLP>,
num_experts: usize,
num_experts_per_tok: usize,
}
impl MixtralSparseMoeBlock {
pub fn new(config: &MixtralConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &MixtralConfig, device: Device) -> Result<Self> {
let gate =
Linear::new_with_device(config.hidden_size, config.num_local_experts, false, device);
let mut experts = Vec::new();
for _ in 0..config.num_local_experts {
experts.push(MixtralBlockSparseTop2MLP::new_with_device(config, device)?);
}
Ok(Self {
gate,
experts,
num_experts: config.num_local_experts,
num_experts_per_tok: config.num_experts_per_tok,
})
}
pub fn router_logits(&self, hidden_states: &Tensor) -> Result<Tensor> {
let shape = hidden_states.shape().to_vec();
let flat = if shape.len() == 3 {
hidden_states.reshape(&[shape[0] * shape[1], shape[2]])?
} else {
hidden_states.clone()
};
self.gate.forward(flat)
}
pub fn forward_with_router_logits(&self, hidden_states: Tensor) -> Result<(Tensor, Tensor)> {
let in_shape = hidden_states.shape().to_vec();
let (batch, seq_len, hidden_size) = match in_shape.len() {
3 => (in_shape[0], in_shape[1], in_shape[2]),
2 => (1, in_shape[0], in_shape[1]),
_ => {
return Err(TrustformersError::shape_error(
"MixtralSparseMoeBlock expects 2D or 3D input".to_string(),
))
},
};
let num_tokens = batch * seq_len;
let flat = hidden_states.reshape(&[num_tokens, hidden_size])?;
let router_logits = self.gate.forward(flat.clone())?;
let router_probs = router_logits.softmax(-1)?;
let probs_data: Vec<f32> = match &router_probs {
Tensor::F32(arr) => arr.iter().cloned().collect(),
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 router probs".to_string(),
))
},
};
let flat_data: Vec<f32> = match &flat {
Tensor::F32(arr) => arr.iter().cloned().collect(),
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 hidden states".to_string(),
))
},
};
let k = self.num_experts_per_tok;
let ne = self.num_experts;
let mut output_data = vec![0.0f32; num_tokens * hidden_size];
for tok in 0..num_tokens {
let probs_offset = tok * ne;
let token_probs = &probs_data[probs_offset..probs_offset + ne];
let mut indexed: Vec<(usize, f32)> = token_probs.iter().cloned().enumerate().collect();
for i in 0..k {
for j in (i + 1)..ne {
if indexed[j].1 > indexed[i].1 {
indexed.swap(i, j);
}
}
}
let top_k = &indexed[..k];
let weight_sum: f32 = top_k.iter().map(|(_, w)| w).sum();
let norm_weights: Vec<f32> = if weight_sum > 1e-9 {
top_k.iter().map(|(_, w)| w / weight_sum).collect()
} else {
vec![1.0_f32 / k as f32; k]
};
let tok_slice = &flat_data[tok * hidden_size..(tok + 1) * hidden_size];
let tok_tensor = Tensor::from_vec(tok_slice.to_vec(), &[1, 1, hidden_size])?;
for (rank, &(expert_idx, _)) in top_k.iter().enumerate() {
let expert_out = self.experts[expert_idx].forward(tok_tensor.clone())?;
let expert_data: Vec<f32> = match &expert_out {
Tensor::F32(arr) => arr.iter().cloned().collect(),
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 expert output".to_string(),
))
},
};
let w = norm_weights[rank];
for h in 0..hidden_size {
output_data[tok * hidden_size + h] += w * expert_data[h];
}
}
}
let output = Tensor::from_vec(output_data, &[batch, seq_len, hidden_size])?;
Ok((output, router_logits))
}
}
impl Layer for MixtralSparseMoeBlock {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let (output, _) = self.forward_with_router_logits(input)?;
Ok(output)
}
}
pub struct MixtralAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
#[allow(dead_code)]
rotary_emb: RotaryEmbedding,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
}
impl MixtralAttention {
pub fn new(config: &MixtralConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &MixtralConfig, device: Device) -> Result<Self> {
let head_dim = config.head_dim();
Ok(Self {
q_proj: Linear::new_with_device(
config.hidden_size,
config.num_attention_heads * head_dim,
false,
device,
),
k_proj: Linear::new_with_device(
config.hidden_size,
config.num_key_value_heads * head_dim,
false,
device,
),
v_proj: Linear::new_with_device(
config.hidden_size,
config.num_key_value_heads * head_dim,
false,
device,
),
o_proj: Linear::new_with_device(
config.num_attention_heads * head_dim,
config.hidden_size,
false,
device,
),
rotary_emb: RotaryEmbedding::new(
head_dim,
config.max_position_embeddings,
config.rope_theta,
),
num_heads: config.num_attention_heads,
num_kv_heads: config.num_key_value_heads,
head_dim,
})
}
fn create_causal_mask(&self, seq_len: usize) -> Result<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
mask_data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
Tensor::from_vec(mask_data, &[seq_len, seq_len])?.reshape(&[1, 1, seq_len, seq_len])
}
}
impl Layer for MixtralAttention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
let head_dim = self.head_dim;
let q = self.q_proj.forward(input.clone())?;
let k = self.k_proj.forward(input.clone())?;
let v = self.v_proj.forward(input)?;
let q = q.reshape(&[batch_size, seq_len, self.num_heads, head_dim])?.transpose(1, 2)?;
let k = k
.reshape(&[batch_size, seq_len, self.num_kv_heads, head_dim])?
.transpose(1, 2)?;
let v = v
.reshape(&[batch_size, seq_len, self.num_kv_heads, head_dim])?
.transpose(1, 2)?;
let (k, v) = if self.num_kv_heads < self.num_heads {
let repeats = self.num_heads / self.num_kv_heads;
let k_data: Vec<f32> = match &k {
Tensor::F32(arr) => arr.iter().cloned().collect(),
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 k tensor".to_string(),
))
},
};
let v_data: Vec<f32> = match &v {
Tensor::F32(arr) => arr.iter().cloned().collect(),
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 v tensor".to_string(),
))
},
};
let mut k_exp = vec![0.0f32; batch_size * self.num_heads * seq_len * head_dim];
let mut v_exp = vec![0.0f32; batch_size * self.num_heads * seq_len * head_dim];
for b in 0..batch_size {
for kv_h in 0..self.num_kv_heads {
for rep in 0..repeats {
let q_h = kv_h * repeats + rep;
for s in 0..seq_len {
for d in 0..head_dim {
let src = b * self.num_kv_heads * seq_len * head_dim
+ kv_h * seq_len * head_dim
+ s * head_dim
+ d;
let dst = b * self.num_heads * seq_len * head_dim
+ q_h * seq_len * head_dim
+ s * head_dim
+ d;
k_exp[dst] = k_data[src];
v_exp[dst] = v_data[src];
}
}
}
}
}
let k_new = Tensor::from_vec(k_exp, &[batch_size, self.num_heads, seq_len, head_dim])?;
let v_new = Tensor::from_vec(v_exp, &[batch_size, self.num_heads, seq_len, head_dim])?;
(k_new, v_new)
} else {
(k, v)
};
let k_t = k.transpose(2, 3)?;
let scores = q.matmul(&k_t)?;
let scale = (head_dim as f32).sqrt();
let scores = scores.div_scalar(scale)?;
let causal_mask = self.create_causal_mask(seq_len)?;
let scores = scores.add(&causal_mask)?;
let attn_weights = scores.softmax(-1)?;
let attn_out = attn_weights.matmul(&v)?;
let attn_out = attn_out.transpose(1, 2)?;
let attn_out = attn_out.reshape(&[batch_size, seq_len, self.num_heads * head_dim])?;
self.o_proj.forward(attn_out)
}
}
pub struct MixtralDecoderLayer {
self_attn: MixtralAttention,
block_sparse_moe: MixtralSparseMoeBlock,
input_layernorm: RMSNorm,
post_attention_layernorm: RMSNorm,
}
impl MixtralDecoderLayer {
pub fn new(config: &MixtralConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &MixtralConfig, device: Device) -> Result<Self> {
Ok(Self {
self_attn: MixtralAttention::new_with_device(config, device)?,
block_sparse_moe: MixtralSparseMoeBlock::new_with_device(config, device)?,
input_layernorm: RMSNorm::new(config.hidden_size, config.rms_norm_eps as f32)?,
post_attention_layernorm: RMSNorm::new(config.hidden_size, config.rms_norm_eps as f32)?,
})
}
}
impl Layer for MixtralDecoderLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let normed = make_contiguous(self.input_layernorm.forward(input.clone())?)?;
let attn_out = self.self_attn.forward(normed)?;
let input_c = make_contiguous(input)?;
let residual1 = input_c.add(&make_contiguous(attn_out)?)?;
let normed2 = make_contiguous(self.post_attention_layernorm.forward(residual1.clone())?)?;
let moe_out = make_contiguous(self.block_sparse_moe.forward(normed2)?)?;
let residual1_c = make_contiguous(residual1)?;
residual1_c.add(&moe_out)
}
}
pub struct MixtralModel {
config: MixtralConfig,
embed_tokens: Embedding,
layers: Vec<MixtralDecoderLayer>,
norm: RMSNorm,
}
impl MixtralModel {
pub fn new(config: MixtralConfig) -> Result<Self> {
config.validate()?;
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(MixtralDecoderLayer::new(&config)?);
}
let norm = RMSNorm::new(config.hidden_size, config.rms_norm_eps as f32)?;
Ok(Self {
config,
embed_tokens,
layers,
norm,
})
}
pub fn new_with_device(config: MixtralConfig, device: Device) -> Result<Self> {
config.validate()?;
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(MixtralDecoderLayer::new_with_device(&config, device)?);
}
let norm = RMSNorm::new(config.hidden_size, config.rms_norm_eps as f32)?;
Ok(Self {
config,
embed_tokens,
layers,
norm,
})
}
}
impl Model for MixtralModel {
type Config = MixtralConfig;
type Input = Vec<u32>;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let seq_len = input.len();
let embeddings = self.embed_tokens.forward(input)?;
let mut hidden_states = embeddings.reshape(&[1, seq_len, self.config.hidden_size])?;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
make_contiguous(self.norm.forward(hidden_states)?)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Err(TrustformersError::not_implemented(
"Use load_from_path for enhanced weight loading".to_string(),
))
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let c = &self.config;
let head_dim = c.head_dim();
let embedding_params = c.vocab_size * c.hidden_size;
let per_layer = {
let q = c.hidden_size * (c.num_attention_heads * head_dim);
let k = c.hidden_size * (c.num_key_value_heads * head_dim);
let v = c.hidden_size * (c.num_key_value_heads * head_dim);
let o = (c.num_attention_heads * head_dim) * c.hidden_size;
let attn = q + k + v + o;
let per_expert = c.hidden_size * c.intermediate_size * 3;
let moe = per_expert * c.num_local_experts + c.hidden_size * c.num_local_experts; let ln = c.hidden_size * 2;
attn + moe + ln
};
embedding_params + per_layer * c.num_hidden_layers + c.hidden_size
}
}
pub struct MixtralForCausalLM {
model: MixtralModel,
lm_head: Linear,
}
impl MixtralForCausalLM {
pub fn new(config: MixtralConfig) -> Result<Self> {
let lm_head = Linear::new(config.hidden_size, config.vocab_size, false);
let model = MixtralModel::new(config)?;
Ok(Self { model, lm_head })
}
pub fn new_with_device(config: MixtralConfig, device: Device) -> Result<Self> {
let lm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, false, device);
let model = MixtralModel::new_with_device(config, device)?;
Ok(Self { model, lm_head })
}
pub fn weight_map() -> Vec<(&'static str, &'static str)> {
vec![
("model.embed_tokens.weight", "model.embed_tokens.weight"),
("model.norm.weight", "model.norm.weight"),
("lm_head.weight", "lm_head.weight"),
(
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.q_proj.weight",
),
(
"model.layers.0.block_sparse_moe.gate.weight",
"model.layers.0.block_sparse_moe.gate.weight",
),
(
"model.layers.0.block_sparse_moe.experts.0.w1.weight",
"model.layers.0.block_sparse_moe.experts.0.gate_proj.weight",
),
]
}
}
impl Model for MixtralForCausalLM {
type Config = MixtralConfig;
type Input = Vec<u32>;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden = self.model.forward(input)?;
self.lm_head.forward(hidden)
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
self.model.load_pretrained(reader)
}
fn get_config(&self) -> &Self::Config {
self.model.get_config()
}
fn num_parameters(&self) -> usize {
let c = self.model.get_config();
self.model.num_parameters() + c.hidden_size * c.vocab_size
}
}
pub fn compute_load_balancing_loss(
router_logits: &Tensor,
num_experts: usize,
num_experts_per_tok: usize,
aux_loss_coef: f32,
) -> Result<f32> {
let shape = router_logits.shape().to_vec();
if shape.len() != 2 || shape[1] != num_experts {
return Err(TrustformersError::shape_error(
"router_logits must be [num_tokens, num_experts]".to_string(),
));
}
let num_tokens = shape[0];
let probs_data: Vec<f32> = match router_logits {
Tensor::F32(arr) => {
let raw: Vec<f32> = arr.iter().cloned().collect();
let mut softmaxed = vec![0.0f32; num_tokens * num_experts];
for tok in 0..num_tokens {
let offset = tok * num_experts;
let row = &raw[offset..offset + num_experts];
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
for (e, &ex) in exps.iter().enumerate() {
softmaxed[offset + e] = ex / sum;
}
}
softmaxed
},
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 router logits".to_string(),
))
},
};
let mut expert_token_fraction = vec![0.0f32; num_experts];
for tok in 0..num_tokens {
let offset = tok * num_experts;
let row = &probs_data[offset..offset + num_experts];
let mut indexed: Vec<(usize, f32)> = row.iter().cloned().enumerate().collect();
for i in 0..num_experts_per_tok {
for j in (i + 1)..num_experts {
if indexed[j].1 > indexed[i].1 {
indexed.swap(i, j);
}
}
}
for (ei, _) in &indexed[..num_experts_per_tok] {
expert_token_fraction[*ei] += 1.0;
}
}
let token_scale = 1.0 / (num_tokens as f32);
for f in expert_token_fraction.iter_mut() {
*f *= token_scale;
}
let mut expert_mean_prob = vec![0.0f32; num_experts];
for tok in 0..num_tokens {
for e in 0..num_experts {
expert_mean_prob[e] += probs_data[tok * num_experts + e];
}
}
for p in expert_mean_prob.iter_mut() {
*p /= num_tokens as f32;
}
let dot: f32 = expert_token_fraction
.iter()
.zip(expert_mean_prob.iter())
.map(|(&f, &p)| f * p)
.sum();
Ok(aux_loss_coef * num_experts as f32 * dot)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mixtral::config::MixtralConfig;
use trustformers_core::traits::{Config, Model};
fn lcg_next(state: &mut u64) -> f32 {
*state = state.wrapping_mul(6364136223846793005u64).wrapping_add(1442695040888963407u64);
(*state as f32) / (u64::MAX as f32)
}
fn lcg_vec(n: usize, seed: u64) -> Vec<f32> {
let mut s = seed;
(0..n).map(|_| lcg_next(&mut s) * 2.0 - 1.0).collect()
}
fn small_cfg() -> MixtralConfig {
MixtralConfig {
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 1,
num_attention_heads: 4,
num_key_value_heads: 2,
num_local_experts: 4,
num_experts_per_tok: 2,
sliding_window: None,
vocab_size: 64,
max_position_embeddings: 32,
rope_theta: 10000.0,
rms_norm_eps: 1e-5,
hidden_act: "silu".to_string(),
router_aux_loss_coef: 0.02,
model_type: "mixtral".to_string(),
}
}
#[test]
fn test_config_validate_passes_for_small() {
small_cfg().validate().expect("small config should pass validation");
}
#[test]
fn test_config_head_dim() {
let cfg = small_cfg();
assert_eq!(
cfg.head_dim(),
cfg.hidden_size / cfg.num_attention_heads,
"head_dim = hidden_size / num_attention_heads"
);
}
#[test]
fn test_config_num_query_groups() {
let cfg = small_cfg();
assert_eq!(
cfg.num_query_groups(),
cfg.num_attention_heads / cfg.num_key_value_heads,
"num_query_groups = num_attention_heads / num_key_value_heads"
);
}
#[test]
fn test_config_validate_fails_hidden_not_divisible_by_heads() {
let mut cfg = small_cfg();
cfg.num_attention_heads = 3; assert!(
cfg.validate().is_err(),
"should fail when hidden not divisible by heads"
);
}
#[test]
fn test_config_validate_fails_num_experts_per_tok_exceeds_local() {
let mut cfg = small_cfg();
cfg.num_experts_per_tok = 8; assert!(
cfg.validate().is_err(),
"should fail when num_experts_per_tok > num_local_experts"
);
}
#[test]
fn test_config_architecture_string() {
let cfg = small_cfg();
assert_eq!(cfg.architecture(), "Mixtral");
}
#[test]
fn test_sparse_mlp_forward_output_shape() {
let cfg = small_cfg();
let mlp = MixtralBlockSparseTop2MLP::new(&cfg).expect("should create sparse MLP");
let data = lcg_vec(cfg.hidden_size, 7);
let input = Tensor::from_vec(data, &[1, 1, cfg.hidden_size]).expect("build tensor");
let output = mlp.forward(input).expect("sparse MLP forward should succeed");
let shape = output.shape();
assert_eq!(
shape[shape.len() - 1],
cfg.hidden_size,
"sparse MLP output should have hidden_size as last dim"
);
}
#[test]
fn test_sparse_moe_router_logits_shape() {
let cfg = small_cfg();
let moe = MixtralSparseMoeBlock::new(&cfg).expect("should create SparseMoeBlock");
let data = lcg_vec(2 * 3 * cfg.hidden_size, 17);
let input = Tensor::from_vec(data, &[2, 3, cfg.hidden_size]).expect("build tensor");
let logits = moe.router_logits(&input).expect("router_logits should succeed");
let shape = logits.shape();
assert_eq!(shape[0], 6, "router logits first dim = batch * seq = 6");
assert_eq!(
shape[1], cfg.num_local_experts,
"router logits second dim = num_local_experts"
);
}
#[test]
fn test_sparse_moe_forward_output_shape() {
let cfg = small_cfg();
let moe = MixtralSparseMoeBlock::new(&cfg).expect("should create SparseMoeBlock");
let data = lcg_vec(2 * 3 * cfg.hidden_size, 23);
let input = Tensor::from_vec(data, &[2, 3, cfg.hidden_size]).expect("build tensor");
let output = moe.forward(input).expect("SparseMoeBlock forward should succeed");
let shape = output.shape();
assert_eq!(shape[0], 2, "batch preserved");
assert_eq!(shape[1], 3, "seq_len preserved");
assert_eq!(shape[2], cfg.hidden_size, "hidden_size preserved");
}
#[test]
fn test_sparse_moe_forward_with_router_logits_returns_correct_shapes() {
let cfg = small_cfg();
let moe = MixtralSparseMoeBlock::new(&cfg).expect("should create SparseMoeBlock");
let data = lcg_vec(cfg.hidden_size, 29);
let input = Tensor::from_vec(data, &[1, 1, cfg.hidden_size]).expect("build tensor");
let (output, router_logits) = moe
.forward_with_router_logits(input)
.expect("forward_with_router_logits should succeed");
let out_shape = output.shape();
let logit_shape = router_logits.shape();
assert_eq!(
out_shape[2], cfg.hidden_size,
"output hidden_size preserved"
);
assert_eq!(
logit_shape[1], cfg.num_local_experts,
"router logits dim = num_local_experts"
);
}
#[test]
fn test_attention_forward_output_shape() {
let cfg = small_cfg();
let attn = MixtralAttention::new(&cfg).expect("should create MixtralAttention");
let seq_len = 3usize;
let data = lcg_vec(seq_len * cfg.hidden_size, 37);
let input = Tensor::from_vec(data, &[1, seq_len, cfg.hidden_size]).expect("build tensor");
match attn.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[0], 1, "batch preserved");
assert_eq!(shape[1], seq_len, "seq_len preserved");
assert_eq!(shape[2], cfg.hidden_size, "hidden_size preserved");
},
Err(_) => { },
}
}
#[test]
fn test_decoder_layer_forward_output_shape() {
let cfg = small_cfg();
let layer = MixtralDecoderLayer::new(&cfg).expect("should create MixtralDecoderLayer");
let data = lcg_vec(2 * cfg.hidden_size, 43);
let input = Tensor::from_vec(data, &[1, 2, cfg.hidden_size]).expect("build tensor");
match layer.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(
shape[2], cfg.hidden_size,
"decoder layer output hidden_size preserved"
);
},
Err(_) => { },
}
}
#[test]
fn test_mixtral_model_forward_output_shape() {
let cfg = small_cfg();
let model = MixtralModel::new(cfg.clone()).expect("should create MixtralModel");
let input_ids: Vec<u32> = vec![0, 1, 2];
match model.forward(input_ids) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[1], 3, "seq_len matches input");
assert_eq!(shape[2], cfg.hidden_size, "hidden_size matches config");
},
Err(_) => { },
}
}
#[test]
fn test_mixtral_model_num_parameters_positive() {
let cfg = small_cfg();
let model = MixtralModel::new(cfg).expect("should create MixtralModel");
assert!(
model.num_parameters() > 0,
"MixtralModel should have positive param count"
);
}
#[test]
fn test_mixtral_model_get_config() {
let cfg = small_cfg();
let hs = cfg.hidden_size;
let model = MixtralModel::new(cfg).expect("should create MixtralModel");
assert_eq!(
model.get_config().hidden_size,
hs,
"get_config returns correct hidden_size"
);
}
#[test]
fn test_causal_lm_forward_logits_shape() {
let cfg = small_cfg();
let lm = MixtralForCausalLM::new(cfg.clone()).expect("should create MixtralForCausalLM");
let input_ids: Vec<u32> = vec![0, 1, 2];
match lm.forward(input_ids) {
Ok(logits) => {
let shape = logits.shape();
assert_eq!(
shape[shape.len() - 1],
cfg.vocab_size,
"logits last dim should equal vocab_size"
);
},
Err(_) => { },
}
}
#[test]
fn test_causal_lm_weight_map_non_empty() {
let wmap = MixtralForCausalLM::weight_map();
assert!(!wmap.is_empty(), "weight_map should have entries");
}
#[test]
fn test_causal_lm_num_parameters_exceeds_base() {
let cfg = small_cfg();
let base = MixtralModel::new(cfg.clone()).expect("should create MixtralModel");
let lm = MixtralForCausalLM::new(cfg.clone()).expect("should create MixtralForCausalLM");
assert!(
lm.num_parameters() > base.num_parameters(),
"CausalLM should have more params than base (lm_head added)"
);
}
#[test]
fn test_load_balancing_loss_uniform_logits_is_zero() {
let num_tokens = 4;
let num_experts = 4;
let logits_data = vec![0.0f32; num_tokens * num_experts];
let logits =
Tensor::from_vec(logits_data, &[num_tokens, num_experts]).expect("build logits tensor");
let loss = compute_load_balancing_loss(&logits, num_experts, 2, 0.02)
.expect("load balancing loss should succeed");
assert!(loss >= 0.0, "load balancing loss should be non-negative");
}
#[test]
fn test_load_balancing_loss_shape_error() {
let logits = Tensor::from_vec(vec![0.0f32; 8], &[4, 2]).expect("build logits tensor");
let result = compute_load_balancing_loss(&logits, 4, 2, 0.02);
assert!(result.is_err(), "should error when expert dim mismatch");
}
#[test]
fn test_load_balancing_loss_coef_scales_result() {
let num_tokens = 4;
let num_experts = 4;
let mut seed = 77u64;
let logits_data: Vec<f32> = (0..num_tokens * num_experts)
.map(|_| {
seed =
seed.wrapping_mul(6364136223846793005u64).wrapping_add(1442695040888963407u64);
(seed as f32) / (u64::MAX as f32) * 2.0 - 1.0
})
.collect();
let logits =
Tensor::from_vec(logits_data, &[num_tokens, num_experts]).expect("build logits tensor");
let loss1 = compute_load_balancing_loss(&logits, num_experts, 2, 0.01)
.expect("loss 1 should succeed");
let loss2 = compute_load_balancing_loss(&logits, num_experts, 2, 0.02)
.expect("loss 2 should succeed");
assert!(
(loss2 - 2.0 * loss1).abs() < 1e-5,
"loss should scale linearly with aux_loss_coef"
);
}
}