Skip to main content

apr_cli/federation/
traits.rs

1//! Core trait definitions for APR Federation
2//!
3//! These traits define the contract for federation components.
4//! Implementations can be swapped for different backends (NATS, Redis, etcd, etc.)
5
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9
10// ============================================================================
11// Core Types
12// ============================================================================
13
14/// Unique identifier for a model instance in the federation
15#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16pub struct ModelId(pub String);
17
18impl std::fmt::Display for ModelId {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{}", self.0)
21    }
22}
23
24/// Unique identifier for a region/cluster
25#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26pub struct RegionId(pub String);
27
28impl std::fmt::Display for RegionId {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "{}", self.0)
31    }
32}
33
34/// Unique identifier for a node within a region
35#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct NodeId(pub String);
37
38impl std::fmt::Display for NodeId {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}", self.0)
41    }
42}
43
44/// Model capabilities that can be queried
45#[derive(Debug, Clone, PartialEq)]
46pub enum Capability {
47    /// Automatic speech recognition
48    Transcribe,
49    /// Text-to-speech
50    Synthesize,
51    /// Text generation (LLM)
52    Generate,
53    /// Code generation
54    Code,
55    /// Embeddings
56    Embed,
57    /// Image generation
58    ImageGen,
59    /// Custom capability
60    Custom(String),
61}
62
63/// Privacy/compliance level for data routing
64#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum PrivacyLevel {
66    /// Public data, can route anywhere
67    Public = 0,
68    /// Internal only, keep within org
69    Internal = 1,
70    /// Confidential, specific regions only
71    Confidential = 2,
72    /// Restricted, on-prem only
73    Restricted = 3,
74}
75
76/// Quality of Service requirements
77#[derive(Debug, Clone)]
78pub struct QoSRequirements {
79    /// Maximum acceptable latency
80    pub max_latency: Option<Duration>,
81    /// Minimum throughput (tokens/sec)
82    pub min_throughput: Option<u32>,
83    /// Privacy level required
84    pub privacy: PrivacyLevel,
85    /// Prefer GPU acceleration
86    pub prefer_gpu: bool,
87    /// Cost tier (0 = cheapest, 100 = fastest)
88    pub cost_tolerance: u8,
89}
90
91impl Default for QoSRequirements {
92    fn default() -> Self {
93        Self {
94            max_latency: None,
95            min_throughput: None,
96            privacy: PrivacyLevel::Internal,
97            prefer_gpu: true,
98            cost_tolerance: 50,
99        }
100    }
101}
102
103/// Inference request metadata
104#[derive(Debug, Clone)]
105pub struct InferenceRequest {
106    /// Requested capability
107    pub capability: Capability,
108    /// Input data (opaque bytes)
109    pub input: Vec<u8>,
110    /// QoS requirements
111    pub qos: QoSRequirements,
112    /// Request ID for tracing
113    pub request_id: String,
114    /// User/tenant ID
115    pub tenant_id: Option<String>,
116}
117
118/// Inference response
119#[derive(Debug)]
120pub struct InferenceResponse {
121    /// Output data
122    pub output: Vec<u8>,
123    /// Which node handled the request
124    pub served_by: NodeId,
125    /// Actual latency
126    pub latency: Duration,
127    /// Tokens generated (if applicable)
128    pub tokens: Option<u32>,
129}
130
131/// Error types for federation operations
132#[derive(Debug, thiserror::Error)]
133pub enum FederationError {
134    #[error("No nodes available for capability: {0:?}")]
135    NoCapacity(Capability),
136
137    #[error("All nodes unhealthy for capability: {0:?}")]
138    AllNodesUnhealthy(Capability),
139
140    #[error("QoS requirements cannot be met: {0}")]
141    QoSViolation(String),
142
143    #[error("Privacy policy violation: {0}")]
144    PrivacyViolation(String),
145
146    #[error("Node unreachable: {0}")]
147    NodeUnreachable(NodeId),
148
149    #[error("Timeout after {0:?}")]
150    Timeout(Duration),
151
152    #[error("Circuit breaker open for node: {0}")]
153    CircuitOpen(NodeId),
154
155    #[error("Internal error: {0}")]
156    Internal(String),
157}
158
159// ============================================================================
160// Core Traits
161// ============================================================================
162
163/// Result type alias for federation operations
164pub type FederationResult<T> = Result<T, FederationError>;
165
166/// Boxed future for async trait methods
167pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
168
169/// Model catalog - tracks available models across the federation
170pub trait ModelCatalogTrait: Send + Sync {
171    /// Register a model instance
172    fn register(
173        &self,
174        model_id: ModelId,
175        node_id: NodeId,
176        region_id: RegionId,
177        capabilities: Vec<Capability>,
178    ) -> BoxFuture<'_, FederationResult<()>>;
179
180    /// Deregister a model instance
181    fn deregister(&self, model_id: ModelId, node_id: NodeId)
182        -> BoxFuture<'_, FederationResult<()>>;
183
184    /// Find nodes with a specific capability
185    fn find_by_capability(
186        &self,
187        capability: &Capability,
188    ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>>;
189
190    /// List all registered models
191    fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>>;
192
193    /// Get model metadata
194    fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>>;
195}
196
197/// Model metadata stored in catalog
198#[derive(Debug, Clone)]
199pub struct ModelMetadata {
200    pub model_id: ModelId,
201    pub name: String,
202    pub version: String,
203    pub capabilities: Vec<Capability>,
204    pub parameters: u64,
205    pub quantization: Option<String>,
206}
207
208/// Health checker - monitors node health across federation
209pub trait HealthCheckerTrait: Send + Sync {
210    /// Check health of a specific node
211    fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>>;
212
213    /// Get cached health status (non-blocking)
214    fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth>;
215
216    /// Start background health monitoring
217    fn start_monitoring(&self, interval: Duration) -> BoxFuture<'_, ()>;
218
219    /// Stop health monitoring
220    fn stop_monitoring(&self) -> BoxFuture<'_, ()>;
221}
222
223/// Node health information
224#[derive(Debug, Clone)]
225pub struct NodeHealth {
226    pub node_id: NodeId,
227    pub status: HealthState,
228    pub latency_p50: Duration,
229    pub latency_p99: Duration,
230    pub throughput: u32,
231    pub gpu_utilization: Option<f32>,
232    pub queue_depth: u32,
233    pub last_check: std::time::Instant,
234}
235
236/// Health state enum
237#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum HealthState {
239    /// Node is healthy and accepting requests
240    Healthy,
241    /// Node is degraded but functional
242    Degraded,
243    /// Node is unhealthy, avoid routing
244    Unhealthy,
245    /// Node health is unknown
246    Unknown,
247}
248
249/// Router - selects the best node for a request
250pub trait RouterTrait: Send + Sync {
251    /// Route a request to the best available node
252    fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>>;
253
254    /// Get all possible routes for a request (for debugging/transparency)
255    fn get_candidates(
256        &self,
257        request: &InferenceRequest,
258    ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>>;
259}
260
261/// Selected route target
262#[derive(Debug, Clone)]
263pub struct RouteTarget {
264    pub node_id: NodeId,
265    pub region_id: RegionId,
266    pub endpoint: String,
267    pub estimated_latency: Duration,
268    pub score: f64,
269}
270
271/// Route candidate with scoring details
272#[derive(Debug, Clone)]
273pub struct RouteCandidate {
274    pub target: RouteTarget,
275    pub scores: RouteScores,
276    pub eligible: bool,
277    pub rejection_reason: Option<String>,
278}
279
280/// Breakdown of routing scores
281#[derive(Debug, Clone)]
282pub struct RouteScores {
283    pub latency_score: f64,
284    pub throughput_score: f64,
285    pub cost_score: f64,
286    pub locality_score: f64,
287    pub health_score: f64,
288    pub total: f64,
289}
290
291/// Gateway - the main entry point for federation requests
292pub trait GatewayTrait: Send + Sync {
293    /// Execute an inference request through the federation
294    fn infer(
295        &self,
296        request: InferenceRequest,
297    ) -> BoxFuture<'_, FederationResult<InferenceResponse>>;
298
299    /// Execute with streaming response
300    fn infer_stream(
301        &self,
302        request: InferenceRequest,
303    ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>>;
304
305    /// Get gateway statistics
306    fn stats(&self) -> GatewayStats;
307}
308
309/// Streaming token interface
310pub trait TokenStream: Send {
311    /// Get next token (None = stream complete)
312    fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>>;
313
314    /// Cancel the stream
315    fn cancel(&mut self) -> BoxFuture<'_, ()>;
316}
317
318/// Gateway statistics
319#[derive(Debug, Clone, Default)]
320pub struct GatewayStats {
321    pub total_requests: u64,
322    pub successful_requests: u64,
323    pub failed_requests: u64,
324    pub total_tokens: u64,
325    pub avg_latency: Duration,
326    pub active_streams: u32,
327}
328
329// ============================================================================
330// Middleware Traits (Tower-style composability)
331// ============================================================================
332
333/// Middleware that can wrap a gateway
334pub trait GatewayMiddleware: Send + Sync {
335    /// Process request before routing
336    fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()>;
337
338    /// Process response after inference
339    fn after_infer(
340        &self,
341        request: &InferenceRequest,
342        response: &mut InferenceResponse,
343    ) -> FederationResult<()>;
344
345    /// Handle errors
346    fn on_error(&self, request: &InferenceRequest, error: &FederationError);
347}
348
349/// Circuit breaker for fault tolerance
350pub trait CircuitBreakerTrait: Send + Sync {
351    /// Check if circuit is open (should skip this node)
352    fn is_open(&self, node_id: &NodeId) -> bool;
353
354    /// Record a success
355    fn record_success(&self, node_id: &NodeId);
356
357    /// Record a failure
358    fn record_failure(&self, node_id: &NodeId);
359
360    /// Get circuit state
361    fn state(&self, node_id: &NodeId) -> CircuitState;
362}
363
364/// Circuit breaker state
365#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum CircuitState {
367    /// Normal operation
368    Closed,
369    /// Failing, allowing probe requests
370    HalfOpen,
371    /// Failing, blocking all requests
372    Open,
373}
374
375// ============================================================================
376// Policy Traits
377// ============================================================================
378
379/// Routing policy that influences node selection
380pub trait RoutingPolicyTrait: Send + Sync {
381    /// Score a candidate node (higher = better)
382    fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64;
383
384    /// Check if a candidate is eligible
385    fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool;
386
387    /// Get policy name for logging
388    fn name(&self) -> &'static str;
389}
390
391/// Load balancing strategy
392#[derive(Debug, Clone, Copy, Default)]
393pub enum LoadBalanceStrategy {
394    /// Round-robin across healthy nodes
395    RoundRobin,
396    /// Route to least loaded node
397    LeastConnections,
398    /// Route based on latency
399    #[default]
400    LeastLatency,
401    /// Weighted random
402    WeightedRandom,
403    /// Consistent hashing (sticky sessions)
404    ConsistentHash,
405}
406
407// ============================================================================
408// Builder Pattern for Configuration
409// ============================================================================
410
411/// Builder for creating federation gateways
412#[derive(Default)]
413pub struct FederationBuilder {
414    pub catalog: Option<Box<dyn ModelCatalogTrait>>,
415    pub health_checker: Option<Box<dyn HealthCheckerTrait>>,
416    pub router: Option<Box<dyn RouterTrait>>,
417    pub policies: Vec<Box<dyn RoutingPolicyTrait>>,
418    pub middlewares: Vec<Box<dyn GatewayMiddleware>>,
419    pub load_balance: LoadBalanceStrategy,
420}
421
422impl FederationBuilder {
423    pub fn new() -> Self {
424        Self {
425            load_balance: LoadBalanceStrategy::LeastLatency,
426            ..Default::default()
427        }
428    }
429
430    #[must_use]
431    pub fn with_catalog(mut self, catalog: impl ModelCatalogTrait + 'static) -> Self {
432        self.catalog = Some(Box::new(catalog));
433        self
434    }
435
436    #[must_use]
437    pub fn with_health_checker(mut self, checker: impl HealthCheckerTrait + 'static) -> Self {
438        self.health_checker = Some(Box::new(checker));
439        self
440    }
441
442    #[must_use]
443    pub fn with_router(mut self, router: impl RouterTrait + 'static) -> Self {
444        self.router = Some(Box::new(router));
445        self
446    }
447
448    #[must_use]
449    pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
450        self.policies.push(Box::new(policy));
451        self
452    }
453
454    #[must_use]
455    pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
456        self.middlewares.push(Box::new(middleware));
457        self
458    }
459
460    #[must_use]
461    pub fn with_load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
462        self.load_balance = strategy;
463        self
464    }
465}
466
467// ============================================================================
468// Tests
469// ============================================================================
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_qos_default() {
477        let qos = QoSRequirements::default();
478        assert_eq!(qos.privacy, PrivacyLevel::Internal);
479        assert!(qos.prefer_gpu);
480        assert_eq!(qos.cost_tolerance, 50);
481    }
482
483    #[test]
484    fn test_privacy_ordering() {
485        assert!(PrivacyLevel::Public < PrivacyLevel::Internal);
486        assert!(PrivacyLevel::Internal < PrivacyLevel::Confidential);
487        assert!(PrivacyLevel::Confidential < PrivacyLevel::Restricted);
488    }
489
490    #[test]
491    fn test_health_state() {
492        let healthy = HealthState::Healthy;
493        let degraded = HealthState::Degraded;
494        assert_ne!(healthy, degraded);
495    }
496
497    #[test]
498    fn test_circuit_state() {
499        let closed = CircuitState::Closed;
500        let open = CircuitState::Open;
501        assert_ne!(closed, open);
502    }
503
504    #[test]
505    fn test_federation_builder() {
506        let builder =
507            FederationBuilder::new().with_load_balance(LoadBalanceStrategy::LeastConnections);
508
509        assert!(matches!(
510            builder.load_balance,
511            LoadBalanceStrategy::LeastConnections
512        ));
513    }
514
515    #[test]
516    fn test_model_id_equality() {
517        let id1 = ModelId("whisper-v3".to_string());
518        let id2 = ModelId("whisper-v3".to_string());
519        let id3 = ModelId("llama-7b".to_string());
520
521        assert_eq!(id1, id2);
522        assert_ne!(id1, id3);
523    }
524
525    #[test]
526    fn test_capability_variants() {
527        let cap1 = Capability::Transcribe;
528        let cap2 = Capability::Custom("sentiment".to_string());
529
530        assert_ne!(cap1, cap2);
531        assert_eq!(cap1, Capability::Transcribe);
532    }
533
534    // =========================================================================
535    // Display trait tests
536    // =========================================================================
537
538    #[test]
539    fn test_model_id_display() {
540        let id = ModelId("whisper-v3".to_string());
541        assert_eq!(format!("{}", id), "whisper-v3");
542        assert_eq!(id.to_string(), "whisper-v3");
543    }
544
545    #[test]
546    fn test_model_id_display_empty() {
547        let id = ModelId(String::new());
548        assert_eq!(format!("{}", id), "");
549    }
550
551    #[test]
552    fn test_region_id_display() {
553        let id = RegionId("us-west-2".to_string());
554        assert_eq!(format!("{}", id), "us-west-2");
555        assert_eq!(id.to_string(), "us-west-2");
556    }
557
558    #[test]
559    fn test_region_id_display_empty() {
560        let id = RegionId(String::new());
561        assert_eq!(format!("{}", id), "");
562    }
563
564    #[test]
565    fn test_node_id_display() {
566        let id = NodeId("gpu-node-01".to_string());
567        assert_eq!(format!("{}", id), "gpu-node-01");
568        assert_eq!(id.to_string(), "gpu-node-01");
569    }
570
571    #[test]
572    fn test_node_id_display_empty() {
573        let id = NodeId(String::new());
574        assert_eq!(format!("{}", id), "");
575    }
576
577    // =========================================================================
578    // Hash trait tests
579    // =========================================================================
580
581    #[test]
582    fn test_model_id_hash_consistency() {
583        use std::collections::HashSet;
584        let mut set = HashSet::new();
585        set.insert(ModelId("a".to_string()));
586        set.insert(ModelId("b".to_string()));
587        set.insert(ModelId("a".to_string())); // duplicate
588        assert_eq!(set.len(), 2);
589    }
590
591    #[test]
592    fn test_region_id_hash_consistency() {
593        use std::collections::HashSet;
594        let mut set = HashSet::new();
595        set.insert(RegionId("us-west".to_string()));
596        set.insert(RegionId("eu-west".to_string()));
597        set.insert(RegionId("us-west".to_string())); // duplicate
598        assert_eq!(set.len(), 2);
599    }
600
601    #[test]
602    fn test_node_id_hash_consistency() {
603        use std::collections::HashSet;
604        let mut set = HashSet::new();
605        set.insert(NodeId("node1".to_string()));
606        set.insert(NodeId("node2".to_string()));
607        set.insert(NodeId("node1".to_string())); // duplicate
608        assert_eq!(set.len(), 2);
609    }
610
611    // =========================================================================
612    // Clone/Eq trait tests
613    // =========================================================================
614
615    #[test]
616    fn test_model_id_clone() {
617        let id = ModelId("test".to_string());
618        let cloned = id.clone();
619        assert_eq!(id, cloned);
620    }
621
622    #[test]
623    fn test_region_id_equality() {
624        let a = RegionId("us-west".to_string());
625        let b = RegionId("us-west".to_string());
626        let c = RegionId("eu-west".to_string());
627        assert_eq!(a, b);
628        assert_ne!(a, c);
629    }
630
631    #[test]
632    fn test_node_id_equality() {
633        let a = NodeId("node-1".to_string());
634        let b = NodeId("node-1".to_string());
635        let c = NodeId("node-2".to_string());
636        assert_eq!(a, b);
637        assert_ne!(a, c);
638    }
639
640    // =========================================================================
641    // Capability exhaustive tests
642    // =========================================================================
643
644    #[test]
645    fn test_all_capability_variants() {
646        let caps = vec![
647            Capability::Transcribe,
648            Capability::Synthesize,
649            Capability::Generate,
650            Capability::Code,
651            Capability::Embed,
652            Capability::ImageGen,
653            Capability::Custom("my_cap".to_string()),
654        ];
655
656        // All should be distinct from each other
657        for (i, a) in caps.iter().enumerate() {
658            for (j, b) in caps.iter().enumerate() {
659                if i == j {
660                    assert_eq!(a, b);
661                } else {
662                    assert_ne!(a, b);
663                }
664            }
665        }
666    }
667
668    #[test]
669    fn test_capability_custom_equality() {
670        let a = Capability::Custom("sentiment".to_string());
671        let b = Capability::Custom("sentiment".to_string());
672        let c = Capability::Custom("other".to_string());
673        assert_eq!(a, b);
674        assert_ne!(a, c);
675    }
676
677    #[test]
678    fn test_capability_debug_format() {
679        let cap = Capability::Transcribe;
680        let debug = format!("{:?}", cap);
681        assert_eq!(debug, "Transcribe");
682
683        let custom = Capability::Custom("test".to_string());
684        let debug = format!("{:?}", custom);
685        assert!(debug.contains("Custom"));
686        assert!(debug.contains("test"));
687    }
688
689    // =========================================================================
690    // PrivacyLevel tests
691    // =========================================================================
692
693    #[test]
694    fn test_privacy_level_all_orderings() {
695        let levels = [
696            PrivacyLevel::Public,
697            PrivacyLevel::Internal,
698            PrivacyLevel::Confidential,
699            PrivacyLevel::Restricted,
700        ];
701
702        // Verify strictly increasing
703        for i in 0..levels.len() - 1 {
704            assert!(levels[i] < levels[i + 1]);
705            assert!(levels[i + 1] > levels[i]);
706        }
707    }
708
709    #[test]
710    fn test_privacy_level_copy() {
711        let level = PrivacyLevel::Confidential;
712        let copied = level;
713        assert_eq!(level, copied);
714    }
715
716    #[test]
717    fn test_privacy_level_eq() {
718        assert_eq!(PrivacyLevel::Public, PrivacyLevel::Public);
719        assert_ne!(PrivacyLevel::Public, PrivacyLevel::Internal);
720    }
721
722    // =========================================================================
723    // QoSRequirements tests
724    // =========================================================================
725
726    #[test]
727    fn test_qos_default_none_fields() {
728        let qos = QoSRequirements::default();
729        assert!(qos.max_latency.is_none());
730        assert!(qos.min_throughput.is_none());
731    }
732
733    #[test]
734    fn test_qos_custom_values() {
735        let qos = QoSRequirements {
736            max_latency: Some(Duration::from_secs(2)),
737            min_throughput: Some(100),
738            privacy: PrivacyLevel::Restricted,
739            prefer_gpu: false,
740            cost_tolerance: 10,
741        };
742        assert_eq!(qos.max_latency, Some(Duration::from_secs(2)));
743        assert_eq!(qos.min_throughput, Some(100));
744        assert_eq!(qos.privacy, PrivacyLevel::Restricted);
745        assert!(!qos.prefer_gpu);
746        assert_eq!(qos.cost_tolerance, 10);
747    }
748
749    // =========================================================================
750    // InferenceRequest/Response tests
751    // =========================================================================
752
753    #[test]
754    fn test_inference_request_construction() {
755        let req = InferenceRequest {
756            capability: Capability::Generate,
757            input: b"hello world".to_vec(),
758            qos: QoSRequirements::default(),
759            request_id: "req-123".to_string(),
760            tenant_id: Some("tenant-1".to_string()),
761        };
762        assert_eq!(req.request_id, "req-123");
763        assert_eq!(req.tenant_id, Some("tenant-1".to_string()));
764        assert_eq!(req.input, b"hello world");
765    }
766
767    #[test]
768    fn test_inference_request_no_tenant() {
769        let req = InferenceRequest {
770            capability: Capability::Embed,
771            input: vec![],
772            qos: QoSRequirements::default(),
773            request_id: "req-456".to_string(),
774            tenant_id: None,
775        };
776        assert!(req.tenant_id.is_none());
777    }
778
779    #[test]
780    fn test_inference_request_clone() {
781        let req = InferenceRequest {
782            capability: Capability::Code,
783            input: b"fn main()".to_vec(),
784            qos: QoSRequirements::default(),
785            request_id: "req-789".to_string(),
786            tenant_id: None,
787        };
788        let cloned = req.clone();
789        assert_eq!(cloned.request_id, "req-789");
790        assert_eq!(cloned.input, b"fn main()");
791    }
792
793    #[test]
794    fn test_inference_response_construction() {
795        let resp = InferenceResponse {
796            output: b"generated text".to_vec(),
797            served_by: NodeId("node-42".to_string()),
798            latency: Duration::from_millis(150),
799            tokens: Some(25),
800        };
801        assert_eq!(resp.output, b"generated text");
802        assert_eq!(resp.served_by, NodeId("node-42".to_string()));
803        assert_eq!(resp.latency, Duration::from_millis(150));
804        assert_eq!(resp.tokens, Some(25));
805    }
806
807    #[test]
808    fn test_inference_response_no_tokens() {
809        let resp = InferenceResponse {
810            output: vec![],
811            served_by: NodeId("node-1".to_string()),
812            latency: Duration::from_millis(10),
813            tokens: None,
814        };
815        assert!(resp.tokens.is_none());
816    }
817
818    // =========================================================================
819    // FederationError tests
820    // =========================================================================
821
822    #[test]
823    fn test_federation_error_no_capacity() {
824        let err = FederationError::NoCapacity(Capability::Transcribe);
825        let msg = format!("{}", err);
826        assert!(msg.contains("No nodes available"));
827        assert!(msg.contains("Transcribe"));
828    }
829
830    #[test]
831    fn test_federation_error_all_nodes_unhealthy() {
832        let err = FederationError::AllNodesUnhealthy(Capability::Generate);
833        let msg = format!("{}", err);
834        assert!(msg.contains("All nodes unhealthy"));
835        assert!(msg.contains("Generate"));
836    }
837
838    #[test]
839    fn test_federation_error_qos_violation() {
840        let err = FederationError::QoSViolation("latency too high".to_string());
841        let msg = format!("{}", err);
842        assert!(msg.contains("QoS requirements cannot be met"));
843        assert!(msg.contains("latency too high"));
844    }
845
846    #[test]
847    fn test_federation_error_privacy_violation() {
848        let err = FederationError::PrivacyViolation("data must stay in EU".to_string());
849        let msg = format!("{}", err);
850        assert!(msg.contains("Privacy policy violation"));
851        assert!(msg.contains("data must stay in EU"));
852    }
853
854    #[test]
855    fn test_federation_error_node_unreachable() {
856        let err = FederationError::NodeUnreachable(NodeId("dead-node".to_string()));
857        let msg = format!("{}", err);
858        assert!(msg.contains("Node unreachable"));
859        assert!(msg.contains("dead-node"));
860    }
861
862    #[test]
863    fn test_federation_error_timeout() {
864        let err = FederationError::Timeout(Duration::from_secs(30));
865        let msg = format!("{}", err);
866        assert!(msg.contains("Timeout"));
867        assert!(msg.contains("30"));
868    }
869
870    #[test]
871    fn test_federation_error_circuit_open() {
872        let err = FederationError::CircuitOpen(NodeId("overloaded".to_string()));
873        let msg = format!("{}", err);
874        assert!(msg.contains("Circuit breaker open"));
875        assert!(msg.contains("overloaded"));
876    }
877
878    #[test]
879    fn test_federation_error_internal() {
880        let err = FederationError::Internal("unexpected state".to_string());
881        let msg = format!("{}", err);
882        assert!(msg.contains("Internal error"));
883        assert!(msg.contains("unexpected state"));
884    }
885
886    #[test]
887    fn test_federation_error_debug() {
888        let err = FederationError::NoCapacity(Capability::Embed);
889        let debug = format!("{:?}", err);
890        assert!(debug.contains("NoCapacity"));
891    }
892
893    // =========================================================================
894    // HealthState tests
895    // =========================================================================
896
897    #[test]
898    fn test_health_state_all_variants() {
899        let states = [
900            HealthState::Healthy,
901            HealthState::Degraded,
902            HealthState::Unhealthy,
903            HealthState::Unknown,
904        ];
905        // All distinct
906        for (i, a) in states.iter().enumerate() {
907            for (j, b) in states.iter().enumerate() {
908                if i == j {
909                    assert_eq!(a, b);
910                } else {
911                    assert_ne!(a, b);
912                }
913            }
914        }
915    }
916
917    #[test]
918    fn test_health_state_copy() {
919        let state = HealthState::Healthy;
920        let copied = state;
921        assert_eq!(state, copied);
922    }
923
924    // =========================================================================
925    // CircuitState tests
926    // =========================================================================
927
928    #[test]
929    fn test_circuit_state_all_variants() {
930        let states = [
931            CircuitState::Closed,
932            CircuitState::HalfOpen,
933            CircuitState::Open,
934        ];
935        for (i, a) in states.iter().enumerate() {
936            for (j, b) in states.iter().enumerate() {
937                if i == j {
938                    assert_eq!(a, b);
939                } else {
940                    assert_ne!(a, b);
941                }
942            }
943        }
944    }
945
946    #[test]
947    fn test_circuit_state_copy() {
948        let state = CircuitState::HalfOpen;
949        let copied = state;
950        assert_eq!(state, copied);
951    }
952
953    // =========================================================================
954    // LoadBalanceStrategy tests
955    // =========================================================================
956
957    #[test]
958    fn test_load_balance_default() {
959        let strategy = LoadBalanceStrategy::default();
960        assert!(matches!(strategy, LoadBalanceStrategy::LeastLatency));
961    }
962
963    #[test]
964    fn test_load_balance_all_variants() {
965        let strategies = [
966            LoadBalanceStrategy::RoundRobin,
967            LoadBalanceStrategy::LeastConnections,
968            LoadBalanceStrategy::LeastLatency,
969            LoadBalanceStrategy::WeightedRandom,
970            LoadBalanceStrategy::ConsistentHash,
971        ];
972        // Verify debug output
973        for s in &strategies {
974            let debug = format!("{:?}", s);
975            assert!(!debug.is_empty());
976        }
977    }
978
979    #[test]
980    fn test_load_balance_clone() {
981        let strategy = LoadBalanceStrategy::WeightedRandom;
982        let cloned = strategy;
983        assert!(matches!(cloned, LoadBalanceStrategy::WeightedRandom));
984    }
985
986    // =========================================================================
987    // RouteTarget/RouteCandidate/RouteScores tests
988    // =========================================================================
989
990    #[test]
991    fn test_route_target_construction() {
992        let target = RouteTarget {
993            node_id: NodeId("n1".to_string()),
994            region_id: RegionId("r1".to_string()),
995            endpoint: "http://n1:8080".to_string(),
996            estimated_latency: Duration::from_millis(50),
997            score: 0.95,
998        };
999        assert_eq!(target.node_id, NodeId("n1".to_string()));
1000        assert_eq!(target.endpoint, "http://n1:8080");
1001        assert_eq!(target.estimated_latency, Duration::from_millis(50));
1002    }
1003
1004    #[test]
1005    fn test_route_target_clone() {
1006        let target = RouteTarget {
1007            node_id: NodeId("n1".to_string()),
1008            region_id: RegionId("r1".to_string()),
1009            endpoint: "http://n1:8080".to_string(),
1010            estimated_latency: Duration::from_millis(50),
1011            score: 0.5,
1012        };
1013        let cloned = target.clone();
1014        assert_eq!(cloned.node_id, NodeId("n1".to_string()));
1015        assert_eq!(cloned.score, 0.5);
1016    }
1017
1018    #[test]
1019    fn test_route_scores_construction() {
1020        let scores = RouteScores {
1021            latency_score: 0.9,
1022            throughput_score: 0.8,
1023            cost_score: 0.7,
1024            locality_score: 0.6,
1025            health_score: 1.0,
1026            total: 0.85,
1027        };
1028        assert_eq!(scores.latency_score, 0.9);
1029        assert_eq!(scores.total, 0.85);
1030    }
1031
1032    #[test]
1033    fn test_route_candidate_eligible() {
1034        let candidate = RouteCandidate {
1035            target: RouteTarget {
1036                node_id: NodeId("n1".to_string()),
1037                region_id: RegionId("r1".to_string()),
1038                endpoint: String::new(),
1039                estimated_latency: Duration::from_millis(100),
1040                score: 0.8,
1041            },
1042            scores: RouteScores {
1043                latency_score: 0.9,
1044                throughput_score: 0.8,
1045                cost_score: 0.5,
1046                locality_score: 0.7,
1047                health_score: 1.0,
1048                total: 0.8,
1049            },
1050            eligible: true,
1051            rejection_reason: None,
1052        };
1053        assert!(candidate.eligible);
1054        assert!(candidate.rejection_reason.is_none());
1055    }
1056
1057    #[test]
1058    fn test_route_candidate_rejected() {
1059        let candidate = RouteCandidate {
1060            target: RouteTarget {
1061                node_id: NodeId("n1".to_string()),
1062                region_id: RegionId("r1".to_string()),
1063                endpoint: String::new(),
1064                estimated_latency: Duration::from_millis(100),
1065                score: 0.0,
1066            },
1067            scores: RouteScores {
1068                latency_score: 0.0,
1069                throughput_score: 0.0,
1070                cost_score: 0.0,
1071                locality_score: 0.0,
1072                health_score: 0.0,
1073                total: 0.0,
1074            },
1075            eligible: false,
1076            rejection_reason: Some("Policy rejected".to_string()),
1077        };
1078        assert!(!candidate.eligible);
1079        assert_eq!(
1080            candidate.rejection_reason,
1081            Some("Policy rejected".to_string())
1082        );
1083    }
1084
1085    // =========================================================================
1086    // ModelMetadata tests
1087    // =========================================================================
1088
1089    #[test]
1090    fn test_model_metadata_construction() {
1091        let meta = ModelMetadata {
1092            model_id: ModelId("llama-7b".to_string()),
1093            name: "LLaMA 7B".to_string(),
1094            version: "2.0".to_string(),
1095            capabilities: vec![Capability::Generate, Capability::Code],
1096            parameters: 7_000_000_000,
1097            quantization: Some("Q4_K".to_string()),
1098        };
1099        assert_eq!(meta.name, "LLaMA 7B");
1100        assert_eq!(meta.parameters, 7_000_000_000);
1101        assert_eq!(meta.quantization, Some("Q4_K".to_string()));
1102        assert_eq!(meta.capabilities.len(), 2);
1103    }
1104
1105    #[test]
1106    fn test_model_metadata_no_quantization() {
1107        let meta = ModelMetadata {
1108            model_id: ModelId("whisper".to_string()),
1109            name: "Whisper".to_string(),
1110            version: "1.0".to_string(),
1111            capabilities: vec![Capability::Transcribe],
1112            parameters: 1_500_000_000,
1113            quantization: None,
1114        };
1115        assert!(meta.quantization.is_none());
1116    }
1117
1118    #[test]
1119    fn test_model_metadata_clone() {
1120        let meta = ModelMetadata {
1121            model_id: ModelId("test".to_string()),
1122            name: "Test".to_string(),
1123            version: "1.0".to_string(),
1124            capabilities: vec![Capability::Embed],
1125            parameters: 100,
1126            quantization: None,
1127        };
1128        let cloned = meta.clone();
1129        assert_eq!(cloned.model_id, ModelId("test".to_string()));
1130    }
1131
1132    // =========================================================================
1133    // NodeHealth tests
1134    // =========================================================================
1135
1136    #[test]
1137    fn test_node_health_construction() {
1138        let health = NodeHealth {
1139            node_id: NodeId("test-node".to_string()),
1140            status: HealthState::Healthy,
1141            latency_p50: Duration::from_millis(25),
1142            latency_p99: Duration::from_millis(100),
1143            throughput: 500,
1144            gpu_utilization: Some(0.75),
1145            queue_depth: 3,
1146            last_check: std::time::Instant::now(),
1147        };
1148        assert_eq!(health.status, HealthState::Healthy);
1149        assert_eq!(health.throughput, 500);
1150        assert_eq!(health.gpu_utilization, Some(0.75));
1151        assert_eq!(health.queue_depth, 3);
1152    }
1153
1154    #[test]
1155    fn test_node_health_no_gpu() {
1156        let health = NodeHealth {
1157            node_id: NodeId("cpu-node".to_string()),
1158            status: HealthState::Healthy,
1159            latency_p50: Duration::from_millis(50),
1160            latency_p99: Duration::from_millis(200),
1161            throughput: 100,
1162            gpu_utilization: None,
1163            queue_depth: 0,
1164            last_check: std::time::Instant::now(),
1165        };
1166        assert!(health.gpu_utilization.is_none());
1167    }
1168
1169    // =========================================================================
1170    // GatewayStats tests
1171    // =========================================================================
1172
1173    #[test]
1174    fn test_gateway_stats_default() {
1175        let stats = GatewayStats::default();
1176        assert_eq!(stats.total_requests, 0);
1177        assert_eq!(stats.successful_requests, 0);
1178        assert_eq!(stats.failed_requests, 0);
1179        assert_eq!(stats.total_tokens, 0);
1180        assert_eq!(stats.avg_latency, Duration::ZERO);
1181        assert_eq!(stats.active_streams, 0);
1182    }
1183
1184    #[test]
1185    fn test_gateway_stats_clone() {
1186        let stats = GatewayStats {
1187            total_requests: 100,
1188            successful_requests: 95,
1189            failed_requests: 5,
1190            total_tokens: 5000,
1191            avg_latency: Duration::from_millis(50),
1192            active_streams: 2,
1193        };
1194        let cloned = stats.clone();
1195        assert_eq!(cloned.total_requests, 100);
1196        assert_eq!(cloned.active_streams, 2);
1197    }
1198
1199    // =========================================================================
1200    // FederationBuilder tests
1201    // =========================================================================
1202
1203    #[test]
1204    fn test_federation_builder_default() {
1205        let builder = FederationBuilder::default();
1206        assert!(builder.catalog.is_none());
1207        assert!(builder.health_checker.is_none());
1208        assert!(builder.router.is_none());
1209        assert!(builder.policies.is_empty());
1210        assert!(builder.middlewares.is_empty());
1211    }
1212
1213    #[test]
1214    fn test_federation_builder_new_defaults() {
1215        let builder = FederationBuilder::new();
1216        assert!(matches!(
1217            builder.load_balance,
1218            LoadBalanceStrategy::LeastLatency
1219        ));
1220    }
1221
1222    #[test]
1223    fn test_federation_builder_with_load_balance_all_strategies() {
1224        for strategy in [
1225            LoadBalanceStrategy::RoundRobin,
1226            LoadBalanceStrategy::LeastConnections,
1227            LoadBalanceStrategy::LeastLatency,
1228            LoadBalanceStrategy::WeightedRandom,
1229            LoadBalanceStrategy::ConsistentHash,
1230        ] {
1231            let builder = FederationBuilder::new().with_load_balance(strategy);
1232            let debug = format!("{:?}", builder.load_balance);
1233            assert!(!debug.is_empty());
1234        }
1235    }
1236
1237    #[test]
1238    fn test_federation_builder_with_policy() {
1239        use super::*;
1240
1241        struct MockPolicy;
1242        impl RoutingPolicyTrait for MockPolicy {
1243            fn score(&self, _: &RouteCandidate, _: &InferenceRequest) -> f64 {
1244                1.0
1245            }
1246            fn is_eligible(&self, _: &RouteCandidate, _: &InferenceRequest) -> bool {
1247                true
1248            }
1249            fn name(&self) -> &'static str {
1250                "mock"
1251            }
1252        }
1253
1254        let builder = FederationBuilder::new()
1255            .with_policy(MockPolicy)
1256            .with_policy(MockPolicy);
1257        assert_eq!(builder.policies.len(), 2);
1258    }
1259
1260    #[test]
1261    fn test_federation_builder_with_middleware() {
1262        use super::*;
1263
1264        struct MockMiddleware;
1265        impl GatewayMiddleware for MockMiddleware {
1266            fn before_route(&self, _: &mut InferenceRequest) -> FederationResult<()> {
1267                Ok(())
1268            }
1269            fn after_infer(
1270                &self,
1271                _: &InferenceRequest,
1272                _: &mut InferenceResponse,
1273            ) -> FederationResult<()> {
1274                Ok(())
1275            }
1276            fn on_error(&self, _: &InferenceRequest, _: &FederationError) {}
1277        }
1278
1279        let builder = FederationBuilder::new().with_middleware(MockMiddleware);
1280        assert_eq!(builder.middlewares.len(), 1);
1281    }
1282}