Skip to main content

apr_cli/federation/
routing.rs

1//! Router - Intelligent node selection for inference requests
2//!
3//! Combines catalog, health, and policy data to select the optimal
4//! node for each request.
5
6use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::policy::CompositePolicy;
9use super::traits::*;
10use std::sync::Arc;
11use std::time::Duration;
12
13// ============================================================================
14// Route Decision (exported type)
15// ============================================================================
16
17/// Final routing decision with reasoning
18#[derive(Debug, Clone)]
19pub struct RouteDecision {
20    pub target: RouteTarget,
21    pub alternatives: Vec<RouteTarget>,
22    pub reasoning: String,
23}
24
25// ============================================================================
26// Router Implementation
27// ============================================================================
28
29/// Configuration for the router
30#[derive(Debug, Clone)]
31pub struct RouterConfig {
32    /// Maximum candidates to evaluate
33    pub max_candidates: usize,
34    /// Minimum score to be considered
35    pub min_score: f64,
36    /// Load balance strategy
37    pub strategy: LoadBalanceStrategy,
38}
39
40impl Default for RouterConfig {
41    fn default() -> Self {
42        Self {
43            max_candidates: 10,
44            min_score: 0.1,
45            strategy: LoadBalanceStrategy::LeastLatency,
46        }
47    }
48}
49
50/// The main router implementation
51pub struct Router {
52    config: RouterConfig,
53    catalog: Arc<ModelCatalog>,
54    health: Arc<HealthChecker>,
55    circuit_breaker: Arc<CircuitBreaker>,
56    policy: CompositePolicy,
57}
58
59impl Router {
60    pub fn new(
61        config: RouterConfig,
62        catalog: Arc<ModelCatalog>,
63        health: Arc<HealthChecker>,
64        circuit_breaker: Arc<CircuitBreaker>,
65    ) -> Self {
66        Self {
67            config,
68            catalog,
69            health,
70            circuit_breaker,
71            policy: CompositePolicy::enterprise_default(),
72        }
73    }
74
75    /// Create router with custom policy
76    #[must_use]
77    pub fn with_policy(mut self, policy: CompositePolicy) -> Self {
78        self.policy = policy;
79        self
80    }
81
82    /// Build candidates from catalog and health data
83    fn build_candidates(&self, capability: &Capability) -> Vec<RouteCandidate> {
84        // This would be async in production, but for simplicity we use sync here
85        // The actual routing is async via the trait
86
87        let mut candidates = Vec::new();
88
89        // Get nodes from catalog that support this capability
90        let entries = self.catalog.all_entries();
91
92        for entry in entries {
93            // Check if model supports the capability
94            let has_capability = entry.metadata.capabilities.iter().any(|c| c == capability);
95            if !has_capability {
96                continue;
97            }
98
99            for deployment in &entry.deployments {
100                // Check circuit breaker
101                if self.circuit_breaker.is_open(&deployment.node_id) {
102                    continue;
103                }
104
105                // Get health status
106                let health = self
107                    .health
108                    .get_cached_health(&deployment.node_id)
109                    .unwrap_or_else(|| NodeHealth {
110                        node_id: deployment.node_id.clone(),
111                        status: HealthState::Unknown,
112                        latency_p50: Duration::from_secs(1),
113                        latency_p99: Duration::from_secs(5),
114                        throughput: 0,
115                        gpu_utilization: None,
116                        queue_depth: 0,
117                        last_check: std::time::Instant::now(),
118                    });
119
120                // Skip unhealthy nodes
121                if health.status == HealthState::Unhealthy {
122                    continue;
123                }
124
125                let target = RouteTarget {
126                    node_id: deployment.node_id.clone(),
127                    region_id: deployment.region_id.clone(),
128                    endpoint: deployment.endpoint.clone(),
129                    estimated_latency: health.latency_p50,
130                    score: 0.0, // Will be calculated by policy
131                };
132
133                let health_score = match health.status {
134                    HealthState::Healthy => 1.0,
135                    HealthState::Degraded => 0.5,
136                    HealthState::Unknown => 0.3,
137                    HealthState::Unhealthy => 0.0,
138                };
139
140                let scores = RouteScores {
141                    latency_score: 1.0 - (health.latency_p50.as_millis() as f64 / 5000.0).min(1.0),
142                    throughput_score: (health.throughput as f64 / 1000.0).min(1.0),
143                    cost_score: 0.5,     // Would come from region pricing
144                    locality_score: 0.5, // Would be calculated based on request origin
145                    health_score,
146                    total: 0.0,
147                };
148
149                candidates.push(RouteCandidate {
150                    target,
151                    scores,
152                    eligible: true,
153                    rejection_reason: None,
154                });
155            }
156        }
157
158        candidates
159    }
160
161    /// Score and rank candidates
162    fn rank_candidates(&self, candidates: &mut [RouteCandidate], request: &InferenceRequest) {
163        for candidate in candidates.iter_mut() {
164            // Check eligibility
165            if !self.policy.is_eligible(candidate, request) {
166                candidate.eligible = false;
167                candidate.rejection_reason = Some("Policy rejected".to_string());
168                continue;
169            }
170
171            // Calculate score
172            let score = self.policy.score(candidate, request);
173            candidate.target.score = score;
174            candidate.scores.total = score;
175        }
176
177        // Sort by score descending
178        candidates.sort_by(|a, b| {
179            b.scores
180                .total
181                .partial_cmp(&a.scores.total)
182                .unwrap_or(std::cmp::Ordering::Equal)
183        });
184    }
185
186    /// Select best candidate based on strategy
187    fn select_best(&self, candidates: &[RouteCandidate]) -> Option<RouteCandidate> {
188        let eligible: Vec<_> = candidates
189            .iter()
190            .filter(|c| c.eligible && c.scores.total >= self.config.min_score)
191            .take(self.config.max_candidates)
192            .collect();
193
194        if eligible.is_empty() {
195            return None;
196        }
197
198        match self.config.strategy {
199            LoadBalanceStrategy::LeastLatency => {
200                // Already sorted by score (which factors in latency)
201                eligible.first().map(|c| (*c).clone())
202            }
203            LoadBalanceStrategy::LeastConnections => {
204                // Would use queue_depth in production
205                eligible.first().map(|c| (*c).clone())
206            }
207            LoadBalanceStrategy::RoundRobin => {
208                // Would maintain counter state
209                eligible.first().map(|c| (*c).clone())
210            }
211            LoadBalanceStrategy::WeightedRandom => {
212                // Weighted random selection based on scores
213                use std::collections::hash_map::DefaultHasher;
214                use std::hash::{Hash, Hasher};
215
216                let total_weight: f64 = eligible.iter().map(|c| c.scores.total).sum();
217                if total_weight <= 0.0 {
218                    return eligible.first().map(|c| (*c).clone());
219                }
220
221                // Simple pseudo-random using current time
222                let mut hasher = DefaultHasher::new();
223                std::time::SystemTime::now()
224                    .duration_since(std::time::UNIX_EPOCH)
225                    .unwrap_or_default()
226                    .as_nanos()
227                    .hash(&mut hasher);
228                let random = (hasher.finish() as f64) / (u64::MAX as f64);
229
230                let target = random * total_weight;
231                let mut cumulative = 0.0;
232
233                for candidate in &eligible {
234                    cumulative += candidate.scores.total;
235                    if cumulative >= target {
236                        return Some((*candidate).clone());
237                    }
238                }
239
240                eligible.last().map(|c| (*c).clone())
241            }
242            LoadBalanceStrategy::ConsistentHash => {
243                // Would hash request ID to consistent bucket
244                eligible.first().map(|c| (*c).clone())
245            }
246        }
247    }
248}
249
250impl RouterTrait for Router {
251    fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>> {
252        // Clone request to avoid lifetime issues with async block
253        let request = request.clone();
254
255        Box::pin(async move {
256            let mut candidates = self.build_candidates(&request.capability);
257
258            if candidates.is_empty() {
259                return Err(FederationError::NoCapacity(request.capability.clone()));
260            }
261
262            self.rank_candidates(&mut candidates, &request);
263
264            self.select_best(&candidates)
265                .map(|c| c.target)
266                .ok_or_else(|| FederationError::AllNodesUnhealthy(request.capability.clone()))
267        })
268    }
269
270    fn get_candidates(
271        &self,
272        request: &InferenceRequest,
273    ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>> {
274        // Clone request to avoid lifetime issues with async block
275        let request = request.clone();
276
277        Box::pin(async move {
278            let mut candidates = self.build_candidates(&request.capability);
279            self.rank_candidates(&mut candidates, &request);
280            Ok(candidates)
281        })
282    }
283}
284
285// ============================================================================
286// Builder for Router
287// ============================================================================
288
289/// Builder for creating routers with custom configuration
290pub struct RouterBuilder {
291    config: RouterConfig,
292    catalog: Option<Arc<ModelCatalog>>,
293    health: Option<Arc<HealthChecker>>,
294    circuit_breaker: Option<Arc<CircuitBreaker>>,
295    policy: Option<CompositePolicy>,
296}
297
298impl RouterBuilder {
299    pub fn new() -> Self {
300        Self {
301            config: RouterConfig::default(),
302            catalog: None,
303            health: None,
304            circuit_breaker: None,
305            policy: None,
306        }
307    }
308
309    #[must_use]
310    pub fn config(mut self, config: RouterConfig) -> Self {
311        self.config = config;
312        self
313    }
314
315    #[must_use]
316    pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
317        self.catalog = Some(catalog);
318        self
319    }
320
321    #[must_use]
322    pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
323        self.health = Some(health);
324        self
325    }
326
327    #[must_use]
328    pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
329        self.circuit_breaker = Some(cb);
330        self
331    }
332
333    #[must_use]
334    pub fn policy(mut self, policy: CompositePolicy) -> Self {
335        self.policy = Some(policy);
336        self
337    }
338
339    pub fn build(self) -> Router {
340        let catalog = self
341            .catalog
342            .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
343        let health = self
344            .health
345            .unwrap_or_else(|| Arc::new(HealthChecker::default()));
346        let circuit_breaker = self
347            .circuit_breaker
348            .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
349
350        let router = Router::new(self.config, catalog, health, circuit_breaker);
351
352        if let Some(policy) = self.policy {
353            router.with_policy(policy)
354        } else {
355            router
356        }
357    }
358}
359
360impl Default for RouterBuilder {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366// ============================================================================
367// Tests
368// ============================================================================
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    fn setup_test_router() -> (Router, Arc<ModelCatalog>, Arc<HealthChecker>) {
375        let catalog = Arc::new(ModelCatalog::new());
376        let health = Arc::new(HealthChecker::default());
377        let circuit_breaker = Arc::new(CircuitBreaker::default());
378
379        let router = Router::new(
380            RouterConfig::default(),
381            Arc::clone(&catalog),
382            Arc::clone(&health),
383            circuit_breaker,
384        );
385
386        (router, catalog, health)
387    }
388
389    #[tokio::test]
390    async fn test_route_no_nodes() {
391        let (router, _, _) = setup_test_router();
392
393        let request = InferenceRequest {
394            capability: Capability::Transcribe,
395            input: vec![],
396            qos: QoSRequirements::default(),
397            request_id: "test-1".to_string(),
398            tenant_id: None,
399        };
400
401        let result = router.route(&request).await;
402        assert!(matches!(result, Err(FederationError::NoCapacity(_))));
403    }
404
405    #[tokio::test]
406    async fn test_route_single_node() {
407        let (router, catalog, health) = setup_test_router();
408
409        // Register a node
410        catalog
411            .register(
412                ModelId("whisper".to_string()),
413                NodeId("node-1".to_string()),
414                RegionId("us-west".to_string()),
415                vec![Capability::Transcribe],
416            )
417            .await
418            .expect("registration failed");
419
420        health.register_node(NodeId("node-1".to_string()));
421        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(50));
422
423        let request = InferenceRequest {
424            capability: Capability::Transcribe,
425            input: vec![],
426            qos: QoSRequirements::default(),
427            request_id: "test-2".to_string(),
428            tenant_id: None,
429        };
430
431        let result = router.route(&request).await;
432        assert!(result.is_ok());
433
434        let target = result.expect("routing failed");
435        assert_eq!(target.node_id, NodeId("node-1".to_string()));
436    }
437
438    #[tokio::test]
439    async fn test_route_prefers_healthy() {
440        let (router, catalog, health) = setup_test_router();
441
442        // Register two nodes
443        catalog
444            .register(
445                ModelId("llama".to_string()),
446                NodeId("healthy-node".to_string()),
447                RegionId("us-west".to_string()),
448                vec![Capability::Generate],
449            )
450            .await
451            .expect("registration failed");
452
453        catalog
454            .register(
455                ModelId("llama".to_string()),
456                NodeId("degraded-node".to_string()),
457                RegionId("us-east".to_string()),
458                vec![Capability::Generate],
459            )
460            .await
461            .expect("registration failed");
462
463        // Make one healthy, one degraded
464        health.register_node(NodeId("healthy-node".to_string()));
465        health.register_node(NodeId("degraded-node".to_string()));
466
467        for _ in 0..5 {
468            health.report_success(
469                &NodeId("healthy-node".to_string()),
470                Duration::from_millis(20),
471            );
472            health.report_failure(&NodeId("degraded-node".to_string()));
473        }
474
475        let request = InferenceRequest {
476            capability: Capability::Generate,
477            input: vec![],
478            qos: QoSRequirements::default(),
479            request_id: "test-3".to_string(),
480            tenant_id: None,
481        };
482
483        let result = router.route(&request).await;
484        assert!(result.is_ok());
485
486        let target = result.expect("routing failed");
487        assert_eq!(target.node_id, NodeId("healthy-node".to_string()));
488    }
489
490    #[tokio::test]
491    async fn test_get_candidates_returns_all() {
492        let (router, catalog, health) = setup_test_router();
493
494        // Register multiple nodes
495        for i in 0..3 {
496            catalog
497                .register(
498                    ModelId("embed".to_string()),
499                    NodeId(format!("node-{}", i)),
500                    RegionId("us-west".to_string()),
501                    vec![Capability::Embed],
502                )
503                .await
504                .expect("registration failed");
505
506            health.register_node(NodeId(format!("node-{}", i)));
507            health.report_success(&NodeId(format!("node-{}", i)), Duration::from_millis(50));
508        }
509
510        let request = InferenceRequest {
511            capability: Capability::Embed,
512            input: vec![],
513            qos: QoSRequirements::default(),
514            request_id: "test-4".to_string(),
515            tenant_id: None,
516        };
517
518        let candidates = router
519            .get_candidates(&request)
520            .await
521            .expect("get_candidates failed");
522
523        assert_eq!(candidates.len(), 3);
524    }
525
526    #[test]
527    fn test_router_builder() {
528        let router = RouterBuilder::new()
529            .config(RouterConfig {
530                max_candidates: 5,
531                min_score: 0.2,
532                strategy: LoadBalanceStrategy::RoundRobin,
533            })
534            .build();
535
536        assert_eq!(router.config.max_candidates, 5);
537        assert_eq!(router.config.min_score, 0.2);
538    }
539
540    // =========================================================================
541    // RouterConfig tests
542    // =========================================================================
543
544    #[test]
545    fn test_router_config_default() {
546        let config = RouterConfig::default();
547        assert_eq!(config.max_candidates, 10);
548        assert_eq!(config.min_score, 0.1);
549        assert!(matches!(config.strategy, LoadBalanceStrategy::LeastLatency));
550    }
551
552    #[test]
553    fn test_router_config_clone() {
554        let config = RouterConfig {
555            max_candidates: 20,
556            min_score: 0.5,
557            strategy: LoadBalanceStrategy::ConsistentHash,
558        };
559        let cloned = config.clone();
560        assert_eq!(cloned.max_candidates, 20);
561        assert_eq!(cloned.min_score, 0.5);
562    }
563
564    // =========================================================================
565    // RouteDecision tests
566    // =========================================================================
567
568    #[test]
569    fn test_route_decision_construction() {
570        let decision = RouteDecision {
571            target: RouteTarget {
572                node_id: NodeId("n1".to_string()),
573                region_id: RegionId("r1".to_string()),
574                endpoint: "http://n1:8080".to_string(),
575                estimated_latency: Duration::from_millis(50),
576                score: 0.95,
577            },
578            alternatives: vec![],
579            reasoning: "best latency score".to_string(),
580        };
581        assert_eq!(decision.target.score, 0.95);
582        assert!(decision.alternatives.is_empty());
583    }
584
585    #[test]
586    fn test_route_decision_with_alternatives() {
587        let primary = RouteTarget {
588            node_id: NodeId("n1".to_string()),
589            region_id: RegionId("us-west".to_string()),
590            endpoint: "http://n1:8080".to_string(),
591            estimated_latency: Duration::from_millis(50),
592            score: 0.95,
593        };
594        let alt = RouteTarget {
595            node_id: NodeId("n2".to_string()),
596            region_id: RegionId("eu-west".to_string()),
597            endpoint: "http://n2:8080".to_string(),
598            estimated_latency: Duration::from_millis(120),
599            score: 0.7,
600        };
601        let decision = RouteDecision {
602            target: primary,
603            alternatives: vec![alt],
604            reasoning: "latency-based".to_string(),
605        };
606        assert_eq!(decision.alternatives.len(), 1);
607    }
608
609    // =========================================================================
610    // Router with_policy tests
611    // =========================================================================
612
613    #[test]
614    fn test_router_with_custom_policy() {
615        let catalog = Arc::new(ModelCatalog::new());
616        let health = Arc::new(HealthChecker::default());
617        let circuit_breaker = Arc::new(CircuitBreaker::default());
618
619        let custom_policy =
620            CompositePolicy::new().with_policy(super::super::policy::LatencyPolicy::default());
621
622        let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker)
623            .with_policy(custom_policy);
624
625        assert_eq!(router.config.max_candidates, 10);
626    }
627
628    // =========================================================================
629    // RouterBuilder chaining tests
630    // =========================================================================
631
632    #[test]
633    fn test_router_builder_default() {
634        let builder = RouterBuilder::default();
635        assert!(builder.catalog.is_none());
636        assert!(builder.health.is_none());
637        assert!(builder.circuit_breaker.is_none());
638        assert!(builder.policy.is_none());
639    }
640
641    #[test]
642    fn test_router_builder_with_catalog() {
643        let catalog = Arc::new(ModelCatalog::new());
644        let builder = RouterBuilder::new().catalog(Arc::clone(&catalog));
645        assert!(builder.catalog.is_some());
646    }
647
648    #[test]
649    fn test_router_builder_with_health() {
650        let health = Arc::new(HealthChecker::default());
651        let builder = RouterBuilder::new().health(health);
652        assert!(builder.health.is_some());
653    }
654
655    #[test]
656    fn test_router_builder_with_circuit_breaker() {
657        let cb = Arc::new(CircuitBreaker::default());
658        let builder = RouterBuilder::new().circuit_breaker(cb);
659        assert!(builder.circuit_breaker.is_some());
660    }
661
662    #[test]
663    fn test_router_builder_with_policy() {
664        let policy = CompositePolicy::new();
665        let builder = RouterBuilder::new().policy(policy);
666        assert!(builder.policy.is_some());
667    }
668
669    #[test]
670    fn test_router_builder_full_chain() {
671        let catalog = Arc::new(ModelCatalog::new());
672        let health = Arc::new(HealthChecker::default());
673        let cb = Arc::new(CircuitBreaker::default());
674
675        let router = RouterBuilder::new()
676            .config(RouterConfig {
677                max_candidates: 20,
678                min_score: 0.05,
679                strategy: LoadBalanceStrategy::WeightedRandom,
680            })
681            .catalog(catalog)
682            .health(health)
683            .circuit_breaker(cb)
684            .policy(CompositePolicy::enterprise_default())
685            .build();
686
687        assert_eq!(router.config.max_candidates, 20);
688        assert_eq!(router.config.min_score, 0.05);
689    }
690
691    #[test]
692    fn test_router_builder_build_without_policy() {
693        let router = RouterBuilder::new().build();
694        // Should use enterprise default policy
695        assert_eq!(router.config.max_candidates, 10);
696    }
697
698    // =========================================================================
699    // build_candidates tests
700    // =========================================================================
701
702    #[tokio::test]
703    async fn test_build_candidates_skips_circuit_open() {
704        let catalog = Arc::new(ModelCatalog::new());
705        let health = Arc::new(HealthChecker::default());
706        let circuit_breaker = Arc::new(CircuitBreaker::default());
707
708        // Register node
709        catalog
710            .register(
711                ModelId("model".to_string()),
712                NodeId("n1".to_string()),
713                RegionId("us-west".to_string()),
714                vec![Capability::Generate],
715            )
716            .await
717            .expect("registration failed");
718
719        health.register_node(NodeId("n1".to_string()));
720        for _ in 0..3 {
721            health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
722        }
723
724        // Open circuit breaker for the node
725        for _ in 0..5 {
726            circuit_breaker.record_failure(&NodeId("n1".to_string()));
727        }
728
729        let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker);
730
731        let request = InferenceRequest {
732            capability: Capability::Generate,
733            input: vec![],
734            qos: QoSRequirements::default(),
735            request_id: "test".to_string(),
736            tenant_id: None,
737        };
738
739        // Should have no candidates since circuit is open
740        let result = router.route(&request).await;
741        assert!(matches!(result, Err(FederationError::NoCapacity(_))));
742    }
743
744    #[tokio::test]
745    async fn test_build_candidates_skips_unhealthy() {
746        let catalog = Arc::new(ModelCatalog::new());
747        let health = Arc::new(HealthChecker::new(super::super::health::HealthConfig {
748            failure_threshold: 2,
749            ..Default::default()
750        }));
751        let circuit_breaker = Arc::new(CircuitBreaker::default());
752
753        catalog
754            .register(
755                ModelId("model".to_string()),
756                NodeId("unhealthy-node".to_string()),
757                RegionId("us-west".to_string()),
758                vec![Capability::Generate],
759            )
760            .await
761            .expect("registration failed");
762
763        health.register_node(NodeId("unhealthy-node".to_string()));
764        // Make unhealthy
765        for _ in 0..3 {
766            health.report_failure(&NodeId("unhealthy-node".to_string()));
767        }
768
769        let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker);
770
771        let request = InferenceRequest {
772            capability: Capability::Generate,
773            input: vec![],
774            qos: QoSRequirements::default(),
775            request_id: "test".to_string(),
776            tenant_id: None,
777        };
778
779        // Unhealthy nodes are skipped
780        let result = router.route(&request).await;
781        assert!(result.is_err());
782    }
783
784    #[tokio::test]
785    async fn test_build_candidates_wrong_capability() {
786        let catalog = Arc::new(ModelCatalog::new());
787        let health = Arc::new(HealthChecker::default());
788        let circuit_breaker = Arc::new(CircuitBreaker::default());
789
790        catalog
791            .register(
792                ModelId("whisper".to_string()),
793                NodeId("n1".to_string()),
794                RegionId("us-west".to_string()),
795                vec![Capability::Transcribe],
796            )
797            .await
798            .expect("registration failed");
799
800        let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker);
801
802        let request = InferenceRequest {
803            capability: Capability::Generate, // Different capability
804            input: vec![],
805            qos: QoSRequirements::default(),
806            request_id: "test".to_string(),
807            tenant_id: None,
808        };
809
810        let result = router.route(&request).await;
811        assert!(matches!(result, Err(FederationError::NoCapacity(_))));
812    }
813
814    // =========================================================================
815    // select_best strategy tests
816    // =========================================================================
817
818    #[tokio::test]
819    async fn test_route_with_round_robin_strategy() {
820        let catalog = Arc::new(ModelCatalog::new());
821        let health = Arc::new(HealthChecker::default());
822        let circuit_breaker = Arc::new(CircuitBreaker::default());
823
824        catalog
825            .register(
826                ModelId("model".to_string()),
827                NodeId("n1".to_string()),
828                RegionId("us-west".to_string()),
829                vec![Capability::Generate],
830            )
831            .await
832            .expect("registration failed");
833
834        health.register_node(NodeId("n1".to_string()));
835        for _ in 0..3 {
836            health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
837        }
838
839        let router = Router::new(
840            RouterConfig {
841                strategy: LoadBalanceStrategy::RoundRobin,
842                ..Default::default()
843            },
844            catalog,
845            health,
846            circuit_breaker,
847        );
848
849        let request = InferenceRequest {
850            capability: Capability::Generate,
851            input: vec![],
852            qos: QoSRequirements::default(),
853            request_id: "test".to_string(),
854            tenant_id: None,
855        };
856
857        let result = router.route(&request).await;
858        assert!(result.is_ok());
859    }
860
861    #[tokio::test]
862    async fn test_route_with_consistent_hash_strategy() {
863        let catalog = Arc::new(ModelCatalog::new());
864        let health = Arc::new(HealthChecker::default());
865        let circuit_breaker = Arc::new(CircuitBreaker::default());
866
867        catalog
868            .register(
869                ModelId("model".to_string()),
870                NodeId("n1".to_string()),
871                RegionId("us-west".to_string()),
872                vec![Capability::Embed],
873            )
874            .await
875            .expect("registration failed");
876
877        health.register_node(NodeId("n1".to_string()));
878        for _ in 0..3 {
879            health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
880        }
881
882        let router = Router::new(
883            RouterConfig {
884                strategy: LoadBalanceStrategy::ConsistentHash,
885                ..Default::default()
886            },
887            catalog,
888            health,
889            circuit_breaker,
890        );
891
892        let request = InferenceRequest {
893            capability: Capability::Embed,
894            input: vec![],
895            qos: QoSRequirements::default(),
896            request_id: "test".to_string(),
897            tenant_id: None,
898        };
899
900        let result = router.route(&request).await;
901        assert!(result.is_ok());
902    }
903
904    #[tokio::test]
905    async fn test_route_with_weighted_random_strategy() {
906        let catalog = Arc::new(ModelCatalog::new());
907        let health = Arc::new(HealthChecker::default());
908        let circuit_breaker = Arc::new(CircuitBreaker::default());
909
910        catalog
911            .register(
912                ModelId("model".to_string()),
913                NodeId("n1".to_string()),
914                RegionId("us-west".to_string()),
915                vec![Capability::Generate],
916            )
917            .await
918            .expect("registration failed");
919
920        health.register_node(NodeId("n1".to_string()));
921        for _ in 0..3 {
922            health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
923        }
924
925        let router = Router::new(
926            RouterConfig {
927                strategy: LoadBalanceStrategy::WeightedRandom,
928                ..Default::default()
929            },
930            catalog,
931            health,
932            circuit_breaker,
933        );
934
935        let request = InferenceRequest {
936            capability: Capability::Generate,
937            input: vec![],
938            qos: QoSRequirements::default(),
939            request_id: "test".to_string(),
940            tenant_id: None,
941        };
942
943        let result = router.route(&request).await;
944        assert!(result.is_ok());
945    }
946
947    #[tokio::test]
948    async fn test_route_with_least_connections_strategy() {
949        let catalog = Arc::new(ModelCatalog::new());
950        let health = Arc::new(HealthChecker::default());
951        let circuit_breaker = Arc::new(CircuitBreaker::default());
952
953        catalog
954            .register(
955                ModelId("model".to_string()),
956                NodeId("n1".to_string()),
957                RegionId("us-west".to_string()),
958                vec![Capability::Generate],
959            )
960            .await
961            .expect("registration failed");
962
963        health.register_node(NodeId("n1".to_string()));
964        for _ in 0..3 {
965            health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
966        }
967
968        let router = Router::new(
969            RouterConfig {
970                strategy: LoadBalanceStrategy::LeastConnections,
971                ..Default::default()
972            },
973            catalog,
974            health,
975            circuit_breaker,
976        );
977
978        let request = InferenceRequest {
979            capability: Capability::Generate,
980            input: vec![],
981            qos: QoSRequirements::default(),
982            request_id: "test".to_string(),
983            tenant_id: None,
984        };
985
986        let result = router.route(&request).await;
987        assert!(result.is_ok());
988    }
989
990    // =========================================================================
991    // rank_candidates with policy rejection
992    // =========================================================================
993
994    #[tokio::test]
995    async fn test_get_candidates_with_rejected() {
996        let catalog = Arc::new(ModelCatalog::new());
997        let health = Arc::new(HealthChecker::default());
998        let circuit_breaker = Arc::new(CircuitBreaker::default());
999
1000        // Register node
1001        catalog
1002            .register(
1003                ModelId("model".to_string()),
1004                NodeId("n1".to_string()),
1005                RegionId("us-west".to_string()),
1006                vec![Capability::Generate],
1007            )
1008            .await
1009            .expect("registration failed");
1010
1011        health.register_node(NodeId("n1".to_string()));
1012        for _ in 0..3 {
1013            health.report_success(&NodeId("n1".to_string()), Duration::from_millis(50));
1014        }
1015
1016        // Use a policy that rejects based on privacy
1017        let policy = CompositePolicy::new().with_policy(
1018            super::super::policy::PrivacyPolicy::default()
1019                .with_region(RegionId("us-west".to_string()), PrivacyLevel::Public),
1020        );
1021
1022        let router = Router::new(RouterConfig::default(), catalog, health, circuit_breaker)
1023            .with_policy(policy);
1024
1025        // Request requires Confidential, but us-west is Public
1026        let request = InferenceRequest {
1027            capability: Capability::Generate,
1028            input: vec![],
1029            qos: QoSRequirements {
1030                privacy: PrivacyLevel::Confidential,
1031                ..Default::default()
1032            },
1033            request_id: "test".to_string(),
1034            tenant_id: None,
1035        };
1036
1037        let candidates = router
1038            .get_candidates(&request)
1039            .await
1040            .expect("get_candidates failed");
1041        assert!(!candidates.is_empty());
1042        // Candidate should be marked as not eligible
1043        assert!(!candidates[0].eligible);
1044        assert!(candidates[0].rejection_reason.is_some());
1045    }
1046
1047    // =========================================================================
1048    // min_score filtering
1049    // =========================================================================
1050
1051    #[tokio::test]
1052    async fn test_route_min_score_filters_candidates() {
1053        let catalog = Arc::new(ModelCatalog::new());
1054        let health = Arc::new(HealthChecker::default());
1055        let circuit_breaker = Arc::new(CircuitBreaker::default());
1056
1057        catalog
1058            .register(
1059                ModelId("model".to_string()),
1060                NodeId("n1".to_string()),
1061                RegionId("us-west".to_string()),
1062                vec![Capability::Generate],
1063            )
1064            .await
1065            .expect("registration failed");
1066
1067        health.register_node(NodeId("n1".to_string()));
1068        // Leave at Unknown state (low health score)
1069
1070        let router = Router::new(
1071            RouterConfig {
1072                min_score: 0.99, // Very high minimum
1073                ..Default::default()
1074            },
1075            catalog,
1076            health,
1077            circuit_breaker,
1078        );
1079
1080        let request = InferenceRequest {
1081            capability: Capability::Generate,
1082            input: vec![],
1083            qos: QoSRequirements::default(),
1084            request_id: "test".to_string(),
1085            tenant_id: None,
1086        };
1087
1088        let result = router.route(&request).await;
1089        // Might fail if score is below min_score
1090        // Unknown health = 0.3 score, which is below 0.99
1091        assert!(result.is_err());
1092    }
1093}