forge_orchestration/
moe.rs1use crate::types::Expert;
10use async_trait::async_trait;
11use std::collections::hash_map::DefaultHasher;
12use std::hash::{Hash, Hasher};
13use std::sync::Arc;
14
15#[derive(Debug, Clone)]
17pub struct RouteResult {
18 pub expert_index: usize,
20 pub confidence: f64,
22 pub alternatives: Vec<usize>,
24 pub metadata: std::collections::HashMap<String, String>,
26}
27
28impl RouteResult {
29 pub fn new(expert_index: usize) -> Self {
31 Self {
32 expert_index,
33 confidence: 1.0,
34 alternatives: Vec::new(),
35 metadata: std::collections::HashMap::new(),
36 }
37 }
38
39 pub fn with_confidence(mut self, confidence: f64) -> Self {
41 self.confidence = confidence.clamp(0.0, 1.0);
42 self
43 }
44
45 pub fn with_alternatives(mut self, alternatives: Vec<usize>) -> Self {
47 self.alternatives = alternatives;
48 self
49 }
50}
51
52#[async_trait]
80pub trait MoERouter: Send + Sync {
81 async fn route(&self, input: &str, num_experts: usize) -> RouteResult;
90
91 async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
95 let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
96 if available.is_empty() {
97 self.route(input, experts.len()).await
99 } else {
100 let result = self.route(input, available.len()).await;
101 RouteResult::new(available[result.expert_index].index)
102 .with_confidence(result.confidence)
103 }
104 }
105
106 fn name(&self) -> &str {
108 "custom"
109 }
110}
111
112#[derive(Debug, Clone, Default)]
117pub struct DefaultMoERouter {
118 virtual_shards: usize,
120}
121
122impl DefaultMoERouter {
123 pub fn new() -> Self {
125 Self { virtual_shards: 1 }
126 }
127
128 pub fn with_virtual_shards(mut self, shards: usize) -> Self {
130 self.virtual_shards = shards.max(1);
131 self
132 }
133
134 fn hash_input(&self, input: &str) -> u64 {
135 let mut hasher = DefaultHasher::new();
136 input.hash(&mut hasher);
137 hasher.finish()
138 }
139}
140
141#[async_trait]
142impl MoERouter for DefaultMoERouter {
143 async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
144 if num_experts == 0 {
145 return RouteResult::new(0);
146 }
147
148 let hash = self.hash_input(input);
149 let expert_index = (hash % num_experts as u64) as usize;
150
151 RouteResult::new(expert_index).with_confidence(1.0)
152 }
153
154 fn name(&self) -> &str {
155 "default-hash"
156 }
157}
158
159#[derive(Debug, Clone)]
164pub struct LoadAwareMoERouter {
165 affinity_threshold: f64,
167 fallback: DefaultMoERouter,
169}
170
171impl LoadAwareMoERouter {
172 pub fn new() -> Self {
174 Self {
175 affinity_threshold: 0.1,
176 fallback: DefaultMoERouter::new(),
177 }
178 }
179
180 pub fn with_affinity_threshold(mut self, threshold: f64) -> Self {
184 self.affinity_threshold = threshold.clamp(0.0, 1.0);
185 self
186 }
187}
188
189impl Default for LoadAwareMoERouter {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[async_trait]
196impl MoERouter for LoadAwareMoERouter {
197 async fn route(&self, input: &str, num_experts: usize) -> RouteResult {
198 self.fallback.route(input, num_experts).await
200 }
201
202 async fn route_with_experts(&self, input: &str, experts: &[Expert]) -> RouteResult {
203 let available: Vec<_> = experts.iter().filter(|e| e.available()).collect();
204
205 if available.is_empty() {
206 return RouteResult::new(0);
207 }
208
209 let min_load = available
211 .iter()
212 .map(|e| e.load)
213 .fold(f64::INFINITY, f64::min);
214
215 let affinity_result = self.fallback.route(input, experts.len()).await;
217 let affinity_expert = experts.get(affinity_result.expert_index);
218
219 if let Some(expert) = affinity_expert {
221 if expert.available() && (expert.load - min_load) < self.affinity_threshold {
222 return RouteResult::new(expert.index).with_confidence(0.9);
223 }
224 }
225
226 let selected = available
228 .iter()
229 .min_by(|a, b| a.load.partial_cmp(&b.load).unwrap_or(std::cmp::Ordering::Equal))
230 .expect("available is non-empty, checked above");
231
232 let alternatives: Vec<_> = available
233 .iter()
234 .filter(|e| e.index != selected.index)
235 .take(2)
236 .map(|e| e.index)
237 .collect();
238
239 RouteResult::new(selected.index)
240 .with_confidence(0.8)
241 .with_alternatives(alternatives)
242 }
243
244 fn name(&self) -> &str {
245 "load-aware"
246 }
247}
248
249#[derive(Debug)]
253pub struct RoundRobinMoERouter {
254 counter: std::sync::atomic::AtomicUsize,
255}
256
257impl RoundRobinMoERouter {
258 pub fn new() -> Self {
260 Self {
261 counter: std::sync::atomic::AtomicUsize::new(0),
262 }
263 }
264}
265
266impl Default for RoundRobinMoERouter {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272#[async_trait]
273impl MoERouter for RoundRobinMoERouter {
274 async fn route(&self, _input: &str, num_experts: usize) -> RouteResult {
275 if num_experts == 0 {
276 return RouteResult::new(0);
277 }
278
279 let count = self
280 .counter
281 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
282 let expert_index = count % num_experts;
283
284 RouteResult::new(expert_index)
285 }
286
287 fn name(&self) -> &str {
288 "round-robin"
289 }
290}
291
292pub type BoxedMoERouter = Arc<dyn MoERouter>;
294
295pub fn default_router() -> BoxedMoERouter {
297 Arc::new(DefaultMoERouter::new())
298}
299
300pub fn load_aware_router() -> BoxedMoERouter {
302 Arc::new(LoadAwareMoERouter::new())
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::types::NodeId;
309
310 #[tokio::test]
311 async fn test_default_router_consistency() {
312 let router = DefaultMoERouter::new();
313
314 let result1 = router.route("test-input", 8).await;
315 let result2 = router.route("test-input", 8).await;
316
317 assert_eq!(result1.expert_index, result2.expert_index);
318 }
319
320 #[tokio::test]
321 async fn test_default_router_distribution() {
322 let router = DefaultMoERouter::new();
323 let mut counts = vec![0usize; 4];
324
325 for i in 0..1000 {
326 let input = format!("input-{}", i);
327 let result = router.route(&input, 4).await;
328 counts[result.expert_index] += 1;
329 }
330
331 for count in counts {
333 assert!(count > 150 && count < 350, "Uneven distribution: {}", count);
334 }
335 }
336
337 #[tokio::test]
338 async fn test_load_aware_router() {
339 let router = LoadAwareMoERouter::new();
340
341 let experts = vec![
342 Expert::new(0, NodeId::new()),
343 {
344 let mut e = Expert::new(1, NodeId::new());
345 e.update_load(0.9);
346 e
347 },
348 Expert::new(2, NodeId::new()),
349 ];
350
351 let result = router.route_with_experts("test", &experts).await;
352
353 assert_ne!(result.expert_index, 1);
355 }
356
357 #[tokio::test]
358 async fn test_round_robin_router() {
359 let router = RoundRobinMoERouter::new();
360
361 let r0 = router.route("a", 3).await;
362 let r1 = router.route("b", 3).await;
363 let r2 = router.route("c", 3).await;
364 let r3 = router.route("d", 3).await;
365
366 assert_eq!(r0.expert_index, 0);
367 assert_eq!(r1.expert_index, 1);
368 assert_eq!(r2.expert_index, 2);
369 assert_eq!(r3.expert_index, 0);
370 }
371}