1use crate::types::Expert;
10use async_trait::async_trait;
11use std::collections::hash_map::DefaultHasher;
12use std::hash::{Hash, Hasher};
13use std::sync::Arc;
14
15#[derive(Debug, Clone)]
17pub struct RouteResult {
18 pub expert_index: usize,
20 pub confidence: f64,
22 pub alternatives: Vec<usize>,
24 pub metadata: std::collections::HashMap<String, String>,
26}
27
28impl RouteResult {
29 pub fn new(expert_index: usize) -> Self {
31 Self {
32 expert_index,
33 confidence: 1.0,
34 alternatives: Vec::new(),
35 metadata: std::collections::HashMap::new(),
36 }
37 }
38
39 pub fn with_confidence(mut self, confidence: f64) -> Self {
41 self.confidence = confidence.clamp(0.0, 1.0);
42 self
43 }
44
45 pub fn with_alternatives(mut self, alternatives: Vec<usize>) -> Self {
47 self.alternatives = alternatives;
48 self
49 }
50}
51
52#[async_trait]
80pub trait MoERouter: Send + Sync {
81 async fn route(&self, input: &str, num_experts: usize) -> RouteResult;
90
91 async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
95 let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
96 if available.is_empty() {
97 self.route(input, experts.len()).await
99 } else {
100 let result = self.route(input, available.len()).await;
101 RouteResult::new(available[result.expert_index].index)
102 .with_confidence(result.confidence)
103 }
104 }
105
106 fn name(&self) -> &str {
108 "custom"
109 }
110}
111
112#[derive(Debug, Clone, Default)]
117pub struct DefaultMoERouter {
118 virtual_shards: usize,
120}
121
122impl DefaultMoERouter {
123 pub fn new() -> Self {
125 Self { virtual_shards: 1 }
126 }
127
128 pub fn with_virtual_shards(mut self, shards: usize) -> Self {
130 self.virtual_shards = shards.max(1);
131 self
132 }
133
134 fn hash_input(&self, input: &str) -> u64 {
135 let mut hasher = DefaultHasher::new();
136 input.hash(&mut hasher);
137 hasher.finish()
138 }
139}
140
141#[async_trait]
142impl MoERouter for DefaultMoERouter {
143 async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
144 if num_experts == 0 {
145 return RouteResult::new(0);
146 }
147
148 let hash = self.hash_input(input);
149 let expert_index = (hash % num_experts as u64) as usize;
150
151 RouteResult::new(expert_index).with_confidence(1.0)
152 }
153
154 fn name(&self) -> &str {
155 "default-hash"
156 }
157}
158
159#[derive(Debug, Clone)]
164pub struct LoadAwareMoERouter {
165 affinity_threshold: f64,
167 fallback: DefaultMoERouter,
169}
170
171impl LoadAwareMoERouter {
172 pub fn new() -> Self {
174 Self {
175 affinity_threshold: 0.1,
176 fallback: DefaultMoERouter::new(),
177 }
178 }
179
180 pub fn with_affinity_threshold(mut self, threshold: f64) -> Self {
184 self.affinity_threshold = threshold.clamp(0.0, 1.0);
185 self
186 }
187}
188
189impl Default for LoadAwareMoERouter {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[async_trait]
196impl MoERouter for LoadAwareMoERouter {
197 async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
198 self.fallback.route(input, num_experts).await
200 }
201
202 async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
203 let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
204
205 if available.is_empty() {
206 return RouteResult::new(0);
207 }
208
209 let min_load = available
211 .iter()
212 .map(|e| e.load)
213 .fold(f64::INFINITY, f64::min);
214
215 let affinity_result = self.fallback.route(input, experts.len()).await;
217 let affinity_expert = experts.get(affinity_result.expert_index);
218
219 if let Some(expert) = affinity_expert {
221 if expert.available() && (expert.load - min_load) < self.affinity_threshold {
222 return RouteResult::new(expert.index).with_confidence(0.9);
223 }
224 }
225
226 let selected = available
228 .iter()
229 .min_by(|a, b| a.load.partial_cmp(&b.load).unwrap_or(std::cmp::Ordering::Equal))
230 .expect("available is non-empty, checked above");
231
232 let alternatives: Vec<_> = available
233 .iter()
234 .filter(|e| e.index != selected.index)
235 .take(2)
236 .map(|e| e.index)
237 .collect();
238
239 RouteResult::new(selected.index)
240 .with_confidence(0.8)
241 .with_alternatives(alternatives)
242 }
243
244 fn name(&self) -> &str {
245 "load-aware"
246 }
247}
248
249#[derive(Debug)]
253pub struct RoundRobinMoERouter {
254 counter: std::sync::atomic::AtomicUsize,
255}
256
257impl RoundRobinMoERouter {
258 pub fn new() -> Self {
260 Self {
261 counter: std::sync::atomic::AtomicUsize::new(0),
262 }
263 }
264}
265
266impl Default for RoundRobinMoERouter {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272#[async_trait]
273impl MoERouter for RoundRobinMoERouter {
274 async fn route(&self, _input: &str, num_experts: usize) -> RouteResult {
275 if num_experts == 0 {
276 return RouteResult::new(0);
277 }
278
279 let count = self
280 .counter
281 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
282 let expert_index = count % num_experts;
283
284 RouteResult::new(expert_index)
285 }
286
287 fn name(&self) -> &str {
288 "round-robin"
289 }
290}
291
292#[derive(Debug, Clone)]
297pub struct GpuAwareMoERouter {
298 min_memory_mb: u64,
300 prefer_tensor_cores: bool,
302 fallback: LoadAwareMoERouter,
304}
305
306impl GpuAwareMoERouter {
307 pub fn new() -> Self {
309 Self {
310 min_memory_mb: 0,
311 prefer_tensor_cores: false,
312 fallback: LoadAwareMoERouter::new(),
313 }
314 }
315
316 pub fn with_min_memory(mut self, memory_mb: u64) -> Self {
318 self.min_memory_mb = memory_mb;
319 self
320 }
321
322 pub fn prefer_tensor_cores(mut self, prefer: bool) -> Self {
324 self.prefer_tensor_cores = prefer;
325 self
326 }
327}
328
329impl Default for GpuAwareMoERouter {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335#[async_trait]
336impl MoERouter for GpuAwareMoERouter {
337 async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
338 self.fallback.route(input, num_experts).await
340 }
341
342 async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
343 let gpu_experts: Vec<_> = experts
345 .iter()
346 .filter(|e| {
347 e.available() && e.gpu.as_ref().map(|g| {
348 g.available() && g.available_memory_mb() >= self.min_memory_mb
349 }).unwrap_or(false)
350 })
351 .collect();
352
353 if gpu_experts.is_empty() {
354 return self.fallback.route_with_experts(input, experts).await;
356 }
357
358 let candidates: Vec<_> = if self.prefer_tensor_cores {
360 let tensor_experts: Vec<_> = gpu_experts
361 .iter()
362 .filter(|e| e.gpu.as_ref().map(|g| g.tensor_cores).unwrap_or(false))
363 .copied()
364 .collect();
365 if tensor_experts.is_empty() { gpu_experts } else { tensor_experts }
366 } else {
367 gpu_experts
368 };
369
370 let selected = candidates
372 .iter()
373 .min_by(|a, b| {
374 let a_gpu = a.gpu.as_ref().unwrap();
375 let b_gpu = b.gpu.as_ref().unwrap();
376 a_gpu.utilization.partial_cmp(&b_gpu.utilization)
378 .unwrap_or(std::cmp::Ordering::Equal)
379 .then_with(|| b_gpu.available_memory_mb().cmp(&a_gpu.available_memory_mb()))
380 })
381 .expect("candidates is non-empty");
382
383 let alternatives: Vec<_> = candidates
384 .iter()
385 .filter(|e| e.index != selected.index)
386 .take(2)
387 .map(|e| e.index)
388 .collect();
389
390 RouteResult::new(selected.index)
391 .with_confidence(0.9)
392 .with_alternatives(alternatives)
393 }
394
395 fn name(&self) -> &str {
396 "gpu-aware"
397 }
398}
399
400#[derive(Debug, Clone)]
402pub struct VersionAwareMoERouter {
403 target_version: Option<String>,
405 canary_percent: u8,
407 canary_version: Option<String>,
409 fallback: LoadAwareMoERouter,
411}
412
413impl VersionAwareMoERouter {
414 pub fn new() -> Self {
416 Self {
417 target_version: None,
418 canary_percent: 0,
419 canary_version: None,
420 fallback: LoadAwareMoERouter::new(),
421 }
422 }
423
424 pub fn with_version(mut self, version: impl Into<String>) -> Self {
426 self.target_version = Some(version.into());
427 self
428 }
429
430 pub fn with_canary(mut self, version: impl Into<String>, percent: u8) -> Self {
432 self.canary_version = Some(version.into());
433 self.canary_percent = percent.min(100);
434 self
435 }
436}
437
438impl Default for VersionAwareMoERouter {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444#[async_trait]
445impl MoERouter for VersionAwareMoERouter {
446 async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
447 self.fallback.route(input, num_experts).await
448 }
449
450 async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
451 let use_canary = if self.canary_percent > 0 && self.canary_version.is_some() {
453 let hash = {
455 use std::hash::{Hash, Hasher};
456 let mut hasher = std::collections::hash_map::DefaultHasher::new();
457 input.hash(&mut hasher);
458 hasher.finish()
459 };
460 (hash % 100) < self.canary_percent as u64
461 } else {
462 false
463 };
464
465 let target = if use_canary {
466 self.canary_version.as_ref()
467 } else {
468 self.target_version.as_ref()
469 };
470
471 let versioned_experts: Vec<_> = if let Some(version) = target {
473 experts
474 .iter()
475 .filter(|e| e.available() && e.model_version.as_ref() == Some(version))
476 .collect()
477 } else {
478 experts.iter().filter(|e| e.available()).collect()
479 };
480
481 if versioned_experts.is_empty() {
482 return self.fallback.route_with_experts(input, experts).await;
483 }
484
485 let selected = versioned_experts
487 .iter()
488 .min_by(|a, b| a.load.partial_cmp(&b.load).unwrap_or(std::cmp::Ordering::Equal))
489 .expect("versioned_experts is non-empty");
490
491 RouteResult::new(selected.index)
492 .with_confidence(if use_canary { 0.7 } else { 0.9 })
493 }
494
495 fn name(&self) -> &str {
496 "version-aware"
497 }
498}
499
500pub type BoxedMoERouter = Arc<dyn MoERouter>;
502
503pub fn default_router() -> BoxedMoERouter {
505 Arc::new(DefaultMoERouter::new())
506}
507
508pub fn load_aware_router() -> BoxedMoERouter {
510 Arc::new(LoadAwareMoERouter::new())
511}
512
513pub fn gpu_aware_router() -> BoxedMoERouter {
515 Arc::new(GpuAwareMoERouter::new())
516}
517
518pub fn version_aware_router() -> BoxedMoERouter {
520 Arc::new(VersionAwareMoERouter::new())
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::types::NodeId;
527
528 #[tokio::test]
529 async fn test_default_router_consistency() {
530 let router = DefaultMoERouter::new();
531
532 let result1 = router.route("test-input", 8).await;
533 let result2 = router.route("test-input", 8).await;
534
535 assert_eq!(result1.expert_index, result2.expert_index);
536 }
537
538 #[tokio::test]
539 async fn test_default_router_distribution() {
540 let router = DefaultMoERouter::new();
541 let mut counts = vec![0usize; 4];
542
543 for i in 0..1000 {
544 let input = format!("input-{}", i);
545 let result = router.route(&input, 4).await;
546 counts[result.expert_index] += 1;
547 }
548
549 for count in counts {
551 assert!(count > 150 && count < 350, "Uneven distribution: {}", count);
552 }
553 }
554
555 #[tokio::test]
556 async fn test_load_aware_router() {
557 let router = LoadAwareMoERouter::new();
558
559 let experts = vec![
560 Expert::new(0, NodeId::new()),
561 {
562 let mut e = Expert::new(1, NodeId::new());
563 e.update_load(0.9);
564 e
565 },
566 Expert::new(2, NodeId::new()),
567 ];
568
569 let result = router.route_with_experts("test", &experts).await;
570
571 assert_ne!(result.expert_index, 1);
573 }
574
575 #[tokio::test]
576 async fn test_round_robin_router() {
577 let router = RoundRobinMoERouter::new();
578
579 let r0 = router.route("a", 3).await;
580 let r1 = router.route("b", 3).await;
581 let r2 = router.route("c", 3).await;
582 let r3 = router.route("d", 3).await;
583
584 assert_eq!(r0.expert_index, 0);
585 assert_eq!(r1.expert_index, 1);
586 assert_eq!(r2.expert_index, 2);
587 assert_eq!(r3.expert_index, 0);
588 }
589}