1use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::policy::CompositePolicy;
9use super::traits::*;
10use std::sync::Arc;
11use std::time::Duration;
12
13#[derive(Debug, Clone)]
19pub struct RouteDecision {
20 pub target: RouteTarget,
21 pub alternatives: Vec<RouteTarget>,
22 pub reasoning: String,
23}
24
25#[derive(Debug, Clone)]
31pub struct RouterConfig {
32 pub max_candidates: usize,
34 pub min_score: f64,
36 pub strategy: LoadBalanceStrategy,
38}
39
40impl Default for RouterConfig {
41 fn default() -> Self {
42 Self {
43 max_candidates: 10,
44 min_score: 0.1,
45 strategy: LoadBalanceStrategy::LeastLatency,
46 }
47 }
48}
49
50pub struct Router {
52 config: RouterConfig,
53 catalog: Arc<ModelCatalog>,
54 health: Arc<HealthChecker>,
55 circuit_breaker: Arc<CircuitBreaker>,
56 policy: CompositePolicy,
57}
58
59impl Router {
60 pub fn new(
61 config: RouterConfig,
62 catalog: Arc<ModelCatalog>,
63 health: Arc<HealthChecker>,
64 circuit_breaker: Arc<CircuitBreaker>,
65 ) -> Self {
66 Self {
67 config,
68 catalog,
69 health,
70 circuit_breaker,
71 policy: CompositePolicy::enterprise_default(),
72 }
73 }
74
75 #[must_use]
77 pub fn with_policy(mut self, policy: CompositePolicy) -> Self {
78 self.policy = policy;
79 self
80 }
81
82 fn build_candidates(&self, capability: &Capability) -> Vec<RouteCandidate> {
84 let mut candidates = Vec::new();
88
89 let entries = self.catalog.all_entries();
91
92 for entry in entries {
93 let has_capability = entry.metadata.capabilities.iter().any(|c| c == capability);
95 if !has_capability {
96 continue;
97 }
98
99 for deployment in &entry.deployments {
100 if self.circuit_breaker.is_open(&deployment.node_id) {
102 continue;
103 }
104
105 let health = self
107 .health
108 .get_cached_health(&deployment.node_id)
109 .unwrap_or_else(|| NodeHealth {
110 node_id: deployment.node_id.clone(),
111 status: HealthState::Unknown,
112 latency_p50: Duration::from_secs(1),
113 latency_p99: Duration::from_secs(5),
114 throughput: 0,
115 gpu_utilization: None,
116 queue_depth: 0,
117 last_check: std::time::Instant::now(),
118 });
119
120 if health.status == HealthState::Unhealthy {
122 continue;
123 }
124
125 let target = RouteTarget {
126 node_id: deployment.node_id.clone(),
127 region_id: deployment.region_id.clone(),
128 endpoint: deployment.endpoint.clone(),
129 estimated_latency: health.latency_p50,
130 score: 0.0, };
132
133 let health_score = match health.status {
134 HealthState::Healthy => 1.0,
135 HealthState::Degraded => 0.5,
136 HealthState::Unknown => 0.3,
137 HealthState::Unhealthy => 0.0,
138 };
139
140 let scores = RouteScores {
141 latency_score: 1.0 - (health.latency_p50.as_millis() as f64 / 5000.0).min(1.0),
142 throughput_score: (health.throughput as f64 / 1000.0).min(1.0),
143 cost_score: 0.5, locality_score: 0.5, health_score,
146 total: 0.0,
147 };
148
149 candidates.push(RouteCandidate {
150 target,
151 scores,
152 eligible: true,
153 rejection_reason: None,
154 });
155 }
156 }
157
158 candidates
159 }
160
161 fn rank_candidates(&self, candidates: &mut [RouteCandidate], request: &InferenceRequest) {
163 for candidate in candidates.iter_mut() {
164 if !self.policy.is_eligible(candidate, request) {
166 candidate.eligible = false;
167 candidate.rejection_reason = Some("Policy rejected".to_string());
168 continue;
169 }
170
171 let score = self.policy.score(candidate, request);
173 candidate.target.score = score;
174 candidate.scores.total = score;
175 }
176
177 candidates.sort_by(|a, b| {
179 b.scores
180 .total
181 .partial_cmp(&a.scores.total)
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184 }
185
186 fn select_best(&self, candidates: &[RouteCandidate]) -> Option<RouteCandidate> {
188 let eligible: Vec<_> = candidates
189 .iter()
190 .filter(|c| c.eligible && c.scores.total >= self.config.min_score)
191 .take(self.config.max_candidates)
192 .collect();
193
194 if eligible.is_empty() {
195 return None;
196 }
197
198 match self.config.strategy {
199 LoadBalanceStrategy::LeastLatency => {
200 eligible.first().map(|c| (*c).clone())
202 }
203 LoadBalanceStrategy::LeastConnections => {
204 eligible.first().map(|c| (*c).clone())
206 }
207 LoadBalanceStrategy::RoundRobin => {
208 eligible.first().map(|c| (*c).clone())
210 }
211 LoadBalanceStrategy::WeightedRandom => {
212 use std::collections::hash_map::DefaultHasher;
214 use std::hash::{Hash, Hasher};
215
216 let total_weight: f64 = eligible.iter().map(|c| c.scores.total).sum();
217 if total_weight <= 0.0 {
218 return eligible.first().map(|c| (*c).clone());
219 }
220
221 let mut hasher = DefaultHasher::new();
223 std::time::SystemTime::now()
224 .duration_since(std::time::UNIX_EPOCH)
225 .unwrap_or_default()
226 .as_nanos()
227 .hash(&mut hasher);
228 let random = (hasher.finish() as f64) / (u64::MAX as f64);
229
230 let target = random * total_weight;
231 let mut cumulative = 0.0;
232
233 for candidate in &eligible {
234 cumulative += candidate.scores.total;
235 if cumulative >= target {
236 return Some((*candidate).clone());
237 }
238 }
239
240 eligible.last().map(|c| (*c).clone())
241 }
242 LoadBalanceStrategy::ConsistentHash => {
243 eligible.first().map(|c| (*c).clone())
245 }
246 }
247 }
248}
249
250impl RouterTrait for Router {
251 fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>> {
252 let request = request.clone();
254
255 Box::pin(async move {
256 let mut candidates = self.build_candidates(&request.capability);
257
258 if candidates.is_empty() {
259 return Err(FederationError::NoCapacity(request.capability.clone()));
260 }
261
262 self.rank_candidates(&mut candidates, &request);
263
264 self.select_best(&candidates)
265 .map(|c| c.target)
266 .ok_or_else(|| FederationError::AllNodesUnhealthy(request.capability.clone()))
267 })
268 }
269
270 fn get_candidates(
271 &self,
272 request: &InferenceRequest,
273 ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>> {
274 let request = request.clone();
276
277 Box::pin(async move {
278 let mut candidates = self.build_candidates(&request.capability);
279 self.rank_candidates(&mut candidates, &request);
280 Ok(candidates)
281 })
282 }
283}
284
285pub struct RouterBuilder {
291 config: RouterConfig,
292 catalog: Option<Arc<ModelCatalog>>,
293 health: Option<Arc<HealthChecker>>,
294 circuit_breaker: Option<Arc<CircuitBreaker>>,
295 policy: Option<CompositePolicy>,
296}
297
298impl RouterBuilder {
299 pub fn new() -> Self {
300 Self {
301 config: RouterConfig::default(),
302 catalog: None,
303 health: None,
304 circuit_breaker: None,
305 policy: None,
306 }
307 }
308
309 #[must_use]
310 pub fn config(mut self, config: RouterConfig) -> Self {
311 self.config = config;
312 self
313 }
314
315 #[must_use]
316 pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
317 self.catalog = Some(catalog);
318 self
319 }
320
321 #[must_use]
322 pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
323 self.health = Some(health);
324 self
325 }
326
327 #[must_use]
328 pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
329 self.circuit_breaker = Some(cb);
330 self
331 }
332
333 #[must_use]
334 pub fn policy(mut self, policy: CompositePolicy) -> Self {
335 self.policy = Some(policy);
336 self
337 }
338
339 pub fn build(self) -> Router {
340 let catalog = self
341 .catalog
342 .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
343 let health = self
344 .health
345 .unwrap_or_else(|| Arc::new(HealthChecker::default()));
346 let circuit_breaker = self
347 .circuit_breaker
348 .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
349
350 let router = Router::new(self.config, catalog, health, circuit_breaker);
351
352 if let Some(policy) = self.policy {
353 router.with_policy(policy)
354 } else {
355 router
356 }
357 }
358}
359
360impl Default for RouterBuilder {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(test)]
371mod tests {
372 use super::*;
373
374 fn setup_test_router() -> (Router, Arc<ModelCatalog>, Arc<HealthChecker>) {
375 let catalog = Arc::new(ModelCatalog::new());
376 let health = Arc::new(HealthChecker::default());
377 let circuit_breaker = Arc::new(CircuitBreaker::default());
378
379 let router = Router::new(
380 RouterConfig::default(),
381 Arc::clone(&catalog),
382 Arc::clone(&health),
383 circuit_breaker,
384 );
385
386 (router, catalog, health)
387 }
388
389 #[tokio::test]
390 async fn test_route_no_nodes() {
391 let (router, _, _) = setup_test_router();
392
393 let request = InferenceRequest {
394 capability: Capability::Transcribe,
395 input: vec![],
396 qos: QoSRequirements::default(),
397 request_id: "test-1".to_string(),
398 tenant_id: None,
399 };
400
401 let result = router.route(&request).await;
402 assert!(matches!(result, Err(FederationError::NoCapacity(_))));
403 }
404
405 #[tokio::test]
406 async fn test_route_single_node() {
407 let (router, catalog, health) = setup_test_router();
408
409 catalog
411 .register(
412 ModelId("whisper".to_string()),
413 NodeId("node-1".to_string()),
414 RegionId("us-west".to_string()),
415 vec![Capability::Transcribe],
416 )
417 .await
418 .expect("registration failed");
419
420 health.register_node(NodeId("node-1".to_string()));
421 health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(50));
422
423 let request = InferenceRequest {
424 capability: Capability::Transcribe,
425 input: vec![],
426 qos: QoSRequirements::default(),
427 request_id: "test-2".to_string(),
428 tenant_id: None,
429 };
430
431 let result = router.route(&request).await;
432 assert!(result.is_ok());
433
434 let target = result.expect("routing failed");
435 assert_eq!(target.node_id, NodeId("node-1".to_string()));
436 }
437
438 #[tokio::test]
439 async fn test_route_prefers_healthy() {
440 let (router, catalog, health) = setup_test_router();
441
442 catalog
444 .register(
445 ModelId("llama".to_string()),
446 NodeId("healthy-node".to_string()),
447 RegionId("us-west".to_string()),
448 vec![Capability::Generate],
449 )
450 .await
451 .expect("registration failed");
452
453 catalog
454 .register(
455 ModelId("llama".to_string()),
456 NodeId("degraded-node".to_string()),
457 RegionId("us-east".to_string()),
458 vec![Capability::Generate],
459 )
460 .await
461 .expect("registration failed");
462
463 health.register_node(NodeId("healthy-node".to_string()));
465 health.register_node(NodeId("degraded-node".to_string()));
466
467 for _ in 0..5 {
468 health.report_success(
469 &NodeId("healthy-node".to_string()),
470 Duration::from_millis(20),
471 );
472 health.report_failure(&NodeId("degraded-node".to_string()));
473 }
474
475 let request = InferenceRequest {
476 capability: Capability::Generate,
477 input: vec![],
478 qos: QoSRequirements::default(),
479 request_id: "test-3".to_string(),
480 tenant_id: None,
481 };
482
483 let result = router.route(&request).await;
484 assert!(result.is_ok());
485
486 let target = result.expect("routing failed");
487 assert_eq!(target.node_id, NodeId("healthy-node".to_string()));
488 }
489
490 #[tokio::test]
491 async fn test_get_candidates_returns_all() {
492 let (router, catalog, health) = setup_test_router();
493
494 for i in 0..3 {
496 catalog
497 .register(
498 ModelId("embed".to_string()),
499 NodeId(format!("node-{}", i)),
500 RegionId("us-west".to_string()),
501 vec![Capability::Embed],
502 )
503 .await
504 .expect("registration failed");
505
506 health.register_node(NodeId(format!("node-{}", i)));
507 health.report_success(&NodeId(format!("node-{}", i)), Duration::from_millis(50));
508 }
509
510 let request = InferenceRequest {
511 capability: Capability::Embed,
512 input: vec![],
513 qos: QoSRequirements::default(),
514 request_id: "test-4".to_string(),
515 tenant_id: None,
516 };
517
518 let candidates = router
519 .get_candidates(&request)
520 .await
521 .expect("get_candidates failed");
522
523 assert_eq!(candidates.len(), 3);
524 }
525
526 #[test]
527 fn test_router_builder() {
528 let router = RouterBuilder::new()
529 .config(RouterConfig {
530 max_candidates: 5,
531 min_score: 0.2,
532 strategy: LoadBalanceStrategy::RoundRobin,
533 })
534 .build();
535
536 assert_eq!(router.config.max_candidates, 5);
537 assert_eq!(router.config.min_score, 0.2);
538 }
539
540 #[test]
545 fn test_router_config_default() {
546 let config = RouterConfig::default();
547 assert_eq!(config.max_candidates, 10);
548 assert_eq!(config.min_score, 0.1);
549 assert!(matches!(config.strategy, LoadBalanceStrategy::LeastLatency));
550 }
551
552 #[test]
553 fn test_router_config_clone() {
554 let config = RouterConfig {
555 max_candidates: 20,
556 min_score: 0.5,
557 strategy: LoadBalanceStrategy::ConsistentHash,
558 };
559 let cloned = config.clone();
560 assert_eq!(cloned.max_candidates, 20);
561 assert_eq!(cloned.min_score, 0.5);
562 }
563
564 #[test]
569 fn test_route_decision_construction() {
570 let decision = RouteDecision {
571 target: RouteTarget {
572 node_id: NodeId("n1".to_string()),
573 region_id: RegionId("r1".to_string()),
574 endpoint: "http://n1:8080".to_string(),
575 estimated_latency: Duration::from_millis(50),
576 score: 0.95,
577 },
578 alternatives: vec![],
579 reasoning: "best latency score".to_string(),
580 };
581 assert_eq!(decision.target.score, 0.95);
582 assert!(decision.alternatives.is_empty());
583 }
584
585 #[test]
586 fn test_route_decision_with_alternatives() {
587 let primary = RouteTarget {
588 node_id: NodeId("n1".to_string()),
589 region_id: RegionId("us-west".to_string()),
590 endpoint: "http://n1:8080".to_string(),
591 estimated_latency: Duration::from_millis(50),
592 score: 0.95,
593 };
594 let alt = RouteTarget {
595 node_id: NodeId("n2".to_string()),
596 region_id: RegionId("eu-west".to_string()),
597 endpoint: "http://n2:8080".to_string(),
598 estimated_latency: Duration::from_millis(120),
599 score: 0.7,
600 };
601 let decision = RouteDecision {
602 target: primary,
603 alternatives: vec![alt],
604 reasoning: "latency-based".to_string(),
605 };
606 assert_eq!(decision.alternatives.len(), 1);
607 }
608
609 #[test]
614 fn test_router_with_custom_policy() {
615 let catalog = Arc::new(ModelCatalog::new());
616 let health = Arc::new(HealthChecker::default());
617 let circuit_breaker = Arc::new(CircuitBreaker::default());
618
619 let custom_policy =
620 CompositePolicy::new().with_policy(super::super::policy::LatencyPolicy::default());
621
622 let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker)
623 .with_policy(custom_policy);
624
625 assert_eq!(router.config.max_candidates, 10);
626 }
627
628 #[test]
633 fn test_router_builder_default() {
634 let builder = RouterBuilder::default();
635 assert!(builder.catalog.is_none());
636 assert!(builder.health.is_none());
637 assert!(builder.circuit_breaker.is_none());
638 assert!(builder.policy.is_none());
639 }
640
641 #[test]
642 fn test_router_builder_with_catalog() {
643 let catalog = Arc::new(ModelCatalog::new());
644 let builder = RouterBuilder::new().catalog(Arc::clone(&catalog));
645 assert!(builder.catalog.is_some());
646 }
647
648 #[test]
649 fn test_router_builder_with_health() {
650 let health = Arc::new(HealthChecker::default());
651 let builder = RouterBuilder::new().health(health);
652 assert!(builder.health.is_some());
653 }
654
655 #[test]
656 fn test_router_builder_with_circuit_breaker() {
657 let cb = Arc::new(CircuitBreaker::default());
658 let builder = RouterBuilder::new().circuit_breaker(cb);
659 assert!(builder.circuit_breaker.is_some());
660 }
661
662 #[test]
663 fn test_router_builder_with_policy() {
664 let policy = CompositePolicy::new();
665 let builder = RouterBuilder::new().policy(policy);
666 assert!(builder.policy.is_some());
667 }
668
669 #[test]
670 fn test_router_builder_full_chain() {
671 let catalog = Arc::new(ModelCatalog::new());
672 let health = Arc::new(HealthChecker::default());
673 let cb = Arc::new(CircuitBreaker::default());
674
675 let router = RouterBuilder::new()
676 .config(RouterConfig {
677 max_candidates: 20,
678 min_score: 0.05,
679 strategy: LoadBalanceStrategy::WeightedRandom,
680 })
681 .catalog(catalog)
682 .health(health)
683 .circuit_breaker(cb)
684 .policy(CompositePolicy::enterprise_default())
685 .build();
686
687 assert_eq!(router.config.max_candidates, 20);
688 assert_eq!(router.config.min_score, 0.05);
689 }
690
691 #[test]
692 fn test_router_builder_build_without_policy() {
693 let router = RouterBuilder::new().build();
694 assert_eq!(router.config.max_candidates, 10);
696 }
697
698 #[tokio::test]
703 async fn test_build_candidates_skips_circuit_open() {
704 let catalog = Arc::new(ModelCatalog::new());
705 let health = Arc::new(HealthChecker::default());
706 let circuit_breaker = Arc::new(CircuitBreaker::default());
707
708 catalog
710 .register(
711 ModelId("model".to_string()),
712 NodeId("n1".to_string()),
713 RegionId("us-west".to_string()),
714 vec![Capability::Generate],
715 )
716 .await
717 .expect("registration failed");
718
719 health.register_node(NodeId("n1".to_string()));
720 for _ in 0..3 {
721 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
722 }
723
724 for _ in 0..5 {
726 circuit_breaker.record_failure(&NodeId("n1".to_string()));
727 }
728
729 let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker);
730
731 let request = InferenceRequest {
732 capability: Capability::Generate,
733 input: vec![],
734 qos: QoSRequirements::default(),
735 request_id: "test".to_string(),
736 tenant_id: None,
737 };
738
739 let result = router.route(&request).await;
741 assert!(matches!(result, Err(FederationError::NoCapacity(_))));
742 }
743
744 #[tokio::test]
745 async fn test_build_candidates_skips_unhealthy() {
746 let catalog = Arc::new(ModelCatalog::new());
747 let health = Arc::new(HealthChecker::new(super::super::health::HealthConfig {
748 failure_threshold: 2,
749 ..Default::default()
750 }));
751 let circuit_breaker = Arc::new(CircuitBreaker::default());
752
753 catalog
754 .register(
755 ModelId("model".to_string()),
756 NodeId("unhealthy-node".to_string()),
757 RegionId("us-west".to_string()),
758 vec![Capability::Generate],
759 )
760 .await
761 .expect("registration failed");
762
763 health.register_node(NodeId("unhealthy-node".to_string()));
764 for _ in 0..3 {
766 health.report_failure(&NodeId("unhealthy-node".to_string()));
767 }
768
769 let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker);
770
771 let request = InferenceRequest {
772 capability: Capability::Generate,
773 input: vec![],
774 qos: QoSRequirements::default(),
775 request_id: "test".to_string(),
776 tenant_id: None,
777 };
778
779 let result = router.route(&request).await;
781 assert!(result.is_err());
782 }
783
784 #[tokio::test]
785 async fn test_build_candidates_wrong_capability() {
786 let catalog = Arc::new(ModelCatalog::new());
787 let health = Arc::new(HealthChecker::default());
788 let circuit_breaker = Arc::new(CircuitBreaker::default());
789
790 catalog
791 .register(
792 ModelId("whisper".to_string()),
793 NodeId("n1".to_string()),
794 RegionId("us-west".to_string()),
795 vec![Capability::Transcribe],
796 )
797 .await
798 .expect("registration failed");
799
800 let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker);
801
802 let request = InferenceRequest {
803 capability: Capability::Generate, input: vec![],
805 qos: QoSRequirements::default(),
806 request_id: "test".to_string(),
807 tenant_id: None,
808 };
809
810 let result = router.route(&request).await;
811 assert!(matches!(result, Err(FederationError::NoCapacity(_))));
812 }
813
814 #[tokio::test]
819 async fn test_route_with_round_robin_strategy() {
820 let catalog = Arc::new(ModelCatalog::new());
821 let health = Arc::new(HealthChecker::default());
822 let circuit_breaker = Arc::new(CircuitBreaker::default());
823
824 catalog
825 .register(
826 ModelId("model".to_string()),
827 NodeId("n1".to_string()),
828 RegionId("us-west".to_string()),
829 vec![Capability::Generate],
830 )
831 .await
832 .expect("registration failed");
833
834 health.register_node(NodeId("n1".to_string()));
835 for _ in 0..3 {
836 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
837 }
838
839 let router = Router::new(
840 RouterConfig {
841 strategy: LoadBalanceStrategy::RoundRobin,
842 ..Default::default()
843 },
844 catalog,
845 health,
846 circuit_breaker,
847 );
848
849 let request = InferenceRequest {
850 capability: Capability::Generate,
851 input: vec![],
852 qos: QoSRequirements::default(),
853 request_id: "test".to_string(),
854 tenant_id: None,
855 };
856
857 let result = router.route(&request).await;
858 assert!(result.is_ok());
859 }
860
861 #[tokio::test]
862 async fn test_route_with_consistent_hash_strategy() {
863 let catalog = Arc::new(ModelCatalog::new());
864 let health = Arc::new(HealthChecker::default());
865 let circuit_breaker = Arc::new(CircuitBreaker::default());
866
867 catalog
868 .register(
869 ModelId("model".to_string()),
870 NodeId("n1".to_string()),
871 RegionId("us-west".to_string()),
872 vec![Capability::Embed],
873 )
874 .await
875 .expect("registration failed");
876
877 health.register_node(NodeId("n1".to_string()));
878 for _ in 0..3 {
879 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
880 }
881
882 let router = Router::new(
883 RouterConfig {
884 strategy: LoadBalanceStrategy::ConsistentHash,
885 ..Default::default()
886 },
887 catalog,
888 health,
889 circuit_breaker,
890 );
891
892 let request = InferenceRequest {
893 capability: Capability::Embed,
894 input: vec![],
895 qos: QoSRequirements::default(),
896 request_id: "test".to_string(),
897 tenant_id: None,
898 };
899
900 let result = router.route(&request).await;
901 assert!(result.is_ok());
902 }
903
904 #[tokio::test]
905 async fn test_route_with_weighted_random_strategy() {
906 let catalog = Arc::new(ModelCatalog::new());
907 let health = Arc::new(HealthChecker::default());
908 let circuit_breaker = Arc::new(CircuitBreaker::default());
909
910 catalog
911 .register(
912 ModelId("model".to_string()),
913 NodeId("n1".to_string()),
914 RegionId("us-west".to_string()),
915 vec![Capability::Generate],
916 )
917 .await
918 .expect("registration failed");
919
920 health.register_node(NodeId("n1".to_string()));
921 for _ in 0..3 {
922 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
923 }
924
925 let router = Router::new(
926 RouterConfig {
927 strategy: LoadBalanceStrategy::WeightedRandom,
928 ..Default::default()
929 },
930 catalog,
931 health,
932 circuit_breaker,
933 );
934
935 let request = InferenceRequest {
936 capability: Capability::Generate,
937 input: vec![],
938 qos: QoSRequirements::default(),
939 request_id: "test".to_string(),
940 tenant_id: None,
941 };
942
943 let result = router.route(&request).await;
944 assert!(result.is_ok());
945 }
946
947 #[tokio::test]
948 async fn test_route_with_least_connections_strategy() {
949 let catalog = Arc::new(ModelCatalog::new());
950 let health = Arc::new(HealthChecker::default());
951 let circuit_breaker = Arc::new(CircuitBreaker::default());
952
953 catalog
954 .register(
955 ModelId("model".to_string()),
956 NodeId("n1".to_string()),
957 RegionId("us-west".to_string()),
958 vec![Capability::Generate],
959 )
960 .await
961 .expect("registration failed");
962
963 health.register_node(NodeId("n1".to_string()));
964 for _ in 0..3 {
965 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
966 }
967
968 let router = Router::new(
969 RouterConfig {
970 strategy: LoadBalanceStrategy::LeastConnections,
971 ..Default::default()
972 },
973 catalog,
974 health,
975 circuit_breaker,
976 );
977
978 let request = InferenceRequest {
979 capability: Capability::Generate,
980 input: vec![],
981 qos: QoSRequirements::default(),
982 request_id: "test".to_string(),
983 tenant_id: None,
984 };
985
986 let result = router.route(&request).await;
987 assert!(result.is_ok());
988 }
989
990 #[tokio::test]
995 async fn test_get_candidates_with_rejected() {
996 let catalog = Arc::new(ModelCatalog::new());
997 let health = Arc::new(HealthChecker::default());
998 let circuit_breaker = Arc::new(CircuitBreaker::default());
999
1000 catalog
1002 .register(
1003 ModelId("model".to_string()),
1004 NodeId("n1".to_string()),
1005 RegionId("us-west".to_string()),
1006 vec![Capability::Generate],
1007 )
1008 .await
1009 .expect("registration failed");
1010
1011 health.register_node(NodeId("n1".to_string()));
1012 for _ in 0..3 {
1013 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
1014 }
1015
1016 let policy = CompositePolicy::new().with_policy(
1018 super::super::policy::PrivacyPolicy::default()
1019 .with_region(RegionId("us-west".to_string()), PrivacyLevel::Public),
1020 );
1021
1022 let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker)
1023 .with_policy(policy);
1024
1025 let request = InferenceRequest {
1027 capability: Capability::Generate,
1028 input: vec![],
1029 qos: QoSRequirements {
1030 privacy: PrivacyLevel::Confidential,
1031 ..Default::default()
1032 },
1033 request_id: "test".to_string(),
1034 tenant_id: None,
1035 };
1036
1037 let candidates = router
1038 .get_candidates(&request)
1039 .await
1040 .expect("get_candidates failed");
1041 assert!(!candidates.is_empty());
1042 assert!(!candidates[0].eligible);
1044 assert!(candidates[0].rejection_reason.is_some());
1045 }
1046
1047 #[tokio::test]
1052 async fn test_route_min_score_filters_candidates() {
1053 let catalog = Arc::new(ModelCatalog::new());
1054 let health = Arc::new(HealthChecker::default());
1055 let circuit_breaker = Arc::new(CircuitBreaker::default());
1056
1057 catalog
1058 .register(
1059 ModelId("model".to_string()),
1060 NodeId("n1".to_string()),
1061 RegionId("us-west".to_string()),
1062 vec![Capability::Generate],
1063 )
1064 .await
1065 .expect("registration failed");
1066
1067 health.register_node(NodeId("n1".to_string()));
1068 let router = Router::new(
1071 RouterConfig {
1072 min_score: 0.99, ..Default::default()
1074 },
1075 catalog,
1076 health,
1077 circuit_breaker,
1078 );
1079
1080 let request = InferenceRequest {
1081 capability: Capability::Generate,
1082 input: vec![],
1083 qos: QoSRequirements::default(),
1084 request_id: "test".to_string(),
1085 tenant_id: None,
1086 };
1087
1088 let result = router.route(&request).await;
1089 assert!(result.is_err());
1092 }
1093}