use super::ExpertId;
#[inline]
fn decay_scores_simd(scores: &mut [f32], decay: f32) {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
decay_scores_neon(scores, decay);
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
decay_scores_avx2(scores, decay);
}
#[cfg(not(any(
all(target_arch = "aarch64", target_feature = "neon"),
all(target_arch = "x86_64", target_feature = "avx2")
)))]
{
decay_scores_scalar(scores, decay);
}
}
#[inline]
fn decay_scores_scalar(scores: &mut [f32], decay: f32) {
for score in scores.iter_mut() {
*score *= decay;
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
#[inline]
fn decay_scores_neon(scores: &mut [f32], decay: f32) {
use std::arch::aarch64::*;
let len = scores.len();
let chunks = len / 4;
let remainder = len % 4;
unsafe {
let decay_vec = vdupq_n_f32(decay);
let ptr = scores.as_mut_ptr();
for i in 0..chunks {
let offset = i * 4;
let vals = vld1q_f32(ptr.add(offset));
let result = vmulq_f32(vals, decay_vec);
vst1q_f32(ptr.add(offset), result);
}
for i in (chunks * 4)..len {
*scores.get_unchecked_mut(i) *= decay;
}
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[inline]
fn decay_scores_avx2(scores: &mut [f32], decay: f32) {
use std::arch::x86_64::*;
let len = scores.len();
let chunks = len / 8;
unsafe {
let decay_vec = _mm256_set1_ps(decay);
let ptr = scores.as_mut_ptr();
for i in 0..chunks {
let offset = i * 8;
let vals = _mm256_loadu_ps(ptr.add(offset));
let result = _mm256_mul_ps(vals, decay_vec);
_mm256_storeu_ps(ptr.add(offset), result);
}
for i in (chunks * 8)..len {
*scores.get_unchecked_mut(i) *= decay;
}
}
}
#[derive(Debug, Clone)]
pub struct AffinityConfig {
pub num_experts: usize,
pub decay: f32,
pub activation_boost: f32,
pub max_score: f32,
}
impl Default for AffinityConfig {
fn default() -> Self {
Self {
num_experts: 8,
decay: 0.99,
activation_boost: 1.0,
max_score: 1.0,
}
}
}
impl AffinityConfig {
pub fn with_num_experts(num_experts: usize) -> Self {
Self {
num_experts,
..Default::default()
}
}
pub fn with_decay(mut self, decay: f32) -> Self {
self.decay = decay.clamp(0.0, 1.0);
self
}
pub fn with_activation_boost(mut self, boost: f32) -> Self {
self.activation_boost = boost.max(0.0);
self
}
pub fn with_max_score(mut self, max_score: f32) -> Self {
self.max_score = max_score.max(0.0);
self
}
}
#[derive(Debug, Clone)]
pub struct ExpertAffinity {
scores: Vec<f32>,
config: AffinityConfig,
total_activations: Vec<u64>,
}
impl ExpertAffinity {
pub fn new(config: AffinityConfig) -> Self {
Self {
scores: vec![0.0; config.num_experts],
total_activations: vec![0; config.num_experts],
config,
}
}
pub fn update(&mut self, activated: &[ExpertId]) {
decay_scores_simd(&mut self.scores, self.config.decay);
for &id in activated {
if id < self.scores.len() {
self.scores[id] =
(self.scores[id] + self.config.activation_boost).min(self.config.max_score);
self.total_activations[id] += 1;
}
}
}
pub fn score(&self, expert_id: ExpertId) -> f32 {
self.scores.get(expert_id).copied().unwrap_or(0.0)
}
#[inline]
pub fn get_score(&self, expert_id: ExpertId) -> f32 {
self.score(expert_id)
}
#[inline]
pub fn scores(&self) -> &[f32] {
&self.scores
}
#[inline]
pub fn get_scores(&self) -> &[f32] {
self.scores()
}
#[inline]
pub fn activation_count(&self, expert_id: ExpertId) -> u64 {
self.total_activations.get(expert_id).copied().unwrap_or(0)
}
#[inline]
pub fn get_activation_counts(&self) -> &[u64] {
&self.total_activations
}
pub fn top_k_by_affinity(&self, k: usize) -> Vec<ExpertId> {
let mut indexed: Vec<(ExpertId, f32)> = self
.scores
.iter()
.enumerate()
.map(|(id, &s)| (id, if s.is_finite() { s } else { f32::NEG_INFINITY }))
.collect();
indexed.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0)) });
indexed.into_iter().take(k).map(|(id, _)| id).collect()
}
pub fn top_k_by_frequency(&self, k: usize) -> Vec<ExpertId> {
let mut indexed: Vec<(ExpertId, u64)> =
self.total_activations.iter().copied().enumerate().collect();
indexed.sort_by_key(|b| std::cmp::Reverse(b.1));
indexed.into_iter().take(k).map(|(id, _)| id).collect()
}
pub fn least_affinity(&self, candidates: &[ExpertId]) -> Option<ExpertId> {
candidates.iter().copied().min_by(|&a, &b| {
let score_a = self.score(a);
let score_b = self.score(b);
let sa = if score_a.is_finite() {
score_a
} else {
f32::NEG_INFINITY
};
let sb = if score_b.is_finite() {
score_b
} else {
f32::NEG_INFINITY
};
sa.partial_cmp(&sb)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.cmp(&b)) })
}
pub fn frequency_percentile(&self, expert_id: ExpertId) -> f32 {
let count = self.activation_count(expert_id);
let lower = self
.total_activations
.iter()
.filter(|&&c| c < count)
.count();
let equal = self
.total_activations
.iter()
.filter(|&&c| c == count)
.count();
let n = self.total_activations.len();
if n == 0 {
return 0.5;
}
(lower as f32 + 0.5 * equal as f32) / n as f32
}
pub fn reset(&mut self) {
self.scores.fill(0.0);
self.total_activations.fill(0);
}
pub fn num_experts(&self) -> usize {
self.config.num_experts
}
pub fn config(&self) -> &AffinityConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_affinity_creation() {
let config = AffinityConfig::with_num_experts(8);
let affinity = ExpertAffinity::new(config);
assert_eq!(affinity.num_experts(), 8);
assert_eq!(affinity.scores().len(), 8);
assert_eq!(affinity.get_scores().len(), 8);
assert!(affinity.scores().iter().all(|&s| s == 0.0));
assert!(affinity.get_activation_counts().iter().all(|&c| c == 0));
}
#[test]
fn test_update_decays_all() {
let config = AffinityConfig::with_num_experts(4)
.with_decay(0.5)
.with_activation_boost(1.0);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0, 1, 2, 3]);
for &score in affinity.scores() {
assert!((score - 1.0).abs() < 1e-6);
}
affinity.update(&[0]);
assert!((affinity.score(0) - 1.0).abs() < 1e-6);
for id in 1..4 {
assert!(
(affinity.score(id) - 0.5).abs() < 1e-6,
"Expert {} should decay to 0.5, got {}",
id,
affinity.score(id)
);
}
}
#[test]
fn test_update_boosts_activated() {
let config = AffinityConfig::with_num_experts(4)
.with_decay(0.9)
.with_activation_boost(0.5);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[1, 3]);
assert!((affinity.score(1) - 0.5).abs() < 1e-6);
assert!((affinity.score(3) - 0.5).abs() < 1e-6);
assert_eq!(affinity.score(0), 0.0);
assert_eq!(affinity.score(2), 0.0);
}
#[test]
fn test_monotonic_decay() {
let config = AffinityConfig::with_num_experts(8).with_decay(0.95);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[1, 3, 5, 7]);
let scores_t0 = affinity.scores().to_vec();
for iteration in 0..10 {
let scores_before = affinity.scores().to_vec();
affinity.update(&[]); let scores_after = affinity.scores().to_vec();
for (i, (&before, &after)) in scores_before.iter().zip(scores_after.iter()).enumerate()
{
assert!(
after <= before,
"INV-2 violated at iteration {}: score[{}] increased from {} to {}",
iteration,
i,
before,
after
);
}
}
for (i, (&t0, ¤t)) in scores_t0.iter().zip(affinity.scores().iter()).enumerate() {
if t0 > 0.0 {
assert!(
current < t0,
"Score[{}] did not decay: {} -> {}",
i,
t0,
current
);
}
}
}
#[test]
fn test_top_k_by_affinity() {
let config = AffinityConfig::with_num_experts(6)
.with_decay(1.0)
.with_activation_boost(0.1);
let mut affinity = ExpertAffinity::new(config);
for _ in 0..5 {
affinity.update(&[3]);
}
for _ in 0..3 {
affinity.update(&[1]);
}
affinity.update(&[5]);
assert!(
(affinity.score(3) - 0.5).abs() < 1e-6,
"Expert 3 score: {}",
affinity.score(3)
);
assert!(
(affinity.score(1) - 0.3).abs() < 1e-6,
"Expert 1 score: {}",
affinity.score(1)
);
assert!(
(affinity.score(5) - 0.1).abs() < 1e-6,
"Expert 5 score: {}",
affinity.score(5)
);
let top2 = affinity.top_k_by_affinity(2);
assert_eq!(top2.len(), 2);
assert_eq!(top2[0], 3, "Expert 3 should be top");
assert_eq!(top2[1], 1, "Expert 1 should be second");
let top4 = affinity.top_k_by_affinity(4);
assert_eq!(top4.len(), 4);
assert_eq!(top4[0], 3);
assert_eq!(top4[1], 1);
assert_eq!(top4[2], 5);
let top10 = affinity.top_k_by_affinity(10);
assert_eq!(top10.len(), 6);
}
#[test]
fn test_score_clamped_to_one() {
let config = AffinityConfig::with_num_experts(4)
.with_decay(0.99)
.with_activation_boost(1.0);
let mut affinity = ExpertAffinity::new(config);
for _ in 0..100 {
affinity.update(&[0]);
}
assert!(
(affinity.score(0) - 1.0).abs() < 1e-6,
"Score should be clamped to 1.0, got {}",
affinity.score(0)
);
assert!(
affinity.score(0) <= 1.0,
"Score {} exceeds max_score",
affinity.score(0)
);
}
#[test]
fn test_activation_counting() {
let config = AffinityConfig::with_num_experts(4);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0, 1]); affinity.update(&[0, 2]); affinity.update(&[0]);
assert_eq!(affinity.activation_count(0), 3);
assert_eq!(affinity.activation_count(1), 1);
assert_eq!(affinity.activation_count(2), 1);
assert_eq!(affinity.activation_count(3), 0);
assert_eq!(affinity.activation_count(100), 0);
}
#[test]
fn test_reset() {
let config = AffinityConfig::with_num_experts(4);
let mut affinity = ExpertAffinity::new(config);
for _ in 0..10 {
affinity.update(&[0, 1, 2, 3]);
}
assert!(affinity.score(0) > 0.0);
assert!(affinity.activation_count(0) > 0);
affinity.reset();
for &score in affinity.scores() {
assert_eq!(score, 0.0);
}
for &count in affinity.get_activation_counts() {
assert_eq!(count, 0);
}
}
#[test]
fn test_empty_update() {
let config = AffinityConfig::with_num_experts(4).with_decay(0.9);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0, 1, 2, 3]);
let counts_before = affinity.get_activation_counts().to_vec();
affinity.update(&[]);
assert_eq!(affinity.get_activation_counts(), &counts_before);
}
#[test]
fn test_multiple_updates_sequence() {
let config = AffinityConfig::with_num_experts(8)
.with_decay(0.8)
.with_activation_boost(0.5);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0, 1, 7]);
assert!((affinity.score(0) - 0.5).abs() < 1e-6);
assert!((affinity.score(7) - 0.5).abs() < 1e-6);
affinity.update(&[0, 1]);
assert!((affinity.score(0) - 0.9).abs() < 1e-6);
assert!((affinity.score(7) - 0.4).abs() < 1e-6);
affinity.update(&[0, 1]);
assert!((affinity.score(0) - 1.0).abs() < 1e-6);
assert!((affinity.score(7) - 0.32).abs() < 1e-6);
let top2 = affinity.top_k_by_affinity(2);
assert!(top2.contains(&0));
assert!(top2.contains(&1));
assert_eq!(affinity.activation_count(0), 3);
assert_eq!(affinity.activation_count(1), 3);
assert_eq!(affinity.activation_count(7), 1);
}
#[test]
fn test_out_of_bounds_experts_ignored() {
let config = AffinityConfig::with_num_experts(4);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0, 1, 100, 200, 3]);
assert!(affinity.score(0) > 0.0);
assert!(affinity.score(1) > 0.0);
assert!(affinity.score(3) > 0.0);
assert_eq!(affinity.score(2), 0.0);
assert_eq!(affinity.activation_count(0), 1);
assert_eq!(affinity.activation_count(100), 0);
}
#[test]
fn test_config_builders() {
let config = AffinityConfig::with_num_experts(16)
.with_decay(0.95)
.with_activation_boost(0.75)
.with_max_score(2.0);
assert_eq!(config.num_experts, 16);
assert!((config.decay - 0.95).abs() < 1e-6);
assert!((config.activation_boost - 0.75).abs() < 1e-6);
assert!((config.max_score - 2.0).abs() < 1e-6);
}
#[test]
fn test_decay_clamp() {
let config = AffinityConfig::with_num_experts(4).with_decay(1.5);
assert!(
(config.decay - 1.0).abs() < 1e-6,
"Decay should be clamped to 1.0"
);
let config2 = AffinityConfig::with_num_experts(4).with_decay(-0.5);
assert!(
(config2.decay - 0.0).abs() < 1e-6,
"Decay should be clamped to 0.0"
);
}
#[test]
fn test_frequency_percentile() {
let config = AffinityConfig::with_num_experts(4);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0]);
for _ in 0..5 {
affinity.update(&[1]);
}
for _ in 0..3 {
affinity.update(&[2]);
}
let pct_1 = affinity.frequency_percentile(1);
let pct_3 = affinity.frequency_percentile(3);
assert!(
pct_1 > pct_3,
"Expert 1 should have higher percentile than 3"
);
assert!(pct_1 > 0.5, "Expert 1 should be above median");
}
#[test]
fn test_least_affinity() {
let config = AffinityConfig::with_num_experts(4)
.with_decay(1.0)
.with_activation_boost(0.1);
let mut affinity = ExpertAffinity::new(config);
for _ in 0..5 {
affinity.update(&[0]);
}
for _ in 0..2 {
affinity.update(&[1]);
}
affinity.update(&[2]);
assert!((affinity.score(0) - 0.5).abs() < 1e-6);
assert!((affinity.score(1) - 0.2).abs() < 1e-6);
assert!((affinity.score(2) - 0.1).abs() < 1e-6);
let candidates = vec![0, 1, 2];
let least = affinity.least_affinity(&candidates);
assert_eq!(least, Some(2));
let empty: Vec<ExpertId> = vec![];
assert_eq!(affinity.least_affinity(&empty), None);
}
#[test]
fn test_top_k_by_frequency() {
let config = AffinityConfig::with_num_experts(4);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0]);
affinity.update(&[1]);
affinity.update(&[1]);
affinity.update(&[2]);
affinity.update(&[2]);
affinity.update(&[2]);
let top_2 = affinity.top_k_by_frequency(2);
assert_eq!(top_2.len(), 2);
assert_eq!(top_2[0], 2);
assert_eq!(top_2[1], 1);
}
#[test]
fn test_default_config() {
let config = AffinityConfig::default();
assert_eq!(config.num_experts, 8);
assert!((config.decay - 0.99).abs() < 1e-6);
assert!((config.activation_boost - 1.0).abs() < 1e-6);
assert!((config.max_score - 1.0).abs() < 1e-6);
}
#[test]
fn test_simd_decay_non_aligned() {
for size in [1, 3, 5, 7, 9, 15, 17, 33, 65] {
let config = AffinityConfig::with_num_experts(size).with_decay(0.5);
let mut affinity = ExpertAffinity::new(config);
let all_experts: Vec<usize> = (0..size).collect();
affinity.update(&all_experts);
for &score in affinity.scores() {
assert!((score - 1.0).abs() < 1e-6);
}
affinity.update(&[]);
for (i, &score) in affinity.scores().iter().enumerate() {
assert!(
(score - 0.5).abs() < 1e-6,
"Expert {} score should be 0.5, got {}",
i,
score
);
}
}
}
#[test]
fn test_simd_decay_large() {
let config = AffinityConfig::with_num_experts(256).with_decay(0.9);
let mut affinity = ExpertAffinity::new(config);
let activated: Vec<usize> = (0..128).collect();
affinity.update(&activated);
for _ in 0..10 {
affinity.update(&[]);
}
let expected = 0.9f32.powi(10);
for i in 0..128 {
let score = affinity.score(i);
assert!(
(score - expected).abs() < 1e-5,
"Expert {} score should be ~{}, got {}",
i,
expected,
score
);
}
for i in 128..256 {
assert_eq!(affinity.score(i), 0.0);
}
}
#[test]
fn test_simd_decay_correctness() {
let config = AffinityConfig::with_num_experts(64)
.with_decay(0.87)
.with_activation_boost(0.33);
let mut affinity = ExpertAffinity::new(config);
affinity.update(&[0, 7, 15, 23, 31, 39, 47, 55, 63]);
let scores_before: Vec<f32> = affinity.scores().to_vec();
affinity.update(&[]);
for (i, (&before, &after)) in scores_before
.iter()
.zip(affinity.scores().iter())
.enumerate()
{
let expected = before * 0.87;
assert!(
(after - expected).abs() < 1e-6,
"Expert {} decay incorrect: {} * 0.87 = {}, got {}",
i,
before,
expected,
after
);
}
}
}