apr_cli/federation/
routing.rs

1//! Router - Intelligent node selection for inference requests
2//!
3//! Combines catalog, health, and policy data to select the optimal
4//! node for each request.
5
6use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::policy::CompositePolicy;
9use super::traits::*;
10use std::sync::Arc;
11use std::time::Duration;
12
13// ============================================================================
14// Route Decision (exported type)
15// ============================================================================
16
17/// Final routing decision with reasoning
18#[derive(Debug, Clone)]
19pub struct RouteDecision {
20    pub target: RouteTarget,
21    pub alternatives: Vec<RouteTarget>,
22    pub reasoning: String,
23}
24
25// ============================================================================
26// Router Implementation
27// ============================================================================
28
29/// Configuration for the router
30#[derive(Debug, Clone)]
31pub struct RouterConfig {
32    /// Maximum candidates to evaluate
33    pub max_candidates: usize,
34    /// Minimum score to be considered
35    pub min_score: f64,
36    /// Load balance strategy
37    pub strategy: LoadBalanceStrategy,
38}
39
40impl Default for RouterConfig {
41    fn default() -> Self {
42        Self {
43            max_candidates: 10,
44            min_score: 0.1,
45            strategy: LoadBalanceStrategy::LeastLatency,
46        }
47    }
48}
49
50/// The main router implementation
51pub struct Router {
52    config: RouterConfig,
53    catalog: Arc<ModelCatalog>,
54    health: Arc<HealthChecker>,
55    circuit_breaker: Arc<CircuitBreaker>,
56    policy: CompositePolicy,
57}
58
59impl Router {
60    pub fn new(
61        config: RouterConfig,
62        catalog: Arc<ModelCatalog>,
63        health: Arc<HealthChecker>,
64        circuit_breaker: Arc<CircuitBreaker>,
65    ) -> Self {
66        Self {
67            config,
68            catalog,
69            health,
70            circuit_breaker,
71            policy: CompositePolicy::enterprise_default(),
72        }
73    }
74
75    /// Create router with custom policy
76    #[must_use]
77    pub fn with_policy(mut self, policy: CompositePolicy) -> Self {
78        self.policy = policy;
79        self
80    }
81
82    /// Build candidates from catalog and health data
83    fn build_candidates(&self, capability: &Capability) -> Vec<RouteCandidate> {
84        // This would be async in production, but for simplicity we use sync here
85        // The actual routing is async via the trait
86
87        let mut candidates = Vec::new();
88
89        // Get nodes from catalog that support this capability
90        let entries = self.catalog.all_entries();
91
92        for entry in entries {
93            // Check if model supports the capability
94            let has_capability = entry.metadata.capabilities.iter().any(|c| c == capability);
95            if !has_capability {
96                continue;
97            }
98
99            for deployment in &entry.deployments {
100                // Check circuit breaker
101                if self.circuit_breaker.is_open(&deployment.node_id) {
102                    continue;
103                }
104
105                // Get health status
106                let health = self
107                    .health
108                    .get_cached_health(&deployment.node_id)
109                    .unwrap_or_else(|| NodeHealth {
110                        node_id: deployment.node_id.clone(),
111                        status: HealthState::Unknown,
112                        latency_p50: Duration::from_secs(1),
113                        latency_p99: Duration::from_secs(5),
114                        throughput: 0,
115                        gpu_utilization: None,
116                        queue_depth: 0,
117                        last_check: std::time::Instant::now(),
118                    });
119
120                // Skip unhealthy nodes
121                if health.status == HealthState::Unhealthy {
122                    continue;
123                }
124
125                let target = RouteTarget {
126                    node_id: deployment.node_id.clone(),
127                    region_id: deployment.region_id.clone(),
128                    endpoint: deployment.endpoint.clone(),
129                    estimated_latency: health.latency_p50,
130                    score: 0.0, // Will be calculated by policy
131                };
132
133                let health_score = match health.status {
134                    HealthState::Healthy => 1.0,
135                    HealthState::Degraded => 0.5,
136                    HealthState::Unknown => 0.3,
137                    HealthState::Unhealthy => 0.0,
138                };
139
140                let scores = RouteScores {
141                    latency_score: 1.0 - (health.latency_p50.as_millis() as f64 / 5000.0).min(1.0),
142                    throughput_score: (health.throughput as f64 / 1000.0).min(1.0),
143                    cost_score: 0.5,     // Would come from region pricing
144                    locality_score: 0.5, // Would be calculated based on request origin
145                    health_score,
146                    total: 0.0,
147                };
148
149                candidates.push(RouteCandidate {
150                    target,
151                    scores,
152                    eligible: true,
153                    rejection_reason: None,
154                });
155            }
156        }
157
158        candidates
159    }
160
161    /// Score and rank candidates
162    fn rank_candidates(&self, candidates: &mut [RouteCandidate], request: &InferenceRequest) {
163        for candidate in candidates.iter_mut() {
164            // Check eligibility
165            if !self.policy.is_eligible(candidate, request) {
166                candidate.eligible = false;
167                candidate.rejection_reason = Some("Policy rejected".to_string());
168                continue;
169            }
170
171            // Calculate score
172            let score = self.policy.score(candidate, request);
173            candidate.target.score = score;
174            candidate.scores.total = score;
175        }
176
177        // Sort by score descending
178        candidates.sort_by(|a, b| {
179            b.scores
180                .total
181                .partial_cmp(&a.scores.total)
182                .unwrap_or(std::cmp::Ordering::Equal)
183        });
184    }
185
186    /// Select best candidate based on strategy
187    fn select_best(&self, candidates: &[RouteCandidate]) -> Option<RouteCandidate> {
188        let eligible: Vec<_> = candidates
189            .iter()
190            .filter(|c| c.eligible && c.scores.total >= self.config.min_score)
191            .take(self.config.max_candidates)
192            .collect();
193
194        if eligible.is_empty() {
195            return None;
196        }
197
198        match self.config.strategy {
199            LoadBalanceStrategy::LeastLatency => {
200                // Already sorted by score (which factors in latency)
201                eligible.first().map(|c| (*c).clone())
202            }
203            LoadBalanceStrategy::LeastConnections => {
204                // Would use queue_depth in production
205                eligible.first().map(|c| (*c).clone())
206            }
207            LoadBalanceStrategy::RoundRobin => {
208                // Would maintain counter state
209                eligible.first().map(|c| (*c).clone())
210            }
211            LoadBalanceStrategy::WeightedRandom => {
212                // Weighted random selection based on scores
213                use std::collections::hash_map::DefaultHasher;
214                use std::hash::{Hash, Hasher};
215
216                let total_weight: f64 = eligible.iter().map(|c| c.scores.total).sum();
217                if total_weight <= 0.0 {
218                    return eligible.first().map(|c| (*c).clone());
219                }
220
221                // Simple pseudo-random using current time
222                let mut hasher = DefaultHasher::new();
223                std::time::SystemTime::now()
224                    .duration_since(std::time::UNIX_EPOCH)
225                    .unwrap_or_default()
226                    .as_nanos()
227                    .hash(&mut hasher);
228                let random = (hasher.finish() as f64) / (u64::MAX as f64);
229
230                let target = random * total_weight;
231                let mut cumulative = 0.0;
232
233                for candidate in &eligible {
234                    cumulative += candidate.scores.total;
235                    if cumulative >= target {
236                        return Some((*candidate).clone());
237                    }
238                }
239
240                eligible.last().map(|c| (*c).clone())
241            }
242            LoadBalanceStrategy::ConsistentHash => {
243                // Would hash request ID to consistent bucket
244                eligible.first().map(|c| (*c).clone())
245            }
246        }
247    }
248}
249
250impl RouterTrait for Router {
251    fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>> {
252        // Clone request to avoid lifetime issues with async block
253        let request = request.clone();
254
255        Box::pin(async move {
256            let mut candidates = self.build_candidates(&request.capability);
257
258            if candidates.is_empty() {
259                return Err(FederationError::NoCapacity(request.capability.clone()));
260            }
261
262            self.rank_candidates(&mut candidates, &request);
263
264            self.select_best(&candidates)
265                .map(|c| c.target)
266                .ok_or_else(|| FederationError::AllNodesUnhealthy(request.capability.clone()))
267        })
268    }
269
270    fn get_candidates(
271        &self,
272        request: &InferenceRequest,
273    ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>> {
274        // Clone request to avoid lifetime issues with async block
275        let request = request.clone();
276
277        Box::pin(async move {
278            let mut candidates = self.build_candidates(&request.capability);
279            self.rank_candidates(&mut candidates, &request);
280            Ok(candidates)
281        })
282    }
283}
284
285// ============================================================================
286// Builder for Router
287// ============================================================================
288
289/// Builder for creating routers with custom configuration
290pub struct RouterBuilder {
291    config: RouterConfig,
292    catalog: Option<Arc<ModelCatalog>>,
293    health: Option<Arc<HealthChecker>>,
294    circuit_breaker: Option<Arc<CircuitBreaker>>,
295    policy: Option<CompositePolicy>,
296}
297
298impl RouterBuilder {
299    pub fn new() -> Self {
300        Self {
301            config: RouterConfig::default(),
302            catalog: None,
303            health: None,
304            circuit_breaker: None,
305            policy: None,
306        }
307    }
308
309    #[must_use]
310    pub fn config(mut self, config: RouterConfig) -> Self {
311        self.config = config;
312        self
313    }
314
315    #[must_use]
316    pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
317        self.catalog = Some(catalog);
318        self
319    }
320
321    #[must_use]
322    pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
323        self.health = Some(health);
324        self
325    }
326
327    #[must_use]
328    pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
329        self.circuit_breaker = Some(cb);
330        self
331    }
332
333    #[must_use]
334    pub fn policy(mut self, policy: CompositePolicy) -> Self {
335        self.policy = Some(policy);
336        self
337    }
338
339    pub fn build(self) -> Router {
340        let catalog = self
341            .catalog
342            .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
343        let health = self
344            .health
345            .unwrap_or_else(|| Arc::new(HealthChecker::default()));
346        let circuit_breaker = self
347            .circuit_breaker
348            .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
349
350        let router = Router::new(self.config, catalog, health, circuit_breaker);
351
352        if let Some(policy) = self.policy {
353            router.with_policy(policy)
354        } else {
355            router
356        }
357    }
358}
359
360impl Default for RouterBuilder {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366// ============================================================================
367// Tests
368// ============================================================================
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    fn setup_test_router() -> (Router, Arc<ModelCatalog>, Arc<HealthChecker>) {
375        let catalog = Arc::new(ModelCatalog::new());
376        let health = Arc::new(HealthChecker::default());
377        let circuit_breaker = Arc::new(CircuitBreaker::default());
378
379        let router = Router::new(
380            RouterConfig::default(),
381            Arc::clone(&catalog),
382            Arc::clone(&health),
383            circuit_breaker,
384        );
385
386        (router, catalog, health)
387    }
388
389    #[tokio::test]
390    async fn test_route_no_nodes() {
391        let (router, _, _) = setup_test_router();
392
393        let request = InferenceRequest {
394            capability: Capability::Transcribe,
395            input: vec![],
396            qos: QoSRequirements::default(),
397            request_id: "test-1".to_string(),
398            tenant_id: None,
399        };
400
401        let result = router.route(&request).await;
402        assert!(matches!(result, Err(FederationError::NoCapacity(_))));
403    }
404
405    #[tokio::test]
406    async fn test_route_single_node() {
407        let (router, catalog, health) = setup_test_router();
408
409        // Register a node
410        catalog
411            .register(
412                ModelId("whisper".to_string()),
413                NodeId("node-1".to_string()),
414                RegionId("us-west".to_string()),
415                vec![Capability::Transcribe],
416            )
417            .await
418            .expect("registration failed");
419
420        health.register_node(NodeId("node-1".to_string()));
421        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(50));
422
423        let request = InferenceRequest {
424            capability: Capability::Transcribe,
425            input: vec![],
426            qos: QoSRequirements::default(),
427            request_id: "test-2".to_string(),
428            tenant_id: None,
429        };
430
431        let result = router.route(&request).await;
432        assert!(result.is_ok());
433
434        let target = result.expect("routing failed");
435        assert_eq!(target.node_id, NodeId("node-1".to_string()));
436    }
437
438    #[tokio::test]
439    async fn test_route_prefers_healthy() {
440        let (router, catalog, health) = setup_test_router();
441
442        // Register two nodes
443        catalog
444            .register(
445                ModelId("llama".to_string()),
446                NodeId("healthy-node".to_string()),
447                RegionId("us-west".to_string()),
448                vec![Capability::Generate],
449            )
450            .await
451            .expect("registration failed");
452
453        catalog
454            .register(
455                ModelId("llama".to_string()),
456                NodeId("degraded-node".to_string()),
457                RegionId("us-east".to_string()),
458                vec![Capability::Generate],
459            )
460            .await
461            .expect("registration failed");
462
463        // Make one healthy, one degraded
464        health.register_node(NodeId("healthy-node".to_string()));
465        health.register_node(NodeId("degraded-node".to_string()));
466
467        for _ in 0..5 {
468            health.report_success(
469                &NodeId("healthy-node".to_string()),
470                Duration::from_millis(20),
471            );
472            health.report_failure(&NodeId("degraded-node".to_string()));
473        }
474
475        let request = InferenceRequest {
476            capability: Capability::Generate,
477            input: vec![],
478            qos: QoSRequirements::default(),
479            request_id: "test-3".to_string(),
480            tenant_id: None,
481        };
482
483        let result = router.route(&request).await;
484        assert!(result.is_ok());
485
486        let target = result.expect("routing failed");
487        assert_eq!(target.node_id, NodeId("healthy-node".to_string()));
488    }
489
490    #[tokio::test]
491    async fn test_get_candidates_returns_all() {
492        let (router, catalog, health) = setup_test_router();
493
494        // Register multiple nodes
495        for i in 0..3 {
496            catalog
497                .register(
498                    ModelId("embed".to_string()),
499                    NodeId(format!("node-{}", i)),
500                    RegionId("us-west".to_string()),
501                    vec![Capability::Embed],
502                )
503                .await
504                .expect("registration failed");
505
506            health.register_node(NodeId(format!("node-{}", i)));
507            health.report_success(&NodeId(format!("node-{}", i)), Duration::from_millis(50));
508        }
509
510        let request = InferenceRequest {
511            capability: Capability::Embed,
512            input: vec![],
513            qos: QoSRequirements::default(),
514            request_id: "test-4".to_string(),
515            tenant_id: None,
516        };
517
518        let candidates = router
519            .get_candidates(&request)
520            .await
521            .expect("get_candidates failed");
522
523        assert_eq!(candidates.len(), 3);
524    }
525
526    #[test]
527    fn test_router_builder() {
528        let router = RouterBuilder::new()
529            .config(RouterConfig {
530                max_candidates: 5,
531                min_score: 0.2,
532                strategy: LoadBalanceStrategy::RoundRobin,
533            })
534            .build();
535
536        assert_eq!(router.config.max_candidates, 5);
537        assert_eq!(router.config.min_score, 0.2);
538    }
539}