use crate::backend::Backend;
use crate::backend::cpu::simd;
use crate::tensor::{DType, Tensor};
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct MoeConfig {
pub num_experts: usize,
pub num_experts_per_token: usize,
pub expert_hidden_dim: usize,
pub num_shared_experts: usize,
pub aux_loss_coef: f32,
pub normalize_router_logits: bool,
}
impl Default for MoeConfig {
fn default() -> Self {
Self {
num_experts: 8,
num_experts_per_token: 2,
expert_hidden_dim: 14336,
num_shared_experts: 0,
aux_loss_coef: 0.01,
normalize_router_logits: true,
}
}
}
impl MoeConfig {
pub fn mixtral() -> Self {
Self {
num_experts: 8,
num_experts_per_token: 2,
expert_hidden_dim: 14336,
num_shared_experts: 0,
aux_loss_coef: 0.01,
normalize_router_logits: true,
}
}
pub fn deepseek(num_experts: usize, num_shared: usize) -> Self {
Self {
num_experts,
num_experts_per_token: 2,
expert_hidden_dim: 11008,
num_shared_experts: num_shared,
aux_loss_coef: 0.01,
normalize_router_logits: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ExpertSelection {
pub indices: Vec<Vec<usize>>,
pub weights: Vec<Vec<f32>>,
}
#[derive(Debug)]
pub struct MoeRouter {
pub weight: Tensor,
num_experts: usize,
top_k: usize,
normalize: bool,
}
impl MoeRouter {
pub fn new(hidden_dim: usize, num_experts: usize, top_k: usize, normalize: bool) -> Self {
let weight = Tensor::zeros(vec![num_experts, hidden_dim], DType::F32);
Self {
weight,
num_experts,
top_k,
normalize,
}
}
pub fn from_weight(weight: Tensor, top_k: usize, normalize: bool) -> Self {
let shape = weight.shape();
let num_experts = if shape.len() >= 2 { shape[1] } else { shape[0] };
Self {
weight,
num_experts,
top_k,
normalize,
}
}
pub fn route(
&self,
hidden_states: &Tensor,
) -> Result<ExpertSelection, crate::backend::BackendError> {
let h_data = hidden_states.as_f32()?;
let w_data = self.weight.as_f32()?;
let hidden_dim = self.weight.shape()[0];
let h_shape = hidden_states.shape();
let (batch_size, h_offset_stride) = if h_shape.len() == 1 {
(1, 0)
} else {
(h_shape[0], hidden_dim)
};
let mut all_indices = Vec::with_capacity(batch_size);
let mut all_weights = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let h_offset = b * h_offset_stride;
let h_slice = &h_data[h_offset..h_offset + hidden_dim];
let mut logits = vec![0.0f32; self.num_experts];
for e in 0..self.num_experts {
logits[e] = simd::dot_f32(h_slice, &w_data[e * hidden_dim..e * hidden_dim + hidden_dim]);
}
if self.normalize {
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
for l in &mut logits {
*l -= max_logit;
}
}
let mut indexed_logits: Vec<(usize, f32)> =
logits.iter().cloned().enumerate().collect();
indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k_indices: Vec<usize> = indexed_logits[..self.top_k]
.iter()
.map(|(i, _)| *i)
.collect();
let top_k_logits: Vec<f32> = indexed_logits[..self.top_k]
.iter()
.map(|(_, l)| *l)
.collect();
let max_val = top_k_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = top_k_logits.iter().map(|&l| (l - max_val).exp()).sum();
let weights: Vec<f32> = top_k_logits
.iter()
.map(|&l| (l - max_val).exp() / exp_sum)
.collect();
all_indices.push(top_k_indices);
all_weights.push(weights);
}
Ok(ExpertSelection {
indices: all_indices,
weights: all_weights,
})
}
}
#[derive(Debug)]
pub struct MoeExpert {
pub gate_proj: Tensor,
pub up_proj: Tensor,
pub down_proj: Tensor,
}
impl MoeExpert {
pub fn new(hidden_dim: usize, intermediate_dim: usize) -> Self {
Self {
gate_proj: Tensor::zeros(vec![hidden_dim, intermediate_dim], DType::F32),
up_proj: Tensor::zeros(vec![hidden_dim, intermediate_dim], DType::F32),
down_proj: Tensor::zeros(vec![intermediate_dim, hidden_dim], DType::F32),
}
}
pub fn forward(
&self,
x: &Tensor,
backend: &dyn Backend,
) -> Result<Tensor, crate::backend::BackendError> {
let intermediate_dim = self.gate_proj.shape()[1];
let hidden_dim = self.down_proj.shape()[1];
let mut gate_out = Tensor::zeros(vec![intermediate_dim], DType::F32);
if self.gate_proj.dtype().is_quantized() {
backend.vec_mat_q(x, &self.gate_proj, &mut gate_out)?;
} else {
backend.vec_mat(x, &self.gate_proj, &mut gate_out)?;
}
let mut gate_silu = Tensor::zeros(vec![intermediate_dim], DType::F32);
backend.silu(&gate_out, &mut gate_silu)?;
let mut up_out = Tensor::zeros(vec![intermediate_dim], DType::F32);
if self.up_proj.dtype().is_quantized() {
backend.vec_mat_q(x, &self.up_proj, &mut up_out)?;
} else {
backend.vec_mat(x, &self.up_proj, &mut up_out)?;
}
let mut intermediate = Tensor::zeros(vec![intermediate_dim], DType::F32);
backend.mul(&gate_silu, &up_out, &mut intermediate)?;
let mut output = Tensor::zeros(vec![hidden_dim], DType::F32);
if self.down_proj.dtype().is_quantized() {
backend.vec_mat_q(&intermediate, &self.down_proj, &mut output)?;
} else {
backend.vec_mat(&intermediate, &self.down_proj, &mut output)?;
}
Ok(output)
}
}
#[derive(Debug)]
pub struct MoeLayer {
config: MoeConfig,
pub router: MoeRouter,
pub experts: Vec<MoeExpert>,
pub shared_experts: Vec<MoeExpert>,
pub shared_expert_gate: Option<Tensor>,
}
impl MoeLayer {
pub fn new(hidden_dim: usize, config: MoeConfig) -> Self {
let router = MoeRouter::new(
hidden_dim,
config.num_experts,
config.num_experts_per_token,
config.normalize_router_logits,
);
let experts = (0..config.num_experts)
.map(|_| MoeExpert::new(hidden_dim, config.expert_hidden_dim))
.collect();
let shared_experts = (0..config.num_shared_experts)
.map(|_| MoeExpert::new(hidden_dim, config.expert_hidden_dim))
.collect();
Self {
config,
router,
experts,
shared_experts,
shared_expert_gate: None,
}
}
pub fn forward(
&self,
hidden_states: &Tensor,
backend: &dyn Backend,
) -> Result<Tensor, crate::backend::BackendError> {
let h_shape = hidden_states.shape();
let hidden_dim = *h_shape.last().unwrap_or(&0);
let selection = self.router.route(hidden_states)?;
let h_data = hidden_states.as_f32()?;
let batch_size = if h_shape.len() == 1 { 1 } else { h_shape[0] };
let mut output_data = vec![0.0f32; batch_size * hidden_dim];
for (b, (indices, weights)) in selection
.indices
.iter()
.zip(selection.weights.iter())
.enumerate()
{
let h_offset = b * hidden_dim;
let token_input = if h_shape.len() == 1 {
hidden_states.clone()
} else {
Tensor::from_f32(&h_data[h_offset..h_offset + hidden_dim], vec![hidden_dim])?
};
let expert_results: Vec<(Vec<f32>, f32)> = indices
.par_iter()
.zip(weights.par_iter())
.map(|(&expert_idx, &weight)| {
let out = self.experts[expert_idx]
.forward(&token_input, backend)
.expect("expert forward failed");
(out.as_f32().unwrap().to_vec(), weight)
})
.collect();
for (expert_data, weight) in &expert_results {
let out_slice = &mut output_data[b * hidden_dim..(b + 1) * hidden_dim];
for (o, &e) in out_slice.iter_mut().zip(expert_data.iter()) {
*o += weight * e;
}
}
if !self.shared_experts.is_empty() {
let gate_scale = if let Some(ref gate_w) = self.shared_expert_gate {
let gw = gate_w.as_f32()?;
let h_slice = if h_shape.len() == 1 {
h_data
} else {
&h_data[h_offset..h_offset + hidden_dim]
};
let len = hidden_dim.min(gw.len());
let dot = simd::dot_f32(&h_slice[..len], &gw[..len]);
1.0 / (1.0 + (-dot).exp())
} else {
1.0
};
let shared_results: Vec<Vec<f32>> = self.shared_experts
.par_iter()
.map(|expert| {
let out = expert.forward(&token_input, backend)
.expect("shared expert forward failed");
out.as_f32().unwrap().to_vec()
})
.collect();
for shared_data in &shared_results {
let out_slice = &mut output_data[b * hidden_dim..(b + 1) * hidden_dim];
for (o, &s) in out_slice.iter_mut().zip(shared_data.iter()) {
*o += gate_scale * s;
}
}
}
}
if h_shape.len() == 1 {
Ok(Tensor::from_f32(&output_data, vec![hidden_dim])?)
} else {
Ok(Tensor::from_f32(
&output_data,
vec![batch_size, hidden_dim],
)?)
}
}
pub fn num_experts(&self) -> usize {
self.config.num_experts
}
pub fn num_experts_per_token(&self) -> usize {
self.config.num_experts_per_token
}
}
#[derive(Debug, Clone, Default)]
pub struct MoeStats {
pub total_tokens: usize,
pub expert_counts: Vec<usize>,
pub expert_weights: Vec<f32>,
}
impl MoeStats {
pub fn new(num_experts: usize) -> Self {
Self {
total_tokens: 0,
expert_counts: vec![0; num_experts],
expert_weights: vec![0.0; num_experts],
}
}
pub fn record(&mut self, selection: &ExpertSelection) {
for (indices, weights) in selection.indices.iter().zip(selection.weights.iter()) {
self.total_tokens += 1;
for (&idx, &weight) in indices.iter().zip(weights.iter()) {
self.expert_counts[idx] += 1;
self.expert_weights[idx] += weight;
}
}
}
pub fn load_balance_factor(&self) -> f32 {
if self.total_tokens == 0 {
return 1.0;
}
let n = self.expert_counts.len() as f32;
let ideal = self.total_tokens as f32 / n;
let variance: f32 = self
.expert_counts
.iter()
.map(|&c| (c as f32 - ideal).powi(2))
.sum::<f32>()
/ n;
1.0 / (1.0 + variance / ideal.powi(2))
}
pub fn reset(&mut self) {
self.total_tokens = 0;
self.expert_counts.fill(0);
self.expert_weights.fill(0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::cpu::CpuBackend;
#[test]
fn test_moe_config_default() {
let config = MoeConfig::default();
assert_eq!(config.num_experts, 8);
assert_eq!(config.num_experts_per_token, 2);
}
#[test]
fn test_moe_config_mixtral() {
let config = MoeConfig::mixtral();
assert_eq!(config.num_experts, 8);
assert_eq!(config.num_experts_per_token, 2);
}
#[test]
fn test_moe_router() {
let router = MoeRouter::new(64, 4, 2, true);
let hidden = Tensor::from_f32(&vec![0.1f32; 64], vec![64]).unwrap();
let selection = router.route(&hidden).unwrap();
assert_eq!(selection.indices.len(), 1);
assert_eq!(selection.indices[0].len(), 2);
assert_eq!(selection.weights[0].len(), 2);
let weight_sum: f32 = selection.weights[0].iter().sum();
assert!((weight_sum - 1.0).abs() < 0.01);
}
#[test]
fn test_moe_expert() {
let backend = CpuBackend::new();
let expert = MoeExpert::new(64, 256);
let input = Tensor::from_f32(&vec![0.1f32; 64], vec![64]).unwrap();
let output = expert.forward(&input, &backend).unwrap();
assert_eq!(output.shape(), &[64]);
}
#[test]
fn test_moe_layer() {
let backend = CpuBackend::new();
let config = MoeConfig {
num_experts: 4,
num_experts_per_token: 2,
expert_hidden_dim: 128,
num_shared_experts: 0,
aux_loss_coef: 0.01,
normalize_router_logits: true,
};
let layer = MoeLayer::new(64, config);
let input = Tensor::from_f32(&vec![0.1f32; 64], vec![64]).unwrap();
let output = layer.forward(&input, &backend).unwrap();
assert_eq!(output.shape(), &[64]);
}
#[test]
fn test_moe_stats() {
let mut stats = MoeStats::new(4);
let selection = ExpertSelection {
indices: vec![vec![0, 1], vec![1, 2]],
weights: vec![vec![0.6, 0.4], vec![0.7, 0.3]],
};
stats.record(&selection);
assert_eq!(stats.total_tokens, 2);
assert_eq!(stats.expert_counts[0], 1);
assert_eq!(stats.expert_counts[1], 2);
assert_eq!(stats.expert_counts[2], 1);
assert_eq!(stats.expert_counts[3], 0);
}
}