pub fn compute_ids_tpe(
selected_experts: &[u32],
num_experts: usize,
batch: usize,
top_k: usize,
) -> (Vec<i32>, Vec<i32>, usize) {
debug_assert_eq!(selected_experts.len(), batch * top_k);
let mut buckets: Vec<Vec<i32>> = vec![Vec::new(); num_experts];
for b in 0..batch {
for k in 0..top_k {
let pair_idx = (b * top_k + k) as i32;
let e = selected_experts[b * top_k + k] as usize;
if e < num_experts {
buckets[e].push(pair_idx);
}
}
}
let max_per_expert = buckets.iter().map(|v| v.len()).max().unwrap_or(0);
let max_per_expert = max_per_expert.max(1);
let mut tpe = vec![0i32; num_experts];
let mut ids = vec![0i32; num_experts * max_per_expert];
for (e, bucket) in buckets.iter().enumerate() {
tpe[e] = bucket.len() as i32;
let off = e * max_per_expert;
for (slot, &pair) in bucket.iter().enumerate() {
ids[off + slot] = pair;
}
}
(tpe, ids, max_per_expert)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_ids_tpe_simple() {
let selected = vec![1u32, 3, 1, 0];
let (tpe, ids, mpe) = compute_ids_tpe(&selected, 4, 2, 2);
assert_eq!(tpe, vec![1, 2, 0, 1]);
assert_eq!(mpe, 2);
assert_eq!(ids[0], 3);
assert_eq!(ids[2], 0);
assert_eq!(ids[3], 2);
assert_eq!(ids[6], 1);
}
}