1use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16pub struct ModelId(pub String);
17
18impl std::fmt::Display for ModelId {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{}", self.0)
21 }
22}
23
24#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26pub struct RegionId(pub String);
27
28impl std::fmt::Display for RegionId {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "{}", self.0)
31 }
32}
33
34#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct NodeId(pub String);
37
38impl std::fmt::Display for NodeId {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{}", self.0)
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum Capability {
47 Transcribe,
49 Synthesize,
51 Generate,
53 Code,
55 Embed,
57 ImageGen,
59 Custom(String),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum PrivacyLevel {
66 Public = 0,
68 Internal = 1,
70 Confidential = 2,
72 Restricted = 3,
74}
75
76#[derive(Debug, Clone)]
78pub struct QoSRequirements {
79 pub max_latency: Option<Duration>,
81 pub min_throughput: Option<u32>,
83 pub privacy: PrivacyLevel,
85 pub prefer_gpu: bool,
87 pub cost_tolerance: u8,
89}
90
91impl Default for QoSRequirements {
92 fn default() -> Self {
93 Self {
94 max_latency: None,
95 min_throughput: None,
96 privacy: PrivacyLevel::Internal,
97 prefer_gpu: true,
98 cost_tolerance: 50,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct InferenceRequest {
106 pub capability: Capability,
108 pub input: Vec<u8>,
110 pub qos: QoSRequirements,
112 pub request_id: String,
114 pub tenant_id: Option<String>,
116}
117
118#[derive(Debug)]
120pub struct InferenceResponse {
121 pub output: Vec<u8>,
123 pub served_by: NodeId,
125 pub latency: Duration,
127 pub tokens: Option<u32>,
129}
130
131#[derive(Debug, thiserror::Error)]
133pub enum FederationError {
134 #[error("No nodes available for capability: {0:?}")]
135 NoCapacity(Capability),
136
137 #[error("All nodes unhealthy for capability: {0:?}")]
138 AllNodesUnhealthy(Capability),
139
140 #[error("QoS requirements cannot be met: {0}")]
141 QoSViolation(String),
142
143 #[error("Privacy policy violation: {0}")]
144 PrivacyViolation(String),
145
146 #[error("Node unreachable: {0}")]
147 NodeUnreachable(NodeId),
148
149 #[error("Timeout after {0:?}")]
150 Timeout(Duration),
151
152 #[error("Circuit breaker open for node: {0}")]
153 CircuitOpen(NodeId),
154
155 #[error("Internal error: {0}")]
156 Internal(String),
157}
158
159pub type FederationResult<T> = Result<T, FederationError>;
165
166pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
168
169pub trait ModelCatalogTrait: Send + Sync {
171 fn register(
173 &self,
174 model_id: ModelId,
175 node_id: NodeId,
176 region_id: RegionId,
177 capabilities: Vec<Capability>,
178 ) -> BoxFuture<'_, FederationResult<()>>;
179
180 fn deregister(&self, model_id: ModelId, node_id: NodeId)
182 -> BoxFuture<'_, FederationResult<()>>;
183
184 fn find_by_capability(
186 &self,
187 capability: &Capability,
188 ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>>;
189
190 fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>>;
192
193 fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>>;
195}
196
197#[derive(Debug, Clone)]
199pub struct ModelMetadata {
200 pub model_id: ModelId,
201 pub name: String,
202 pub version: String,
203 pub capabilities: Vec<Capability>,
204 pub parameters: u64,
205 pub quantization: Option<String>,
206}
207
208pub trait HealthCheckerTrait: Send + Sync {
210 fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>>;
212
213 fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth>;
215
216 fn start_monitoring(&self, interval: Duration) -> BoxFuture<'_, ()>;
218
219 fn stop_monitoring(&self) -> BoxFuture<'_, ()>;
221}
222
223#[derive(Debug, Clone)]
225pub struct NodeHealth {
226 pub node_id: NodeId,
227 pub status: HealthState,
228 pub latency_p50: Duration,
229 pub latency_p99: Duration,
230 pub throughput: u32,
231 pub gpu_utilization: Option<f32>,
232 pub queue_depth: u32,
233 pub last_check: std::time::Instant,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum HealthState {
239 Healthy,
241 Degraded,
243 Unhealthy,
245 Unknown,
247}
248
249pub trait RouterTrait: Send + Sync {
251 fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>>;
253
254 fn get_candidates(
256 &self,
257 request: &InferenceRequest,
258 ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>>;
259}
260
261#[derive(Debug, Clone)]
263pub struct RouteTarget {
264 pub node_id: NodeId,
265 pub region_id: RegionId,
266 pub endpoint: String,
267 pub estimated_latency: Duration,
268 pub score: f64,
269}
270
271#[derive(Debug, Clone)]
273pub struct RouteCandidate {
274 pub target: RouteTarget,
275 pub scores: RouteScores,
276 pub eligible: bool,
277 pub rejection_reason: Option<String>,
278}
279
280#[derive(Debug, Clone)]
282pub struct RouteScores {
283 pub latency_score: f64,
284 pub throughput_score: f64,
285 pub cost_score: f64,
286 pub locality_score: f64,
287 pub health_score: f64,
288 pub total: f64,
289}
290
291pub trait GatewayTrait: Send + Sync {
293 fn infer(
295 &self,
296 request: InferenceRequest,
297 ) -> BoxFuture<'_, FederationResult<InferenceResponse>>;
298
299 fn infer_stream(
301 &self,
302 request: InferenceRequest,
303 ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>>;
304
305 fn stats(&self) -> GatewayStats;
307}
308
309pub trait TokenStream: Send {
311 fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>>;
313
314 fn cancel(&mut self) -> BoxFuture<'_, ()>;
316}
317
318#[derive(Debug, Clone, Default)]
320pub struct GatewayStats {
321 pub total_requests: u64,
322 pub successful_requests: u64,
323 pub failed_requests: u64,
324 pub total_tokens: u64,
325 pub avg_latency: Duration,
326 pub active_streams: u32,
327}
328
329pub trait GatewayMiddleware: Send + Sync {
335 fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()>;
337
338 fn after_infer(
340 &self,
341 request: &InferenceRequest,
342 response: &mut InferenceResponse,
343 ) -> FederationResult<()>;
344
345 fn on_error(&self, request: &InferenceRequest, error: &FederationError);
347}
348
349pub trait CircuitBreakerTrait: Send + Sync {
351 fn is_open(&self, node_id: &NodeId) -> bool;
353
354 fn record_success(&self, node_id: &NodeId);
356
357 fn record_failure(&self, node_id: &NodeId);
359
360 fn state(&self, node_id: &NodeId) -> CircuitState;
362}
363
364#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum CircuitState {
367 Closed,
369 HalfOpen,
371 Open,
373}
374
375pub trait RoutingPolicyTrait: Send + Sync {
381 fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64;
383
384 fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool;
386
387 fn name(&self) -> &'static str;
389}
390
391#[derive(Debug, Clone, Copy, Default)]
393pub enum LoadBalanceStrategy {
394 RoundRobin,
396 LeastConnections,
398 #[default]
400 LeastLatency,
401 WeightedRandom,
403 ConsistentHash,
405}
406
407#[derive(Default)]
413pub struct FederationBuilder {
414 pub catalog: Option<Box<dyn ModelCatalogTrait>>,
415 pub health_checker: Option<Box<dyn HealthCheckerTrait>>,
416 pub router: Option<Box<dyn RouterTrait>>,
417 pub policies: Vec<Box<dyn RoutingPolicyTrait>>,
418 pub middlewares: Vec<Box<dyn GatewayMiddleware>>,
419 pub load_balance: LoadBalanceStrategy,
420}
421
422impl FederationBuilder {
423 pub fn new() -> Self {
424 Self {
425 load_balance: LoadBalanceStrategy::LeastLatency,
426 ..Default::default()
427 }
428 }
429
430 #[must_use]
431 pub fn with_catalog(mut self, catalog: impl ModelCatalogTrait + 'static) -> Self {
432 self.catalog = Some(Box::new(catalog));
433 self
434 }
435
436 #[must_use]
437 pub fn with_health_checker(mut self, checker: impl HealthCheckerTrait + 'static) -> Self {
438 self.health_checker = Some(Box::new(checker));
439 self
440 }
441
442 #[must_use]
443 pub fn with_router(mut self, router: impl RouterTrait + 'static) -> Self {
444 self.router = Some(Box::new(router));
445 self
446 }
447
448 #[must_use]
449 pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
450 self.policies.push(Box::new(policy));
451 self
452 }
453
454 #[must_use]
455 pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
456 self.middlewares.push(Box::new(middleware));
457 self
458 }
459
460 #[must_use]
461 pub fn with_load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
462 self.load_balance = strategy;
463 self
464 }
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_qos_default() {
477 let qos = QoSRequirements::default();
478 assert_eq!(qos.privacy, PrivacyLevel::Internal);
479 assert!(qos.prefer_gpu);
480 assert_eq!(qos.cost_tolerance, 50);
481 }
482
483 #[test]
484 fn test_privacy_ordering() {
485 assert!(PrivacyLevel::Public < PrivacyLevel::Internal);
486 assert!(PrivacyLevel::Internal < PrivacyLevel::Confidential);
487 assert!(PrivacyLevel::Confidential < PrivacyLevel::Restricted);
488 }
489
490 #[test]
491 fn test_health_state() {
492 let healthy = HealthState::Healthy;
493 let degraded = HealthState::Degraded;
494 assert_ne!(healthy, degraded);
495 }
496
497 #[test]
498 fn test_circuit_state() {
499 let closed = CircuitState::Closed;
500 let open = CircuitState::Open;
501 assert_ne!(closed, open);
502 }
503
504 #[test]
505 fn test_federation_builder() {
506 let builder =
507 FederationBuilder::new().with_load_balance(LoadBalanceStrategy::LeastConnections);
508
509 assert!(matches!(
510 builder.load_balance,
511 LoadBalanceStrategy::LeastConnections
512 ));
513 }
514
515 #[test]
516 fn test_model_id_equality() {
517 let id1 = ModelId("whisper-v3".to_string());
518 let id2 = ModelId("whisper-v3".to_string());
519 let id3 = ModelId("llama-7b".to_string());
520
521 assert_eq!(id1, id2);
522 assert_ne!(id1, id3);
523 }
524
525 #[test]
526 fn test_capability_variants() {
527 let cap1 = Capability::Transcribe;
528 let cap2 = Capability::Custom("sentiment".to_string());
529
530 assert_ne!(cap1, cap2);
531 assert_eq!(cap1, Capability::Transcribe);
532 }
533
534 #[test]
539 fn test_model_id_display() {
540 let id = ModelId("whisper-v3".to_string());
541 assert_eq!(format!("{}", id), "whisper-v3");
542 assert_eq!(id.to_string(), "whisper-v3");
543 }
544
545 #[test]
546 fn test_model_id_display_empty() {
547 let id = ModelId(String::new());
548 assert_eq!(format!("{}", id), "");
549 }
550
551 #[test]
552 fn test_region_id_display() {
553 let id = RegionId("us-west-2".to_string());
554 assert_eq!(format!("{}", id), "us-west-2");
555 assert_eq!(id.to_string(), "us-west-2");
556 }
557
558 #[test]
559 fn test_region_id_display_empty() {
560 let id = RegionId(String::new());
561 assert_eq!(format!("{}", id), "");
562 }
563
564 #[test]
565 fn test_node_id_display() {
566 let id = NodeId("gpu-node-01".to_string());
567 assert_eq!(format!("{}", id), "gpu-node-01");
568 assert_eq!(id.to_string(), "gpu-node-01");
569 }
570
571 #[test]
572 fn test_node_id_display_empty() {
573 let id = NodeId(String::new());
574 assert_eq!(format!("{}", id), "");
575 }
576
577 #[test]
582 fn test_model_id_hash_consistency() {
583 use std::collections::HashSet;
584 let mut set = HashSet::new();
585 set.insert(ModelId("a".to_string()));
586 set.insert(ModelId("b".to_string()));
587 set.insert(ModelId("a".to_string())); assert_eq!(set.len(), 2);
589 }
590
591 #[test]
592 fn test_region_id_hash_consistency() {
593 use std::collections::HashSet;
594 let mut set = HashSet::new();
595 set.insert(RegionId("us-west".to_string()));
596 set.insert(RegionId("eu-west".to_string()));
597 set.insert(RegionId("us-west".to_string())); assert_eq!(set.len(), 2);
599 }
600
601 #[test]
602 fn test_node_id_hash_consistency() {
603 use std::collections::HashSet;
604 let mut set = HashSet::new();
605 set.insert(NodeId("node1".to_string()));
606 set.insert(NodeId("node2".to_string()));
607 set.insert(NodeId("node1".to_string())); assert_eq!(set.len(), 2);
609 }
610
611 #[test]
616 fn test_model_id_clone() {
617 let id = ModelId("test".to_string());
618 let cloned = id.clone();
619 assert_eq!(id, cloned);
620 }
621
622 #[test]
623 fn test_region_id_equality() {
624 let a = RegionId("us-west".to_string());
625 let b = RegionId("us-west".to_string());
626 let c = RegionId("eu-west".to_string());
627 assert_eq!(a, b);
628 assert_ne!(a, c);
629 }
630
631 #[test]
632 fn test_node_id_equality() {
633 let a = NodeId("node-1".to_string());
634 let b = NodeId("node-1".to_string());
635 let c = NodeId("node-2".to_string());
636 assert_eq!(a, b);
637 assert_ne!(a, c);
638 }
639
640 #[test]
645 fn test_all_capability_variants() {
646 let caps = vec![
647 Capability::Transcribe,
648 Capability::Synthesize,
649 Capability::Generate,
650 Capability::Code,
651 Capability::Embed,
652 Capability::ImageGen,
653 Capability::Custom("my_cap".to_string()),
654 ];
655
656 for (i, a) in caps.iter().enumerate() {
658 for (j, b) in caps.iter().enumerate() {
659 if i == j {
660 assert_eq!(a, b);
661 } else {
662 assert_ne!(a, b);
663 }
664 }
665 }
666 }
667
668 #[test]
669 fn test_capability_custom_equality() {
670 let a = Capability::Custom("sentiment".to_string());
671 let b = Capability::Custom("sentiment".to_string());
672 let c = Capability::Custom("other".to_string());
673 assert_eq!(a, b);
674 assert_ne!(a, c);
675 }
676
677 #[test]
678 fn test_capability_debug_format() {
679 let cap = Capability::Transcribe;
680 let debug = format!("{:?}", cap);
681 assert_eq!(debug, "Transcribe");
682
683 let custom = Capability::Custom("test".to_string());
684 let debug = format!("{:?}", custom);
685 assert!(debug.contains("Custom"));
686 assert!(debug.contains("test"));
687 }
688
689 #[test]
694 fn test_privacy_level_all_orderings() {
695 let levels = [
696 PrivacyLevel::Public,
697 PrivacyLevel::Internal,
698 PrivacyLevel::Confidential,
699 PrivacyLevel::Restricted,
700 ];
701
702 for i in 0..levels.len() - 1 {
704 assert!(levels[i] < levels[i + 1]);
705 assert!(levels[i + 1] > levels[i]);
706 }
707 }
708
709 #[test]
710 fn test_privacy_level_copy() {
711 let level = PrivacyLevel::Confidential;
712 let copied = level;
713 assert_eq!(level, copied);
714 }
715
716 #[test]
717 fn test_privacy_level_eq() {
718 assert_eq!(PrivacyLevel::Public, PrivacyLevel::Public);
719 assert_ne!(PrivacyLevel::Public, PrivacyLevel::Internal);
720 }
721
722 #[test]
727 fn test_qos_default_none_fields() {
728 let qos = QoSRequirements::default();
729 assert!(qos.max_latency.is_none());
730 assert!(qos.min_throughput.is_none());
731 }
732
733 #[test]
734 fn test_qos_custom_values() {
735 let qos = QoSRequirements {
736 max_latency: Some(Duration::from_secs(2)),
737 min_throughput: Some(100),
738 privacy: PrivacyLevel::Restricted,
739 prefer_gpu: false,
740 cost_tolerance: 10,
741 };
742 assert_eq!(qos.max_latency, Some(Duration::from_secs(2)));
743 assert_eq!(qos.min_throughput, Some(100));
744 assert_eq!(qos.privacy, PrivacyLevel::Restricted);
745 assert!(!qos.prefer_gpu);
746 assert_eq!(qos.cost_tolerance, 10);
747 }
748
749 #[test]
754 fn test_inference_request_construction() {
755 let req = InferenceRequest {
756 capability: Capability::Generate,
757 input: b"hello world".to_vec(),
758 qos: QoSRequirements::default(),
759 request_id: "req-123".to_string(),
760 tenant_id: Some("tenant-1".to_string()),
761 };
762 assert_eq!(req.request_id, "req-123");
763 assert_eq!(req.tenant_id, Some("tenant-1".to_string()));
764 assert_eq!(req.input, b"hello world");
765 }
766
767 #[test]
768 fn test_inference_request_no_tenant() {
769 let req = InferenceRequest {
770 capability: Capability::Embed,
771 input: vec![],
772 qos: QoSRequirements::default(),
773 request_id: "req-456".to_string(),
774 tenant_id: None,
775 };
776 assert!(req.tenant_id.is_none());
777 }
778
779 #[test]
780 fn test_inference_request_clone() {
781 let req = InferenceRequest {
782 capability: Capability::Code,
783 input: b"fn main()".to_vec(),
784 qos: QoSRequirements::default(),
785 request_id: "req-789".to_string(),
786 tenant_id: None,
787 };
788 let cloned = req.clone();
789 assert_eq!(cloned.request_id, "req-789");
790 assert_eq!(cloned.input, b"fn main()");
791 }
792
793 #[test]
794 fn test_inference_response_construction() {
795 let resp = InferenceResponse {
796 output: b"generated text".to_vec(),
797 served_by: NodeId("node-42".to_string()),
798 latency: Duration::from_millis(150),
799 tokens: Some(25),
800 };
801 assert_eq!(resp.output, b"generated text");
802 assert_eq!(resp.served_by, NodeId("node-42".to_string()));
803 assert_eq!(resp.latency, Duration::from_millis(150));
804 assert_eq!(resp.tokens, Some(25));
805 }
806
807 #[test]
808 fn test_inference_response_no_tokens() {
809 let resp = InferenceResponse {
810 output: vec![],
811 served_by: NodeId("node-1".to_string()),
812 latency: Duration::from_millis(10),
813 tokens: None,
814 };
815 assert!(resp.tokens.is_none());
816 }
817
818 #[test]
823 fn test_federation_error_no_capacity() {
824 let err = FederationError::NoCapacity(Capability::Transcribe);
825 let msg = format!("{}", err);
826 assert!(msg.contains("No nodes available"));
827 assert!(msg.contains("Transcribe"));
828 }
829
830 #[test]
831 fn test_federation_error_all_nodes_unhealthy() {
832 let err = FederationError::AllNodesUnhealthy(Capability::Generate);
833 let msg = format!("{}", err);
834 assert!(msg.contains("All nodes unhealthy"));
835 assert!(msg.contains("Generate"));
836 }
837
838 #[test]
839 fn test_federation_error_qos_violation() {
840 let err = FederationError::QoSViolation("latency too high".to_string());
841 let msg = format!("{}", err);
842 assert!(msg.contains("QoS requirements cannot be met"));
843 assert!(msg.contains("latency too high"));
844 }
845
846 #[test]
847 fn test_federation_error_privacy_violation() {
848 let err = FederationError::PrivacyViolation("data must stay in EU".to_string());
849 let msg = format!("{}", err);
850 assert!(msg.contains("Privacy policy violation"));
851 assert!(msg.contains("data must stay in EU"));
852 }
853
854 #[test]
855 fn test_federation_error_node_unreachable() {
856 let err = FederationError::NodeUnreachable(NodeId("dead-node".to_string()));
857 let msg = format!("{}", err);
858 assert!(msg.contains("Node unreachable"));
859 assert!(msg.contains("dead-node"));
860 }
861
862 #[test]
863 fn test_federation_error_timeout() {
864 let err = FederationError::Timeout(Duration::from_secs(30));
865 let msg = format!("{}", err);
866 assert!(msg.contains("Timeout"));
867 assert!(msg.contains("30"));
868 }
869
870 #[test]
871 fn test_federation_error_circuit_open() {
872 let err = FederationError::CircuitOpen(NodeId("overloaded".to_string()));
873 let msg = format!("{}", err);
874 assert!(msg.contains("Circuit breaker open"));
875 assert!(msg.contains("overloaded"));
876 }
877
878 #[test]
879 fn test_federation_error_internal() {
880 let err = FederationError::Internal("unexpected state".to_string());
881 let msg = format!("{}", err);
882 assert!(msg.contains("Internal error"));
883 assert!(msg.contains("unexpected state"));
884 }
885
886 #[test]
887 fn test_federation_error_debug() {
888 let err = FederationError::NoCapacity(Capability::Embed);
889 let debug = format!("{:?}", err);
890 assert!(debug.contains("NoCapacity"));
891 }
892
893 #[test]
898 fn test_health_state_all_variants() {
899 let states = [
900 HealthState::Healthy,
901 HealthState::Degraded,
902 HealthState::Unhealthy,
903 HealthState::Unknown,
904 ];
905 for (i, a) in states.iter().enumerate() {
907 for (j, b) in states.iter().enumerate() {
908 if i == j {
909 assert_eq!(a, b);
910 } else {
911 assert_ne!(a, b);
912 }
913 }
914 }
915 }
916
917 #[test]
918 fn test_health_state_copy() {
919 let state = HealthState::Healthy;
920 let copied = state;
921 assert_eq!(state, copied);
922 }
923
924 #[test]
929 fn test_circuit_state_all_variants() {
930 let states = [
931 CircuitState::Closed,
932 CircuitState::HalfOpen,
933 CircuitState::Open,
934 ];
935 for (i, a) in states.iter().enumerate() {
936 for (j, b) in states.iter().enumerate() {
937 if i == j {
938 assert_eq!(a, b);
939 } else {
940 assert_ne!(a, b);
941 }
942 }
943 }
944 }
945
946 #[test]
947 fn test_circuit_state_copy() {
948 let state = CircuitState::HalfOpen;
949 let copied = state;
950 assert_eq!(state, copied);
951 }
952
953 #[test]
958 fn test_load_balance_default() {
959 let strategy = LoadBalanceStrategy::default();
960 assert!(matches!(strategy, LoadBalanceStrategy::LeastLatency));
961 }
962
963 #[test]
964 fn test_load_balance_all_variants() {
965 let strategies = [
966 LoadBalanceStrategy::RoundRobin,
967 LoadBalanceStrategy::LeastConnections,
968 LoadBalanceStrategy::LeastLatency,
969 LoadBalanceStrategy::WeightedRandom,
970 LoadBalanceStrategy::ConsistentHash,
971 ];
972 for s in &strategies {
974 let debug = format!("{:?}", s);
975 assert!(!debug.is_empty());
976 }
977 }
978
979 #[test]
980 fn test_load_balance_clone() {
981 let strategy = LoadBalanceStrategy::WeightedRandom;
982 let cloned = strategy;
983 assert!(matches!(cloned, LoadBalanceStrategy::WeightedRandom));
984 }
985
986 #[test]
991 fn test_route_target_construction() {
992 let target = RouteTarget {
993 node_id: NodeId("n1".to_string()),
994 region_id: RegionId("r1".to_string()),
995 endpoint: "http://n1:8080".to_string(),
996 estimated_latency: Duration::from_millis(50),
997 score: 0.95,
998 };
999 assert_eq!(target.node_id, NodeId("n1".to_string()));
1000 assert_eq!(target.endpoint, "http://n1:8080");
1001 assert_eq!(target.estimated_latency, Duration::from_millis(50));
1002 }
1003
1004 #[test]
1005 fn test_route_target_clone() {
1006 let target = RouteTarget {
1007 node_id: NodeId("n1".to_string()),
1008 region_id: RegionId("r1".to_string()),
1009 endpoint: "http://n1:8080".to_string(),
1010 estimated_latency: Duration::from_millis(50),
1011 score: 0.5,
1012 };
1013 let cloned = target.clone();
1014 assert_eq!(cloned.node_id, NodeId("n1".to_string()));
1015 assert_eq!(cloned.score, 0.5);
1016 }
1017
1018 #[test]
1019 fn test_route_scores_construction() {
1020 let scores = RouteScores {
1021 latency_score: 0.9,
1022 throughput_score: 0.8,
1023 cost_score: 0.7,
1024 locality_score: 0.6,
1025 health_score: 1.0,
1026 total: 0.85,
1027 };
1028 assert_eq!(scores.latency_score, 0.9);
1029 assert_eq!(scores.total, 0.85);
1030 }
1031
1032 #[test]
1033 fn test_route_candidate_eligible() {
1034 let candidate = RouteCandidate {
1035 target: RouteTarget {
1036 node_id: NodeId("n1".to_string()),
1037 region_id: RegionId("r1".to_string()),
1038 endpoint: String::new(),
1039 estimated_latency: Duration::from_millis(100),
1040 score: 0.8,
1041 },
1042 scores: RouteScores {
1043 latency_score: 0.9,
1044 throughput_score: 0.8,
1045 cost_score: 0.5,
1046 locality_score: 0.7,
1047 health_score: 1.0,
1048 total: 0.8,
1049 },
1050 eligible: true,
1051 rejection_reason: None,
1052 };
1053 assert!(candidate.eligible);
1054 assert!(candidate.rejection_reason.is_none());
1055 }
1056
1057 #[test]
1058 fn test_route_candidate_rejected() {
1059 let candidate = RouteCandidate {
1060 target: RouteTarget {
1061 node_id: NodeId("n1".to_string()),
1062 region_id: RegionId("r1".to_string()),
1063 endpoint: String::new(),
1064 estimated_latency: Duration::from_millis(100),
1065 score: 0.0,
1066 },
1067 scores: RouteScores {
1068 latency_score: 0.0,
1069 throughput_score: 0.0,
1070 cost_score: 0.0,
1071 locality_score: 0.0,
1072 health_score: 0.0,
1073 total: 0.0,
1074 },
1075 eligible: false,
1076 rejection_reason: Some("Policy rejected".to_string()),
1077 };
1078 assert!(!candidate.eligible);
1079 assert_eq!(
1080 candidate.rejection_reason,
1081 Some("Policy rejected".to_string())
1082 );
1083 }
1084
1085 #[test]
1090 fn test_model_metadata_construction() {
1091 let meta = ModelMetadata {
1092 model_id: ModelId("llama-7b".to_string()),
1093 name: "LLaMA 7B".to_string(),
1094 version: "2.0".to_string(),
1095 capabilities: vec![Capability::Generate, Capability::Code],
1096 parameters: 7_000_000_000,
1097 quantization: Some("Q4_K".to_string()),
1098 };
1099 assert_eq!(meta.name, "LLaMA 7B");
1100 assert_eq!(meta.parameters, 7_000_000_000);
1101 assert_eq!(meta.quantization, Some("Q4_K".to_string()));
1102 assert_eq!(meta.capabilities.len(), 2);
1103 }
1104
1105 #[test]
1106 fn test_model_metadata_no_quantization() {
1107 let meta = ModelMetadata {
1108 model_id: ModelId("whisper".to_string()),
1109 name: "Whisper".to_string(),
1110 version: "1.0".to_string(),
1111 capabilities: vec![Capability::Transcribe],
1112 parameters: 1_500_000_000,
1113 quantization: None,
1114 };
1115 assert!(meta.quantization.is_none());
1116 }
1117
1118 #[test]
1119 fn test_model_metadata_clone() {
1120 let meta = ModelMetadata {
1121 model_id: ModelId("test".to_string()),
1122 name: "Test".to_string(),
1123 version: "1.0".to_string(),
1124 capabilities: vec![Capability::Embed],
1125 parameters: 100,
1126 quantization: None,
1127 };
1128 let cloned = meta.clone();
1129 assert_eq!(cloned.model_id, ModelId("test".to_string()));
1130 }
1131
1132 #[test]
1137 fn test_node_health_construction() {
1138 let health = NodeHealth {
1139 node_id: NodeId("test-node".to_string()),
1140 status: HealthState::Healthy,
1141 latency_p50: Duration::from_millis(25),
1142 latency_p99: Duration::from_millis(100),
1143 throughput: 500,
1144 gpu_utilization: Some(0.75),
1145 queue_depth: 3,
1146 last_check: std::time::Instant::now(),
1147 };
1148 assert_eq!(health.status, HealthState::Healthy);
1149 assert_eq!(health.throughput, 500);
1150 assert_eq!(health.gpu_utilization, Some(0.75));
1151 assert_eq!(health.queue_depth, 3);
1152 }
1153
1154 #[test]
1155 fn test_node_health_no_gpu() {
1156 let health = NodeHealth {
1157 node_id: NodeId("cpu-node".to_string()),
1158 status: HealthState::Healthy,
1159 latency_p50: Duration::from_millis(50),
1160 latency_p99: Duration::from_millis(200),
1161 throughput: 100,
1162 gpu_utilization: None,
1163 queue_depth: 0,
1164 last_check: std::time::Instant::now(),
1165 };
1166 assert!(health.gpu_utilization.is_none());
1167 }
1168
1169 #[test]
1174 fn test_gateway_stats_default() {
1175 let stats = GatewayStats::default();
1176 assert_eq!(stats.total_requests, 0);
1177 assert_eq!(stats.successful_requests, 0);
1178 assert_eq!(stats.failed_requests, 0);
1179 assert_eq!(stats.total_tokens, 0);
1180 assert_eq!(stats.avg_latency, Duration::ZERO);
1181 assert_eq!(stats.active_streams, 0);
1182 }
1183
1184 #[test]
1185 fn test_gateway_stats_clone() {
1186 let stats = GatewayStats {
1187 total_requests: 100,
1188 successful_requests: 95,
1189 failed_requests: 5,
1190 total_tokens: 5000,
1191 avg_latency: Duration::from_millis(50),
1192 active_streams: 2,
1193 };
1194 let cloned = stats.clone();
1195 assert_eq!(cloned.total_requests, 100);
1196 assert_eq!(cloned.active_streams, 2);
1197 }
1198
1199 #[test]
1204 fn test_federation_builder_default() {
1205 let builder = FederationBuilder::default();
1206 assert!(builder.catalog.is_none());
1207 assert!(builder.health_checker.is_none());
1208 assert!(builder.router.is_none());
1209 assert!(builder.policies.is_empty());
1210 assert!(builder.middlewares.is_empty());
1211 }
1212
1213 #[test]
1214 fn test_federation_builder_new_defaults() {
1215 let builder = FederationBuilder::new();
1216 assert!(matches!(
1217 builder.load_balance,
1218 LoadBalanceStrategy::LeastLatency
1219 ));
1220 }
1221
1222 #[test]
1223 fn test_federation_builder_with_load_balance_all_strategies() {
1224 for strategy in [
1225 LoadBalanceStrategy::RoundRobin,
1226 LoadBalanceStrategy::LeastConnections,
1227 LoadBalanceStrategy::LeastLatency,
1228 LoadBalanceStrategy::WeightedRandom,
1229 LoadBalanceStrategy::ConsistentHash,
1230 ] {
1231 let builder = FederationBuilder::new().with_load_balance(strategy);
1232 let debug = format!("{:?}", builder.load_balance);
1233 assert!(!debug.is_empty());
1234 }
1235 }
1236
1237 #[test]
1238 fn test_federation_builder_with_policy() {
1239 use super::*;
1240
1241 struct MockPolicy;
1242 impl RoutingPolicyTrait for MockPolicy {
1243 fn score(&self, _: &RouteCandidate, _: &InferenceRequest) -> f64 {
1244 1.0
1245 }
1246 fn is_eligible(&self, _: &RouteCandidate, _: &InferenceRequest) -> bool {
1247 true
1248 }
1249 fn name(&self) -> &'static str {
1250 "mock"
1251 }
1252 }
1253
1254 let builder = FederationBuilder::new()
1255 .with_policy(MockPolicy)
1256 .with_policy(MockPolicy);
1257 assert_eq!(builder.policies.len(), 2);
1258 }
1259
1260 #[test]
1261 fn test_federation_builder_with_middleware() {
1262 use super::*;
1263
1264 struct MockMiddleware;
1265 impl GatewayMiddleware for MockMiddleware {
1266 fn before_route(&self, _: &mut InferenceRequest) -> FederationResult<()> {
1267 Ok(())
1268 }
1269 fn after_infer(
1270 &self,
1271 _: &InferenceRequest,
1272 _: &mut InferenceResponse,
1273 ) -> FederationResult<()> {
1274 Ok(())
1275 }
1276 fn on_error(&self, _: &InferenceRequest, _: &FederationError) {}
1277 }
1278
1279 let builder = FederationBuilder::new().with_middleware(MockMiddleware);
1280 assert_eq!(builder.middlewares.len(), 1);
1281 }
1282}