1use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::policy::CompositePolicy;
9use super::traits::*;
10use std::sync::Arc;
11use std::time::Duration;
12
13#[derive(Debug, Clone)]
19pub struct RouteDecision {
20 pub target: RouteTarget,
21 pub alternatives: Vec<RouteTarget>,
22 pub reasoning: String,
23}
24
25#[derive(Debug, Clone)]
31pub struct RouterConfig {
32 pub max_candidates: usize,
34 pub min_score: f64,
36 pub strategy: LoadBalanceStrategy,
38}
39
40impl Default for RouterConfig {
41 fn default() -> Self {
42 Self {
43 max_candidates: 10,
44 min_score: 0.1,
45 strategy: LoadBalanceStrategy::LeastLatency,
46 }
47 }
48}
49
50pub struct Router {
52 config: RouterConfig,
53 catalog: Arc<ModelCatalog>,
54 health: Arc<HealthChecker>,
55 circuit_breaker: Arc<CircuitBreaker>,
56 policy: CompositePolicy,
57}
58
59impl Router {
60 pub fn new(
61 config: RouterConfig,
62 catalog: Arc<ModelCatalog>,
63 health: Arc<HealthChecker>,
64 circuit_breaker: Arc<CircuitBreaker>,
65 ) -> Self {
66 Self {
67 config,
68 catalog,
69 health,
70 circuit_breaker,
71 policy: CompositePolicy::enterprise_default(),
72 }
73 }
74
75 #[must_use]
77 pub fn with_policy(mut self, policy: CompositePolicy) -> Self {
78 self.policy = policy;
79 self
80 }
81
82 fn build_candidates(&self, capability: &Capability) -> Vec<RouteCandidate> {
84 let mut candidates = Vec::new();
88
89 let entries = self.catalog.all_entries();
91
92 for entry in entries {
93 let has_capability = entry.metadata.capabilities.iter().any(|c| c == capability);
95 if !has_capability {
96 continue;
97 }
98
99 for deployment in &entry.deployments {
100 if self.circuit_breaker.is_open(&deployment.node_id) {
102 continue;
103 }
104
105 let health = self
107 .health
108 .get_cached_health(&deployment.node_id)
109 .unwrap_or_else(|| NodeHealth {
110 node_id: deployment.node_id.clone(),
111 status: HealthState::Unknown,
112 latency_p50: Duration::from_secs(1),
113 latency_p99: Duration::from_secs(5),
114 throughput: 0,
115 gpu_utilization: None,
116 queue_depth: 0,
117 last_check: std::time::Instant::now(),
118 });
119
120 if health.status == HealthState::Unhealthy {
122 continue;
123 }
124
125 let target = RouteTarget {
126 node_id: deployment.node_id.clone(),
127 region_id: deployment.region_id.clone(),
128 endpoint: deployment.endpoint.clone(),
129 estimated_latency: health.latency_p50,
130 score: 0.0, };
132
133 let health_score = match health.status {
134 HealthState::Healthy => 1.0,
135 HealthState::Degraded => 0.5,
136 HealthState::Unknown => 0.3,
137 HealthState::Unhealthy => 0.0,
138 };
139
140 let scores = RouteScores {
141 latency_score: 1.0 - (health.latency_p50.as_millis() as f64 / 5000.0).min(1.0),
142 throughput_score: (health.throughput as f64 / 1000.0).min(1.0),
143 cost_score: 0.5, locality_score: 0.5, health_score,
146 total: 0.0,
147 };
148
149 candidates.push(RouteCandidate {
150 target,
151 scores,
152 eligible: true,
153 rejection_reason: None,
154 });
155 }
156 }
157
158 candidates
159 }
160
161 fn rank_candidates(&self, candidates: &mut [RouteCandidate], request: &InferenceRequest) {
163 for candidate in candidates.iter_mut() {
164 if !self.policy.is_eligible(candidate, request) {
166 candidate.eligible = false;
167 candidate.rejection_reason = Some("Policy rejected".to_string());
168 continue;
169 }
170
171 let score = self.policy.score(candidate, request);
173 candidate.target.score = score;
174 candidate.scores.total = score;
175 }
176
177 candidates.sort_by(|a, b| {
179 b.scores
180 .total
181 .partial_cmp(&a.scores.total)
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184 }
185
186 fn select_best(&self, candidates: &[RouteCandidate]) -> Option<RouteCandidate> {
188 let eligible: Vec<_> = candidates
189 .iter()
190 .filter(|c| c.eligible && c.scores.total >= self.config.min_score)
191 .take(self.config.max_candidates)
192 .collect();
193
194 if eligible.is_empty() {
195 return None;
196 }
197
198 match self.config.strategy {
199 LoadBalanceStrategy::LeastLatency => {
200 eligible.first().map(|c| (*c).clone())
202 }
203 LoadBalanceStrategy::LeastConnections => {
204 eligible.first().map(|c| (*c).clone())
206 }
207 LoadBalanceStrategy::RoundRobin => {
208 eligible.first().map(|c| (*c).clone())
210 }
211 LoadBalanceStrategy::WeightedRandom => {
212 use std::collections::hash_map::DefaultHasher;
214 use std::hash::{Hash, Hasher};
215
216 let total_weight: f64 = eligible.iter().map(|c| c.scores.total).sum();
217 if total_weight <= 0.0 {
218 return eligible.first().map(|c| (*c).clone());
219 }
220
221 let mut hasher = DefaultHasher::new();
223 std::time::SystemTime::now()
224 .duration_since(std::time::UNIX_EPOCH)
225 .unwrap_or_default()
226 .as_nanos()
227 .hash(&mut hasher);
228 let random = (hasher.finish() as f64) / (u64::MAX as f64);
229
230 let target = random * total_weight;
231 let mut cumulative = 0.0;
232
233 for candidate in &eligible {
234 cumulative += candidate.scores.total;
235 if cumulative >= target {
236 return Some((*candidate).clone());
237 }
238 }
239
240 eligible.last().map(|c| (*c).clone())
241 }
242 LoadBalanceStrategy::ConsistentHash => {
243 eligible.first().map(|c| (*c).clone())
245 }
246 }
247 }
248}
249
250impl RouterTrait for Router {
251 fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>> {
252 let request = request.clone();
254
255 Box::pin(async move {
256 let mut candidates = self.build_candidates(&request.capability);
257
258 if candidates.is_empty() {
259 return Err(FederationError::NoCapacity(request.capability.clone()));
260 }
261
262 self.rank_candidates(&mut candidates, &request);
263
264 self.select_best(&candidates)
265 .map(|c| c.target)
266 .ok_or_else(|| FederationError::AllNodesUnhealthy(request.capability.clone()))
267 })
268 }
269
270 fn get_candidates(
271 &self,
272 request: &InferenceRequest,
273 ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>> {
274 let request = request.clone();
276
277 Box::pin(async move {
278 let mut candidates = self.build_candidates(&request.capability);
279 self.rank_candidates(&mut candidates, &request);
280 Ok(candidates)
281 })
282 }
283}
284
285pub struct RouterBuilder {
291 config: RouterConfig,
292 catalog: Option<Arc<ModelCatalog>>,
293 health: Option<Arc<HealthChecker>>,
294 circuit_breaker: Option<Arc<CircuitBreaker>>,
295 policy: Option<CompositePolicy>,
296}
297
298impl RouterBuilder {
299 pub fn new() -> Self {
300 Self {
301 config: RouterConfig::default(),
302 catalog: None,
303 health: None,
304 circuit_breaker: None,
305 policy: None,
306 }
307 }
308
309 #[must_use]
310 pub fn config(mut self, config: RouterConfig) -> Self {
311 self.config = config;
312 self
313 }
314
315 #[must_use]
316 pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
317 self.catalog = Some(catalog);
318 self
319 }
320
321 #[must_use]
322 pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
323 self.health = Some(health);
324 self
325 }
326
327 #[must_use]
328 pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
329 self.circuit_breaker = Some(cb);
330 self
331 }
332
333 #[must_use]
334 pub fn policy(mut self, policy: CompositePolicy) -> Self {
335 self.policy = Some(policy);
336 self
337 }
338
339 pub fn build(self) -> Router {
340 let catalog = self
341 .catalog
342 .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
343 let health = self
344 .health
345 .unwrap_or_else(|| Arc::new(HealthChecker::default()));
346 let circuit_breaker = self
347 .circuit_breaker
348 .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
349
350 let router = Router::new(self.config, catalog, health, circuit_breaker);
351
352 if let Some(policy) = self.policy {
353 router.with_policy(policy)
354 } else {
355 router
356 }
357 }
358}
359
360impl Default for RouterBuilder {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(test)]
371mod tests {
372 use super::*;
373
374 fn setup_test_router() -> (Router, Arc<ModelCatalog>, Arc<HealthChecker>) {
375 let catalog = Arc::new(ModelCatalog::new());
376 let health = Arc::new(HealthChecker::default());
377 let circuit_breaker = Arc::new(CircuitBreaker::default());
378
379 let router = Router::new(
380 RouterConfig::default(),
381 Arc::clone(&catalog),
382 Arc::clone(&health),
383 circuit_breaker,
384 );
385
386 (router, catalog, health)
387 }
388
389 #[tokio::test]
390 async fn test_route_no_nodes() {
391 let (router, _, _) = setup_test_router();
392
393 let request = InferenceRequest {
394 capability: Capability::Transcribe,
395 input: vec![],
396 qos: QoSRequirements::default(),
397 request_id: "test-1".to_string(),
398 tenant_id: None,
399 };
400
401 let result = router.route(&request).await;
402 assert!(matches!(result, Err(FederationError::NoCapacity(_))));
403 }
404
405 #[tokio::test]
406 async fn test_route_single_node() {
407 let (router, catalog, health) = setup_test_router();
408
409 catalog
411 .register(
412 ModelId("whisper".to_string()),
413 NodeId("node-1".to_string()),
414 RegionId("us-west".to_string()),
415 vec![Capability::Transcribe],
416 )
417 .await
418 .expect("registration failed");
419
420 health.register_node(NodeId("node-1".to_string()));
421 health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(50));
422
423 let request = InferenceRequest {
424 capability: Capability::Transcribe,
425 input: vec![],
426 qos: QoSRequirements::default(),
427 request_id: "test-2".to_string(),
428 tenant_id: None,
429 };
430
431 let result = router.route(&request).await;
432 assert!(result.is_ok());
433
434 let target = result.expect("routing failed");
435 assert_eq!(target.node_id, NodeId("node-1".to_string()));
436 }
437
438 #[tokio::test]
439 async fn test_route_prefers_healthy() {
440 let (router, catalog, health) = setup_test_router();
441
442 catalog
444 .register(
445 ModelId("llama".to_string()),
446 NodeId("healthy-node".to_string()),
447 RegionId("us-west".to_string()),
448 vec![Capability::Generate],
449 )
450 .await
451 .expect("registration failed");
452
453 catalog
454 .register(
455 ModelId("llama".to_string()),
456 NodeId("degraded-node".to_string()),
457 RegionId("us-east".to_string()),
458 vec![Capability::Generate],
459 )
460 .await
461 .expect("registration failed");
462
463 health.register_node(NodeId("healthy-node".to_string()));
465 health.register_node(NodeId("degraded-node".to_string()));
466
467 for _ in 0..5 {
468 health.report_success(
469 &NodeId("healthy-node".to_string()),
470 Duration::from_millis(20),
471 );
472 health.report_failure(&NodeId("degraded-node".to_string()));
473 }
474
475 let request = InferenceRequest {
476 capability: Capability::Generate,
477 input: vec![],
478 qos: QoSRequirements::default(),
479 request_id: "test-3".to_string(),
480 tenant_id: None,
481 };
482
483 let result = router.route(&request).await;
484 assert!(result.is_ok());
485
486 let target = result.expect("routing failed");
487 assert_eq!(target.node_id, NodeId("healthy-node".to_string()));
488 }
489
490 #[tokio::test]
491 async fn test_get_candidates_returns_all() {
492 let (router, catalog, health) = setup_test_router();
493
494 for i in 0..3 {
496 catalog
497 .register(
498 ModelId("embed".to_string()),
499 NodeId(format!("node-{}", i)),
500 RegionId("us-west".to_string()),
501 vec![Capability::Embed],
502 )
503 .await
504 .expect("registration failed");
505
506 health.register_node(NodeId(format!("node-{}", i)));
507 health.report_success(&NodeId(format!("node-{}", i)), Duration::from_millis(50));
508 }
509
510 let request = InferenceRequest {
511 capability: Capability::Embed,
512 input: vec![],
513 qos: QoSRequirements::default(),
514 request_id: "test-4".to_string(),
515 tenant_id: None,
516 };
517
518 let candidates = router
519 .get_candidates(&request)
520 .await
521 .expect("get_candidates failed");
522
523 assert_eq!(candidates.len(), 3);
524 }
525
526 #[test]
527 fn test_router_builder() {
528 let router = RouterBuilder::new()
529 .config(RouterConfig {
530 max_candidates: 5,
531 min_score: 0.2,
532 strategy: LoadBalanceStrategy::RoundRobin,
533 })
534 .build();
535
536 assert_eq!(router.config.max_candidates, 5);
537 assert_eq!(router.config.min_score, 0.2);
538 }
539}