forge_orchestration/
moe.rs

1//! Mixture of Experts (MoE) routing for Forge
2//!
3//! ## Table of Contents
4//! - **MoERouter**: Trait for implementing custom routing logic
5//! - **DefaultMoERouter**: Hash-based default router
6//! - **LoadAwareMoERouter**: Load-balanced routing
7//! - **RouteResult**: Routing decision with metadata
8
9use crate::types::Expert;
10use async_trait::async_trait;
11use std::collections::hash_map::DefaultHasher;
12use std::hash::{Hash, Hasher};
13use std::sync::Arc;
14
15/// Result of a routing decision
16#[derive(Debug, Clone)]
17pub struct RouteResult {
18    /// Selected expert index
19    pub expert_index: usize,
20    /// Confidence score (0.0 - 1.0)
21    pub confidence: f64,
22    /// Alternative experts (for fallback)
23    pub alternatives: Vec<usize>,
24    /// Routing metadata
25    pub metadata: std::collections::HashMap<String, String>,
26}
27
28impl RouteResult {
29    /// Create a new route result
30    pub fn new(expert_index: usize) -> Self {
31        Self {
32            expert_index,
33            confidence: 1.0,
34            alternatives: Vec::new(),
35            metadata: std::collections::HashMap::new(),
36        }
37    }
38
39    /// Set confidence score
40    pub fn with_confidence(mut self, confidence: f64) -> Self {
41        self.confidence = confidence.clamp(0.0, 1.0);
42        self
43    }
44
45    /// Add alternative experts
46    pub fn with_alternatives(mut self, alternatives: Vec<usize>) -> Self {
47        self.alternatives = alternatives;
48        self
49    }
50}
51
52/// Trait for implementing MoE routing logic
53///
54/// Implement this trait to create custom routing strategies for
55/// distributing work across expert instances.
56///
57/// # Example
58///
59/// ```rust,ignore
60/// use forge_orchestration::moe::{MoERouter, RouteResult};
61/// use async_trait::async_trait;
62///
63/// struct TypeBasedRouter;
64///
65/// #[async_trait]
66/// impl MoERouter for TypeBasedRouter {
67///     async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
68///         let expert = if input.starts_with("code:") {
69///             0  // Code expert
70///         } else if input.starts_with("math:") {
71///             1  // Math expert
72///         } else {
73///             2  // General expert
74///         };
75///         RouteResult::new(expert % num_experts)
76///     }
77/// }
78/// ```
79#[async_trait]
80pub trait MoERouter: Send + Sync {
81    /// Route an input to an expert
82    ///
83    /// # Arguments
84    /// * `input` - The input string/key to route
85    /// * `num_experts` - Total number of available experts
86    ///
87    /// # Returns
88    /// A `RouteResult` containing the selected expert and metadata
89    async fn route(&self, input: &str, num_experts: usize) -> RouteResult;
90
91    /// Route with expert health information
92    ///
93    /// Override this for load-aware routing
94    async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
95        let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
96        if available.is_empty() {
97            // Fallback to any expert if none available
98            self.route(input, experts.len()).await
99        } else {
100            let result = self.route(input, available.len()).await;
101            RouteResult::new(available[result.expert_index].index)
102                .with_confidence(result.confidence)
103        }
104    }
105
106    /// Get router name for metrics/logging
107    fn name(&self) -> &str {
108        "custom"
109    }
110}
111
112/// Default hash-based MoE router
113///
114/// Routes inputs consistently using hash-based sharding.
115/// Same input always routes to the same expert (assuming stable expert count).
116#[derive(Debug, Clone, Default)]
117pub struct DefaultMoERouter {
118    /// Number of virtual shards per expert (for better distribution)
119    virtual_shards: usize,
120}
121
122impl DefaultMoERouter {
123    /// Create a new default router
124    pub fn new() -> Self {
125        Self { virtual_shards: 1 }
126    }
127
128    /// Create with virtual sharding for better distribution
129    pub fn with_virtual_shards(mut self, shards: usize) -> Self {
130        self.virtual_shards = shards.max(1);
131        self
132    }
133
134    fn hash_input(&self, input: &str) -> u64 {
135        let mut hasher = DefaultHasher::new();
136        input.hash(&mut hasher);
137        hasher.finish()
138    }
139}
140
141#[async_trait]
142impl MoERouter for DefaultMoERouter {
143    async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
144        if num_experts == 0 {
145            return RouteResult::new(0);
146        }
147
148        let hash = self.hash_input(input);
149        let expert_index = (hash % num_experts as u64) as usize;
150
151        RouteResult::new(expert_index).with_confidence(1.0)
152    }
153
154    fn name(&self) -> &str {
155        "default-hash"
156    }
157}
158
159/// Load-aware MoE router
160///
161/// Routes to the least loaded available expert, with optional
162/// affinity for consistent routing when loads are similar.
163#[derive(Debug, Clone)]
164pub struct LoadAwareMoERouter {
165    /// Load difference threshold for preferring affinity
166    affinity_threshold: f64,
167    /// Fallback router for affinity decisions
168    fallback: DefaultMoERouter,
169}
170
171impl LoadAwareMoERouter {
172    /// Create a new load-aware router
173    pub fn new() -> Self {
174        Self {
175            affinity_threshold: 0.1,
176            fallback: DefaultMoERouter::new(),
177        }
178    }
179
180    /// Set affinity threshold
181    ///
182    /// If load difference is below this threshold, prefer consistent routing
183    pub fn with_affinity_threshold(mut self, threshold: f64) -> Self {
184        self.affinity_threshold = threshold.clamp(0.0, 1.0);
185        self
186    }
187}
188
189impl Default for LoadAwareMoERouter {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195#[async_trait]
196impl MoERouter for LoadAwareMoERouter {
197    async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
198        // Without expert info, fall back to hash routing
199        self.fallback.route(input, num_experts).await
200    }
201
202    async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
203        let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
204
205        if available.is_empty() {
206            return RouteResult::new(0);
207        }
208
209        // Find least loaded expert
210        let min_load = available
211            .iter()
212            .map(|e| e.load)
213            .fold(f64::INFINITY, f64::min);
214
215        // Get affinity expert from hash
216        let affinity_result = self.fallback.route(input, experts.len()).await;
217        let affinity_expert = experts.get(affinity_result.expert_index);
218
219        // If affinity expert is available and load is close to minimum, use it
220        if let Some(expert) = affinity_expert {
221            if expert.available() && (expert.load - min_load) < self.affinity_threshold {
222                return RouteResult::new(expert.index).with_confidence(0.9);
223            }
224        }
225
226        // Otherwise, pick least loaded
227        let selected = available
228            .iter()
229            .min_by(|a, b| a.load.partial_cmp(&b.load).unwrap_or(std::cmp::Ordering::Equal))
230            .expect("available is non-empty, checked above");
231
232        let alternatives: Vec<_> = available
233            .iter()
234            .filter(|e| e.index != selected.index)
235            .take(2)
236            .map(|e| e.index)
237            .collect();
238
239        RouteResult::new(selected.index)
240            .with_confidence(0.8)
241            .with_alternatives(alternatives)
242    }
243
244    fn name(&self) -> &str {
245        "load-aware"
246    }
247}
248
249/// Round-robin MoE router
250///
251/// Distributes requests evenly across experts in order.
252#[derive(Debug)]
253pub struct RoundRobinMoERouter {
254    counter: std::sync::atomic::AtomicUsize,
255}
256
257impl RoundRobinMoERouter {
258    /// Create a new round-robin router
259    pub fn new() -> Self {
260        Self {
261            counter: std::sync::atomic::AtomicUsize::new(0),
262        }
263    }
264}
265
266impl Default for RoundRobinMoERouter {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272#[async_trait]
273impl MoERouter for RoundRobinMoERouter {
274    async fn route(&self, _input: &str, num_experts: usize) -> RouteResult {
275        if num_experts == 0 {
276            return RouteResult::new(0);
277        }
278
279        let count = self
280            .counter
281            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
282        let expert_index = count % num_experts;
283
284        RouteResult::new(expert_index)
285    }
286
287    fn name(&self) -> &str {
288        "round-robin"
289    }
290}
291
292/// GPU-aware MoE router for AI/ML workloads
293///
294/// Routes requests to experts with available GPU resources,
295/// considering memory requirements and utilization.
296#[derive(Debug, Clone)]
297pub struct GpuAwareMoERouter {
298    /// Minimum GPU memory required (MB)
299    min_memory_mb: u64,
300    /// Prefer tensor core capable GPUs
301    prefer_tensor_cores: bool,
302    /// Fallback router when no GPU experts available
303    fallback: LoadAwareMoERouter,
304}
305
306impl GpuAwareMoERouter {
307    /// Create a new GPU-aware router
308    pub fn new() -> Self {
309        Self {
310            min_memory_mb: 0,
311            prefer_tensor_cores: false,
312            fallback: LoadAwareMoERouter::new(),
313        }
314    }
315
316    /// Set minimum GPU memory requirement
317    pub fn with_min_memory(mut self, memory_mb: u64) -> Self {
318        self.min_memory_mb = memory_mb;
319        self
320    }
321
322    /// Prefer experts with tensor core support
323    pub fn prefer_tensor_cores(mut self, prefer: bool) -> Self {
324        self.prefer_tensor_cores = prefer;
325        self
326    }
327}
328
329impl Default for GpuAwareMoERouter {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335#[async_trait]
336impl MoERouter for GpuAwareMoERouter {
337    async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
338        // Without expert info, fall back to load-aware routing
339        self.fallback.route(input, num_experts).await
340    }
341
342    async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
343        // Filter to GPU-capable experts with sufficient resources
344        let gpu_experts: Vec<_> = experts
345            .iter()
346            .filter(|e| {
347                e.available() && e.gpu.as_ref().map(|g| {
348                    g.available() && g.available_memory_mb() >= self.min_memory_mb
349                }).unwrap_or(false)
350            })
351            .collect();
352
353        if gpu_experts.is_empty() {
354            // Fall back to load-aware routing if no GPU experts
355            return self.fallback.route_with_experts(input, experts).await;
356        }
357
358        // If preferring tensor cores, filter further
359        let candidates: Vec<_> = if self.prefer_tensor_cores {
360            let tensor_experts: Vec<_> = gpu_experts
361                .iter()
362                .filter(|e| e.gpu.as_ref().map(|g| g.tensor_cores).unwrap_or(false))
363                .copied()
364                .collect();
365            if tensor_experts.is_empty() { gpu_experts } else { tensor_experts }
366        } else {
367            gpu_experts
368        };
369
370        // Select expert with most available GPU memory and lowest utilization
371        let selected = candidates
372            .iter()
373            .min_by(|a, b| {
374                let a_gpu = a.gpu.as_ref().unwrap();
375                let b_gpu = b.gpu.as_ref().unwrap();
376                // Primary: GPU utilization, Secondary: available memory (higher is better)
377                a_gpu.utilization.partial_cmp(&b_gpu.utilization)
378                    .unwrap_or(std::cmp::Ordering::Equal)
379                    .then_with(|| b_gpu.available_memory_mb().cmp(&a_gpu.available_memory_mb()))
380            })
381            .expect("candidates is non-empty");
382
383        let alternatives: Vec<_> = candidates
384            .iter()
385            .filter(|e| e.index != selected.index)
386            .take(2)
387            .map(|e| e.index)
388            .collect();
389
390        RouteResult::new(selected.index)
391            .with_confidence(0.9)
392            .with_alternatives(alternatives)
393    }
394
395    fn name(&self) -> &str {
396        "gpu-aware"
397    }
398}
399
400/// Model version-aware router for A/B testing and canary deployments
401#[derive(Debug, Clone)]
402pub struct VersionAwareMoERouter {
403    /// Target model version (if specified)
404    target_version: Option<String>,
405    /// Percentage of traffic to route to canary (0-100)
406    canary_percent: u8,
407    /// Canary version
408    canary_version: Option<String>,
409    /// Fallback router
410    fallback: LoadAwareMoERouter,
411}
412
413impl VersionAwareMoERouter {
414    /// Create a new version-aware router
415    pub fn new() -> Self {
416        Self {
417            target_version: None,
418            canary_percent: 0,
419            canary_version: None,
420            fallback: LoadAwareMoERouter::new(),
421        }
422    }
423
424    /// Set target model version
425    pub fn with_version(mut self, version: impl Into<String>) -> Self {
426        self.target_version = Some(version.into());
427        self
428    }
429
430    /// Configure canary deployment
431    pub fn with_canary(mut self, version: impl Into<String>, percent: u8) -> Self {
432        self.canary_version = Some(version.into());
433        self.canary_percent = percent.min(100);
434        self
435    }
436}
437
438impl Default for VersionAwareMoERouter {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444#[async_trait]
445impl MoERouter for VersionAwareMoERouter {
446    async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
447        self.fallback.route(input, num_experts).await
448    }
449
450    async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
451        // Determine if this request should go to canary
452        let use_canary = if self.canary_percent > 0 && self.canary_version.is_some() {
453            // Simple hash-based routing for consistency
454            let hash = {
455                use std::hash::{Hash, Hasher};
456                let mut hasher = std::collections::hash_map::DefaultHasher::new();
457                input.hash(&mut hasher);
458                hasher.finish()
459            };
460            (hash % 100) < self.canary_percent as u64
461        } else {
462            false
463        };
464
465        let target = if use_canary {
466            self.canary_version.as_ref()
467        } else {
468            self.target_version.as_ref()
469        };
470
471        // Filter experts by version if specified
472        let versioned_experts: Vec<_> = if let Some(version) = target {
473            experts
474                .iter()
475                .filter(|e| e.available() && e.model_version.as_ref() == Some(version))
476                .collect()
477        } else {
478            experts.iter().filter(|e| e.available()).collect()
479        };
480
481        if versioned_experts.is_empty() {
482            return self.fallback.route_with_experts(input, experts).await;
483        }
484
485        // Route to least loaded matching expert
486        let selected = versioned_experts
487            .iter()
488            .min_by(|a, b| a.load.partial_cmp(&b.load).unwrap_or(std::cmp::Ordering::Equal))
489            .expect("versioned_experts is non-empty");
490
491        RouteResult::new(selected.index)
492            .with_confidence(if use_canary { 0.7 } else { 0.9 })
493    }
494
495    fn name(&self) -> &str {
496        "version-aware"
497    }
498}
499
500/// Type alias for boxed router
501pub type BoxedMoERouter = Arc<dyn MoERouter>;
502
503/// Create a boxed default router
504pub fn default_router() -> BoxedMoERouter {
505    Arc::new(DefaultMoERouter::new())
506}
507
508/// Create a boxed load-aware router
509pub fn load_aware_router() -> BoxedMoERouter {
510    Arc::new(LoadAwareMoERouter::new())
511}
512
513/// Create a boxed GPU-aware router
514pub fn gpu_aware_router() -> BoxedMoERouter {
515    Arc::new(GpuAwareMoERouter::new())
516}
517
518/// Create a boxed version-aware router
519pub fn version_aware_router() -> BoxedMoERouter {
520    Arc::new(VersionAwareMoERouter::new())
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use crate::types::NodeId;
527
528    #[tokio::test]
529    async fn test_default_router_consistency() {
530        let router = DefaultMoERouter::new();
531
532        let result1 = router.route("test-input", 8).await;
533        let result2 = router.route("test-input", 8).await;
534
535        assert_eq!(result1.expert_index, result2.expert_index);
536    }
537
538    #[tokio::test]
539    async fn test_default_router_distribution() {
540        let router = DefaultMoERouter::new();
541        let mut counts = vec![0usize; 4];
542
543        for i in 0..1000 {
544            let input = format!("input-{}", i);
545            let result = router.route(&input, 4).await;
546            counts[result.expert_index] += 1;
547        }
548
549        // Each expert should get roughly 25% (allow 15-35%)
550        for count in counts {
551            assert!(count > 150 && count < 350, "Uneven distribution: {}", count);
552        }
553    }
554
555    #[tokio::test]
556    async fn test_load_aware_router() {
557        let router = LoadAwareMoERouter::new();
558
559        let experts = vec![
560            Expert::new(0, NodeId::new()),
561            {
562                let mut e = Expert::new(1, NodeId::new());
563                e.update_load(0.9);
564                e
565            },
566            Expert::new(2, NodeId::new()),
567        ];
568
569        let result = router.route_with_experts("test", &experts).await;
570
571        // Should not pick the heavily loaded expert (index 1)
572        assert_ne!(result.expert_index, 1);
573    }
574
575    #[tokio::test]
576    async fn test_round_robin_router() {
577        let router = RoundRobinMoERouter::new();
578
579        let r0 = router.route("a", 3).await;
580        let r1 = router.route("b", 3).await;
581        let r2 = router.route("c", 3).await;
582        let r3 = router.route("d", 3).await;
583
584        assert_eq!(r0.expert_index, 0);
585        assert_eq!(r1.expert_index, 1);
586        assert_eq!(r2.expert_index, 2);
587        assert_eq!(r3.expert_index, 0);
588    }
589}