use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::ops::{ActivationOps, ReduceOps, ScalarOps, SortingOps, TypeConversionOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn moe_top_k_routing_impl<R, C>(
client: &C,
logits: &Tensor<R>,
k: usize,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ ActivationOps<R>
+ SortingOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ TypeConversionOps<R>,
{
let shape = logits.shape();
if shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "logits",
reason: format!(
"expected 2D [num_tokens, num_experts], got {}D",
shape.len()
),
});
}
let num_experts = shape[1];
if k == 0 || k > num_experts {
return Err(Error::InvalidArgument {
arg: "k",
reason: format!("k={} must be in [1, num_experts={}]", k, num_experts),
});
}
let probs = client.softmax(logits, -1).map_err(Error::Numr)?;
let (top_values, top_indices_i64) = client
.topk(&probs, k, -1, true, true)
.map_err(Error::Numr)?;
let top_indices = client
.cast(&top_indices_i64, DType::I32)
.map_err(Error::Numr)?;
let weight_sum = client.sum(&top_values, &[1], true).map_err(Error::Numr)?;
let normalized = client.div(&top_values, &weight_sum).map_err(Error::Numr)?;
Ok((top_indices, normalized))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_top_k_routing_shapes() {
let (client, device) = cpu_setup();
let num_tokens = 6;
let num_experts = 4;
let k = 2;
let logits_data: Vec<f32> = (0..num_tokens * num_experts)
.map(|i| (i as f32 * 0.3).sin())
.collect();
let logits =
Tensor::<CpuRuntime>::from_slice(&logits_data, &[num_tokens, num_experts], &device);
let (indices, weights) = moe_top_k_routing_impl(&client, &logits, k).unwrap();
assert_eq!(indices.shape(), &[num_tokens, k]);
assert_eq!(weights.shape(), &[num_tokens, k]);
let w_vec = weights.to_vec::<f32>();
for t in 0..num_tokens {
let sum: f32 = (0..k).map(|j| w_vec[t * k + j]).sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"token {} weights sum to {}, expected 1.0",
t,
sum
);
}
}
#[test]
fn test_routing_invalid_k() {
let (client, device) = cpu_setup();
let logits = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 8], &[2, 4], &device);
assert!(moe_top_k_routing_impl(&client, &logits, 0).is_err());
assert!(moe_top_k_routing_impl(&client, &logits, 5).is_err());
}
}