apr_cli/federation/
gateway.rs

1//! Federation Gateway - Main entry point for distributed inference
2//!
3//! The gateway orchestrates the full inference lifecycle:
4//! routing, execution, retries, and response handling.
5
6use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::routing::Router;
9use super::traits::*;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14// ============================================================================
15// Gateway Configuration
16// ============================================================================
17
18/// Configuration for the federation gateway
19#[derive(Debug, Clone)]
20pub struct GatewayConfig {
21    /// Maximum retries per request
22    pub max_retries: u32,
23    /// Timeout for individual inference calls
24    pub inference_timeout: Duration,
25    /// Enable request tracing
26    pub enable_tracing: bool,
27}
28
29impl Default for GatewayConfig {
30    fn default() -> Self {
31        Self {
32            max_retries: 3,
33            inference_timeout: Duration::from_secs(30),
34            enable_tracing: true,
35        }
36    }
37}
38
39// ============================================================================
40// Gateway Statistics
41// ============================================================================
42
43/// Thread-safe statistics tracker
44struct StatsTracker {
45    total_requests: AtomicU64,
46    successful_requests: AtomicU64,
47    failed_requests: AtomicU64,
48    total_tokens: AtomicU64,
49    total_latency_ms: AtomicU64,
50    active_streams: AtomicU64,
51}
52
53impl StatsTracker {
54    fn new() -> Self {
55        Self {
56            total_requests: AtomicU64::new(0),
57            successful_requests: AtomicU64::new(0),
58            failed_requests: AtomicU64::new(0),
59            total_tokens: AtomicU64::new(0),
60            total_latency_ms: AtomicU64::new(0),
61            active_streams: AtomicU64::new(0),
62        }
63    }
64
65    fn record_request(&self) {
66        self.total_requests.fetch_add(1, Ordering::SeqCst);
67    }
68
69    fn record_success(&self, latency: Duration, tokens: Option<u32>) {
70        self.successful_requests.fetch_add(1, Ordering::SeqCst);
71        self.total_latency_ms
72            .fetch_add(latency.as_millis() as u64, Ordering::SeqCst);
73        if let Some(t) = tokens {
74            self.total_tokens.fetch_add(t as u64, Ordering::SeqCst);
75        }
76    }
77
78    fn record_failure(&self) {
79        self.failed_requests.fetch_add(1, Ordering::SeqCst);
80    }
81
82    #[allow(dead_code)]
83    fn increment_streams(&self) {
84        self.active_streams.fetch_add(1, Ordering::SeqCst);
85    }
86
87    #[allow(dead_code)]
88    fn decrement_streams(&self) {
89        self.active_streams.fetch_sub(1, Ordering::SeqCst);
90    }
91
92    fn snapshot(&self) -> GatewayStats {
93        let total = self.total_requests.load(Ordering::SeqCst);
94        let successful = self.successful_requests.load(Ordering::SeqCst);
95        let total_latency = self.total_latency_ms.load(Ordering::SeqCst);
96
97        let avg_latency = if successful > 0 {
98            Duration::from_millis(total_latency / successful)
99        } else {
100            Duration::ZERO
101        };
102
103        GatewayStats {
104            total_requests: total,
105            successful_requests: successful,
106            failed_requests: self.failed_requests.load(Ordering::SeqCst),
107            total_tokens: self.total_tokens.load(Ordering::SeqCst),
108            avg_latency,
109            active_streams: self.active_streams.load(Ordering::SeqCst) as u32,
110        }
111    }
112}
113
114impl Default for StatsTracker {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120// ============================================================================
121// Federation Gateway
122// ============================================================================
123
124/// The main federation gateway
125pub struct FederationGateway {
126    config: GatewayConfig,
127    router: Arc<Router>,
128    health: Arc<HealthChecker>,
129    circuit_breaker: Arc<CircuitBreaker>,
130    middlewares: Vec<Box<dyn GatewayMiddleware>>,
131    stats: StatsTracker,
132}
133
134impl FederationGateway {
135    pub fn new(
136        config: GatewayConfig,
137        router: Arc<Router>,
138        health: Arc<HealthChecker>,
139        circuit_breaker: Arc<CircuitBreaker>,
140    ) -> Self {
141        Self {
142            config,
143            router,
144            health,
145            circuit_breaker,
146            middlewares: Vec::new(),
147            stats: StatsTracker::new(),
148        }
149    }
150
151    /// Add middleware to the gateway
152    #[must_use]
153    pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
154        self.middlewares.push(Box::new(middleware));
155        self
156    }
157
158    /// Execute inference with retries
159    async fn execute_with_retries(
160        &self,
161        mut request: InferenceRequest,
162    ) -> FederationResult<InferenceResponse> {
163        // Apply before_route middlewares
164        for middleware in &self.middlewares {
165            middleware.before_route(&mut request)?;
166        }
167
168        let mut last_error = None;
169        let mut tried_nodes = Vec::new();
170
171        for attempt in 0..=self.config.max_retries {
172            // Route request (excluding already-tried nodes)
173            // In production, we'd modify the request to exclude tried_nodes
174            // For now, use the original request
175            let target = match self.router.route(&request).await {
176                Ok(t) => t,
177                Err(e) => {
178                    last_error = Some(e);
179                    continue;
180                }
181            };
182
183            // Check circuit breaker
184            if self.circuit_breaker.is_open(&target.node_id) {
185                last_error = Some(FederationError::CircuitOpen(target.node_id.clone()));
186                tried_nodes.push(target.node_id);
187                continue;
188            }
189
190            // Execute inference
191            let start = Instant::now();
192            match self.execute_on_node(&target, &request).await {
193                Ok(mut response) => {
194                    let latency = start.elapsed();
195
196                    // Record success
197                    self.health.report_success(&target.node_id, latency);
198                    self.circuit_breaker.record_success(&target.node_id);
199                    self.stats.record_success(latency, response.tokens);
200
201                    // Apply after_infer middlewares
202                    for middleware in &self.middlewares {
203                        middleware.after_infer(&request, &mut response)?;
204                    }
205
206                    return Ok(response);
207                }
208                Err(e) => {
209                    // Record failure
210                    self.health.report_failure(&target.node_id);
211                    self.circuit_breaker.record_failure(&target.node_id);
212
213                    // Notify middlewares
214                    for middleware in &self.middlewares {
215                        middleware.on_error(&request, &e);
216                    }
217
218                    last_error = Some(e);
219                    tried_nodes.push(target.node_id);
220
221                    if attempt < self.config.max_retries {
222                        // Brief backoff before retry
223                        tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
224                    }
225                }
226            }
227        }
228
229        self.stats.record_failure();
230        Err(last_error
231            .unwrap_or_else(|| FederationError::Internal("All retries exhausted".to_string())))
232    }
233
234    /// Execute inference on a specific node
235    #[allow(clippy::unused_async)] // Will be async when HTTP calls implemented
236    async fn execute_on_node(
237        &self,
238        target: &RouteTarget,
239        _request: &InferenceRequest,
240    ) -> FederationResult<InferenceResponse> {
241        // In production, this would make an HTTP/gRPC call to the target node
242        // For now, we simulate the response
243
244        if target.endpoint.is_empty() {
245            // Simulated response for testing
246            Ok(InferenceResponse {
247                output: b"simulated output".to_vec(),
248                served_by: target.node_id.clone(),
249                latency: Duration::from_millis(50),
250                tokens: Some(10),
251            })
252        } else {
253            // Would make actual HTTP call here
254            // For now, return simulated response
255            Ok(InferenceResponse {
256                output: b"simulated output".to_vec(),
257                served_by: target.node_id.clone(),
258                latency: Duration::from_millis(50),
259                tokens: Some(10),
260            })
261        }
262    }
263}
264
265impl GatewayTrait for FederationGateway {
266    fn infer(
267        &self,
268        request: InferenceRequest,
269    ) -> BoxFuture<'_, FederationResult<InferenceResponse>> {
270        Box::pin(async move {
271            self.stats.record_request();
272            self.execute_with_retries(request).await
273        })
274    }
275
276    fn infer_stream(
277        &self,
278        request: InferenceRequest,
279    ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>> {
280        Box::pin(async move {
281            self.stats.record_request();
282            self.stats.increment_streams();
283
284            // Route request
285            let target = self.router.route(&request).await?;
286
287            // Create streaming connection
288            let stream = FederationTokenStream::new(
289                target,
290                request,
291                Arc::clone(&self.health),
292                Arc::clone(&self.circuit_breaker),
293            );
294
295            let stream: Box<dyn TokenStream> = Box::new(stream);
296            Ok(stream)
297        })
298    }
299
300    fn stats(&self) -> GatewayStats {
301        self.stats.snapshot()
302    }
303}
304
305// ============================================================================
306// Token Stream Implementation
307// ============================================================================
308
309/// Streaming token response
310struct FederationTokenStream {
311    target: RouteTarget,
312    _request: InferenceRequest,
313    health: Arc<HealthChecker>,
314    circuit_breaker: Arc<CircuitBreaker>,
315    tokens_generated: u32,
316    finished: bool,
317}
318
319impl FederationTokenStream {
320    fn new(
321        target: RouteTarget,
322        request: InferenceRequest,
323        health: Arc<HealthChecker>,
324        circuit_breaker: Arc<CircuitBreaker>,
325    ) -> Self {
326        Self {
327            target,
328            _request: request,
329            health,
330            circuit_breaker,
331            tokens_generated: 0,
332            finished: false,
333        }
334    }
335}
336
337impl TokenStream for FederationTokenStream {
338    fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>> {
339        Box::pin(async move {
340            if self.finished {
341                return None;
342            }
343
344            // Simulate token generation (in production, would read from connection)
345            self.tokens_generated += 1;
346
347            if self.tokens_generated > 10 {
348                self.finished = true;
349                self.health
350                    .report_success(&self.target.node_id, Duration::from_millis(50));
351                self.circuit_breaker.record_success(&self.target.node_id);
352                return None;
353            }
354
355            Some(Ok(format!("token_{}", self.tokens_generated).into_bytes()))
356        })
357    }
358
359    fn cancel(&mut self) -> BoxFuture<'_, ()> {
360        Box::pin(async move {
361            self.finished = true;
362        })
363    }
364}
365
366// ============================================================================
367// Gateway Builder
368// ============================================================================
369
370/// Builder for creating federation gateways
371pub struct GatewayBuilder {
372    config: GatewayConfig,
373    catalog: Option<Arc<ModelCatalog>>,
374    health: Option<Arc<HealthChecker>>,
375    circuit_breaker: Option<Arc<CircuitBreaker>>,
376    router: Option<Arc<Router>>,
377    middlewares: Vec<Box<dyn GatewayMiddleware>>,
378}
379
380impl GatewayBuilder {
381    pub fn new() -> Self {
382        Self {
383            config: GatewayConfig::default(),
384            catalog: None,
385            health: None,
386            circuit_breaker: None,
387            router: None,
388            middlewares: Vec::new(),
389        }
390    }
391
392    #[must_use]
393    pub fn config(mut self, config: GatewayConfig) -> Self {
394        self.config = config;
395        self
396    }
397
398    #[must_use]
399    pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
400        self.catalog = Some(catalog);
401        self
402    }
403
404    #[must_use]
405    pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
406        self.health = Some(health);
407        self
408    }
409
410    #[must_use]
411    pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
412        self.circuit_breaker = Some(cb);
413        self
414    }
415
416    #[must_use]
417    pub fn router(mut self, router: Arc<Router>) -> Self {
418        self.router = Some(router);
419        self
420    }
421
422    #[must_use]
423    pub fn middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
424        self.middlewares.push(Box::new(middleware));
425        self
426    }
427
428    pub fn build(self) -> FederationGateway {
429        let catalog = self
430            .catalog
431            .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
432        let health = self
433            .health
434            .unwrap_or_else(|| Arc::new(HealthChecker::default()));
435        let circuit_breaker = self
436            .circuit_breaker
437            .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
438
439        let router = self.router.unwrap_or_else(|| {
440            Arc::new(Router::new(
441                super::routing::RouterConfig::default(),
442                Arc::clone(&catalog),
443                Arc::clone(&health),
444                Arc::clone(&circuit_breaker),
445            ))
446        });
447
448        let mut gateway = FederationGateway::new(self.config, router, health, circuit_breaker);
449
450        for middleware in self.middlewares {
451            gateway.middlewares.push(middleware);
452        }
453
454        gateway
455    }
456}
457
458impl Default for GatewayBuilder {
459    fn default() -> Self {
460        Self::new()
461    }
462}
463
464// ============================================================================
465// Example Middlewares
466// ============================================================================
467
468/// Logging middleware
469pub struct LoggingMiddleware {
470    prefix: String,
471}
472
473impl LoggingMiddleware {
474    pub fn new(prefix: impl Into<String>) -> Self {
475        Self {
476            prefix: prefix.into(),
477        }
478    }
479}
480
481impl GatewayMiddleware for LoggingMiddleware {
482    fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()> {
483        eprintln!(
484            "[{}] Routing request {} for {:?}",
485            self.prefix, request.request_id, request.capability
486        );
487        Ok(())
488    }
489
490    fn after_infer(
491        &self,
492        request: &InferenceRequest,
493        response: &mut InferenceResponse,
494    ) -> FederationResult<()> {
495        eprintln!(
496            "[{}] Request {} served by {:?} in {:?}",
497            self.prefix, request.request_id, response.served_by, response.latency
498        );
499        Ok(())
500    }
501
502    fn on_error(&self, request: &InferenceRequest, error: &FederationError) {
503        eprintln!(
504            "[{}] Request {} failed: {}",
505            self.prefix, request.request_id, error
506        );
507    }
508}
509
510/// Rate limiting middleware
511pub struct RateLimitMiddleware {
512    #[allow(dead_code)]
513    requests_per_second: u32,
514    // In production, would use a token bucket or sliding window
515}
516
517impl RateLimitMiddleware {
518    pub fn new(requests_per_second: u32) -> Self {
519        Self {
520            requests_per_second,
521        }
522    }
523}
524
525impl GatewayMiddleware for RateLimitMiddleware {
526    fn before_route(&self, _request: &mut InferenceRequest) -> FederationResult<()> {
527        // In production, would check rate limit and return error if exceeded
528        // For now, always allow
529        Ok(())
530    }
531
532    fn after_infer(
533        &self,
534        _request: &InferenceRequest,
535        _response: &mut InferenceResponse,
536    ) -> FederationResult<()> {
537        Ok(())
538    }
539
540    fn on_error(&self, _request: &InferenceRequest, _error: &FederationError) {}
541}
542
543// ============================================================================
544// Tests
545// ============================================================================
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    fn setup_test_gateway() -> (FederationGateway, Arc<ModelCatalog>, Arc<HealthChecker>) {
552        let catalog = Arc::new(ModelCatalog::new());
553        let health = Arc::new(HealthChecker::default());
554        let circuit_breaker = Arc::new(CircuitBreaker::default());
555
556        let router = Arc::new(Router::new(
557            super::super::routing::RouterConfig::default(),
558            Arc::clone(&catalog),
559            Arc::clone(&health),
560            Arc::clone(&circuit_breaker),
561        ));
562
563        let gateway = FederationGateway::new(
564            GatewayConfig::default(),
565            router,
566            Arc::clone(&health),
567            circuit_breaker,
568        );
569
570        (gateway, catalog, health)
571    }
572
573    #[tokio::test]
574    async fn test_infer_no_nodes() {
575        let (gateway, _, _) = setup_test_gateway();
576
577        let request = InferenceRequest {
578            capability: Capability::Generate,
579            input: b"hello".to_vec(),
580            qos: QoSRequirements::default(),
581            request_id: "test-1".to_string(),
582            tenant_id: None,
583        };
584
585        let result = gateway.infer(request).await;
586        assert!(result.is_err());
587    }
588
589    #[tokio::test]
590    async fn test_infer_with_node() {
591        let (gateway, catalog, health) = setup_test_gateway();
592
593        // Register a node
594        catalog
595            .register(
596                ModelId("test-model".to_string()),
597                NodeId("node-1".to_string()),
598                RegionId("us-west".to_string()),
599                vec![Capability::Generate],
600            )
601            .await
602            .expect("registration failed");
603
604        health.register_node(NodeId("node-1".to_string()));
605        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
606
607        let request = InferenceRequest {
608            capability: Capability::Generate,
609            input: b"hello".to_vec(),
610            qos: QoSRequirements::default(),
611            request_id: "test-2".to_string(),
612            tenant_id: None,
613        };
614
615        let result = gateway.infer(request).await;
616        assert!(result.is_ok());
617
618        let response = result.expect("inference failed");
619        assert_eq!(response.served_by, NodeId("node-1".to_string()));
620    }
621
622    #[tokio::test]
623    async fn test_stats_tracking() {
624        let (gateway, catalog, health) = setup_test_gateway();
625
626        catalog
627            .register(
628                ModelId("test".to_string()),
629                NodeId("node-1".to_string()),
630                RegionId("us-west".to_string()),
631                vec![Capability::Embed],
632            )
633            .await
634            .expect("registration failed");
635
636        health.register_node(NodeId("node-1".to_string()));
637        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
638
639        // Make some requests
640        for i in 0..3 {
641            let request = InferenceRequest {
642                capability: Capability::Embed,
643                input: vec![i],
644                qos: QoSRequirements::default(),
645                request_id: format!("test-{}", i),
646                tenant_id: None,
647            };
648
649            let _ = gateway.infer(request).await;
650        }
651
652        let stats = gateway.stats();
653        assert_eq!(stats.total_requests, 3);
654        assert_eq!(stats.successful_requests, 3);
655        assert_eq!(stats.failed_requests, 0);
656    }
657
658    #[tokio::test]
659    async fn test_streaming() {
660        let (gateway, catalog, health) = setup_test_gateway();
661
662        catalog
663            .register(
664                ModelId("stream-model".to_string()),
665                NodeId("node-1".to_string()),
666                RegionId("us-west".to_string()),
667                vec![Capability::Generate],
668            )
669            .await
670            .expect("registration failed");
671
672        health.register_node(NodeId("node-1".to_string()));
673        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
674
675        let request = InferenceRequest {
676            capability: Capability::Generate,
677            input: b"stream test".to_vec(),
678            qos: QoSRequirements::default(),
679            request_id: "stream-1".to_string(),
680            tenant_id: None,
681        };
682
683        let result = gateway.infer_stream(request).await;
684        assert!(result.is_ok());
685
686        let mut stream = result.expect("stream creation failed");
687
688        // Read tokens
689        let mut token_count = 0;
690        while let Some(result) = stream.next_token().await {
691            assert!(result.is_ok());
692            token_count += 1;
693        }
694
695        assert_eq!(token_count, 10);
696    }
697
698    #[test]
699    fn test_gateway_builder() {
700        let gateway = GatewayBuilder::new()
701            .config(GatewayConfig {
702                max_retries: 5,
703                inference_timeout: Duration::from_secs(60),
704                enable_tracing: false,
705            })
706            .middleware(LoggingMiddleware::new("test"))
707            .build();
708
709        assert_eq!(gateway.config.max_retries, 5);
710        assert_eq!(gateway.middlewares.len(), 1);
711    }
712
713    /// Comprehensive integration test demonstrating full federation flow
714    #[tokio::test]
715    async fn test_full_federation_flow() {
716        use super::super::policy::CompositePolicy;
717
718        // =====================================================================
719        // Setup: Create multi-region deployment
720        // =====================================================================
721        let catalog = Arc::new(ModelCatalog::new());
722        let health = Arc::new(HealthChecker::default());
723        let circuit_breaker = Arc::new(CircuitBreaker::default());
724
725        // Register Whisper model in US-West (primary, fast)
726        catalog
727            .register(
728                ModelId("whisper-v3".to_string()),
729                NodeId("us-west-gpu".to_string()),
730                RegionId("us-west".to_string()),
731                vec![Capability::Transcribe],
732            )
733            .await
734            .expect("failed to register us-west");
735
736        // Register Whisper model in EU-West (GDPR compliant)
737        catalog
738            .register(
739                ModelId("whisper-v3".to_string()),
740                NodeId("eu-west-gpu".to_string()),
741                RegionId("eu-west".to_string()),
742                vec![Capability::Transcribe],
743            )
744            .await
745            .expect("failed to register eu-west");
746
747        // Register LLaMA in US-East
748        catalog
749            .register(
750                ModelId("llama-70b".to_string()),
751                NodeId("us-east-gpu".to_string()),
752                RegionId("us-east".to_string()),
753                vec![Capability::Generate, Capability::Code],
754            )
755            .await
756            .expect("failed to register llama");
757
758        // Register embedding model in multiple regions
759        for (node, region) in [("embed-us", "us-west"), ("embed-eu", "eu-west")] {
760            catalog
761                .register(
762                    ModelId("bge-large".to_string()),
763                    NodeId(node.to_string()),
764                    RegionId(region.to_string()),
765                    vec![Capability::Embed],
766                )
767                .await
768                .expect("failed to register embedding");
769        }
770
771        // =====================================================================
772        // Setup health states
773        // =====================================================================
774
775        // US-West: Healthy, fast (45ms)
776        health.register_node(NodeId("us-west-gpu".to_string()));
777        for _ in 0..3 {
778            health.report_success(
779                &NodeId("us-west-gpu".to_string()),
780                Duration::from_millis(45),
781            );
782        }
783
784        // EU-West: Healthy, slower (120ms)
785        health.register_node(NodeId("eu-west-gpu".to_string()));
786        for _ in 0..3 {
787            health.report_success(
788                &NodeId("eu-west-gpu".to_string()),
789                Duration::from_millis(120),
790            );
791        }
792
793        // US-East: Will be unknown/degraded (only 1 success)
794        health.register_node(NodeId("us-east-gpu".to_string()));
795        // Just one success keeps it in Unknown state (needs 2 for Healthy)
796        health.report_success(
797            &NodeId("us-east-gpu".to_string()),
798            Duration::from_millis(100),
799        );
800
801        // Embedding nodes: Healthy
802        for node in ["embed-us", "embed-eu"] {
803            health.register_node(NodeId(node.to_string()));
804            for _ in 0..3 {
805                health.report_success(&NodeId(node.to_string()), Duration::from_millis(15));
806            }
807        }
808
809        // =====================================================================
810        // Create Router with enterprise policies
811        // =====================================================================
812        let router = Arc::new(
813            Router::new(
814                super::super::routing::RouterConfig {
815                    max_candidates: 10,
816                    min_score: 0.1,
817                    strategy: LoadBalanceStrategy::LeastLatency,
818                },
819                Arc::clone(&catalog),
820                Arc::clone(&health),
821                Arc::clone(&circuit_breaker),
822            )
823            .with_policy(CompositePolicy::enterprise_default()),
824        );
825
826        // =====================================================================
827        // Create Gateway
828        // =====================================================================
829        let gateway = FederationGateway::new(
830            GatewayConfig {
831                max_retries: 3,
832                inference_timeout: Duration::from_secs(30),
833                enable_tracing: true,
834            },
835            Arc::clone(&router),
836            Arc::clone(&health),
837            Arc::clone(&circuit_breaker),
838        );
839
840        // =====================================================================
841        // Test 1: Transcribe routes to fastest healthy node (us-west)
842        // =====================================================================
843        let request = InferenceRequest {
844            capability: Capability::Transcribe,
845            input: b"audio data".to_vec(),
846            qos: QoSRequirements::default(),
847            request_id: "test-transcribe".to_string(),
848            tenant_id: Some("acme".to_string()),
849        };
850
851        let candidates = router
852            .get_candidates(&request)
853            .await
854            .expect("get_candidates failed");
855        assert_eq!(candidates.len(), 2, "Should have 2 Transcribe candidates");
856
857        let target = router.route(&request).await.expect("route failed");
858        // US-West should be selected (lower latency = higher score)
859        assert_eq!(target.node_id, NodeId("us-west-gpu".to_string()));
860
861        // =====================================================================
862        // Test 2: Generate routes to only available node
863        // =====================================================================
864        let request = InferenceRequest {
865            capability: Capability::Generate,
866            input: b"prompt".to_vec(),
867            qos: QoSRequirements::default(),
868            request_id: "test-generate".to_string(),
869            tenant_id: None,
870        };
871
872        let target = router.route(&request).await.expect("route failed");
873        assert_eq!(target.node_id, NodeId("us-east-gpu".to_string()));
874
875        // =====================================================================
876        // Test 3: Embed has multiple candidates
877        // =====================================================================
878        let request = InferenceRequest {
879            capability: Capability::Embed,
880            input: b"text".to_vec(),
881            qos: QoSRequirements::default(),
882            request_id: "test-embed".to_string(),
883            tenant_id: None,
884        };
885
886        let candidates = router
887            .get_candidates(&request)
888            .await
889            .expect("get_candidates failed");
890        assert_eq!(candidates.len(), 2, "Should have 2 Embed candidates");
891
892        // =====================================================================
893        // Test 4: Gateway inference with stats
894        // =====================================================================
895        let request = InferenceRequest {
896            capability: Capability::Transcribe,
897            input: b"audio".to_vec(),
898            qos: QoSRequirements::default(),
899            request_id: "test-infer".to_string(),
900            tenant_id: None,
901        };
902
903        let response = gateway.infer(request).await.expect("inference failed");
904        assert_eq!(response.served_by, NodeId("us-west-gpu".to_string()));
905        assert!(!response.output.is_empty());
906
907        let stats = gateway.stats();
908        assert_eq!(stats.total_requests, 1);
909        assert_eq!(stats.successful_requests, 1);
910        assert_eq!(stats.failed_requests, 0);
911
912        // =====================================================================
913        // Test 5: Streaming inference
914        // =====================================================================
915        let request = InferenceRequest {
916            capability: Capability::Generate,
917            input: b"stream prompt".to_vec(),
918            qos: QoSRequirements::default(),
919            request_id: "test-stream".to_string(),
920            tenant_id: None,
921        };
922
923        let mut stream = gateway.infer_stream(request).await.expect("stream failed");
924        let mut tokens = 0;
925        while let Some(result) = stream.next_token().await {
926            result.expect("token error");
927            tokens += 1;
928        }
929        assert_eq!(tokens, 10, "Should receive 10 tokens");
930
931        // =====================================================================
932        // Test 6: Circuit breaker
933        // =====================================================================
934        let bad_node = NodeId("failing-node".to_string());
935
936        // Initially closed
937        assert_eq!(circuit_breaker.state(&bad_node), CircuitState::Closed);
938
939        // Simulate failures
940        for _ in 0..5 {
941            circuit_breaker.record_failure(&bad_node);
942        }
943        assert_eq!(circuit_breaker.state(&bad_node), CircuitState::Open);
944        assert!(circuit_breaker.is_open(&bad_node));
945
946        // =====================================================================
947        // Verify catalog state
948        // =====================================================================
949        let all_models = catalog.list_all().await.expect("list failed");
950        assert_eq!(all_models.len(), 3); // whisper-v3, llama-70b, bge-large
951
952        // =====================================================================
953        // Verify health states are tracked (all nodes have cached health)
954        // =====================================================================
955        let nodes_with_health = [
956            "us-west-gpu",
957            "eu-west-gpu",
958            "us-east-gpu",
959            "embed-us",
960            "embed-eu",
961        ];
962        for node in nodes_with_health {
963            let h = health.get_cached_health(&NodeId(node.to_string()));
964            assert!(h.is_some(), "Health should be tracked for {}", node);
965        }
966
967        // Verify us-west has good health (3 successes)
968        let us_west_health = health
969            .get_cached_health(&NodeId("us-west-gpu".to_string()))
970            .unwrap();
971        assert_eq!(
972            us_west_health.status,
973            HealthState::Healthy,
974            "US-West should be healthy"
975        );
976
977        // =====================================================================
978        // Final summary - print results for visibility
979        // =====================================================================
980        println!("\n✅ Full Federation Flow Test PASSED!");
981        println!("   - 3 models registered across 5 nodes");
982        println!("   - 6 health entries tracked");
983        println!("   - Routing correctly prefers fastest healthy nodes");
984        println!("   - Gateway inference succeeds with stats tracking");
985        println!("   - Streaming returns expected token count");
986        println!("   - Circuit breaker opens after failures");
987    }
988}