Skip to main content

apr_cli/federation/
gateway.rs

1//! Federation Gateway - Main entry point for distributed inference
2//!
3//! The gateway orchestrates the full inference lifecycle:
4//! routing, execution, retries, and response handling.
5
6use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::routing::Router;
9use super::traits::*;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14// ============================================================================
15// Gateway Configuration
16// ============================================================================
17
18/// Configuration for the federation gateway
19#[derive(Debug, Clone)]
20pub struct GatewayConfig {
21    /// Maximum retries per request
22    pub max_retries: u32,
23    /// Timeout for individual inference calls
24    pub inference_timeout: Duration,
25    /// Enable request tracing
26    pub enable_tracing: bool,
27}
28
29impl Default for GatewayConfig {
30    fn default() -> Self {
31        Self {
32            max_retries: 3,
33            inference_timeout: Duration::from_secs(30),
34            enable_tracing: true,
35        }
36    }
37}
38
39// ============================================================================
40// Gateway Statistics
41// ============================================================================
42
43/// Thread-safe statistics tracker
44struct StatsTracker {
45    total_requests: AtomicU64,
46    successful_requests: AtomicU64,
47    failed_requests: AtomicU64,
48    total_tokens: AtomicU64,
49    total_latency_ms: AtomicU64,
50    active_streams: AtomicU64,
51}
52
53impl StatsTracker {
54    fn new() -> Self {
55        Self {
56            total_requests: AtomicU64::new(0),
57            successful_requests: AtomicU64::new(0),
58            failed_requests: AtomicU64::new(0),
59            total_tokens: AtomicU64::new(0),
60            total_latency_ms: AtomicU64::new(0),
61            active_streams: AtomicU64::new(0),
62        }
63    }
64
65    fn record_request(&self) {
66        self.total_requests.fetch_add(1, Ordering::SeqCst);
67    }
68
69    fn record_success(&self, latency: Duration, tokens: Option<u32>) {
70        self.successful_requests.fetch_add(1, Ordering::SeqCst);
71        self.total_latency_ms
72            .fetch_add(latency.as_millis() as u64, Ordering::SeqCst);
73        if let Some(t) = tokens {
74            self.total_tokens.fetch_add(t as u64, Ordering::SeqCst);
75        }
76    }
77
78    fn record_failure(&self) {
79        self.failed_requests.fetch_add(1, Ordering::SeqCst);
80    }
81
82    #[allow(dead_code)]
83    fn increment_streams(&self) {
84        self.active_streams.fetch_add(1, Ordering::SeqCst);
85    }
86
87    #[allow(dead_code)]
88    fn decrement_streams(&self) {
89        self.active_streams.fetch_sub(1, Ordering::SeqCst);
90    }
91
92    fn snapshot(&self) -> GatewayStats {
93        let total = self.total_requests.load(Ordering::SeqCst);
94        let successful = self.successful_requests.load(Ordering::SeqCst);
95        let total_latency = self.total_latency_ms.load(Ordering::SeqCst);
96
97        let avg_latency = if successful > 0 {
98            Duration::from_millis(total_latency / successful)
99        } else {
100            Duration::ZERO
101        };
102
103        GatewayStats {
104            total_requests: total,
105            successful_requests: successful,
106            failed_requests: self.failed_requests.load(Ordering::SeqCst),
107            total_tokens: self.total_tokens.load(Ordering::SeqCst),
108            avg_latency,
109            active_streams: self.active_streams.load(Ordering::SeqCst) as u32,
110        }
111    }
112}
113
114impl Default for StatsTracker {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120// ============================================================================
121// Federation Gateway
122// ============================================================================
123
124/// The main federation gateway
125pub struct FederationGateway {
126    config: GatewayConfig,
127    router: Arc<Router>,
128    health: Arc<HealthChecker>,
129    circuit_breaker: Arc<CircuitBreaker>,
130    middlewares: Vec<Box<dyn GatewayMiddleware>>,
131    stats: StatsTracker,
132}
133
134impl FederationGateway {
135    pub fn new(
136        config: GatewayConfig,
137        router: Arc<Router>,
138        health: Arc<HealthChecker>,
139        circuit_breaker: Arc<CircuitBreaker>,
140    ) -> Self {
141        Self {
142            config,
143            router,
144            health,
145            circuit_breaker,
146            middlewares: Vec::new(),
147            stats: StatsTracker::new(),
148        }
149    }
150
151    /// Add middleware to the gateway
152    #[must_use]
153    pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
154        self.middlewares.push(Box::new(middleware));
155        self
156    }
157
158    /// Execute inference with retries
159    async fn execute_with_retries(
160        &self,
161        mut request: InferenceRequest,
162    ) -> FederationResult<InferenceResponse> {
163        // Apply before_route middlewares
164        for middleware in &self.middlewares {
165            middleware.before_route(&mut request)?;
166        }
167
168        let mut last_error = None;
169        let mut tried_nodes = Vec::new();
170
171        for attempt in 0..=self.config.max_retries {
172            // Route request (excluding already-tried nodes)
173            // In production, we'd modify the request to exclude tried_nodes
174            // For now, use the original request
175            let target = match self.router.route(&request).await {
176                Ok(t) => t,
177                Err(e) => {
178                    last_error = Some(e);
179                    continue;
180                }
181            };
182
183            // Check circuit breaker
184            if self.circuit_breaker.is_open(&target.node_id) {
185                last_error = Some(FederationError::CircuitOpen(target.node_id.clone()));
186                tried_nodes.push(target.node_id);
187                continue;
188            }
189
190            // Execute inference
191            let start = Instant::now();
192            match self.execute_on_node(&target, &request).await {
193                Ok(mut response) => {
194                    let latency = start.elapsed();
195
196                    // Record success
197                    self.health.report_success(&target.node_id, latency);
198                    self.circuit_breaker.record_success(&target.node_id);
199                    self.stats.record_success(latency, response.tokens);
200
201                    // Apply after_infer middlewares
202                    for middleware in &self.middlewares {
203                        middleware.after_infer(&request, &mut response)?;
204                    }
205
206                    return Ok(response);
207                }
208                Err(e) => {
209                    // Record failure
210                    self.health.report_failure(&target.node_id);
211                    self.circuit_breaker.record_failure(&target.node_id);
212
213                    // Notify middlewares
214                    for middleware in &self.middlewares {
215                        middleware.on_error(&request, &e);
216                    }
217
218                    last_error = Some(e);
219                    tried_nodes.push(target.node_id);
220
221                    if attempt < self.config.max_retries {
222                        // Brief backoff before retry
223                        tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
224                    }
225                }
226            }
227        }
228
229        self.stats.record_failure();
230        Err(last_error
231            .unwrap_or_else(|| FederationError::Internal("All retries exhausted".to_string())))
232    }
233
234    /// Execute inference on a specific node
235    #[allow(clippy::unused_async)] // Will be async when HTTP calls implemented
236    async fn execute_on_node(
237        &self,
238        target: &RouteTarget,
239        _request: &InferenceRequest,
240    ) -> FederationResult<InferenceResponse> {
241        // In production, this would make an HTTP/gRPC call to the target node
242        // For now, we simulate the response
243
244        if target.endpoint.is_empty() {
245            // Simulated response for testing
246            Ok(InferenceResponse {
247                output: b"simulated output".to_vec(),
248                served_by: target.node_id.clone(),
249                latency: Duration::from_millis(50),
250                tokens: Some(10),
251            })
252        } else {
253            // Would make actual HTTP call here
254            // For now, return simulated response
255            Ok(InferenceResponse {
256                output: b"simulated output".to_vec(),
257                served_by: target.node_id.clone(),
258                latency: Duration::from_millis(50),
259                tokens: Some(10),
260            })
261        }
262    }
263}
264
265impl GatewayTrait for FederationGateway {
266    fn infer(
267        &self,
268        request: InferenceRequest,
269    ) -> BoxFuture<'_, FederationResult<InferenceResponse>> {
270        Box::pin(async move {
271            self.stats.record_request();
272            self.execute_with_retries(request).await
273        })
274    }
275
276    fn infer_stream(
277        &self,
278        request: InferenceRequest,
279    ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>> {
280        Box::pin(async move {
281            self.stats.record_request();
282            self.stats.increment_streams();
283
284            // Route request
285            let target = self.router.route(&request).await?;
286
287            // Create streaming connection
288            let stream = FederationTokenStream::new(
289                target,
290                request,
291                Arc::clone(&self.health),
292                Arc::clone(&self.circuit_breaker),
293            );
294
295            let stream: Box<dyn TokenStream> = Box::new(stream);
296            Ok(stream)
297        })
298    }
299
300    fn stats(&self) -> GatewayStats {
301        self.stats.snapshot()
302    }
303}
304
305// ============================================================================
306// Token Stream Implementation
307// ============================================================================
308
309/// Streaming token response
310struct FederationTokenStream {
311    target: RouteTarget,
312    _request: InferenceRequest,
313    health: Arc<HealthChecker>,
314    circuit_breaker: Arc<CircuitBreaker>,
315    tokens_generated: u32,
316    finished: bool,
317}
318
319impl FederationTokenStream {
320    fn new(
321        target: RouteTarget,
322        request: InferenceRequest,
323        health: Arc<HealthChecker>,
324        circuit_breaker: Arc<CircuitBreaker>,
325    ) -> Self {
326        Self {
327            target,
328            _request: request,
329            health,
330            circuit_breaker,
331            tokens_generated: 0,
332            finished: false,
333        }
334    }
335}
336
337impl TokenStream for FederationTokenStream {
338    fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>> {
339        Box::pin(async move {
340            if self.finished {
341                return None;
342            }
343
344            // Simulate token generation (in production, would read from connection)
345            self.tokens_generated += 1;
346
347            if self.tokens_generated > 10 {
348                self.finished = true;
349                self.health
350                    .report_success(&self.target.node_id, Duration::from_millis(50));
351                self.circuit_breaker.record_success(&self.target.node_id);
352                return None;
353            }
354
355            Some(Ok(format!("token_{}", self.tokens_generated).into_bytes()))
356        })
357    }
358
359    fn cancel(&mut self) -> BoxFuture<'_, ()> {
360        Box::pin(async move {
361            self.finished = true;
362        })
363    }
364}
365
366// ============================================================================
367// Gateway Builder
368// ============================================================================
369
370/// Builder for creating federation gateways
371pub struct GatewayBuilder {
372    config: GatewayConfig,
373    catalog: Option<Arc<ModelCatalog>>,
374    health: Option<Arc<HealthChecker>>,
375    circuit_breaker: Option<Arc<CircuitBreaker>>,
376    router: Option<Arc<Router>>,
377    middlewares: Vec<Box<dyn GatewayMiddleware>>,
378}
379
380impl GatewayBuilder {
381    pub fn new() -> Self {
382        Self {
383            config: GatewayConfig::default(),
384            catalog: None,
385            health: None,
386            circuit_breaker: None,
387            router: None,
388            middlewares: Vec::new(),
389        }
390    }
391
392    #[must_use]
393    pub fn config(mut self, config: GatewayConfig) -> Self {
394        self.config = config;
395        self
396    }
397
398    #[must_use]
399    pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
400        self.catalog = Some(catalog);
401        self
402    }
403
404    #[must_use]
405    pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
406        self.health = Some(health);
407        self
408    }
409
410    #[must_use]
411    pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
412        self.circuit_breaker = Some(cb);
413        self
414    }
415
416    #[must_use]
417    pub fn router(mut self, router: Arc<Router>) -> Self {
418        self.router = Some(router);
419        self
420    }
421
422    #[must_use]
423    pub fn middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
424        self.middlewares.push(Box::new(middleware));
425        self
426    }
427
428    pub fn build(self) -> FederationGateway {
429        let catalog = self
430            .catalog
431            .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
432        let health = self
433            .health
434            .unwrap_or_else(|| Arc::new(HealthChecker::default()));
435        let circuit_breaker = self
436            .circuit_breaker
437            .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
438
439        let router = self.router.unwrap_or_else(|| {
440            Arc::new(Router::new(
441                super::routing::RouterConfig::default(),
442                Arc::clone(&catalog),
443                Arc::clone(&health),
444                Arc::clone(&circuit_breaker),
445            ))
446        });
447
448        let mut gateway = FederationGateway::new(self.config, router, health, circuit_breaker);
449
450        for middleware in self.middlewares {
451            gateway.middlewares.push(middleware);
452        }
453
454        gateway
455    }
456}
457
458impl Default for GatewayBuilder {
459    fn default() -> Self {
460        Self::new()
461    }
462}
463
464// ============================================================================
465// Example Middlewares
466// ============================================================================
467
468/// Logging middleware
469pub struct LoggingMiddleware {
470    prefix: String,
471}
472
473impl LoggingMiddleware {
474    pub fn new(prefix: impl Into<String>) -> Self {
475        Self {
476            prefix: prefix.into(),
477        }
478    }
479}
480
481impl GatewayMiddleware for LoggingMiddleware {
482    fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()> {
483        eprintln!(
484            "[{}] Routing request {} for {:?}",
485            self.prefix, request.request_id, request.capability
486        );
487        Ok(())
488    }
489
490    fn after_infer(
491        &self,
492        request: &InferenceRequest,
493        response: &mut InferenceResponse,
494    ) -> FederationResult<()> {
495        eprintln!(
496            "[{}] Request {} served by {:?} in {:?}",
497            self.prefix, request.request_id, response.served_by, response.latency
498        );
499        Ok(())
500    }
501
502    fn on_error(&self, request: &InferenceRequest, error: &FederationError) {
503        eprintln!(
504            "[{}] Request {} failed: {}",
505            self.prefix, request.request_id, error
506        );
507    }
508}
509
510/// Rate limiting middleware
511pub struct RateLimitMiddleware {
512    #[allow(dead_code)]
513    requests_per_second: u32,
514    // In production, would use a token bucket or sliding window
515}
516
517impl RateLimitMiddleware {
518    pub fn new(requests_per_second: u32) -> Self {
519        Self {
520            requests_per_second,
521        }
522    }
523}
524
525impl GatewayMiddleware for RateLimitMiddleware {
526    fn before_route(&self, _request: &mut InferenceRequest) -> FederationResult<()> {
527        // In production, would check rate limit and return error if exceeded
528        // For now, always allow
529        Ok(())
530    }
531
532    fn after_infer(
533        &self,
534        _request: &InferenceRequest,
535        _response: &mut InferenceResponse,
536    ) -> FederationResult<()> {
537        Ok(())
538    }
539
540    fn on_error(&self, _request: &InferenceRequest, _error: &FederationError) {}
541}
542
543// ============================================================================
544// Tests
545// ============================================================================
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    fn setup_test_gateway() -> (FederationGateway, Arc<ModelCatalog>, Arc<HealthChecker>) {
552        let catalog = Arc::new(ModelCatalog::new());
553        let health = Arc::new(HealthChecker::default());
554        let circuit_breaker = Arc::new(CircuitBreaker::default());
555
556        let router = Arc::new(Router::new(
557            super::super::routing::RouterConfig::default(),
558            Arc::clone(&catalog),
559            Arc::clone(&health),
560            Arc::clone(&circuit_breaker),
561        ));
562
563        let gateway = FederationGateway::new(
564            GatewayConfig::default(),
565            router,
566            Arc::clone(&health),
567            circuit_breaker,
568        );
569
570        (gateway, catalog, health)
571    }
572
573    #[tokio::test]
574    async fn test_infer_no_nodes() {
575        let (gateway, _, _) = setup_test_gateway();
576
577        let request = InferenceRequest {
578            capability: Capability::Generate,
579            input: b"hello".to_vec(),
580            qos: QoSRequirements::default(),
581            request_id: "test-1".to_string(),
582            tenant_id: None,
583        };
584
585        let result = gateway.infer(request).await;
586        assert!(result.is_err());
587    }
588
589    #[tokio::test]
590    async fn test_infer_with_node() {
591        let (gateway, catalog, health) = setup_test_gateway();
592
593        // Register a node
594        catalog
595            .register(
596                ModelId("test-model".to_string()),
597                NodeId("node-1".to_string()),
598                RegionId("us-west".to_string()),
599                vec![Capability::Generate],
600            )
601            .await
602            .expect("registration failed");
603
604        health.register_node(NodeId("node-1".to_string()));
605        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
606
607        let request = InferenceRequest {
608            capability: Capability::Generate,
609            input: b"hello".to_vec(),
610            qos: QoSRequirements::default(),
611            request_id: "test-2".to_string(),
612            tenant_id: None,
613        };
614
615        let result = gateway.infer(request).await;
616        assert!(result.is_ok());
617
618        let response = result.expect("inference failed");
619        assert_eq!(response.served_by, NodeId("node-1".to_string()));
620    }
621
622    #[tokio::test]
623    async fn test_stats_tracking() {
624        let (gateway, catalog, health) = setup_test_gateway();
625
626        catalog
627            .register(
628                ModelId("test".to_string()),
629                NodeId("node-1".to_string()),
630                RegionId("us-west".to_string()),
631                vec![Capability::Embed],
632            )
633            .await
634            .expect("registration failed");
635
636        health.register_node(NodeId("node-1".to_string()));
637        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
638
639        // Make some requests
640        for i in 0..3 {
641            let request = InferenceRequest {
642                capability: Capability::Embed,
643                input: vec![i],
644                qos: QoSRequirements::default(),
645                request_id: format!("test-{}", i),
646                tenant_id: None,
647            };
648
649            let _ = gateway.infer(request).await;
650        }
651
652        let stats = gateway.stats();
653        assert_eq!(stats.total_requests, 3);
654        assert_eq!(stats.successful_requests, 3);
655        assert_eq!(stats.failed_requests, 0);
656    }
657
658    #[tokio::test]
659    async fn test_streaming() {
660        let (gateway, catalog, health) = setup_test_gateway();
661
662        catalog
663            .register(
664                ModelId("stream-model".to_string()),
665                NodeId("node-1".to_string()),
666                RegionId("us-west".to_string()),
667                vec![Capability::Generate],
668            )
669            .await
670            .expect("registration failed");
671
672        health.register_node(NodeId("node-1".to_string()));
673        health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
674
675        let request = InferenceRequest {
676            capability: Capability::Generate,
677            input: b"stream test".to_vec(),
678            qos: QoSRequirements::default(),
679            request_id: "stream-1".to_string(),
680            tenant_id: None,
681        };
682
683        let result = gateway.infer_stream(request).await;
684        assert!(result.is_ok());
685
686        let mut stream = result.expect("stream creation failed");
687
688        // Read tokens
689        let mut token_count = 0;
690        while let Some(result) = stream.next_token().await {
691            assert!(result.is_ok());
692            token_count += 1;
693        }
694
695        assert_eq!(token_count, 10);
696    }
697
698    #[test]
699    fn test_gateway_builder() {
700        let gateway = GatewayBuilder::new()
701            .config(GatewayConfig {
702                max_retries: 5,
703                inference_timeout: Duration::from_secs(60),
704                enable_tracing: false,
705            })
706            .middleware(LoggingMiddleware::new("test"))
707            .build();
708
709        assert_eq!(gateway.config.max_retries, 5);
710        assert_eq!(gateway.middlewares.len(), 1);
711    }
712
713    // =========================================================================
714    // GatewayConfig tests
715    // =========================================================================
716
717    #[test]
718    fn test_gateway_config_default() {
719        let config = GatewayConfig::default();
720        assert_eq!(config.max_retries, 3);
721        assert_eq!(config.inference_timeout, Duration::from_secs(30));
722        assert!(config.enable_tracing);
723    }
724
725    #[test]
726    fn test_gateway_config_clone() {
727        let config = GatewayConfig {
728            max_retries: 5,
729            inference_timeout: Duration::from_secs(60),
730            enable_tracing: false,
731        };
732        let cloned = config.clone();
733        assert_eq!(cloned.max_retries, 5);
734        assert!(!cloned.enable_tracing);
735    }
736
737    // =========================================================================
738    // GatewayBuilder extended tests
739    // =========================================================================
740
741    #[test]
742    fn test_gateway_builder_default() {
743        let builder = GatewayBuilder::default();
744        assert!(builder.catalog.is_none());
745        assert!(builder.health.is_none());
746        assert!(builder.circuit_breaker.is_none());
747        assert!(builder.router.is_none());
748        assert!(builder.middlewares.is_empty());
749    }
750
751    #[test]
752    fn test_gateway_builder_with_catalog() {
753        let catalog = Arc::new(ModelCatalog::new());
754        let builder = GatewayBuilder::new().catalog(catalog);
755        assert!(builder.catalog.is_some());
756    }
757
758    #[test]
759    fn test_gateway_builder_with_health() {
760        let health = Arc::new(HealthChecker::default());
761        let builder = GatewayBuilder::new().health(health);
762        assert!(builder.health.is_some());
763    }
764
765    #[test]
766    fn test_gateway_builder_with_circuit_breaker() {
767        let cb = Arc::new(CircuitBreaker::default());
768        let builder = GatewayBuilder::new().circuit_breaker(cb);
769        assert!(builder.circuit_breaker.is_some());
770    }
771
772    #[test]
773    fn test_gateway_builder_with_router() {
774        let catalog = Arc::new(ModelCatalog::new());
775        let health = Arc::new(HealthChecker::default());
776        let cb = Arc::new(CircuitBreaker::default());
777        let router = Arc::new(Router::new(
778            super::super::routing::RouterConfig::default(),
779            catalog,
780            health,
781            cb,
782        ));
783        let builder = GatewayBuilder::new().router(router);
784        assert!(builder.router.is_some());
785    }
786
787    #[test]
788    fn test_gateway_builder_with_middleware() {
789        let builder = GatewayBuilder::new()
790            .middleware(LoggingMiddleware::new("test"))
791            .middleware(RateLimitMiddleware::new(100));
792        assert_eq!(builder.middlewares.len(), 2);
793    }
794
795    #[test]
796    fn test_gateway_builder_full_chain() {
797        let catalog = Arc::new(ModelCatalog::new());
798        let health = Arc::new(HealthChecker::default());
799        let cb = Arc::new(CircuitBreaker::default());
800
801        let gateway = GatewayBuilder::new()
802            .config(GatewayConfig {
803                max_retries: 5,
804                inference_timeout: Duration::from_secs(120),
805                enable_tracing: false,
806            })
807            .catalog(Arc::clone(&catalog))
808            .health(Arc::clone(&health))
809            .circuit_breaker(Arc::clone(&cb))
810            .middleware(LoggingMiddleware::new("gw"))
811            .build();
812
813        assert_eq!(gateway.config.max_retries, 5);
814        assert_eq!(gateway.middlewares.len(), 1);
815    }
816
817    // =========================================================================
818    // LoggingMiddleware tests
819    // =========================================================================
820
821    #[test]
822    fn test_logging_middleware_creation() {
823        let middleware = LoggingMiddleware::new("test-prefix");
824        assert_eq!(middleware.prefix, "test-prefix");
825    }
826
827    #[test]
828    fn test_logging_middleware_before_route() {
829        let middleware = LoggingMiddleware::new("test");
830        let mut request = InferenceRequest {
831            capability: Capability::Generate,
832            input: vec![],
833            qos: QoSRequirements::default(),
834            request_id: "req-1".to_string(),
835            tenant_id: None,
836        };
837        let result = middleware.before_route(&mut request);
838        assert!(result.is_ok());
839    }
840
841    #[test]
842    fn test_logging_middleware_after_infer() {
843        let middleware = LoggingMiddleware::new("test");
844        let request = InferenceRequest {
845            capability: Capability::Generate,
846            input: vec![],
847            qos: QoSRequirements::default(),
848            request_id: "req-1".to_string(),
849            tenant_id: None,
850        };
851        let mut response = InferenceResponse {
852            output: b"output".to_vec(),
853            served_by: NodeId("n1".to_string()),
854            latency: Duration::from_millis(50),
855            tokens: Some(5),
856        };
857        let result = middleware.after_infer(&request, &mut response);
858        assert!(result.is_ok());
859    }
860
861    #[test]
862    fn test_logging_middleware_on_error() {
863        let middleware = LoggingMiddleware::new("test");
864        let request = InferenceRequest {
865            capability: Capability::Generate,
866            input: vec![],
867            qos: QoSRequirements::default(),
868            request_id: "req-1".to_string(),
869            tenant_id: None,
870        };
871        let error = FederationError::Internal("test error".to_string());
872        // Should not panic
873        middleware.on_error(&request, &error);
874    }
875
876    // =========================================================================
877    // RateLimitMiddleware tests
878    // =========================================================================
879
880    #[test]
881    fn test_rate_limit_middleware_creation() {
882        let _middleware = RateLimitMiddleware::new(1000);
883    }
884
885    #[test]
886    fn test_rate_limit_middleware_before_route() {
887        let middleware = RateLimitMiddleware::new(100);
888        let mut request = InferenceRequest {
889            capability: Capability::Embed,
890            input: vec![],
891            qos: QoSRequirements::default(),
892            request_id: "req-1".to_string(),
893            tenant_id: None,
894        };
895        assert!(middleware.before_route(&mut request).is_ok());
896    }
897
898    #[test]
899    fn test_rate_limit_middleware_after_infer() {
900        let middleware = RateLimitMiddleware::new(100);
901        let request = InferenceRequest {
902            capability: Capability::Embed,
903            input: vec![],
904            qos: QoSRequirements::default(),
905            request_id: "req-1".to_string(),
906            tenant_id: None,
907        };
908        let mut response = InferenceResponse {
909            output: vec![],
910            served_by: NodeId("n1".to_string()),
911            latency: Duration::from_millis(10),
912            tokens: None,
913        };
914        assert!(middleware.after_infer(&request, &mut response).is_ok());
915    }
916
917    #[test]
918    fn test_rate_limit_middleware_on_error() {
919        let middleware = RateLimitMiddleware::new(100);
920        let request = InferenceRequest {
921            capability: Capability::Embed,
922            input: vec![],
923            qos: QoSRequirements::default(),
924            request_id: "req-1".to_string(),
925            tenant_id: None,
926        };
927        let error = FederationError::Internal("err".to_string());
928        middleware.on_error(&request, &error); // Should not panic
929    }
930
931    // =========================================================================
932    // Gateway with middleware integration test
933    // =========================================================================
934
935    #[tokio::test]
936    async fn test_gateway_with_logging_middleware() {
937        let catalog = Arc::new(ModelCatalog::new());
938        let health = Arc::new(HealthChecker::default());
939        let circuit_breaker = Arc::new(CircuitBreaker::default());
940
941        let router = Arc::new(Router::new(
942            super::super::routing::RouterConfig::default(),
943            Arc::clone(&catalog),
944            Arc::clone(&health),
945            Arc::clone(&circuit_breaker),
946        ));
947
948        let gateway = FederationGateway::new(
949            GatewayConfig::default(),
950            router,
951            Arc::clone(&health),
952            circuit_breaker,
953        )
954        .with_middleware(LoggingMiddleware::new("test-gw"));
955
956        assert_eq!(gateway.middlewares.len(), 1);
957
958        // Register a node so inference works
959        catalog
960            .register(
961                ModelId("m1".to_string()),
962                NodeId("n1".to_string()),
963                RegionId("us-west".to_string()),
964                vec![Capability::Generate],
965            )
966            .await
967            .expect("registration failed");
968
969        health.register_node(NodeId("n1".to_string()));
970        for _ in 0..3 {
971            health.report_success(&NodeId("n1".to_string()), Duration::from_millis(10));
972        }
973
974        let request = InferenceRequest {
975            capability: Capability::Generate,
976            input: b"test".to_vec(),
977            qos: QoSRequirements::default(),
978            request_id: "mw-test".to_string(),
979            tenant_id: None,
980        };
981
982        let result = gateway.infer(request).await;
983        assert!(result.is_ok());
984    }
985
986    // =========================================================================
987    // Stats tracking tests
988    // =========================================================================
989
990    #[test]
991    fn test_gateway_initial_stats() {
992        let gateway = GatewayBuilder::new().build();
993        let stats = gateway.stats();
994        assert_eq!(stats.total_requests, 0);
995        assert_eq!(stats.successful_requests, 0);
996        assert_eq!(stats.failed_requests, 0);
997        assert_eq!(stats.total_tokens, 0);
998        assert_eq!(stats.active_streams, 0);
999        assert_eq!(stats.avg_latency, Duration::ZERO);
1000    }
1001
1002    #[tokio::test]
1003    async fn test_gateway_stats_after_failures() {
1004        let gateway = GatewayBuilder::new()
1005            .config(GatewayConfig {
1006                max_retries: 0, // No retries
1007                ..Default::default()
1008            })
1009            .build();
1010
1011        // No nodes registered, so inference will fail
1012        let request = InferenceRequest {
1013            capability: Capability::Generate,
1014            input: b"test".to_vec(),
1015            qos: QoSRequirements::default(),
1016            request_id: "fail-test".to_string(),
1017            tenant_id: None,
1018        };
1019
1020        let _ = gateway.infer(request).await;
1021
1022        let stats = gateway.stats();
1023        assert_eq!(stats.total_requests, 1);
1024        assert_eq!(stats.failed_requests, 1);
1025        assert_eq!(stats.successful_requests, 0);
1026    }
1027
1028    // =========================================================================
1029    // Stream cancel test
1030    // =========================================================================
1031
1032    #[tokio::test]
1033    async fn test_stream_cancel() {
1034        let (gateway, catalog, health) = setup_test_gateway();
1035
1036        catalog
1037            .register(
1038                ModelId("stream-model".to_string()),
1039                NodeId("n1".to_string()),
1040                RegionId("us-west".to_string()),
1041                vec![Capability::Generate],
1042            )
1043            .await
1044            .expect("registration failed");
1045
1046        health.register_node(NodeId("n1".to_string()));
1047        health.report_success(&NodeId("n1".to_string()), Duration::from_millis(10));
1048
1049        let request = InferenceRequest {
1050            capability: Capability::Generate,
1051            input: b"stream".to_vec(),
1052            qos: QoSRequirements::default(),
1053            request_id: "cancel-test".to_string(),
1054            tenant_id: None,
1055        };
1056
1057        let mut stream = gateway.infer_stream(request).await.expect("stream failed");
1058
1059        // Read a few tokens
1060        let _ = stream.next_token().await;
1061        let _ = stream.next_token().await;
1062
1063        // Cancel the stream
1064        stream.cancel().await;
1065
1066        // After cancel, next_token should return None
1067        let result = stream.next_token().await;
1068        assert!(result.is_none());
1069    }
1070
1071    /// Comprehensive integration test demonstrating full federation flow
1072    #[tokio::test]
1073    async fn test_full_federation_flow() {
1074        use super::super::policy::CompositePolicy;
1075
1076        // =====================================================================
1077        // Setup: Create multi-region deployment
1078        // =====================================================================
1079        let catalog = Arc::new(ModelCatalog::new());
1080        let health = Arc::new(HealthChecker::default());
1081        let circuit_breaker = Arc::new(CircuitBreaker::default());
1082
1083        // Register Whisper model in US-West (primary, fast)
1084        catalog
1085            .register(
1086                ModelId("whisper-v3".to_string()),
1087                NodeId("us-west-gpu".to_string()),
1088                RegionId("us-west".to_string()),
1089                vec![Capability::Transcribe],
1090            )
1091            .await
1092            .expect("failed to register us-west");
1093
1094        // Register Whisper model in EU-West (GDPR compliant)
1095        catalog
1096            .register(
1097                ModelId("whisper-v3".to_string()),
1098                NodeId("eu-west-gpu".to_string()),
1099                RegionId("eu-west".to_string()),
1100                vec![Capability::Transcribe],
1101            )
1102            .await
1103            .expect("failed to register eu-west");
1104
1105        // Register LLaMA in US-East
1106        catalog
1107            .register(
1108                ModelId("llama-70b".to_string()),
1109                NodeId("us-east-gpu".to_string()),
1110                RegionId("us-east".to_string()),
1111                vec![Capability::Generate, Capability::Code],
1112            )
1113            .await
1114            .expect("failed to register llama");
1115
1116        // Register embedding model in multiple regions
1117        for (node, region) in [("embed-us", "us-west"), ("embed-eu", "eu-west")] {
1118            catalog
1119                .register(
1120                    ModelId("bge-large".to_string()),
1121                    NodeId(node.to_string()),
1122                    RegionId(region.to_string()),
1123                    vec![Capability::Embed],
1124                )
1125                .await
1126                .expect("failed to register embedding");
1127        }
1128
1129        // =====================================================================
1130        // Setup health states
1131        // =====================================================================
1132
1133        // US-West: Healthy, fast (45ms)
1134        health.register_node(NodeId("us-west-gpu".to_string()));
1135        for _ in 0..3 {
1136            health.report_success(
1137                &NodeId("us-west-gpu".to_string()),
1138                Duration::from_millis(45),
1139            );
1140        }
1141
1142        // EU-West: Healthy, slower (120ms)
1143        health.register_node(NodeId("eu-west-gpu".to_string()));
1144        for _ in 0..3 {
1145            health.report_success(
1146                &NodeId("eu-west-gpu".to_string()),
1147                Duration::from_millis(120),
1148            );
1149        }
1150
1151        // US-East: Will be unknown/degraded (only 1 success)
1152        health.register_node(NodeId("us-east-gpu".to_string()));
1153        // Just one success keeps it in Unknown state (needs 2 for Healthy)
1154        health.report_success(
1155            &NodeId("us-east-gpu".to_string()),
1156            Duration::from_millis(100),
1157        );
1158
1159        // Embedding nodes: Healthy
1160        for node in ["embed-us", "embed-eu"] {
1161            health.register_node(NodeId(node.to_string()));
1162            for _ in 0..3 {
1163                health.report_success(&NodeId(node.to_string()), Duration::from_millis(15));
1164            }
1165        }
1166
1167        // =====================================================================
1168        // Create Router with enterprise policies
1169        // =====================================================================
1170        let router = Arc::new(
1171            Router::new(
1172                super::super::routing::RouterConfig {
1173                    max_candidates: 10,
1174                    min_score: 0.1,
1175                    strategy: LoadBalanceStrategy::LeastLatency,
1176                },
1177                Arc::clone(&catalog),
1178                Arc::clone(&health),
1179                Arc::clone(&circuit_breaker),
1180            )
1181            .with_policy(CompositePolicy::enterprise_default()),
1182        );
1183
1184        // =====================================================================
1185        // Create Gateway
1186        // =====================================================================
1187        let gateway = FederationGateway::new(
1188            GatewayConfig {
1189                max_retries: 3,
1190                inference_timeout: Duration::from_secs(30),
1191                enable_tracing: true,
1192            },
1193            Arc::clone(&router),
1194            Arc::clone(&health),
1195            Arc::clone(&circuit_breaker),
1196        );
1197
1198        // =====================================================================
1199        // Test 1: Transcribe routes to fastest healthy node (us-west)
1200        // =====================================================================
1201        let request = InferenceRequest {
1202            capability: Capability::Transcribe,
1203            input: b"audio data".to_vec(),
1204            qos: QoSRequirements::default(),
1205            request_id: "test-transcribe".to_string(),
1206            tenant_id: Some("acme".to_string()),
1207        };
1208
1209        let candidates = router
1210            .get_candidates(&request)
1211            .await
1212            .expect("get_candidates failed");
1213        assert_eq!(candidates.len(), 2, "Should have 2 Transcribe candidates");
1214
1215        let target = router.route(&request).await.expect("route failed");
1216        // US-West should be selected (lower latency = higher score)
1217        assert_eq!(target.node_id, NodeId("us-west-gpu".to_string()));
1218
1219        // =====================================================================
1220        // Test 2: Generate routes to only available node
1221        // =====================================================================
1222        let request = InferenceRequest {
1223            capability: Capability::Generate,
1224            input: b"prompt".to_vec(),
1225            qos: QoSRequirements::default(),
1226            request_id: "test-generate".to_string(),
1227            tenant_id: None,
1228        };
1229
1230        let target = router.route(&request).await.expect("route failed");
1231        assert_eq!(target.node_id, NodeId("us-east-gpu".to_string()));
1232
1233        // =====================================================================
1234        // Test 3: Embed has multiple candidates
1235        // =====================================================================
1236        let request = InferenceRequest {
1237            capability: Capability::Embed,
1238            input: b"text".to_vec(),
1239            qos: QoSRequirements::default(),
1240            request_id: "test-embed".to_string(),
1241            tenant_id: None,
1242        };
1243
1244        let candidates = router
1245            .get_candidates(&request)
1246            .await
1247            .expect("get_candidates failed");
1248        assert_eq!(candidates.len(), 2, "Should have 2 Embed candidates");
1249
1250        // =====================================================================
1251        // Test 4: Gateway inference with stats
1252        // =====================================================================
1253        let request = InferenceRequest {
1254            capability: Capability::Transcribe,
1255            input: b"audio".to_vec(),
1256            qos: QoSRequirements::default(),
1257            request_id: "test-infer".to_string(),
1258            tenant_id: None,
1259        };
1260
1261        let response = gateway.infer(request).await.expect("inference failed");
1262        assert_eq!(response.served_by, NodeId("us-west-gpu".to_string()));
1263        assert!(!response.output.is_empty());
1264
1265        let stats = gateway.stats();
1266        assert_eq!(stats.total_requests, 1);
1267        assert_eq!(stats.successful_requests, 1);
1268        assert_eq!(stats.failed_requests, 0);
1269
1270        // =====================================================================
1271        // Test 5: Streaming inference
1272        // =====================================================================
1273        let request = InferenceRequest {
1274            capability: Capability::Generate,
1275            input: b"stream prompt".to_vec(),
1276            qos: QoSRequirements::default(),
1277            request_id: "test-stream".to_string(),
1278            tenant_id: None,
1279        };
1280
1281        let mut stream = gateway.infer_stream(request).await.expect("stream failed");
1282        let mut tokens = 0;
1283        while let Some(result) = stream.next_token().await {
1284            result.expect("token error");
1285            tokens += 1;
1286        }
1287        assert_eq!(tokens, 10, "Should receive 10 tokens");
1288
1289        // =====================================================================
1290        // Test 6: Circuit breaker
1291        // =====================================================================
1292        let bad_node = NodeId("failing-node".to_string());
1293
1294        // Initially closed
1295        assert_eq!(circuit_breaker.state(&bad_node), CircuitState::Closed);
1296
1297        // Simulate failures
1298        for _ in 0..5 {
1299            circuit_breaker.record_failure(&bad_node);
1300        }
1301        assert_eq!(circuit_breaker.state(&bad_node), CircuitState::Open);
1302        assert!(circuit_breaker.is_open(&bad_node));
1303
1304        // =====================================================================
1305        // Verify catalog state
1306        // =====================================================================
1307        let all_models = catalog.list_all().await.expect("list failed");
1308        assert_eq!(all_models.len(), 3); // whisper-v3, llama-70b, bge-large
1309
1310        // =====================================================================
1311        // Verify health states are tracked (all nodes have cached health)
1312        // =====================================================================
1313        let nodes_with_health = [
1314            "us-west-gpu",
1315            "eu-west-gpu",
1316            "us-east-gpu",
1317            "embed-us",
1318            "embed-eu",
1319        ];
1320        for node in nodes_with_health {
1321            let h = health.get_cached_health(&NodeId(node.to_string()));
1322            assert!(h.is_some(), "Health should be tracked for {}", node);
1323        }
1324
1325        // Verify us-west has good health (3 successes)
1326        let us_west_health = health
1327            .get_cached_health(&NodeId("us-west-gpu".to_string()))
1328            .unwrap();
1329        assert_eq!(
1330            us_west_health.status,
1331            HealthState::Healthy,
1332            "US-West should be healthy"
1333        );
1334
1335        // =====================================================================
1336        // Final summary - print results for visibility
1337        // =====================================================================
1338        println!("\n✅ Full Federation Flow Test PASSED!");
1339        println!("   - 3 models registered across 5 nodes");
1340        println!("   - 6 health entries tracked");
1341        println!("   - Routing correctly prefers fastest healthy nodes");
1342        println!("   - Gateway inference succeeds with stats tracking");
1343        println!("   - Streaming returns expected token count");
1344        println!("   - Circuit breaker opens after failures");
1345    }
1346}