#[derive(Debug, Clone, PartialEq)]
pub struct RouterOutput {
pub expert_ids: Vec<u32>,
pub expert_weights: Vec<f32>,
}
impl RouterOutput {
pub fn empty() -> Self {
Self {
expert_ids: Vec::new(),
expert_weights: Vec::new(),
}
}
pub fn reset(&mut self, batch: usize, top_k: usize) {
let n = batch * top_k;
self.expert_ids.clear();
self.expert_ids.resize(n, 0);
self.expert_weights.clear();
self.expert_weights.resize(n, 0.0);
}
pub fn batch(&self) -> usize {
self.expert_ids.len() / self.batch_top_k_pair_count()
}
fn batch_top_k_pair_count(&self) -> usize {
self.expert_ids.len().max(1)
}
}
pub fn route(
logits: &[f32],
batch: usize,
num_experts: usize,
top_k: usize,
norm_topk_prob: bool,
) -> RouterOutput {
let mut out = RouterOutput::empty();
let mut scratch = Vec::new();
route_into(
logits,
batch,
num_experts,
top_k,
norm_topk_prob,
&mut out,
&mut scratch,
);
out
}
pub fn route_into(
logits: &[f32],
batch: usize,
num_experts: usize,
top_k: usize,
norm_topk_prob: bool,
out: &mut RouterOutput,
scratch_probs: &mut Vec<f32>,
) {
assert_eq!(
logits.len(),
batch * num_experts,
"router logits shape mismatch: expected {batch}×{num_experts}, got {}",
logits.len()
);
assert!(top_k > 0, "top_k must be > 0");
assert!(
top_k <= num_experts,
"top_k {top_k} exceeds num_experts {num_experts}"
);
out.reset(batch, top_k);
scratch_probs.clear();
scratch_probs.resize(num_experts, 0.0);
for b in 0..batch {
let row = &logits[b * num_experts..(b + 1) * num_experts];
let mut max = f32::NEG_INFINITY;
for &v in row {
if v > max {
max = v;
}
}
let mut sum = 0.0f32;
for (i, &v) in row.iter().enumerate() {
let e = (v - max).exp();
scratch_probs[i] = e;
sum += e;
}
let inv_sum = 1.0 / sum;
for v in scratch_probs.iter_mut() {
*v *= inv_sum;
}
let mut sel_sum = 0.0f32;
let dst_lo = b * top_k;
for k in 0..top_k {
let mut best = f32::NEG_INFINITY;
let mut best_idx = 0usize;
for (i, &v) in scratch_probs.iter().enumerate() {
if v > best {
best = v;
best_idx = i;
}
}
out.expert_ids[dst_lo + k] = best_idx as u32;
out.expert_weights[dst_lo + k] = best;
sel_sum += best;
scratch_probs[best_idx] = f32::NEG_INFINITY;
}
if norm_topk_prob {
if sel_sum > 0.0 {
let scale = 1.0 / sel_sum;
for w in &mut out.expert_weights[dst_lo..dst_lo + top_k] {
*w *= scale;
}
} else {
let uniform = 1.0 / top_k as f32;
for w in &mut out.expert_weights[dst_lo..dst_lo + top_k] {
*w = uniform;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
fn run_parity(batch: usize, num_experts: usize, top_k: usize, norm: bool, seed: u64) {
let mut rng = StdRng::seed_from_u64(seed);
let logits: Vec<f32> = (0..batch * num_experts)
.map(|_| rng.gen_range(-3.0..3.0_f32))
.collect();
let a = route(&logits, batch, num_experts, top_k, norm);
let mut b = RouterOutput::empty();
let mut probs = Vec::new();
route_into(&logits, batch, num_experts, top_k, norm, &mut b, &mut probs);
assert_eq!(a.expert_ids, b.expert_ids, "expert_ids mismatch");
for (i, (&aw, &bw)) in a.expert_weights.iter().zip(&b.expert_weights).enumerate() {
let delta = (aw - bw).abs();
assert!(
delta < 1e-6,
"weight[{i}] mismatch: route={aw} route_into={bw} delta={delta}"
);
}
}
#[test]
fn parity_qwen3_moe_shape() {
run_parity(32, 128, 8, true, 0xDEADBEEF);
run_parity(1, 128, 8, true, 0x1234);
run_parity(64, 128, 8, true, 0x5678);
}
#[test]
fn parity_no_renorm() {
run_parity(8, 64, 4, false, 0xC0FFEE);
}
#[test]
fn parity_topk_one() {
run_parity(4, 16, 1, true, 0x42);
run_parity(4, 16, 1, false, 0x42);
}
#[test]
fn allocation_free_after_warmup() {
let mut out = RouterOutput::empty();
let mut probs = Vec::new();
let logits = vec![0.5f32; 32 * 128];
route_into(&logits, 32, 128, 8, true, &mut out, &mut probs);
let cap_ids = out.expert_ids.capacity();
let cap_w = out.expert_weights.capacity();
let cap_p = probs.capacity();
for _ in 0..16 {
route_into(&logits, 32, 128, 8, true, &mut out, &mut probs);
assert_eq!(out.expert_ids.capacity(), cap_ids);
assert_eq!(out.expert_weights.capacity(), cap_w);
assert_eq!(probs.capacity(), cap_p);
}
}
}