#[derive(Debug, thiserror::Error)]
pub enum MoeRouterError {
#[error(
"k {k} must satisfy 1 ≤ k ≤ scores.len() ({n})"
)]
BadK { k: usize, n: usize },
#[error("scores empty")]
EmptyScores,
#[error(
"indices length {got} != k {k}"
)]
IndicesLen { got: usize, k: usize },
#[error(
"weights length {got} != k {k}"
)]
WeightsLen { got: usize, k: usize },
}
pub fn softmax(x: &mut [f32]) -> Result<(), MoeRouterError> {
if x.is_empty() {
return Err(MoeRouterError::EmptyScores);
}
let mut max_val = x[0];
for &v in &x[1..] {
if v > max_val {
max_val = v;
}
}
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max_val).exp();
sum += *v;
}
let inv_sum = 1.0f32 / sum;
for v in x.iter_mut() {
*v *= inv_sum;
}
Ok(())
}
pub fn topk(
scores: &[f32],
k: usize,
indices: &mut [i32],
values: &mut [f32],
) -> Result<(), MoeRouterError> {
let n = scores.len();
if k == 0 || k > n {
return Err(MoeRouterError::BadK { k, n });
}
if indices.len() != k {
return Err(MoeRouterError::IndicesLen {
got: indices.len(),
k,
});
}
if values.len() != k {
return Err(MoeRouterError::WeightsLen {
got: values.len(),
k,
});
}
for slot in 0..k {
values[slot] = -1e30f32;
indices[slot] = 0;
}
for (i, &score) in scores.iter().enumerate() {
let mut min_k = 0usize;
for slot in 1..k {
if values[slot] < values[min_k] {
min_k = slot;
}
}
if score > values[min_k] {
values[min_k] = score;
indices[min_k] = i as i32;
}
}
Ok(())
}
pub fn normalize_weights(weights: &mut [f32]) {
let mut sum = 0.0f32;
for &w in weights.iter() {
sum += w;
}
if sum > 0.0 {
let inv = 1.0f32 / sum;
for w in weights.iter_mut() {
*w *= inv;
}
}
}
pub fn moe_router_cpu(
scores: &mut [f32],
k: usize,
indices: &mut [i32],
weights: &mut [f32],
) -> Result<(), MoeRouterError> {
softmax(scores)?;
topk(scores, k, indices, weights)?;
normalize_weights(weights);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn softmax_normalizes_to_one() {
let mut x = [1.0f32, 2.0, 3.0, 4.0];
softmax(&mut x).unwrap();
let sum: f32 = x.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "softmax sum = {sum}");
}
#[test]
fn topk_picks_largest() {
let scores = [0.1f32, 0.5, 0.3, 0.9, 0.2];
let mut idx = [0i32; 3];
let mut val = [0.0f32; 3];
topk(&scores, 3, &mut idx, &mut val).unwrap();
let mut pairs: Vec<(i32, f32)> =
idx.iter().copied().zip(val.iter().copied()).collect();
pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
assert_eq!(pairs[0].0, 3); assert_eq!(pairs[1].0, 1); assert_eq!(pairs[2].0, 2); }
#[test]
fn normalize_sums_to_one() {
let mut w = [0.5f32, 1.5, 2.0];
normalize_weights(&mut w);
let sum: f32 = w.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "normalize sum = {sum}");
}
}