1use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::routing::Router;
9use super::traits::*;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone)]
20pub struct GatewayConfig {
21 pub max_retries: u32,
23 pub inference_timeout: Duration,
25 pub enable_tracing: bool,
27}
28
29impl Default for GatewayConfig {
30 fn default() -> Self {
31 Self {
32 max_retries: 3,
33 inference_timeout: Duration::from_secs(30),
34 enable_tracing: true,
35 }
36 }
37}
38
39struct StatsTracker {
45 total_requests: AtomicU64,
46 successful_requests: AtomicU64,
47 failed_requests: AtomicU64,
48 total_tokens: AtomicU64,
49 total_latency_ms: AtomicU64,
50 active_streams: AtomicU64,
51}
52
53impl StatsTracker {
54 fn new() -> Self {
55 Self {
56 total_requests: AtomicU64::new(0),
57 successful_requests: AtomicU64::new(0),
58 failed_requests: AtomicU64::new(0),
59 total_tokens: AtomicU64::new(0),
60 total_latency_ms: AtomicU64::new(0),
61 active_streams: AtomicU64::new(0),
62 }
63 }
64
65 fn record_request(&self) {
66 self.total_requests.fetch_add(1, Ordering::SeqCst);
67 }
68
69 fn record_success(&self, latency: Duration, tokens: Option<u32>) {
70 self.successful_requests.fetch_add(1, Ordering::SeqCst);
71 self.total_latency_ms
72 .fetch_add(latency.as_millis() as u64, Ordering::SeqCst);
73 if let Some(t) = tokens {
74 self.total_tokens.fetch_add(t as u64, Ordering::SeqCst);
75 }
76 }
77
78 fn record_failure(&self) {
79 self.failed_requests.fetch_add(1, Ordering::SeqCst);
80 }
81
82 #[allow(dead_code)]
83 fn increment_streams(&self) {
84 self.active_streams.fetch_add(1, Ordering::SeqCst);
85 }
86
87 #[allow(dead_code)]
88 fn decrement_streams(&self) {
89 self.active_streams.fetch_sub(1, Ordering::SeqCst);
90 }
91
92 fn snapshot(&self) -> GatewayStats {
93 let total = self.total_requests.load(Ordering::SeqCst);
94 let successful = self.successful_requests.load(Ordering::SeqCst);
95 let total_latency = self.total_latency_ms.load(Ordering::SeqCst);
96
97 let avg_latency = if successful > 0 {
98 Duration::from_millis(total_latency / successful)
99 } else {
100 Duration::ZERO
101 };
102
103 GatewayStats {
104 total_requests: total,
105 successful_requests: successful,
106 failed_requests: self.failed_requests.load(Ordering::SeqCst),
107 total_tokens: self.total_tokens.load(Ordering::SeqCst),
108 avg_latency,
109 active_streams: self.active_streams.load(Ordering::SeqCst) as u32,
110 }
111 }
112}
113
114impl Default for StatsTracker {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120pub struct FederationGateway {
126 config: GatewayConfig,
127 router: Arc<Router>,
128 health: Arc<HealthChecker>,
129 circuit_breaker: Arc<CircuitBreaker>,
130 middlewares: Vec<Box<dyn GatewayMiddleware>>,
131 stats: StatsTracker,
132}
133
134impl FederationGateway {
135 pub fn new(
136 config: GatewayConfig,
137 router: Arc<Router>,
138 health: Arc<HealthChecker>,
139 circuit_breaker: Arc<CircuitBreaker>,
140 ) -> Self {
141 Self {
142 config,
143 router,
144 health,
145 circuit_breaker,
146 middlewares: Vec::new(),
147 stats: StatsTracker::new(),
148 }
149 }
150
151 #[must_use]
153 pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
154 self.middlewares.push(Box::new(middleware));
155 self
156 }
157
158 async fn execute_with_retries(
160 &self,
161 mut request: InferenceRequest,
162 ) -> FederationResult<InferenceResponse> {
163 for middleware in &self.middlewares {
165 middleware.before_route(&mut request)?;
166 }
167
168 let mut last_error = None;
169 let mut tried_nodes = Vec::new();
170
171 for attempt in 0..=self.config.max_retries {
172 let target = match self.router.route(&request).await {
176 Ok(t) => t,
177 Err(e) => {
178 last_error = Some(e);
179 continue;
180 }
181 };
182
183 if self.circuit_breaker.is_open(&target.node_id) {
185 last_error = Some(FederationError::CircuitOpen(target.node_id.clone()));
186 tried_nodes.push(target.node_id);
187 continue;
188 }
189
190 let start = Instant::now();
192 match self.execute_on_node(&target, &request).await {
193 Ok(mut response) => {
194 let latency = start.elapsed();
195
196 self.health.report_success(&target.node_id, latency);
198 self.circuit_breaker.record_success(&target.node_id);
199 self.stats.record_success(latency, response.tokens);
200
201 for middleware in &self.middlewares {
203 middleware.after_infer(&request, &mut response)?;
204 }
205
206 return Ok(response);
207 }
208 Err(e) => {
209 self.health.report_failure(&target.node_id);
211 self.circuit_breaker.record_failure(&target.node_id);
212
213 for middleware in &self.middlewares {
215 middleware.on_error(&request, &e);
216 }
217
218 last_error = Some(e);
219 tried_nodes.push(target.node_id);
220
221 if attempt < self.config.max_retries {
222 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
224 }
225 }
226 }
227 }
228
229 self.stats.record_failure();
230 Err(last_error
231 .unwrap_or_else(|| FederationError::Internal("All retries exhausted".to_string())))
232 }
233
234 #[allow(clippy::unused_async)] async fn execute_on_node(
237 &self,
238 target: &RouteTarget,
239 _request: &InferenceRequest,
240 ) -> FederationResult<InferenceResponse> {
241 if target.endpoint.is_empty() {
245 Ok(InferenceResponse {
247 output: b"simulated output".to_vec(),
248 served_by: target.node_id.clone(),
249 latency: Duration::from_millis(50),
250 tokens: Some(10),
251 })
252 } else {
253 Ok(InferenceResponse {
256 output: b"simulated output".to_vec(),
257 served_by: target.node_id.clone(),
258 latency: Duration::from_millis(50),
259 tokens: Some(10),
260 })
261 }
262 }
263}
264
265impl GatewayTrait for FederationGateway {
266 fn infer(
267 &self,
268 request: InferenceRequest,
269 ) -> BoxFuture<'_, FederationResult<InferenceResponse>> {
270 Box::pin(async move {
271 self.stats.record_request();
272 self.execute_with_retries(request).await
273 })
274 }
275
276 fn infer_stream(
277 &self,
278 request: InferenceRequest,
279 ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>> {
280 Box::pin(async move {
281 self.stats.record_request();
282 self.stats.increment_streams();
283
284 let target = self.router.route(&request).await?;
286
287 let stream = FederationTokenStream::new(
289 target,
290 request,
291 Arc::clone(&self.health),
292 Arc::clone(&self.circuit_breaker),
293 );
294
295 let stream: Box<dyn TokenStream> = Box::new(stream);
296 Ok(stream)
297 })
298 }
299
300 fn stats(&self) -> GatewayStats {
301 self.stats.snapshot()
302 }
303}
304
305struct FederationTokenStream {
311 target: RouteTarget,
312 _request: InferenceRequest,
313 health: Arc<HealthChecker>,
314 circuit_breaker: Arc<CircuitBreaker>,
315 tokens_generated: u32,
316 finished: bool,
317}
318
319impl FederationTokenStream {
320 fn new(
321 target: RouteTarget,
322 request: InferenceRequest,
323 health: Arc<HealthChecker>,
324 circuit_breaker: Arc<CircuitBreaker>,
325 ) -> Self {
326 Self {
327 target,
328 _request: request,
329 health,
330 circuit_breaker,
331 tokens_generated: 0,
332 finished: false,
333 }
334 }
335}
336
337impl TokenStream for FederationTokenStream {
338 fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>> {
339 Box::pin(async move {
340 if self.finished {
341 return None;
342 }
343
344 self.tokens_generated += 1;
346
347 if self.tokens_generated > 10 {
348 self.finished = true;
349 self.health
350 .report_success(&self.target.node_id, Duration::from_millis(50));
351 self.circuit_breaker.record_success(&self.target.node_id);
352 return None;
353 }
354
355 Some(Ok(format!("token_{}", self.tokens_generated).into_bytes()))
356 })
357 }
358
359 fn cancel(&mut self) -> BoxFuture<'_, ()> {
360 Box::pin(async move {
361 self.finished = true;
362 })
363 }
364}
365
366pub struct GatewayBuilder {
372 config: GatewayConfig,
373 catalog: Option<Arc<ModelCatalog>>,
374 health: Option<Arc<HealthChecker>>,
375 circuit_breaker: Option<Arc<CircuitBreaker>>,
376 router: Option<Arc<Router>>,
377 middlewares: Vec<Box<dyn GatewayMiddleware>>,
378}
379
380impl GatewayBuilder {
381 pub fn new() -> Self {
382 Self {
383 config: GatewayConfig::default(),
384 catalog: None,
385 health: None,
386 circuit_breaker: None,
387 router: None,
388 middlewares: Vec::new(),
389 }
390 }
391
392 #[must_use]
393 pub fn config(mut self, config: GatewayConfig) -> Self {
394 self.config = config;
395 self
396 }
397
398 #[must_use]
399 pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
400 self.catalog = Some(catalog);
401 self
402 }
403
404 #[must_use]
405 pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
406 self.health = Some(health);
407 self
408 }
409
410 #[must_use]
411 pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
412 self.circuit_breaker = Some(cb);
413 self
414 }
415
416 #[must_use]
417 pub fn router(mut self, router: Arc<Router>) -> Self {
418 self.router = Some(router);
419 self
420 }
421
422 #[must_use]
423 pub fn middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
424 self.middlewares.push(Box::new(middleware));
425 self
426 }
427
428 pub fn build(self) -> FederationGateway {
429 let catalog = self
430 .catalog
431 .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
432 let health = self
433 .health
434 .unwrap_or_else(|| Arc::new(HealthChecker::default()));
435 let circuit_breaker = self
436 .circuit_breaker
437 .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
438
439 let router = self.router.unwrap_or_else(|| {
440 Arc::new(Router::new(
441 super::routing::RouterConfig::default(),
442 Arc::clone(&catalog),
443 Arc::clone(&health),
444 Arc::clone(&circuit_breaker),
445 ))
446 });
447
448 let mut gateway = FederationGateway::new(self.config, router, health, circuit_breaker);
449
450 for middleware in self.middlewares {
451 gateway.middlewares.push(middleware);
452 }
453
454 gateway
455 }
456}
457
458impl Default for GatewayBuilder {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464pub struct LoggingMiddleware {
470 prefix: String,
471}
472
473impl LoggingMiddleware {
474 pub fn new(prefix: impl Into<String>) -> Self {
475 Self {
476 prefix: prefix.into(),
477 }
478 }
479}
480
481impl GatewayMiddleware for LoggingMiddleware {
482 fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()> {
483 eprintln!(
484 "[{}] Routing request {} for {:?}",
485 self.prefix, request.request_id, request.capability
486 );
487 Ok(())
488 }
489
490 fn after_infer(
491 &self,
492 request: &InferenceRequest,
493 response: &mut InferenceResponse,
494 ) -> FederationResult<()> {
495 eprintln!(
496 "[{}] Request {} served by {:?} in {:?}",
497 self.prefix, request.request_id, response.served_by, response.latency
498 );
499 Ok(())
500 }
501
502 fn on_error(&self, request: &InferenceRequest, error: &FederationError) {
503 eprintln!(
504 "[{}] Request {} failed: {}",
505 self.prefix, request.request_id, error
506 );
507 }
508}
509
510pub struct RateLimitMiddleware {
512 #[allow(dead_code)]
513 requests_per_second: u32,
514 }
516
517impl RateLimitMiddleware {
518 pub fn new(requests_per_second: u32) -> Self {
519 Self {
520 requests_per_second,
521 }
522 }
523}
524
525impl GatewayMiddleware for RateLimitMiddleware {
526 fn before_route(&self, _request: &mut InferenceRequest) -> FederationResult<()> {
527 Ok(())
530 }
531
532 fn after_infer(
533 &self,
534 _request: &InferenceRequest,
535 _response: &mut InferenceResponse,
536 ) -> FederationResult<()> {
537 Ok(())
538 }
539
540 fn on_error(&self, _request: &InferenceRequest, _error: &FederationError) {}
541}
542
543#[cfg(test)]
548mod tests {
549 use super::*;
550
551 fn setup_test_gateway() -> (FederationGateway, Arc<ModelCatalog>, Arc<HealthChecker>) {
552 let catalog = Arc::new(ModelCatalog::new());
553 let health = Arc::new(HealthChecker::default());
554 let circuit_breaker = Arc::new(CircuitBreaker::default());
555
556 let router = Arc::new(Router::new(
557 super::super::routing::RouterConfig::default(),
558 Arc::clone(&catalog),
559 Arc::clone(&health),
560 Arc::clone(&circuit_breaker),
561 ));
562
563 let gateway = FederationGateway::new(
564 GatewayConfig::default(),
565 router,
566 Arc::clone(&health),
567 circuit_breaker,
568 );
569
570 (gateway, catalog, health)
571 }
572
573 #[tokio::test]
574 async fn test_infer_no_nodes() {
575 let (gateway, _, _) = setup_test_gateway();
576
577 let request = InferenceRequest {
578 capability: Capability::Generate,
579 input: b"hello".to_vec(),
580 qos: QoSRequirements::default(),
581 request_id: "test-1".to_string(),
582 tenant_id: None,
583 };
584
585 let result = gateway.infer(request).await;
586 assert!(result.is_err());
587 }
588
589 #[tokio::test]
590 async fn test_infer_with_node() {
591 let (gateway, catalog, health) = setup_test_gateway();
592
593 catalog
595 .register(
596 ModelId("test-model".to_string()),
597 NodeId("node-1".to_string()),
598 RegionId("us-west".to_string()),
599 vec![Capability::Generate],
600 )
601 .await
602 .expect("registration failed");
603
604 health.register_node(NodeId("node-1".to_string()));
605 health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
606
607 let request = InferenceRequest {
608 capability: Capability::Generate,
609 input: b"hello".to_vec(),
610 qos: QoSRequirements::default(),
611 request_id: "test-2".to_string(),
612 tenant_id: None,
613 };
614
615 let result = gateway.infer(request).await;
616 assert!(result.is_ok());
617
618 let response = result.expect("inference failed");
619 assert_eq!(response.served_by, NodeId("node-1".to_string()));
620 }
621
622 #[tokio::test]
623 async fn test_stats_tracking() {
624 let (gateway, catalog, health) = setup_test_gateway();
625
626 catalog
627 .register(
628 ModelId("test".to_string()),
629 NodeId("node-1".to_string()),
630 RegionId("us-west".to_string()),
631 vec![Capability::Embed],
632 )
633 .await
634 .expect("registration failed");
635
636 health.register_node(NodeId("node-1".to_string()));
637 health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
638
639 for i in 0..3 {
641 let request = InferenceRequest {
642 capability: Capability::Embed,
643 input: vec![i],
644 qos: QoSRequirements::default(),
645 request_id: format!("test-{}", i),
646 tenant_id: None,
647 };
648
649 let _ = gateway.infer(request).await;
650 }
651
652 let stats = gateway.stats();
653 assert_eq!(stats.total_requests, 3);
654 assert_eq!(stats.successful_requests, 3);
655 assert_eq!(stats.failed_requests, 0);
656 }
657
658 #[tokio::test]
659 async fn test_streaming() {
660 let (gateway, catalog, health) = setup_test_gateway();
661
662 catalog
663 .register(
664 ModelId("stream-model".to_string()),
665 NodeId("node-1".to_string()),
666 RegionId("us-west".to_string()),
667 vec![Capability::Generate],
668 )
669 .await
670 .expect("registration failed");
671
672 health.register_node(NodeId("node-1".to_string()));
673 health.report_success(&NodeId("node-1".to_string()), Duration::from_millis(10));
674
675 let request = InferenceRequest {
676 capability: Capability::Generate,
677 input: b"stream test".to_vec(),
678 qos: QoSRequirements::default(),
679 request_id: "stream-1".to_string(),
680 tenant_id: None,
681 };
682
683 let result = gateway.infer_stream(request).await;
684 assert!(result.is_ok());
685
686 let mut stream = result.expect("stream creation failed");
687
688 let mut token_count = 0;
690 while let Some(result) = stream.next_token().await {
691 assert!(result.is_ok());
692 token_count += 1;
693 }
694
695 assert_eq!(token_count, 10);
696 }
697
698 #[test]
699 fn test_gateway_builder() {
700 let gateway = GatewayBuilder::new()
701 .config(GatewayConfig {
702 max_retries: 5,
703 inference_timeout: Duration::from_secs(60),
704 enable_tracing: false,
705 })
706 .middleware(LoggingMiddleware::new("test"))
707 .build();
708
709 assert_eq!(gateway.config.max_retries, 5);
710 assert_eq!(gateway.middlewares.len(), 1);
711 }
712
713 #[test]
718 fn test_gateway_config_default() {
719 let config = GatewayConfig::default();
720 assert_eq!(config.max_retries, 3);
721 assert_eq!(config.inference_timeout, Duration::from_secs(30));
722 assert!(config.enable_tracing);
723 }
724
725 #[test]
726 fn test_gateway_config_clone() {
727 let config = GatewayConfig {
728 max_retries: 5,
729 inference_timeout: Duration::from_secs(60),
730 enable_tracing: false,
731 };
732 let cloned = config.clone();
733 assert_eq!(cloned.max_retries, 5);
734 assert!(!cloned.enable_tracing);
735 }
736
737 #[test]
742 fn test_gateway_builder_default() {
743 let builder = GatewayBuilder::default();
744 assert!(builder.catalog.is_none());
745 assert!(builder.health.is_none());
746 assert!(builder.circuit_breaker.is_none());
747 assert!(builder.router.is_none());
748 assert!(builder.middlewares.is_empty());
749 }
750
751 #[test]
752 fn test_gateway_builder_with_catalog() {
753 let catalog = Arc::new(ModelCatalog::new());
754 let builder = GatewayBuilder::new().catalog(catalog);
755 assert!(builder.catalog.is_some());
756 }
757
758 #[test]
759 fn test_gateway_builder_with_health() {
760 let health = Arc::new(HealthChecker::default());
761 let builder = GatewayBuilder::new().health(health);
762 assert!(builder.health.is_some());
763 }
764
765 #[test]
766 fn test_gateway_builder_with_circuit_breaker() {
767 let cb = Arc::new(CircuitBreaker::default());
768 let builder = GatewayBuilder::new().circuit_breaker(cb);
769 assert!(builder.circuit_breaker.is_some());
770 }
771
772 #[test]
773 fn test_gateway_builder_with_router() {
774 let catalog = Arc::new(ModelCatalog::new());
775 let health = Arc::new(HealthChecker::default());
776 let cb = Arc::new(CircuitBreaker::default());
777 let router = Arc::new(Router::new(
778 super::super::routing::RouterConfig::default(),
779 catalog,
780 health,
781 cb,
782 ));
783 let builder = GatewayBuilder::new().router(router);
784 assert!(builder.router.is_some());
785 }
786
787 #[test]
788 fn test_gateway_builder_with_middleware() {
789 let builder = GatewayBuilder::new()
790 .middleware(LoggingMiddleware::new("test"))
791 .middleware(RateLimitMiddleware::new(100));
792 assert_eq!(builder.middlewares.len(), 2);
793 }
794
795 #[test]
796 fn test_gateway_builder_full_chain() {
797 let catalog = Arc::new(ModelCatalog::new());
798 let health = Arc::new(HealthChecker::default());
799 let cb = Arc::new(CircuitBreaker::default());
800
801 let gateway = GatewayBuilder::new()
802 .config(GatewayConfig {
803 max_retries: 5,
804 inference_timeout: Duration::from_secs(120),
805 enable_tracing: false,
806 })
807 .catalog(Arc::clone(&catalog))
808 .health(Arc::clone(&health))
809 .circuit_breaker(Arc::clone(&cb))
810 .middleware(LoggingMiddleware::new("gw"))
811 .build();
812
813 assert_eq!(gateway.config.max_retries, 5);
814 assert_eq!(gateway.middlewares.len(), 1);
815 }
816
817 #[test]
822 fn test_logging_middleware_creation() {
823 let middleware = LoggingMiddleware::new("test-prefix");
824 assert_eq!(middleware.prefix, "test-prefix");
825 }
826
827 #[test]
828 fn test_logging_middleware_before_route() {
829 let middleware = LoggingMiddleware::new("test");
830 let mut request = InferenceRequest {
831 capability: Capability::Generate,
832 input: vec![],
833 qos: QoSRequirements::default(),
834 request_id: "req-1".to_string(),
835 tenant_id: None,
836 };
837 let result = middleware.before_route(&mut request);
838 assert!(result.is_ok());
839 }
840
841 #[test]
842 fn test_logging_middleware_after_infer() {
843 let middleware = LoggingMiddleware::new("test");
844 let request = InferenceRequest {
845 capability: Capability::Generate,
846 input: vec![],
847 qos: QoSRequirements::default(),
848 request_id: "req-1".to_string(),
849 tenant_id: None,
850 };
851 let mut response = InferenceResponse {
852 output: b"output".to_vec(),
853 served_by: NodeId("n1".to_string()),
854 latency: Duration::from_millis(50),
855 tokens: Some(5),
856 };
857 let result = middleware.after_infer(&request, &mut response);
858 assert!(result.is_ok());
859 }
860
861 #[test]
862 fn test_logging_middleware_on_error() {
863 let middleware = LoggingMiddleware::new("test");
864 let request = InferenceRequest {
865 capability: Capability::Generate,
866 input: vec![],
867 qos: QoSRequirements::default(),
868 request_id: "req-1".to_string(),
869 tenant_id: None,
870 };
871 let error = FederationError::Internal("test error".to_string());
872 middleware.on_error(&request, &error);
874 }
875
876 #[test]
881 fn test_rate_limit_middleware_creation() {
882 let _middleware = RateLimitMiddleware::new(1000);
883 }
884
885 #[test]
886 fn test_rate_limit_middleware_before_route() {
887 let middleware = RateLimitMiddleware::new(100);
888 let mut request = InferenceRequest {
889 capability: Capability::Embed,
890 input: vec![],
891 qos: QoSRequirements::default(),
892 request_id: "req-1".to_string(),
893 tenant_id: None,
894 };
895 assert!(middleware.before_route(&mut request).is_ok());
896 }
897
898 #[test]
899 fn test_rate_limit_middleware_after_infer() {
900 let middleware = RateLimitMiddleware::new(100);
901 let request = InferenceRequest {
902 capability: Capability::Embed,
903 input: vec![],
904 qos: QoSRequirements::default(),
905 request_id: "req-1".to_string(),
906 tenant_id: None,
907 };
908 let mut response = InferenceResponse {
909 output: vec![],
910 served_by: NodeId("n1".to_string()),
911 latency: Duration::from_millis(10),
912 tokens: None,
913 };
914 assert!(middleware.after_infer(&request, &mut response).is_ok());
915 }
916
917 #[test]
918 fn test_rate_limit_middleware_on_error() {
919 let middleware = RateLimitMiddleware::new(100);
920 let request = InferenceRequest {
921 capability: Capability::Embed,
922 input: vec![],
923 qos: QoSRequirements::default(),
924 request_id: "req-1".to_string(),
925 tenant_id: None,
926 };
927 let error = FederationError::Internal("err".to_string());
928 middleware.on_error(&request, &error); }
930
931 #[tokio::test]
936 async fn test_gateway_with_logging_middleware() {
937 let catalog = Arc::new(ModelCatalog::new());
938 let health = Arc::new(HealthChecker::default());
939 let circuit_breaker = Arc::new(CircuitBreaker::default());
940
941 let router = Arc::new(Router::new(
942 super::super::routing::RouterConfig::default(),
943 Arc::clone(&catalog),
944 Arc::clone(&health),
945 Arc::clone(&circuit_breaker),
946 ));
947
948 let gateway = FederationGateway::new(
949 GatewayConfig::default(),
950 router,
951 Arc::clone(&health),
952 circuit_breaker,
953 )
954 .with_middleware(LoggingMiddleware::new("test-gw"));
955
956 assert_eq!(gateway.middlewares.len(), 1);
957
958 catalog
960 .register(
961 ModelId("m1".to_string()),
962 NodeId("n1".to_string()),
963 RegionId("us-west".to_string()),
964 vec![Capability::Generate],
965 )
966 .await
967 .expect("registration failed");
968
969 health.register_node(NodeId("n1".to_string()));
970 for _ in 0..3 {
971 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(10));
972 }
973
974 let request = InferenceRequest {
975 capability: Capability::Generate,
976 input: b"test".to_vec(),
977 qos: QoSRequirements::default(),
978 request_id: "mw-test".to_string(),
979 tenant_id: None,
980 };
981
982 let result = gateway.infer(request).await;
983 assert!(result.is_ok());
984 }
985
986 #[test]
991 fn test_gateway_initial_stats() {
992 let gateway = GatewayBuilder::new().build();
993 let stats = gateway.stats();
994 assert_eq!(stats.total_requests, 0);
995 assert_eq!(stats.successful_requests, 0);
996 assert_eq!(stats.failed_requests, 0);
997 assert_eq!(stats.total_tokens, 0);
998 assert_eq!(stats.active_streams, 0);
999 assert_eq!(stats.avg_latency, Duration::ZERO);
1000 }
1001
1002 #[tokio::test]
1003 async fn test_gateway_stats_after_failures() {
1004 let gateway = GatewayBuilder::new()
1005 .config(GatewayConfig {
1006 max_retries: 0, ..Default::default()
1008 })
1009 .build();
1010
1011 let request = InferenceRequest {
1013 capability: Capability::Generate,
1014 input: b"test".to_vec(),
1015 qos: QoSRequirements::default(),
1016 request_id: "fail-test".to_string(),
1017 tenant_id: None,
1018 };
1019
1020 let _ = gateway.infer(request).await;
1021
1022 let stats = gateway.stats();
1023 assert_eq!(stats.total_requests, 1);
1024 assert_eq!(stats.failed_requests, 1);
1025 assert_eq!(stats.successful_requests, 0);
1026 }
1027
1028 #[tokio::test]
1033 async fn test_stream_cancel() {
1034 let (gateway, catalog, health) = setup_test_gateway();
1035
1036 catalog
1037 .register(
1038 ModelId("stream-model".to_string()),
1039 NodeId("n1".to_string()),
1040 RegionId("us-west".to_string()),
1041 vec![Capability::Generate],
1042 )
1043 .await
1044 .expect("registration failed");
1045
1046 health.register_node(NodeId("n1".to_string()));
1047 health.report_success(&NodeId("n1".to_string()), Duration::from_millis(10));
1048
1049 let request = InferenceRequest {
1050 capability: Capability::Generate,
1051 input: b"stream".to_vec(),
1052 qos: QoSRequirements::default(),
1053 request_id: "cancel-test".to_string(),
1054 tenant_id: None,
1055 };
1056
1057 let mut stream = gateway.infer_stream(request).await.expect("stream failed");
1058
1059 let _ = stream.next_token().await;
1061 let _ = stream.next_token().await;
1062
1063 stream.cancel().await;
1065
1066 let result = stream.next_token().await;
1068 assert!(result.is_none());
1069 }
1070
1071 #[tokio::test]
1073 async fn test_full_federation_flow() {
1074 use super::super::policy::CompositePolicy;
1075
1076 let catalog = Arc::new(ModelCatalog::new());
1080 let health = Arc::new(HealthChecker::default());
1081 let circuit_breaker = Arc::new(CircuitBreaker::default());
1082
1083 catalog
1085 .register(
1086 ModelId("whisper-v3".to_string()),
1087 NodeId("us-west-gpu".to_string()),
1088 RegionId("us-west".to_string()),
1089 vec![Capability::Transcribe],
1090 )
1091 .await
1092 .expect("failed to register us-west");
1093
1094 catalog
1096 .register(
1097 ModelId("whisper-v3".to_string()),
1098 NodeId("eu-west-gpu".to_string()),
1099 RegionId("eu-west".to_string()),
1100 vec![Capability::Transcribe],
1101 )
1102 .await
1103 .expect("failed to register eu-west");
1104
1105 catalog
1107 .register(
1108 ModelId("llama-70b".to_string()),
1109 NodeId("us-east-gpu".to_string()),
1110 RegionId("us-east".to_string()),
1111 vec![Capability::Generate, Capability::Code],
1112 )
1113 .await
1114 .expect("failed to register llama");
1115
1116 for (node, region) in [("embed-us", "us-west"), ("embed-eu", "eu-west")] {
1118 catalog
1119 .register(
1120 ModelId("bge-large".to_string()),
1121 NodeId(node.to_string()),
1122 RegionId(region.to_string()),
1123 vec![Capability::Embed],
1124 )
1125 .await
1126 .expect("failed to register embedding");
1127 }
1128
1129 health.register_node(NodeId("us-west-gpu".to_string()));
1135 for _ in 0..3 {
1136 health.report_success(
1137 &NodeId("us-west-gpu".to_string()),
1138 Duration::from_millis(45),
1139 );
1140 }
1141
1142 health.register_node(NodeId("eu-west-gpu".to_string()));
1144 for _ in 0..3 {
1145 health.report_success(
1146 &NodeId("eu-west-gpu".to_string()),
1147 Duration::from_millis(120),
1148 );
1149 }
1150
1151 health.register_node(NodeId("us-east-gpu".to_string()));
1153 health.report_success(
1155 &NodeId("us-east-gpu".to_string()),
1156 Duration::from_millis(100),
1157 );
1158
1159 for node in ["embed-us", "embed-eu"] {
1161 health.register_node(NodeId(node.to_string()));
1162 for _ in 0..3 {
1163 health.report_success(&NodeId(node.to_string()), Duration::from_millis(15));
1164 }
1165 }
1166
1167 let router = Arc::new(
1171 Router::new(
1172 super::super::routing::RouterConfig {
1173 max_candidates: 10,
1174 min_score: 0.1,
1175 strategy: LoadBalanceStrategy::LeastLatency,
1176 },
1177 Arc::clone(&catalog),
1178 Arc::clone(&health),
1179 Arc::clone(&circuit_breaker),
1180 )
1181 .with_policy(CompositePolicy::enterprise_default()),
1182 );
1183
1184 let gateway = FederationGateway::new(
1188 GatewayConfig {
1189 max_retries: 3,
1190 inference_timeout: Duration::from_secs(30),
1191 enable_tracing: true,
1192 },
1193 Arc::clone(&router),
1194 Arc::clone(&health),
1195 Arc::clone(&circuit_breaker),
1196 );
1197
1198 let request = InferenceRequest {
1202 capability: Capability::Transcribe,
1203 input: b"audio data".to_vec(),
1204 qos: QoSRequirements::default(),
1205 request_id: "test-transcribe".to_string(),
1206 tenant_id: Some("acme".to_string()),
1207 };
1208
1209 let candidates = router
1210 .get_candidates(&request)
1211 .await
1212 .expect("get_candidates failed");
1213 assert_eq!(candidates.len(), 2, "Should have 2 Transcribe candidates");
1214
1215 let target = router.route(&request).await.expect("route failed");
1216 assert_eq!(target.node_id, NodeId("us-west-gpu".to_string()));
1218
1219 let request = InferenceRequest {
1223 capability: Capability::Generate,
1224 input: b"prompt".to_vec(),
1225 qos: QoSRequirements::default(),
1226 request_id: "test-generate".to_string(),
1227 tenant_id: None,
1228 };
1229
1230 let target = router.route(&request).await.expect("route failed");
1231 assert_eq!(target.node_id, NodeId("us-east-gpu".to_string()));
1232
1233 let request = InferenceRequest {
1237 capability: Capability::Embed,
1238 input: b"text".to_vec(),
1239 qos: QoSRequirements::default(),
1240 request_id: "test-embed".to_string(),
1241 tenant_id: None,
1242 };
1243
1244 let candidates = router
1245 .get_candidates(&request)
1246 .await
1247 .expect("get_candidates failed");
1248 assert_eq!(candidates.len(), 2, "Should have 2 Embed candidates");
1249
1250 let request = InferenceRequest {
1254 capability: Capability::Transcribe,
1255 input: b"audio".to_vec(),
1256 qos: QoSRequirements::default(),
1257 request_id: "test-infer".to_string(),
1258 tenant_id: None,
1259 };
1260
1261 let response = gateway.infer(request).await.expect("inference failed");
1262 assert_eq!(response.served_by, NodeId("us-west-gpu".to_string()));
1263 assert!(!response.output.is_empty());
1264
1265 let stats = gateway.stats();
1266 assert_eq!(stats.total_requests, 1);
1267 assert_eq!(stats.successful_requests, 1);
1268 assert_eq!(stats.failed_requests, 0);
1269
1270 let request = InferenceRequest {
1274 capability: Capability::Generate,
1275 input: b"stream prompt".to_vec(),
1276 qos: QoSRequirements::default(),
1277 request_id: "test-stream".to_string(),
1278 tenant_id: None,
1279 };
1280
1281 let mut stream = gateway.infer_stream(request).await.expect("stream failed");
1282 let mut tokens = 0;
1283 while let Some(result) = stream.next_token().await {
1284 result.expect("token error");
1285 tokens += 1;
1286 }
1287 assert_eq!(tokens, 10, "Should receive 10 tokens");
1288
1289 let bad_node = NodeId("failing-node".to_string());
1293
1294 assert_eq!(circuit_breaker.state(&bad_node), CircuitState::Closed);
1296
1297 for _ in 0..5 {
1299 circuit_breaker.record_failure(&bad_node);
1300 }
1301 assert_eq!(circuit_breaker.state(&bad_node), CircuitState::Open);
1302 assert!(circuit_breaker.is_open(&bad_node));
1303
1304 let all_models = catalog.list_all().await.expect("list failed");
1308 assert_eq!(all_models.len(), 3); let nodes_with_health = [
1314 "us-west-gpu",
1315 "eu-west-gpu",
1316 "us-east-gpu",
1317 "embed-us",
1318 "embed-eu",
1319 ];
1320 for node in nodes_with_health {
1321 let h = health.get_cached_health(&NodeId(node.to_string()));
1322 assert!(h.is_some(), "Health should be tracked for {}", node);
1323 }
1324
1325 let us_west_health = health
1327 .get_cached_health(&NodeId("us-west-gpu".to_string()))
1328 .unwrap();
1329 assert_eq!(
1330 us_west_health.status,
1331 HealthState::Healthy,
1332 "US-West should be healthy"
1333 );
1334
1335 println!("\n✅ Full Federation Flow Test PASSED!");
1339 println!(" - 3 models registered across 5 nodes");
1340 println!(" - 6 health entries tracked");
1341 println!(" - Routing correctly prefers fastest healthy nodes");
1342 println!(" - Gateway inference succeeds with stats tracking");
1343 println!(" - Streaming returns expected token count");
1344 println!(" - Circuit breaker opens after failures");
1345 }
1346}