use super::types::MoeExpertWeights;
#[derive(Debug)]
pub struct MoeRouteResult {
pub indices: Vec<usize>,
pub weights: Vec<f32>,
}
pub fn moe_route(
hidden_state: &[f32],
gate_weight: &[f32],
num_experts: usize,
num_experts_per_tok: usize,
hidden_dim: usize,
) -> MoeRouteResult {
let mut logits = vec![0.0f32; num_experts];
for e in 0..num_experts {
let mut sum = 0.0f32;
let offset = e * hidden_dim;
for j in 0..hidden_dim {
sum += hidden_state[j] * gate_weight[offset + j];
}
logits[e] = sum;
}
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut exp_sum = 0.0f32;
let mut probs = vec![0.0f32; num_experts];
for (i, &logit) in logits.iter().enumerate() {
probs[i] = (logit - max_logit).exp();
exp_sum += probs[i];
}
for p in &mut probs {
*p /= exp_sum;
}
let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k: Vec<(usize, f32)> = indexed.into_iter().take(num_experts_per_tok).collect();
let weight_sum: f32 = top_k.iter().map(|(_, w)| w).sum();
let indices: Vec<usize> = top_k.iter().map(|(i, _)| *i).collect();
let weights: Vec<f32> = if weight_sum > 0.0 {
top_k.iter().map(|(_, w)| w / weight_sum).collect()
} else {
vec![1.0 / num_experts_per_tok as f32; num_experts_per_tok]
};
MoeRouteResult { indices, weights }
}
fn expert_swiglu(
x: &[f32],
gate_proj: &[f32],
up_proj: &[f32],
down_proj: &[f32],
hidden_dim: usize,
intermediate: usize,
) -> Vec<f32> {
let mut gate_out = vec![0.0f32; intermediate];
for i in 0..intermediate {
let offset = i * hidden_dim;
let mut sum = 0.0f32;
for j in 0..hidden_dim {
sum += gate_proj[offset + j] * x[j];
}
gate_out[i] = sum;
}
let mut up_out = vec![0.0f32; intermediate];
for i in 0..intermediate {
let offset = i * hidden_dim;
let mut sum = 0.0f32;
for j in 0..hidden_dim {
sum += up_proj[offset + j] * x[j];
}
up_out[i] = sum;
}
for i in 0..intermediate {
let silu = gate_out[i] / (1.0 + (-gate_out[i]).exp());
gate_out[i] = silu * up_out[i];
}
let mut output = vec![0.0f32; hidden_dim];
for i in 0..hidden_dim {
let offset = i * intermediate;
let mut sum = 0.0f32;
for j in 0..intermediate {
sum += down_proj[offset + j] * gate_out[j];
}
output[i] = sum;
}
output
}
pub fn moe_forward_token(
hidden_state: &[f32],
moe: &MoeExpertWeights,
hidden_dim: usize,
) -> Vec<f32> {
let intermediate = moe.expert_intermediate;
let num_experts = moe.num_experts;
let k = moe.num_experts_per_tok;
let route = moe_route(hidden_state, &moe.gate_weight, num_experts, k, hidden_dim);
let mut routed_out = vec![0.0f32; hidden_dim];
for (idx, &expert_id) in route.indices.iter().enumerate() {
let expert_offset = expert_id * 2 * intermediate * hidden_dim;
let gate_proj =
&moe.expert_gate_up[expert_offset..expert_offset + intermediate * hidden_dim];
let up_proj = &moe.expert_gate_up[expert_offset + intermediate * hidden_dim
..expert_offset + 2 * intermediate * hidden_dim];
let down_offset = expert_id * hidden_dim * intermediate;
let down_proj = &moe.expert_down[down_offset..down_offset + hidden_dim * intermediate];
let expert_out = expert_swiglu(
hidden_state,
gate_proj,
up_proj,
down_proj,
hidden_dim,
intermediate,
);
let w = route.weights[idx];
for i in 0..hidden_dim {
routed_out[i] += w * expert_out[i];
}
}
let shared_out = expert_swiglu(
hidden_state,
&moe.shared_gate,
&moe.shared_up,
&moe.shared_down,
hidden_dim,
intermediate,
);
if !moe.shared_expert_gate_weight.is_empty() {
let mut gate_logit = 0.0f32;
for j in 0..hidden_dim {
gate_logit += moe.shared_expert_gate_weight[j] * hidden_state[j];
}
let gate_scale = 1.0 / (1.0 + (-gate_logit).exp()); for i in 0..hidden_dim {
routed_out[i] += gate_scale * shared_out[i];
}
} else {
for i in 0..hidden_dim {
routed_out[i] += shared_out[i];
}
}
routed_out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_stability_large_logits() {
let hidden_dim = 4;
let num_experts = 8;
let hidden_state = vec![1.0f32; hidden_dim];
let mut gate_weight = vec![0.0f32; num_experts * hidden_dim];
for j in 0..hidden_dim {
gate_weight[0 * hidden_dim + j] = 250.0;
}
for j in 0..hidden_dim {
gate_weight[1 * hidden_dim + j] = 249.75;
}
let result = moe_route(&hidden_state, &gate_weight, num_experts, 2, hidden_dim);
for &w in &result.weights {
assert!(w.is_finite(), "weight is not finite: {w}");
}
assert_eq!(result.indices[0], 0);
assert_eq!(result.indices[1], 1);
let sum: f32 = result.weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "weights sum = {sum}");
}
#[test]
fn test_zero_gate_uniform_routing() {
let hidden_dim = 4;
let num_experts = 8;
let k = 4;
let hidden_state = vec![1.0f32; hidden_dim];
let gate_weight = vec![0.0f32; num_experts * hidden_dim];
let result = moe_route(&hidden_state, &gate_weight, num_experts, k, hidden_dim);
assert_eq!(result.indices.len(), k);
for &w in &result.weights {
assert!((w - 0.25).abs() < 1e-6, "expected 0.25, got {w}");
}
}
#[test]
fn test_uniform_routing_averages_experts() {
let hidden_dim = 4;
let intermediate = 2;
let num_experts = 4;
let k = 4;
let mut expert_gate_up = vec![0.1f32; num_experts * 2 * intermediate * hidden_dim];
let expert_down = vec![0.1f32; num_experts * hidden_dim * intermediate];
for i in 0..(2 * intermediate * hidden_dim) {
expert_gate_up[0 * 2 * intermediate * hidden_dim + i] = 0.2;
}
let moe = MoeExpertWeights {
gate_weight: vec![0.0f32; num_experts * hidden_dim], expert_gate_up,
expert_down,
shared_gate: vec![0.0f32; intermediate * hidden_dim], shared_up: vec![0.0f32; intermediate * hidden_dim],
shared_down: vec![0.0f32; hidden_dim * intermediate],
shared_expert_gate_weight: vec![],
num_experts,
num_experts_per_tok: k,
expert_intermediate: intermediate,
};
let x = vec![1.0f32; hidden_dim];
let output = moe_forward_token(&x, &moe, hidden_dim);
for &v in &output {
assert!(v.is_finite(), "output not finite: {v}");
}
}
#[test]
fn test_shared_expert_always_active() {
let hidden_dim = 4;
let intermediate = 2;
let num_experts = 4;
let k = 2;
let moe = MoeExpertWeights {
gate_weight: vec![0.0f32; num_experts * hidden_dim],
expert_gate_up: vec![0.0f32; num_experts * 2 * intermediate * hidden_dim],
expert_down: vec![0.0f32; num_experts * hidden_dim * intermediate],
shared_gate: vec![0.1f32; intermediate * hidden_dim],
shared_up: vec![0.1f32; intermediate * hidden_dim],
shared_down: vec![0.1f32; hidden_dim * intermediate],
shared_expert_gate_weight: vec![],
num_experts,
num_experts_per_tok: k,
expert_intermediate: intermediate,
};
let x = vec![1.0f32; hidden_dim];
let output = moe_forward_token(&x, &moe, hidden_dim);
let norm: f32 = output.iter().map(|v| v * v).sum();
assert!(norm > 0.0, "shared expert output should be non-zero");
}
#[test]
fn test_renorm_preserves_order() {
let hidden_dim = 4;
let num_experts = 8;
let k = 4;
let mut gate_weight = vec![0.0f32; num_experts * hidden_dim];
for e in 0..num_experts {
for j in 0..hidden_dim {
gate_weight[e * hidden_dim + j] = (num_experts - e) as f32;
}
}
let hidden_state = vec![1.0f32; hidden_dim];
let result = moe_route(&hidden_state, &gate_weight, num_experts, k, hidden_dim);
for i in 1..k {
assert!(
result.weights[i - 1] >= result.weights[i],
"weights not ordered: {} < {}",
result.weights[i - 1],
result.weights[i]
);
}
}
#[test]
fn test_shared_expert_gate_scales_output() {
let hidden_dim = 4;
let intermediate = 2;
let num_experts = 4;
let k = 2;
let base_moe = MoeExpertWeights {
gate_weight: vec![0.0f32; num_experts * hidden_dim],
expert_gate_up: vec![0.0f32; num_experts * 2 * intermediate * hidden_dim],
expert_down: vec![0.0f32; num_experts * hidden_dim * intermediate],
shared_gate: vec![0.1f32; intermediate * hidden_dim],
shared_up: vec![0.1f32; intermediate * hidden_dim],
shared_down: vec![0.1f32; hidden_dim * intermediate],
shared_expert_gate_weight: vec![], num_experts,
num_experts_per_tok: k,
expert_intermediate: intermediate,
};
let x = vec![1.0f32; hidden_dim];
let ungated = moe_forward_token(&x, &base_moe, hidden_dim);
let mut gated_moe = base_moe.clone();
gated_moe.shared_expert_gate_weight = vec![-10.0f32; hidden_dim];
let gated = moe_forward_token(&x, &gated_moe, hidden_dim);
let ungated_norm: f32 = ungated.iter().map(|v| v * v).sum::<f32>().sqrt();
let gated_norm: f32 = gated.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!(
gated_norm < ungated_norm * 0.1,
"shared expert gate should suppress output: ungated={ungated_norm}, gated={gated_norm}"
);
}
}