apr_cli/federation/
traits.rs

1//! Core trait definitions for APR Federation
2//!
3//! These traits define the contract for federation components.
4//! Implementations can be swapped for different backends (NATS, Redis, etcd, etc.)
5
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9
10// ============================================================================
11// Core Types
12// ============================================================================
13
14/// Unique identifier for a model instance in the federation
15#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16pub struct ModelId(pub String);
17
18impl std::fmt::Display for ModelId {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{}", self.0)
21    }
22}
23
24/// Unique identifier for a region/cluster
25#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26pub struct RegionId(pub String);
27
28impl std::fmt::Display for RegionId {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "{}", self.0)
31    }
32}
33
34/// Unique identifier for a node within a region
35#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct NodeId(pub String);
37
38impl std::fmt::Display for NodeId {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}", self.0)
41    }
42}
43
44/// Model capabilities that can be queried
45#[derive(Debug, Clone, PartialEq)]
46pub enum Capability {
47    /// Automatic speech recognition
48    Transcribe,
49    /// Text-to-speech
50    Synthesize,
51    /// Text generation (LLM)
52    Generate,
53    /// Code generation
54    Code,
55    /// Embeddings
56    Embed,
57    /// Image generation
58    ImageGen,
59    /// Custom capability
60    Custom(String),
61}
62
63/// Privacy/compliance level for data routing
64#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum PrivacyLevel {
66    /// Public data, can route anywhere
67    Public = 0,
68    /// Internal only, keep within org
69    Internal = 1,
70    /// Confidential, specific regions only
71    Confidential = 2,
72    /// Restricted, on-prem only
73    Restricted = 3,
74}
75
76/// Quality of Service requirements
77#[derive(Debug, Clone)]
78pub struct QoSRequirements {
79    /// Maximum acceptable latency
80    pub max_latency: Option<Duration>,
81    /// Minimum throughput (tokens/sec)
82    pub min_throughput: Option<u32>,
83    /// Privacy level required
84    pub privacy: PrivacyLevel,
85    /// Prefer GPU acceleration
86    pub prefer_gpu: bool,
87    /// Cost tier (0 = cheapest, 100 = fastest)
88    pub cost_tolerance: u8,
89}
90
91impl Default for QoSRequirements {
92    fn default() -> Self {
93        Self {
94            max_latency: None,
95            min_throughput: None,
96            privacy: PrivacyLevel::Internal,
97            prefer_gpu: true,
98            cost_tolerance: 50,
99        }
100    }
101}
102
103/// Inference request metadata
104#[derive(Debug, Clone)]
105pub struct InferenceRequest {
106    /// Requested capability
107    pub capability: Capability,
108    /// Input data (opaque bytes)
109    pub input: Vec<u8>,
110    /// QoS requirements
111    pub qos: QoSRequirements,
112    /// Request ID for tracing
113    pub request_id: String,
114    /// User/tenant ID
115    pub tenant_id: Option<String>,
116}
117
118/// Inference response
119#[derive(Debug)]
120pub struct InferenceResponse {
121    /// Output data
122    pub output: Vec<u8>,
123    /// Which node handled the request
124    pub served_by: NodeId,
125    /// Actual latency
126    pub latency: Duration,
127    /// Tokens generated (if applicable)
128    pub tokens: Option<u32>,
129}
130
131/// Error types for federation operations
132#[derive(Debug, thiserror::Error)]
133pub enum FederationError {
134    #[error("No nodes available for capability: {0:?}")]
135    NoCapacity(Capability),
136
137    #[error("All nodes unhealthy for capability: {0:?}")]
138    AllNodesUnhealthy(Capability),
139
140    #[error("QoS requirements cannot be met: {0}")]
141    QoSViolation(String),
142
143    #[error("Privacy policy violation: {0}")]
144    PrivacyViolation(String),
145
146    #[error("Node unreachable: {0}")]
147    NodeUnreachable(NodeId),
148
149    #[error("Timeout after {0:?}")]
150    Timeout(Duration),
151
152    #[error("Circuit breaker open for node: {0}")]
153    CircuitOpen(NodeId),
154
155    #[error("Internal error: {0}")]
156    Internal(String),
157}
158
159// ============================================================================
160// Core Traits
161// ============================================================================
162
163/// Result type alias for federation operations
164pub type FederationResult<T> = Result<T, FederationError>;
165
166/// Boxed future for async trait methods
167pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
168
169/// Model catalog - tracks available models across the federation
170pub trait ModelCatalogTrait: Send + Sync {
171    /// Register a model instance
172    fn register(
173        &self,
174        model_id: ModelId,
175        node_id: NodeId,
176        region_id: RegionId,
177        capabilities: Vec<Capability>,
178    ) -> BoxFuture<'_, FederationResult<()>>;
179
180    /// Deregister a model instance
181    fn deregister(&self, model_id: ModelId, node_id: NodeId)
182        -> BoxFuture<'_, FederationResult<()>>;
183
184    /// Find nodes with a specific capability
185    fn find_by_capability(
186        &self,
187        capability: &Capability,
188    ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>>;
189
190    /// List all registered models
191    fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>>;
192
193    /// Get model metadata
194    fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>>;
195}
196
197/// Model metadata stored in catalog
198#[derive(Debug, Clone)]
199pub struct ModelMetadata {
200    pub model_id: ModelId,
201    pub name: String,
202    pub version: String,
203    pub capabilities: Vec<Capability>,
204    pub parameters: u64,
205    pub quantization: Option<String>,
206}
207
208/// Health checker - monitors node health across federation
209pub trait HealthCheckerTrait: Send + Sync {
210    /// Check health of a specific node
211    fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>>;
212
213    /// Get cached health status (non-blocking)
214    fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth>;
215
216    /// Start background health monitoring
217    fn start_monitoring(&self, interval: Duration) -> BoxFuture<'_, ()>;
218
219    /// Stop health monitoring
220    fn stop_monitoring(&self) -> BoxFuture<'_, ()>;
221}
222
223/// Node health information
224#[derive(Debug, Clone)]
225pub struct NodeHealth {
226    pub node_id: NodeId,
227    pub status: HealthState,
228    pub latency_p50: Duration,
229    pub latency_p99: Duration,
230    pub throughput: u32,
231    pub gpu_utilization: Option<f32>,
232    pub queue_depth: u32,
233    pub last_check: std::time::Instant,
234}
235
236/// Health state enum
237#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum HealthState {
239    /// Node is healthy and accepting requests
240    Healthy,
241    /// Node is degraded but functional
242    Degraded,
243    /// Node is unhealthy, avoid routing
244    Unhealthy,
245    /// Node health is unknown
246    Unknown,
247}
248
249/// Router - selects the best node for a request
250pub trait RouterTrait: Send + Sync {
251    /// Route a request to the best available node
252    fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>>;
253
254    /// Get all possible routes for a request (for debugging/transparency)
255    fn get_candidates(
256        &self,
257        request: &InferenceRequest,
258    ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>>;
259}
260
261/// Selected route target
262#[derive(Debug, Clone)]
263pub struct RouteTarget {
264    pub node_id: NodeId,
265    pub region_id: RegionId,
266    pub endpoint: String,
267    pub estimated_latency: Duration,
268    pub score: f64,
269}
270
271/// Route candidate with scoring details
272#[derive(Debug, Clone)]
273pub struct RouteCandidate {
274    pub target: RouteTarget,
275    pub scores: RouteScores,
276    pub eligible: bool,
277    pub rejection_reason: Option<String>,
278}
279
280/// Breakdown of routing scores
281#[derive(Debug, Clone)]
282pub struct RouteScores {
283    pub latency_score: f64,
284    pub throughput_score: f64,
285    pub cost_score: f64,
286    pub locality_score: f64,
287    pub health_score: f64,
288    pub total: f64,
289}
290
291/// Gateway - the main entry point for federation requests
292pub trait GatewayTrait: Send + Sync {
293    /// Execute an inference request through the federation
294    fn infer(
295        &self,
296        request: InferenceRequest,
297    ) -> BoxFuture<'_, FederationResult<InferenceResponse>>;
298
299    /// Execute with streaming response
300    fn infer_stream(
301        &self,
302        request: InferenceRequest,
303    ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>>;
304
305    /// Get gateway statistics
306    fn stats(&self) -> GatewayStats;
307}
308
309/// Streaming token interface
310pub trait TokenStream: Send {
311    /// Get next token (None = stream complete)
312    fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>>;
313
314    /// Cancel the stream
315    fn cancel(&mut self) -> BoxFuture<'_, ()>;
316}
317
318/// Gateway statistics
319#[derive(Debug, Clone, Default)]
320pub struct GatewayStats {
321    pub total_requests: u64,
322    pub successful_requests: u64,
323    pub failed_requests: u64,
324    pub total_tokens: u64,
325    pub avg_latency: Duration,
326    pub active_streams: u32,
327}
328
329// ============================================================================
330// Middleware Traits (Tower-style composability)
331// ============================================================================
332
333/// Middleware that can wrap a gateway
334pub trait GatewayMiddleware: Send + Sync {
335    /// Process request before routing
336    fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()>;
337
338    /// Process response after inference
339    fn after_infer(
340        &self,
341        request: &InferenceRequest,
342        response: &mut InferenceResponse,
343    ) -> FederationResult<()>;
344
345    /// Handle errors
346    fn on_error(&self, request: &InferenceRequest, error: &FederationError);
347}
348
349/// Circuit breaker for fault tolerance
350pub trait CircuitBreakerTrait: Send + Sync {
351    /// Check if circuit is open (should skip this node)
352    fn is_open(&self, node_id: &NodeId) -> bool;
353
354    /// Record a success
355    fn record_success(&self, node_id: &NodeId);
356
357    /// Record a failure
358    fn record_failure(&self, node_id: &NodeId);
359
360    /// Get circuit state
361    fn state(&self, node_id: &NodeId) -> CircuitState;
362}
363
364/// Circuit breaker state
365#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum CircuitState {
367    /// Normal operation
368    Closed,
369    /// Failing, allowing probe requests
370    HalfOpen,
371    /// Failing, blocking all requests
372    Open,
373}
374
375// ============================================================================
376// Policy Traits
377// ============================================================================
378
379/// Routing policy that influences node selection
380pub trait RoutingPolicyTrait: Send + Sync {
381    /// Score a candidate node (higher = better)
382    fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64;
383
384    /// Check if a candidate is eligible
385    fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool;
386
387    /// Get policy name for logging
388    fn name(&self) -> &'static str;
389}
390
391/// Load balancing strategy
392#[derive(Debug, Clone, Copy, Default)]
393pub enum LoadBalanceStrategy {
394    /// Round-robin across healthy nodes
395    RoundRobin,
396    /// Route to least loaded node
397    LeastConnections,
398    /// Route based on latency
399    #[default]
400    LeastLatency,
401    /// Weighted random
402    WeightedRandom,
403    /// Consistent hashing (sticky sessions)
404    ConsistentHash,
405}
406
407// ============================================================================
408// Builder Pattern for Configuration
409// ============================================================================
410
411/// Builder for creating federation gateways
412#[derive(Default)]
413pub struct FederationBuilder {
414    pub catalog: Option<Box<dyn ModelCatalogTrait>>,
415    pub health_checker: Option<Box<dyn HealthCheckerTrait>>,
416    pub router: Option<Box<dyn RouterTrait>>,
417    pub policies: Vec<Box<dyn RoutingPolicyTrait>>,
418    pub middlewares: Vec<Box<dyn GatewayMiddleware>>,
419    pub load_balance: LoadBalanceStrategy,
420}
421
422impl FederationBuilder {
423    pub fn new() -> Self {
424        Self {
425            load_balance: LoadBalanceStrategy::LeastLatency,
426            ..Default::default()
427        }
428    }
429
430    #[must_use]
431    pub fn with_catalog(mut self, catalog: impl ModelCatalogTrait + 'static) -> Self {
432        self.catalog = Some(Box::new(catalog));
433        self
434    }
435
436    #[must_use]
437    pub fn with_health_checker(mut self, checker: impl HealthCheckerTrait + 'static) -> Self {
438        self.health_checker = Some(Box::new(checker));
439        self
440    }
441
442    #[must_use]
443    pub fn with_router(mut self, router: impl RouterTrait + 'static) -> Self {
444        self.router = Some(Box::new(router));
445        self
446    }
447
448    #[must_use]
449    pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
450        self.policies.push(Box::new(policy));
451        self
452    }
453
454    #[must_use]
455    pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
456        self.middlewares.push(Box::new(middleware));
457        self
458    }
459
460    #[must_use]
461    pub fn with_load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
462        self.load_balance = strategy;
463        self
464    }
465}
466
467// ============================================================================
468// Tests
469// ============================================================================
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_qos_default() {
477        let qos = QoSRequirements::default();
478        assert_eq!(qos.privacy, PrivacyLevel::Internal);
479        assert!(qos.prefer_gpu);
480        assert_eq!(qos.cost_tolerance, 50);
481    }
482
483    #[test]
484    fn test_privacy_ordering() {
485        assert!(PrivacyLevel::Public < PrivacyLevel::Internal);
486        assert!(PrivacyLevel::Internal < PrivacyLevel::Confidential);
487        assert!(PrivacyLevel::Confidential < PrivacyLevel::Restricted);
488    }
489
490    #[test]
491    fn test_health_state() {
492        let healthy = HealthState::Healthy;
493        let degraded = HealthState::Degraded;
494        assert_ne!(healthy, degraded);
495    }
496
497    #[test]
498    fn test_circuit_state() {
499        let closed = CircuitState::Closed;
500        let open = CircuitState::Open;
501        assert_ne!(closed, open);
502    }
503
504    #[test]
505    fn test_federation_builder() {
506        let builder =
507            FederationBuilder::new().with_load_balance(LoadBalanceStrategy::LeastConnections);
508
509        assert!(matches!(
510            builder.load_balance,
511            LoadBalanceStrategy::LeastConnections
512        ));
513    }
514
515    #[test]
516    fn test_model_id_equality() {
517        let id1 = ModelId("whisper-v3".to_string());
518        let id2 = ModelId("whisper-v3".to_string());
519        let id3 = ModelId("llama-7b".to_string());
520
521        assert_eq!(id1, id2);
522        assert_ne!(id1, id3);
523    }
524
525    #[test]
526    fn test_capability_variants() {
527        let cap1 = Capability::Transcribe;
528        let cap2 = Capability::Custom("sentiment".to_string());
529
530        assert_ne!(cap1, cap2);
531        assert_eq!(cap1, Capability::Transcribe);
532    }
533}