#[derive(Debug, thiserror::Error)]
pub enum MoeRouterError {
#[error(
"k {k} must satisfy 1 ≤ k ≤ scores.len() ({n})"
)]
BadK { k: usize, n: usize },
#[error("scores empty")]
EmptyScores,
#[error(
"indices length {got} != k {k}"
)]
IndicesLen { got: usize, k: usize },
#[error(
"weights length {got} != k {k}"
)]
WeightsLen { got: usize, k: usize },
}
pub fn softmax(x: &mut [f32]) -> Result<(), MoeRouterError> {
if x.is_empty() {
return Err(MoeRouterError::EmptyScores);
}
let mut max_val = x[0];
for &v in &x[1..] {
if v > max_val {
max_val = v;
}
}
let mut sum = 0.0f32;
for v in x.iter_mut() {
*v = (*v - max_val).exp();
sum += *v;
}
let inv_sum = 1.0f32 / sum;
for v in x.iter_mut() {
*v *= inv_sum;
}
Ok(())
}
pub fn topk(
scores: &[f32],
k: usize,
indices: &mut [i32],
values: &mut [f32],
) -> Result<(), MoeRouterError> {
let n = scores.len();
if k == 0 || k > n {
return Err(MoeRouterError::BadK { k, n });
}
if indices.len() != k {
return Err(MoeRouterError::IndicesLen {
got: indices.len(),
k,
});
}
if values.len() != k {
return Err(MoeRouterError::WeightsLen {
got: values.len(),
k,
});
}
for slot in 0..k {
values[slot] = -1e30f32;
indices[slot] = 0;
}
for (i, &score) in scores.iter().enumerate() {
let mut min_k = 0usize;
for slot in 1..k {
if values[slot] < values[min_k] {
min_k = slot;
}
}
if score > values[min_k] {
values[min_k] = score;
indices[min_k] = i as i32;
}
}
Ok(())
}
pub fn normalize_weights(weights: &mut [f32]) {
let mut sum = 0.0f32;
for &w in weights.iter() {
sum += w;
}
if sum > 0.0 {
let inv = 1.0f32 / sum;
for w in weights.iter_mut() {
*w *= inv;
}
}
}
pub fn moe_router_cpu(
scores: &mut [f32],
k: usize,
indices: &mut [i32],
weights: &mut [f32],
) -> Result<(), MoeRouterError> {
softmax(scores)?;
topk(scores, k, indices, weights)?;
normalize_weights(weights);
Ok(())
}
pub fn noaux_tc_router_cpu(
score_logits: &mut [f32],
correction_bias: &[f32],
n_group: usize,
topk_group: usize,
k: usize,
routed_scaling_factor: f32,
indices: &mut [i32],
weights: &mut [f32],
) -> Result<(), MoeRouterError> {
let n = score_logits.len();
if n == 0 {
return Err(MoeRouterError::EmptyScores);
}
if correction_bias.len() != n {
return Err(MoeRouterError::WeightsLen {
got: correction_bias.len(),
k: n,
});
}
if n_group == 0 || n % n_group != 0 {
return Err(MoeRouterError::BadK { k: n_group, n });
}
if topk_group == 0 || topk_group > n_group {
return Err(MoeRouterError::BadK {
k: topk_group,
n: n_group,
});
}
if k == 0 || k > n {
return Err(MoeRouterError::BadK { k, n });
}
if indices.len() != k {
return Err(MoeRouterError::IndicesLen {
got: indices.len(),
k,
});
}
if weights.len() != k {
return Err(MoeRouterError::WeightsLen {
got: weights.len(),
k,
});
}
for s in score_logits.iter_mut() {
*s = 1.0 / (1.0 + (-*s).exp());
}
let mut biased: Vec<f32> = score_logits
.iter()
.zip(correction_bias.iter())
.map(|(s, b)| s + b)
.collect();
let group_size = n / n_group;
let mut group_scores = vec![0.0f32; n_group];
for g in 0..n_group {
let slice = &biased[g * group_size..(g + 1) * group_size];
let mut top1 = f32::NEG_INFINITY;
let mut top2 = f32::NEG_INFINITY;
for &v in slice.iter() {
if v > top1 {
top2 = top1;
top1 = v;
} else if v > top2 {
top2 = v;
}
}
group_scores[g] = top1 + top2;
}
let mut group_idx = vec![0i32; topk_group];
let mut group_vals = vec![0.0f32; topk_group];
topk(&group_scores, topk_group, &mut group_idx, &mut group_vals)?;
let group_idx_set: std::collections::HashSet<usize> =
group_idx.iter().map(|&i| i as usize).collect();
for g in 0..n_group {
if !group_idx_set.contains(&g) {
for v in biased[g * group_size..(g + 1) * group_size].iter_mut() {
*v = f32::NEG_INFINITY;
}
}
}
let mut throwaway_vals = vec![0.0f32; k];
topk(&biased, k, indices, &mut throwaway_vals)?;
let mut sum = 0.0f32;
for (slot, &i) in indices.iter().enumerate() {
let w = score_logits[i as usize];
weights[slot] = w;
sum += w;
}
if sum > 0.0 {
let inv = 1.0f32 / sum;
for w in weights.iter_mut() {
*w *= inv;
}
}
for w in weights.iter_mut() {
*w *= routed_scaling_factor;
}
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
pub struct ExpertBuckets {
pub expert_ids: Vec<i32>,
pub offsets: Vec<u32>,
pub token_idx: Vec<i32>,
pub weights: Vec<f32>,
}
pub fn build_expert_buckets(
per_token_indices: &[i32],
per_token_weights: &[f32],
n_tokens: usize,
k_active: usize,
num_experts: usize,
) -> ExpertBuckets {
debug_assert_eq!(per_token_indices.len(), n_tokens * k_active);
debug_assert_eq!(per_token_weights.len(), n_tokens * k_active);
let mut counts: Vec<u32> = vec![0; num_experts];
for &e in per_token_indices {
debug_assert!(
e >= 0 && (e as usize) < num_experts,
"expert id {e} out of range [0, {num_experts})"
);
counts[e as usize] += 1;
}
let num_non_empty: usize = counts.iter().filter(|&&c| c > 0).count();
let total_assignments: usize = n_tokens * k_active;
let mut expert_ids: Vec<i32> = Vec::with_capacity(num_non_empty);
let mut offsets: Vec<u32> = Vec::with_capacity(num_non_empty + 1);
let mut expert_to_bucket: Vec<i32> = vec![-1; num_experts];
let mut running: u32 = 0;
offsets.push(0);
for (e, &c) in counts.iter().enumerate() {
if c == 0 {
continue;
}
expert_to_bucket[e] = expert_ids.len() as i32;
expert_ids.push(e as i32);
running += c;
offsets.push(running);
}
debug_assert_eq!(running as usize, total_assignments);
let mut cursors: Vec<u32> = offsets[..num_non_empty].to_vec();
let mut token_idx: Vec<i32> = vec![0; total_assignments];
let mut weights: Vec<f32> = vec![0.0; total_assignments];
for t in 0..n_tokens {
for s in 0..k_active {
let flat = t * k_active + s;
let e = per_token_indices[flat];
let b = expert_to_bucket[e as usize] as usize;
let pos = cursors[b] as usize;
token_idx[pos] = t as i32;
weights[pos] = per_token_weights[flat];
cursors[b] += 1;
}
}
ExpertBuckets {
expert_ids,
offsets,
token_idx,
weights,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn softmax_normalizes_to_one() {
let mut x = [1.0f32, 2.0, 3.0, 4.0];
softmax(&mut x).unwrap();
let sum: f32 = x.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "softmax sum = {sum}");
}
#[test]
fn topk_picks_largest() {
let scores = [0.1f32, 0.5, 0.3, 0.9, 0.2];
let mut idx = [0i32; 3];
let mut val = [0.0f32; 3];
topk(&scores, 3, &mut idx, &mut val).unwrap();
let mut pairs: Vec<(i32, f32)> =
idx.iter().copied().zip(val.iter().copied()).collect();
pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
assert_eq!(pairs[0].0, 3); assert_eq!(pairs[1].0, 1); assert_eq!(pairs[2].0, 2); }
#[test]
fn normalize_sums_to_one() {
let mut w = [0.5f32, 1.5, 2.0];
normalize_weights(&mut w);
let sum: f32 = w.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "normalize sum = {sum}");
}
#[test]
fn noaux_tc_against_hand_reference() {
let mut logits = [3.0f32, -2.0, 1.0, -1.0, -3.0, 2.0, 0.0, 4.0];
let bias = [0.1f32, 0.2, -0.3, 0.0, 0.0, 0.4, -0.1, 0.05];
let mut indices = [0i32; 2];
let mut weights = [0.0f32; 2];
noaux_tc_router_cpu(
&mut logits,
&bias,
2,
1,
2,
2.5,
&mut indices,
&mut weights,
)
.unwrap();
let mut idx_sorted: Vec<i32> = indices.to_vec();
idx_sorted.sort();
assert_eq!(idx_sorted, vec![5, 7]);
let mut pairs: Vec<(i32, f32)> =
indices.iter().copied().zip(weights.iter().copied()).collect();
pairs.sort_by_key(|&(i, _)| i);
let tol = 1e-3;
assert_eq!(pairs[0].0, 5);
assert!(
(pairs[0].1 - 1.1822).abs() < tol,
"weight for idx 5: got {}, want ~1.1822",
pairs[0].1
);
assert_eq!(pairs[1].0, 7);
assert!(
(pairs[1].1 - 1.3178).abs() < tol,
"weight for idx 7: got {}, want ~1.3178",
pairs[1].1
);
let sum: f32 = weights.iter().sum();
assert!((sum - 2.5).abs() < tol, "sum = {sum}");
}
#[test]
fn noaux_tc_full_cogito_shape() {
let n_experts = 256;
let n_group = 8;
let topk_group = 4;
let k = 8;
let scaling = 2.5;
let mut logits: Vec<f32> = (0..n_experts)
.map(|i| ((i as f32) * 0.137).sin() * 2.0)
.collect();
let bias: Vec<f32> = (0..n_experts)
.map(|i| ((i as f32) * 0.241).cos() * 0.5)
.collect();
let mut indices = vec![0i32; k];
let mut weights = vec![0.0f32; k];
noaux_tc_router_cpu(
&mut logits,
&bias,
n_group,
topk_group,
k,
scaling,
&mut indices,
&mut weights,
)
.unwrap();
let sum: f32 = weights.iter().sum();
assert!(
(sum - scaling).abs() < 1e-3,
"weight sum = {sum}, want {scaling}"
);
let group_size = n_experts / n_group;
let chosen_groups: std::collections::HashSet<usize> = indices
.iter()
.map(|&i| (i as usize) / group_size)
.collect();
assert!(
chosen_groups.len() <= topk_group,
"selected experts span {} groups, must be ≤ {topk_group}",
chosen_groups.len()
);
let unique: std::collections::HashSet<i32> =
indices.iter().copied().collect();
assert_eq!(unique.len(), k, "duplicate expert indices in output");
}
#[test]
fn build_expert_buckets_round_trips() {
let indices: [i32; 6] = [0, 3, 0, 2, 3, 0];
let weights: [f32; 6] = [0.7, 0.3, 0.4, 0.6, 0.5, 0.5];
let b = build_expert_buckets(&indices, &weights, 3, 2, 4);
assert_eq!(b.expert_ids, vec![0, 2, 3]);
assert_eq!(b.offsets, vec![0, 3, 4, 6]);
let mut got: Vec<Vec<(i32, f32)>> = vec![vec![]; 3];
for (bi, &e) in b.expert_ids.iter().enumerate() {
let start = b.offsets[bi] as usize;
let end = b.offsets[bi + 1] as usize;
for j in start..end {
let t = b.token_idx[j] as usize;
got[t].push((e, b.weights[j]));
}
}
for t in 0..3 {
let mut want: Vec<(i32, f32)> = (0..2)
.map(|s| (indices[t * 2 + s], weights[t * 2 + s]))
.collect();
want.sort_by_key(|&(e, _)| e);
got[t].sort_by_key(|&(e, _)| e);
assert_eq!(got[t].len(), 2, "token {t} count");
for (g, w) in got[t].iter().zip(want.iter()) {
assert_eq!(g.0, w.0, "token {t} expert");
assert!((g.1 - w.1).abs() < 1e-9, "token {t} weight");
}
}
}
#[test]
fn build_expert_buckets_distinct_tokens_per_bucket() {
let indices: [i32; 8] = [0, 1, 1, 2, 0, 2, 1, 3];
let weights: [f32; 8] = [0.5, 0.5, 0.4, 0.6, 0.7, 0.3, 0.5, 0.5];
let b = build_expert_buckets(&indices, &weights, 4, 2, 4);
for bi in 0..b.expert_ids.len() {
let start = b.offsets[bi] as usize;
let end = b.offsets[bi + 1] as usize;
let mut sorted: Vec<i32> = b.token_idx[start..end].to_vec();
sorted.sort();
let len = sorted.len();
sorted.dedup();
assert_eq!(
sorted.len(),
len,
"duplicate token in bucket {bi}"
);
}
}
#[test]
fn build_expert_buckets_empty_skipped() {
let indices: [i32; 4] = [2, 7, 7, 2];
let weights: [f32; 4] = [0.5; 4];
let b = build_expert_buckets(&indices, &weights, 2, 2, 10);
assert_eq!(b.expert_ids, vec![2, 7]);
assert_eq!(b.offsets, vec![0, 2, 4]);
let b0_tokens = &b.token_idx[0..2];
let b1_tokens = &b.token_idx[2..4];
let mut s0 = b0_tokens.to_vec();
s0.sort();
let mut s1 = b1_tokens.to_vec();
s1.sort();
assert_eq!(s0, vec![0, 1]);
assert_eq!(s1, vec![0, 1]);
}
#[test]
fn build_expert_buckets_single_expert() {
let indices: [i32; 4] = [3, 3, 3, 3];
let weights: [f32; 4] = [1.0, 1.0, 1.0, 1.0];
let b = build_expert_buckets(&indices, &weights, 4, 1, 8);
assert_eq!(b.expert_ids, vec![3]);
assert_eq!(b.offsets, vec![0, 4]);
assert_eq!(b.token_idx, vec![0, 1, 2, 3]);
assert_eq!(b.weights, vec![1.0; 4]);
}
}