use super::{ExpertAffinity, ExpertId, MoeMetrics};
use std::time::Instant;
#[derive(Debug, Clone)]
struct CacheMask {
small: u64,
extended: Option<Vec<u64>>,
num_experts: usize,
}
impl CacheMask {
fn new(num_experts: usize) -> Self {
if num_experts <= 64 {
Self {
small: 0,
extended: None,
num_experts,
}
} else {
let num_words = (num_experts + 63) / 64;
Self {
small: 0,
extended: Some(vec![0u64; num_words]),
num_experts,
}
}
}
#[inline]
fn is_set(&self, id: ExpertId) -> bool {
if id >= self.num_experts {
return false;
}
if self.num_experts <= 64 {
(self.small & (1u64 << id)) != 0
} else {
let word = id / 64;
let bit = id % 64;
self.extended
.as_ref()
.map(|v| (v[word] & (1u64 << bit)) != 0)
.unwrap_or(false)
}
}
#[inline]
fn set(&mut self, id: ExpertId, resident: bool) {
if id >= self.num_experts {
return;
}
if self.num_experts <= 64 {
if resident {
self.small |= 1u64 << id;
} else {
self.small &= !(1u64 << id);
}
} else if let Some(ref mut v) = self.extended {
let word = id / 64;
let bit = id % 64;
if resident {
v[word] |= 1u64 << bit;
} else {
v[word] &= !(1u64 << bit);
}
}
}
#[inline]
fn clear(&mut self) {
self.small = 0;
if let Some(ref mut v) = self.extended {
v.fill(0);
}
}
fn resident_list(&self) -> Vec<ExpertId> {
let mut result = Vec::new();
if self.num_experts <= 64 {
let mut bits = self.small;
while bits != 0 {
let trailing = bits.trailing_zeros() as usize;
result.push(trailing);
bits &= bits - 1; }
} else if let Some(ref v) = self.extended {
for (word_idx, &word) in v.iter().enumerate() {
let mut bits = word;
while bits != 0 {
let trailing = bits.trailing_zeros() as usize;
let id = word_idx * 64 + trailing;
if id < self.num_experts {
result.push(id);
}
bits &= bits - 1;
}
}
}
result
}
#[inline]
fn count(&self) -> usize {
if self.num_experts <= 64 {
self.small.count_ones() as usize
} else {
self.extended
.as_ref()
.map(|v| v.iter().map(|w| w.count_ones() as usize).sum())
.unwrap_or(0)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PagingDirection {
In,
Out,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PagingPriority {
Normal,
Urgent,
Prefetch,
}
#[derive(Debug, Clone)]
pub struct PagingRequest {
pub expert_id: ExpertId,
pub direction: PagingDirection,
pub priority: PagingPriority,
}
impl PagingRequest {
pub fn new(expert_id: ExpertId, direction: PagingDirection, priority: PagingPriority) -> Self {
Self {
expert_id,
direction,
priority,
}
}
pub fn page_in_urgent(expert_id: ExpertId) -> Self {
Self::new(expert_id, PagingDirection::In, PagingPriority::Urgent)
}
pub fn prefetch(expert_id: ExpertId) -> Self {
Self::new(expert_id, PagingDirection::In, PagingPriority::Prefetch)
}
pub fn page_out(expert_id: ExpertId) -> Self {
Self::new(expert_id, PagingDirection::Out, PagingPriority::Normal)
}
}
#[derive(Debug, Clone)]
pub struct RouterConfig {
pub cache_bonus: f32,
pub top_k: usize,
pub num_experts: usize,
pub memory_aware: bool,
pub prefetch_threshold: f32,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
cache_bonus: 0.15,
top_k: 2,
num_experts: 8,
memory_aware: true,
prefetch_threshold: 0.1,
}
}
}
impl RouterConfig {
pub fn new(num_experts: usize, top_k: usize) -> Self {
Self {
num_experts,
top_k,
..Default::default()
}
}
pub fn with_cache_bonus(mut self, bonus: f32) -> Self {
self.cache_bonus = bonus.clamp(0.0, 1.0);
self
}
pub fn with_memory_aware(mut self, enabled: bool) -> Self {
self.memory_aware = enabled;
self
}
pub fn with_prefetch_threshold(mut self, threshold: f32) -> Self {
self.prefetch_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn validate(&self) -> Result<(), &'static str> {
if self.top_k == 0 {
return Err("top_k must be at least 1");
}
if self.top_k > self.num_experts {
return Err("top_k cannot exceed num_experts");
}
if self.num_experts == 0 {
return Err("num_experts must be at least 1");
}
Ok(())
}
}
pub struct MemoryAwareRouter {
config: RouterConfig,
affinity: ExpertAffinity,
cache_resident: CacheMask,
metrics: MoeMetrics,
score_buffer: Vec<f32>,
index_buffer: Vec<(ExpertId, f32)>,
result_buffer: Vec<ExpertId>,
}
impl MemoryAwareRouter {
pub fn new(config: RouterConfig, affinity: ExpertAffinity) -> Result<Self, &'static str> {
config.validate()?;
let num_experts = config.num_experts;
let top_k = config.top_k;
Ok(Self {
cache_resident: CacheMask::new(num_experts),
score_buffer: vec![0.0; num_experts],
index_buffer: Vec::with_capacity(num_experts),
result_buffer: Vec::with_capacity(top_k),
config,
affinity,
metrics: MoeMetrics::new(),
})
}
pub fn with_default_affinity(config: RouterConfig) -> Result<Self, &'static str> {
let affinity =
ExpertAffinity::new(super::AffinityConfig::with_num_experts(config.num_experts));
Self::new(config, affinity)
}
#[inline]
pub fn route(&mut self, gate_logits: &[f32]) -> (Vec<ExpertId>, Vec<PagingRequest>) {
#[cfg(feature = "routing-metrics")]
let start = Instant::now();
if gate_logits.len() != self.config.num_experts {
let selected: Vec<ExpertId> =
(0..self.config.top_k.min(self.config.num_experts)).collect();
return (selected, Vec::new());
}
let selected = self.route_into_buffer(gate_logits);
self.affinity.update(&selected);
let paging_requests = self.generate_paging_requests(&selected);
#[cfg(feature = "routing-metrics")]
{
let mut hits = 0usize;
for &id in &selected {
if self.cache_resident.is_set(id) {
hits += 1;
}
}
let misses = selected.len() - hits;
self.metrics.record_cache_hits(hits);
self.metrics.record_cache_misses(misses);
self.metrics.record_routing(start.elapsed());
}
(selected, paging_requests)
}
#[inline]
fn route_into_buffer(&mut self, gate_logits: &[f32]) -> Vec<ExpertId> {
let n = gate_logits.len();
self.score_buffer.clear();
self.score_buffer.extend_from_slice(gate_logits);
if self.config.memory_aware {
self.apply_cache_bonus_inplace_buffer();
}
self.select_top_k_buffered(n)
}
#[inline]
fn apply_cache_bonus_inplace_buffer(&mut self) {
let bonus = self.config.cache_bonus;
for (id, score) in self.score_buffer.iter_mut().enumerate() {
if !score.is_finite() {
*score = 0.0;
continue;
}
if self.cache_resident.is_set(id) {
*score += bonus;
}
}
}
#[inline]
fn select_top_k_buffered(&mut self, n: usize) -> Vec<ExpertId> {
let k = self.config.top_k.min(n);
self.result_buffer.clear();
if k == 0 || n == 0 {
return std::mem::take(&mut self.result_buffer);
}
self.index_buffer.clear();
self.index_buffer.extend(
self.score_buffer
.iter()
.enumerate()
.map(|(id, &s)| (id, if s.is_finite() { s } else { f32::NEG_INFINITY })),
);
if k == 2 && n >= 2 {
return self.select_top_2_unrolled();
}
if k < n / 2 {
self.index_buffer.select_nth_unstable_by(k - 1, |a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
self.index_buffer[..k].sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
} else {
self.index_buffer.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
}
self.result_buffer
.extend(self.index_buffer.iter().take(k).map(|(id, _)| *id));
std::mem::take(&mut self.result_buffer)
}
#[inline]
fn select_top_2_unrolled(&mut self) -> Vec<ExpertId> {
let mut best = (0, f32::NEG_INFINITY);
let mut second = (0, f32::NEG_INFINITY);
for &(id, score) in &self.index_buffer {
if score > best.1 || (score == best.1 && id < best.0) {
second = best;
best = (id, score);
} else if score > second.1 || (score == second.1 && id < second.0) {
second = (id, score);
}
}
self.result_buffer.clear();
self.result_buffer.push(best.0);
self.result_buffer.push(second.0);
std::mem::take(&mut self.result_buffer)
}
pub fn route_batch(
&mut self,
batch_logits: &[&[f32]],
) -> Vec<(Vec<ExpertId>, Vec<PagingRequest>)> {
let mut results = Vec::with_capacity(batch_logits.len());
for logits in batch_logits {
results.push(self.route(logits));
}
results
}
pub fn apply_cache_bonus_inplace(&self, scores: &mut [f32]) {
for (id, score) in scores.iter_mut().enumerate() {
if !score.is_finite() {
*score = 0.0;
continue;
}
if self.cache_resident.is_set(id) {
*score += self.config.cache_bonus;
}
}
}
pub fn apply_cache_bonus(&self, scores: &[f32]) -> Vec<f32> {
let mut result = scores.to_vec();
self.apply_cache_bonus_inplace(&mut result);
result
}
pub fn select_top_k(&self, scores: &[f32]) -> Vec<ExpertId> {
let n = scores.len();
let k = self.config.top_k.min(n);
if k == 0 || n == 0 {
return Vec::new();
}
let mut indexed: Vec<(ExpertId, f32)> = scores
.iter()
.enumerate()
.map(|(id, &s)| (id, if s.is_finite() { s } else { f32::NEG_INFINITY }))
.collect();
if k < n / 2 {
indexed.select_nth_unstable_by(k - 1, |a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
indexed[..k].sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
} else {
indexed.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
}
indexed.into_iter().take(k).map(|(id, _)| id).collect()
}
pub fn update_cache_state(&mut self, resident: &[ExpertId]) {
self.cache_resident.clear();
for &id in resident {
self.cache_resident.set(id, true);
}
}
pub fn set_resident(&mut self, expert_id: ExpertId, resident: bool) {
self.cache_resident.set(expert_id, resident);
}
pub fn is_resident(&self, expert_id: ExpertId) -> bool {
self.cache_resident.is_set(expert_id)
}
pub fn generate_paging_requests(&self, selected: &[ExpertId]) -> Vec<PagingRequest> {
let mut requests = Vec::new();
for &expert_id in selected {
if !self.is_resident(expert_id) {
requests.push(PagingRequest::page_in_urgent(expert_id));
}
}
requests
}
pub fn generate_prefetch_requests(&self, budget: usize) -> Vec<PagingRequest> {
let candidates = self.affinity.top_k_by_affinity(budget * 2);
candidates
.into_iter()
.filter(|&id| !self.is_resident(id))
.take(budget)
.map(PagingRequest::prefetch)
.collect()
}
pub fn metrics(&self) -> &MoeMetrics {
&self.metrics
}
pub fn reset_metrics(&mut self) {
self.metrics.reset();
}
pub fn affinity(&self) -> &ExpertAffinity {
&self.affinity
}
pub fn affinity_mut(&mut self) -> &mut ExpertAffinity {
&mut self.affinity
}
pub fn config(&self) -> &RouterConfig {
&self.config
}
pub fn hit_rate(&self) -> f32 {
self.metrics.hit_rate()
}
pub fn resident_experts(&self) -> Vec<ExpertId> {
self.cache_resident.resident_list()
}
pub fn num_experts(&self) -> usize {
self.config.num_experts
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::moe::AffinityConfig;
fn make_router(num_experts: usize, top_k: usize, cache_bonus: f32) -> MemoryAwareRouter {
let config = RouterConfig::new(num_experts, top_k).with_cache_bonus(cache_bonus);
MemoryAwareRouter::with_default_affinity(config).expect("test config should be valid")
}
#[test]
fn test_routing_basic() {
let mut router = make_router(8, 2, 0.0);
let gate_logits = vec![0.1, 0.3, 0.5, 0.2, 0.4, 0.1, 0.2, 0.15];
let (selected, _) = router.route(&gate_logits);
assert_eq!(selected.len(), 2);
assert!(selected.contains(&2));
assert!(selected.contains(&4));
}
#[test]
fn test_cache_bonus_increases_resident_score() {
let mut router = make_router(4, 1, 0.3);
router.update_cache_state(&[1]);
let gate_logits = vec![0.4, 0.3, 0.2, 0.1];
let (selected, _) = router.route(&gate_logits);
assert_eq!(selected, vec![1]);
}
#[test]
fn test_top_k_selection() {
let mut router = make_router(8, 3, 0.0);
let gate_logits = vec![0.8, 0.1, 0.2, 0.7, 0.3, 0.6, 0.4, 0.5];
let (selected, _) = router.route(&gate_logits);
assert_eq!(selected.len(), 3);
assert_eq!(selected[0], 0);
assert_eq!(selected[1], 3);
assert_eq!(selected[2], 5);
}
#[test]
fn test_paging_requests_for_non_resident() {
let mut router = make_router(4, 2, 0.0);
router.update_cache_state(&[0]);
let gate_logits = vec![0.5, 0.6, 0.4, 0.3];
let (selected, paging) = router.route(&gate_logits);
assert!(selected.contains(&0));
assert!(selected.contains(&1));
assert_eq!(paging.len(), 1);
assert_eq!(paging[0].expert_id, 1);
assert_eq!(paging[0].direction, PagingDirection::In);
assert_eq!(paging[0].priority, PagingPriority::Urgent);
}
#[test]
fn test_router_determinism() {
let mut router1 = make_router(8, 2, 0.15);
let mut router2 = make_router(8, 2, 0.15);
router1.update_cache_state(&[0, 3, 5]);
router2.update_cache_state(&[0, 3, 5]);
let gate_logits = vec![0.1, 0.3, 0.5, 0.2, 0.4, 0.1, 0.2, 0.15];
let (selected1, paging1) = router1.route(&gate_logits);
let (selected2, paging2) = router2.route(&gate_logits);
assert_eq!(
selected1, selected2,
"INV-6 violation: different expert selection"
);
assert_eq!(
paging1.len(),
paging2.len(),
"INV-6 violation: different paging count"
);
router1.reset_metrics();
let (selected3, _) = router1.route(&gate_logits);
assert_eq!(
selected1, selected3,
"INV-6 violation: non-deterministic routing"
);
}
#[test]
fn test_affinity_updates() {
let mut router = make_router(4, 2, 0.0);
let gate_logits = vec![0.4, 0.3, 0.5, 0.1];
for _ in 0..5 {
router.route(&gate_logits);
}
let top = router.affinity().top_k_by_affinity(2);
assert!(top.contains(&2), "Expert 2 should have high affinity");
assert!(top.contains(&0), "Expert 0 should have high affinity");
}
#[test]
fn test_zero_cache_bonus_fallback() {
let mut router = make_router(4, 2, 0.0);
router.update_cache_state(&[0, 1, 2, 3]);
let gate_logits = vec![0.1, 0.4, 0.3, 0.2];
let (selected, _) = router.route(&gate_logits);
assert_eq!(selected[0], 1);
assert_eq!(selected[1], 2);
}
#[test]
fn test_all_experts_resident() {
let mut router = make_router(4, 2, 0.15);
router.update_cache_state(&[0, 1, 2, 3]);
let gate_logits = vec![0.1, 0.4, 0.3, 0.2];
let (selected, paging) = router.route(&gate_logits);
assert_eq!(selected.len(), 2);
assert!(
paging.is_empty(),
"No paging should be needed when all selected are resident"
);
assert_eq!(router.metrics().cache_hits, 2);
assert_eq!(router.metrics().cache_misses, 0);
}
#[test]
fn test_no_experts_resident() {
let mut router = make_router(4, 2, 0.15);
router.update_cache_state(&[]);
let gate_logits = vec![0.1, 0.4, 0.3, 0.2];
let (selected, paging) = router.route(&gate_logits);
assert_eq!(selected.len(), 2);
assert_eq!(
paging.len(),
2,
"Should need to page in all selected experts"
);
assert_eq!(router.metrics().cache_misses, 2);
assert_eq!(router.metrics().cache_hits, 0);
}
#[test]
fn test_config_validation() {
let valid = RouterConfig::new(8, 2);
assert!(valid.validate().is_ok());
let invalid1 = RouterConfig {
top_k: 0,
..RouterConfig::default()
};
assert!(invalid1.validate().is_err());
let invalid2 = RouterConfig {
top_k: 10,
num_experts: 8,
..RouterConfig::default()
};
assert!(invalid2.validate().is_err());
let invalid3 = RouterConfig {
num_experts: 0,
..RouterConfig::default()
};
assert!(invalid3.validate().is_err());
}
#[test]
fn test_memory_aware_disabled() {
let config = RouterConfig::new(4, 2)
.with_memory_aware(false)
.with_cache_bonus(0.5);
let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap();
router.update_cache_state(&[3]);
let gate_logits = vec![0.4, 0.3, 0.5, 0.2];
let (selected, _) = router.route(&gate_logits);
assert_eq!(selected[0], 2);
assert_eq!(selected[1], 0);
}
#[test]
fn test_hit_rate_tracking() {
let mut router = make_router(4, 2, 0.0);
router.update_cache_state(&[0, 2]);
let gate_logits = vec![0.4, 0.3, 0.5, 0.2];
router.route(&gate_logits);
assert_eq!(router.hit_rate(), 1.0);
router.reset_metrics();
router.update_cache_state(&[1, 3]);
router.route(&gate_logits);
assert_eq!(router.hit_rate(), 0.0); }
#[test]
fn test_prefetch_requests() {
let config = RouterConfig::new(4, 2).with_cache_bonus(0.0);
let affinity_config = AffinityConfig::with_num_experts(4).with_decay(1.0);
let affinity = ExpertAffinity::new(affinity_config);
let mut router = MemoryAwareRouter::new(config, affinity).unwrap();
let gate_logits = vec![0.4, 0.3, 0.5, 0.2];
for _ in 0..10 {
router.route(&gate_logits);
}
router.update_cache_state(&[1]);
let prefetch = router.generate_prefetch_requests(2);
for req in &prefetch {
assert_ne!(req.expert_id, 1);
assert_eq!(req.priority, PagingPriority::Prefetch);
}
}
#[test]
fn test_resident_experts_list() {
let mut router = make_router(8, 2, 0.15);
router.update_cache_state(&[1, 3, 5, 7]);
let resident = router.resident_experts();
assert_eq!(resident.len(), 4);
assert!(resident.contains(&1));
assert!(resident.contains(&3));
assert!(resident.contains(&5));
assert!(resident.contains(&7));
assert!(!resident.contains(&0));
}
#[test]
fn test_set_resident() {
let mut router = make_router(4, 2, 0.15);
assert!(!router.is_resident(0));
router.set_resident(0, true);
assert!(router.is_resident(0));
router.set_resident(0, false);
assert!(!router.is_resident(0));
}
#[test]
fn test_tie_breaking_determinism() {
let mut router = make_router(4, 2, 0.0);
let gate_logits = vec![0.5, 0.5, 0.5, 0.5];
let (selected1, _) = router.route(&gate_logits);
let (selected2, _) = router.route(&gate_logits);
assert_eq!(selected1, selected2);
assert_eq!(selected1, vec![0, 1]); }
#[test]
fn test_invalid_gate_logits_length() {
let mut router = make_router(4, 2, 0.15);
let gate_logits = vec![0.5, 0.3]; let (selected, paging) = router.route(&gate_logits);
assert_eq!(selected.len(), 2);
assert!(paging.is_empty() || paging.len() <= 2);
}
#[test]
fn test_apply_cache_bonus() {
let mut router = make_router(4, 2, 0.2);
router.update_cache_state(&[1, 2]);
let scores = vec![0.1, 0.3, 0.4, 0.5];
let adjusted = router.apply_cache_bonus(&scores);
assert!((adjusted[0] - 0.1).abs() < 1e-6);
assert!((adjusted[1] - 0.5).abs() < 1e-6);
assert!((adjusted[2] - 0.6).abs() < 1e-6);
assert!((adjusted[3] - 0.5).abs() < 1e-6);
}
#[test]
fn test_paging_request_constructors() {
let req1 = PagingRequest::page_in_urgent(5);
assert_eq!(req1.expert_id, 5);
assert_eq!(req1.direction, PagingDirection::In);
assert_eq!(req1.priority, PagingPriority::Urgent);
let req2 = PagingRequest::prefetch(3);
assert_eq!(req2.expert_id, 3);
assert_eq!(req2.direction, PagingDirection::In);
assert_eq!(req2.priority, PagingPriority::Prefetch);
let req3 = PagingRequest::page_out(7);
assert_eq!(req3.expert_id, 7);
assert_eq!(req3.direction, PagingDirection::Out);
assert_eq!(req3.priority, PagingPriority::Normal);
}
#[test]
fn test_config_builder() {
let config = RouterConfig::new(16, 4)
.with_cache_bonus(0.25)
.with_memory_aware(true)
.with_prefetch_threshold(0.15);
assert_eq!(config.num_experts, 16);
assert_eq!(config.top_k, 4);
assert!((config.cache_bonus - 0.25).abs() < 1e-6);
assert!(config.memory_aware);
assert!((config.prefetch_threshold - 0.15).abs() < 1e-6);
}
#[test]
fn test_cache_bonus_clamping() {
let config = RouterConfig::new(8, 2).with_cache_bonus(1.5);
assert!(
(config.cache_bonus - 1.0).abs() < 1e-6,
"cache_bonus should be clamped to 1.0"
);
let config2 = RouterConfig::new(8, 2).with_cache_bonus(-0.5);
assert!(
(config2.cache_bonus - 0.0).abs() < 1e-6,
"cache_bonus should be clamped to 0.0"
);
}
#[test]
fn test_cache_mask_small() {
let mut mask = CacheMask::new(64);
for i in 0..64 {
assert!(!mask.is_set(i), "Bit {} should be clear initially", i);
}
mask.set(0, true);
mask.set(31, true);
mask.set(63, true);
assert!(mask.is_set(0));
assert!(mask.is_set(31));
assert!(mask.is_set(63));
assert!(!mask.is_set(1));
assert!(!mask.is_set(32));
assert_eq!(mask.count(), 3);
let list = mask.resident_list();
assert_eq!(list.len(), 3);
assert!(list.contains(&0));
assert!(list.contains(&31));
assert!(list.contains(&63));
mask.clear();
assert_eq!(mask.count(), 0);
assert!(!mask.is_set(0));
}
#[test]
fn test_cache_mask_large() {
let mut mask = CacheMask::new(128);
mask.set(0, true);
mask.set(63, true);
mask.set(64, true); mask.set(127, true);
assert!(mask.is_set(0));
assert!(mask.is_set(63));
assert!(mask.is_set(64));
assert!(mask.is_set(127));
assert!(!mask.is_set(65));
assert_eq!(mask.count(), 4);
let list = mask.resident_list();
assert_eq!(list.len(), 4);
mask.clear();
assert_eq!(mask.count(), 0);
}
#[test]
fn test_cache_mask_out_of_bounds() {
let mut mask = CacheMask::new(8);
mask.set(100, true);
assert!(!mask.is_set(100));
assert_eq!(mask.count(), 0);
}
#[test]
fn test_router_with_many_experts() {
let config = RouterConfig::new(128, 4);
let mut router = MemoryAwareRouter::with_default_affinity(config).unwrap();
router.update_cache_state(&[0, 32, 64, 96, 127]);
assert!(router.is_resident(0));
assert!(router.is_resident(64));
assert!(router.is_resident(127));
assert!(!router.is_resident(1));
let resident = router.resident_experts();
assert_eq!(resident.len(), 5);
}
#[test]
fn test_empty_cache_state() {
let mut router = make_router(8, 2, 0.15);
router.update_cache_state(&[]);
for i in 0..8 {
assert!(
!router.is_resident(i),
"Expert {} should not be resident",
i
);
}
assert!(router.resident_experts().is_empty());
}
}