forge_orchestration/
moe.rs

1//! Mixture of Experts (MoE) routing for Forge
2//!
3//! ## Table of Contents
4//! - **MoERouter**: Trait for implementing custom routing logic
5//! - **DefaultMoERouter**: Hash-based default router
6//! - **LoadAwareMoERouter**: Load-balanced routing
7//! - **RouteResult**: Routing decision with metadata
8
9use crate::types::Expert;
10use async_trait::async_trait;
11use std::collections::hash_map::DefaultHasher;
12use std::hash::{Hash, Hasher};
13use std::sync::Arc;
14
15/// Result of a routing decision
16#[derive(Debug, Clone)]
17pub struct RouteResult {
18    /// Selected expert index
19    pub expert_index: usize,
20    /// Confidence score (0.0 - 1.0)
21    pub confidence: f64,
22    /// Alternative experts (for fallback)
23    pub alternatives: Vec<usize>,
24    /// Routing metadata
25    pub metadata: std::collections::HashMap<String, String>,
26}
27
28impl RouteResult {
29    /// Create a new route result
30    pub fn new(expert_index: usize) -> Self {
31        Self {
32            expert_index,
33            confidence: 1.0,
34            alternatives: Vec::new(),
35            metadata: std::collections::HashMap::new(),
36        }
37    }
38
39    /// Set confidence score
40    pub fn with_confidence(mut self, confidence: f64) -> Self {
41        self.confidence = confidence.clamp(0.0, 1.0);
42        self
43    }
44
45    /// Add alternative experts
46    pub fn with_alternatives(mut self, alternatives: Vec<usize>) -> Self {
47        self.alternatives = alternatives;
48        self
49    }
50}
51
52/// Trait for implementing MoE routing logic
53///
54/// Implement this trait to create custom routing strategies for
55/// distributing work across expert instances.
56///
57/// # Example
58///
59/// ```rust
60/// use forge::moe::{MoERouter, RouteResult};
61/// use async_trait::async_trait;
62///
63/// struct TypeBasedRouter;
64///
65/// #[async_trait]
66/// impl MoERouter for TypeBasedRouter {
67///     async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
68///         let expert = if input.starts_with("code:") {
69///             0  // Code expert
70///         } else if input.starts_with("math:") {
71///             1  // Math expert
72///         } else {
73///             2  // General expert
74///         };
75///         RouteResult::new(expert % num_experts)
76///     }
77/// }
78/// ```
79#[async_trait]
80pub trait MoERouter: Send + Sync {
81    /// Route an input to an expert
82    ///
83    /// # Arguments
84    /// * `input` - The input string/key to route
85    /// * `num_experts` - Total number of available experts
86    ///
87    /// # Returns
88    /// A `RouteResult` containing the selected expert and metadata
89    async fn route(&self, input: &str, num_experts: usize) -> RouteResult;
90
91    /// Route with expert health information
92    ///
93    /// Override this for load-aware routing
94    async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
95        let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
96        if available.is_empty() {
97            // Fallback to any expert if none available
98            self.route(input, experts.len()).await
99        } else {
100            let result = self.route(input, available.len()).await;
101            RouteResult::new(available[result.expert_index].index)
102                .with_confidence(result.confidence)
103        }
104    }
105
106    /// Get router name for metrics/logging
107    fn name(&self) -> &str {
108        "custom"
109    }
110}
111
112/// Default hash-based MoE router
113///
114/// Routes inputs consistently using hash-based sharding.
115/// Same input always routes to the same expert (assuming stable expert count).
116#[derive(Debug, Clone, Default)]
117pub struct DefaultMoERouter {
118    /// Number of virtual shards per expert (for better distribution)
119    virtual_shards: usize,
120}
121
122impl DefaultMoERouter {
123    /// Create a new default router
124    pub fn new() -> Self {
125        Self { virtual_shards: 1 }
126    }
127
128    /// Create with virtual sharding for better distribution
129    pub fn with_virtual_shards(mut self, shards: usize) -> Self {
130        self.virtual_shards = shards.max(1);
131        self
132    }
133
134    fn hash_input(&self, input: &str) -> u64 {
135        let mut hasher = DefaultHasher::new();
136        input.hash(&mut hasher);
137        hasher.finish()
138    }
139}
140
141#[async_trait]
142impl MoERouter for DefaultMoERouter {
143    async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
144        if num_experts == 0 {
145            return RouteResult::new(0);
146        }
147
148        let hash = self.hash_input(input);
149        let expert_index = (hash % num_experts as u64) as usize;
150
151        RouteResult::new(expert_index).with_confidence(1.0)
152    }
153
154    fn name(&self) -> &str {
155        "default-hash"
156    }
157}
158
159/// Load-aware MoE router
160///
161/// Routes to the least loaded available expert, with optional
162/// affinity for consistent routing when loads are similar.
163#[derive(Debug, Clone)]
164pub struct LoadAwareMoERouter {
165    /// Load difference threshold for preferring affinity
166    affinity_threshold: f64,
167    /// Fallback router for affinity decisions
168    fallback: DefaultMoERouter,
169}
170
171impl LoadAwareMoERouter {
172    /// Create a new load-aware router
173    pub fn new() -> Self {
174        Self {
175            affinity_threshold: 0.1,
176            fallback: DefaultMoERouter::new(),
177        }
178    }
179
180    /// Set affinity threshold
181    ///
182    /// If load difference is below this threshold, prefer consistent routing
183    pub fn with_affinity_threshold(mut self, threshold: f64) -> Self {
184        self.affinity_threshold = threshold.clamp(0.0, 1.0);
185        self
186    }
187}
188
189impl Default for LoadAwareMoERouter {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195#[async_trait]
196impl MoERouter for LoadAwareMoERouter {
197    async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
198        // Without expert info, fall back to hash routing
199        self.fallback.route(input, num_experts).await
200    }
201
202    async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
203        let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
204
205        if available.is_empty() {
206            return RouteResult::new(0);
207        }
208
209        // Find least loaded expert
210        let min_load = available
211            .iter()
212            .map(|e| e.load)
213            .fold(f64::INFINITY, f64::min);
214
215        // Get affinity expert from hash
216        let affinity_result = self.fallback.route(input, experts.len()).await;
217        let affinity_expert = experts.get(affinity_result.expert_index);
218
219        // If affinity expert is available and load is close to minimum, use it
220        if let Some(expert) = affinity_expert {
221            if expert.available() && (expert.load - min_load) < self.affinity_threshold {
222                return RouteResult::new(expert.index).with_confidence(0.9);
223            }
224        }
225
226        // Otherwise, pick least loaded
227        let selected = available
228            .iter()
229            .min_by(|a, b| a.load.partial_cmp(&b.load).unwrap_or(std::cmp::Ordering::Equal))
230            .expect("available is non-empty, checked above");
231
232        let alternatives: Vec<_> = available
233            .iter()
234            .filter(|e| e.index != selected.index)
235            .take(2)
236            .map(|e| e.index)
237            .collect();
238
239        RouteResult::new(selected.index)
240            .with_confidence(0.8)
241            .with_alternatives(alternatives)
242    }
243
244    fn name(&self) -> &str {
245        "load-aware"
246    }
247}
248
249/// Round-robin MoE router
250///
251/// Distributes requests evenly across experts in order.
252#[derive(Debug)]
253pub struct RoundRobinMoERouter {
254    counter: std::sync::atomic::AtomicUsize,
255}
256
257impl RoundRobinMoERouter {
258    /// Create a new round-robin router
259    pub fn new() -> Self {
260        Self {
261            counter: std::sync::atomic::AtomicUsize::new(0),
262        }
263    }
264}
265
266impl Default for RoundRobinMoERouter {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272#[async_trait]
273impl MoERouter for RoundRobinMoERouter {
274    async fn route(&self, _input: &str, num_experts: usize) -> RouteResult {
275        if num_experts == 0 {
276            return RouteResult::new(0);
277        }
278
279        let count = self
280            .counter
281            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
282        let expert_index = count % num_experts;
283
284        RouteResult::new(expert_index)
285    }
286
287    fn name(&self) -> &str {
288        "round-robin"
289    }
290}
291
292/// Type alias for boxed router
293pub type BoxedMoERouter = Arc<dyn MoERouter>;
294
295/// Create a boxed default router
296pub fn default_router() -> BoxedMoERouter {
297    Arc::new(DefaultMoERouter::new())
298}
299
300/// Create a boxed load-aware router
301pub fn load_aware_router() -> BoxedMoERouter {
302    Arc::new(LoadAwareMoERouter::new())
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::types::NodeId;
309
310    #[tokio::test]
311    async fn test_default_router_consistency() {
312        let router = DefaultMoERouter::new();
313
314        let result1 = router.route("test-input", 8).await;
315        let result2 = router.route("test-input", 8).await;
316
317        assert_eq!(result1.expert_index, result2.expert_index);
318    }
319
320    #[tokio::test]
321    async fn test_default_router_distribution() {
322        let router = DefaultMoERouter::new();
323        let mut counts = vec![0usize; 4];
324
325        for i in 0..1000 {
326            let input = format!("input-{}", i);
327            let result = router.route(&input, 4).await;
328            counts[result.expert_index] += 1;
329        }
330
331        // Each expert should get roughly 25% (allow 15-35%)
332        for count in counts {
333            assert!(count > 150 && count < 350, "Uneven distribution: {}", count);
334        }
335    }
336
337    #[tokio::test]
338    async fn test_load_aware_router() {
339        let router = LoadAwareMoERouter::new();
340
341        let experts = vec![
342            Expert::new(0, NodeId::new()),
343            {
344                let mut e = Expert::new(1, NodeId::new());
345                e.update_load(0.9);
346                e
347            },
348            Expert::new(2, NodeId::new()),
349        ];
350
351        let result = router.route_with_experts("test", &experts).await;
352
353        // Should not pick the heavily loaded expert (index 1)
354        assert_ne!(result.expert_index, 1);
355    }
356
357    #[tokio::test]
358    async fn test_round_robin_router() {
359        let router = RoundRobinMoERouter::new();
360
361        let r0 = router.route("a", 3).await;
362        let r1 = router.route("b", 3).await;
363        let r2 = router.route("c", 3).await;
364        let r3 = router.route("d", 3).await;
365
366        assert_eq!(r0.expert_index, 0);
367        assert_eq!(r1.expert_index, 1);
368        assert_eq!(r2.expert_index, 2);
369        assert_eq!(r3.expert_index, 0);
370    }
371}