#[derive(Debug, Clone, PartialEq)]
pub struct RouterOutput {
pub expert_ids: Vec<u32>,
pub expert_weights: Vec<f32>,
}
impl RouterOutput {
pub fn batch(&self) -> usize {
self.expert_ids.len() / self.batch_top_k_pair_count()
}
fn batch_top_k_pair_count(&self) -> usize {
self.expert_ids.len().max(1)
}
}
pub fn route(
logits: &[f32],
batch: usize,
num_experts: usize,
top_k: usize,
norm_topk_prob: bool,
) -> RouterOutput {
assert_eq!(
logits.len(),
batch * num_experts,
"router logits shape mismatch: expected {batch}×{num_experts}, got {}",
logits.len()
);
assert!(top_k > 0, "top_k must be > 0");
assert!(
top_k <= num_experts,
"top_k {top_k} exceeds num_experts {num_experts}"
);
let mut expert_ids = Vec::with_capacity(batch * top_k);
let mut expert_weights = Vec::with_capacity(batch * top_k);
for b in 0..batch {
let row = &logits[b * num_experts..(b + 1) * num_experts];
let probs = softmax(row);
let topk = top_k_indices(&probs, top_k);
let combine_weights = if norm_topk_prob {
renormalise(&topk, &probs)
} else {
topk.iter().map(|&i| probs[i]).collect::<Vec<_>>()
};
for (i, &exp_id) in topk.iter().enumerate() {
expert_ids.push(exp_id as u32);
expert_weights.push(combine_weights[i]);
}
}
RouterOutput {
expert_ids,
expert_weights,
}
}
fn softmax(row: &[f32]) -> Vec<f32> {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut exp: Vec<f32> = row.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exp.iter().sum();
for v in &mut exp {
*v /= sum;
}
exp
}
fn top_k_indices(probs: &[f32], top_k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
indexed.truncate(top_k);
indexed.into_iter().map(|(i, _)| i).collect()
}
fn renormalise(selected: &[usize], probs: &[f32]) -> Vec<f32> {
let sum: f32 = selected.iter().map(|&i| probs[i]).sum();
if sum > 0.0 {
selected.iter().map(|&i| probs[i] / sum).collect()
} else {
let k = selected.len() as f32;
vec![1.0 / k; selected.len()]
}
}