Skip to main content

apr_cli/federation/
traits.rs

1//! Core trait definitions for APR Federation
2//!
3//! These traits define the contract for federation components.
4//! Implementations can be swapped for different backends (NATS, Redis, etcd, etc.)
5
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9
10// ============================================================================
11// Core Types
12// ============================================================================
13
14/// Unique identifier for a model instance in the federation
15#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16pub struct ModelId(pub String);
17
18impl std::fmt::Display for ModelId {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{}", self.0)
21    }
22}
23
24/// Unique identifier for a region/cluster
25#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26pub struct RegionId(pub String);
27
28impl std::fmt::Display for RegionId {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "{}", self.0)
31    }
32}
33
34/// Unique identifier for a node within a region
35#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct NodeId(pub String);
37
38impl std::fmt::Display for NodeId {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}", self.0)
41    }
42}
43
44/// Model capabilities that can be queried
45#[derive(Debug, Clone, PartialEq)]
46pub enum Capability {
47    /// Automatic speech recognition
48    Transcribe,
49    /// Text-to-speech
50    Synthesize,
51    /// Text generation (LLM)
52    Generate,
53    /// Code generation
54    Code,
55    /// Embeddings
56    Embed,
57    /// Image generation
58    ImageGen,
59    /// Custom capability
60    Custom(String),
61}
62
63/// Privacy/compliance level for data routing
64#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum PrivacyLevel {
66    /// Public data, can route anywhere
67    Public = 0,
68    /// Internal only, keep within org
69    Internal = 1,
70    /// Confidential, specific regions only
71    Confidential = 2,
72    /// Restricted, on-prem only
73    Restricted = 3,
74}
75
76/// Quality of Service requirements
77#[derive(Debug, Clone)]
78pub struct QoSRequirements {
79    /// Maximum acceptable latency
80    pub max_latency: Option<Duration>,
81    /// Minimum throughput (tokens/sec)
82    pub min_throughput: Option<u32>,
83    /// Privacy level required
84    pub privacy: PrivacyLevel,
85    /// Prefer GPU acceleration
86    pub prefer_gpu: bool,
87    /// Cost tier (0 = cheapest, 100 = fastest)
88    pub cost_tolerance: u8,
89}
90
91impl Default for QoSRequirements {
92    fn default() -> Self {
93        Self {
94            max_latency: None,
95            min_throughput: None,
96            privacy: PrivacyLevel::Internal,
97            prefer_gpu: true,
98            cost_tolerance: 50,
99        }
100    }
101}
102
103/// Inference request metadata
104#[derive(Debug, Clone)]
105pub struct InferenceRequest {
106    /// Requested capability
107    pub capability: Capability,
108    /// Input data (opaque bytes)
109    pub input: Vec<u8>,
110    /// QoS requirements
111    pub qos: QoSRequirements,
112    /// Request ID for tracing
113    pub request_id: String,
114    /// User/tenant ID
115    pub tenant_id: Option<String>,
116}
117
118/// Inference response
119#[derive(Debug)]
120pub struct InferenceResponse {
121    /// Output data
122    pub output: Vec<u8>,
123    /// Which node handled the request
124    pub served_by: NodeId,
125    /// Actual latency
126    pub latency: Duration,
127    /// Tokens generated (if applicable)
128    pub tokens: Option<u32>,
129}
130
131/// Error types for federation operations
132#[derive(Debug, thiserror::Error)]
133pub enum FederationError {
134    #[error("No nodes available for capability: {0:?}")]
135    NoCapacity(Capability),
136
137    #[error("All nodes unhealthy for capability: {0:?}")]
138    AllNodesUnhealthy(Capability),
139
140    #[error("QoS requirements cannot be met: {0}")]
141    QoSViolation(String),
142
143    #[error("Privacy policy violation: {0}")]
144    PrivacyViolation(String),
145
146    #[error("Node unreachable: {0}")]
147    NodeUnreachable(NodeId),
148
149    #[error("Timeout after {0:?}")]
150    Timeout(Duration),
151
152    #[error("Circuit breaker open for node: {0}")]
153    CircuitOpen(NodeId),
154
155    #[error("Internal error: {0}")]
156    Internal(String),
157}
158
159// ============================================================================
160// Core Traits
161// ============================================================================
162
163/// Result type alias for federation operations
164pub type FederationResult<T> = Result<T, FederationError>;
165
166/// Boxed future for async trait methods
167pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
168
169/// Model catalog - tracks available models across the federation
170pub trait ModelCatalogTrait: Send + Sync {
171    /// Register a model instance
172    fn register(
173        &self,
174        model_id: ModelId,
175        node_id: NodeId,
176        region_id: RegionId,
177        capabilities: Vec<Capability>,
178    ) -> BoxFuture<'_, FederationResult<()>>;
179
180    /// Deregister a model instance
181    fn deregister(&self, model_id: ModelId, node_id: NodeId)
182        -> BoxFuture<'_, FederationResult<()>>;
183
184    /// Find nodes with a specific capability
185    fn find_by_capability(
186        &self,
187        capability: &Capability,
188    ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>>;
189
190    /// List all registered models
191    fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>>;
192
193    /// Get model metadata
194    fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>>;
195}
196
197/// Model metadata stored in catalog
198#[derive(Debug, Clone)]
199pub struct ModelMetadata {
200    pub model_id: ModelId,
201    pub name: String,
202    pub version: String,
203    pub capabilities: Vec<Capability>,
204    pub parameters: u64,
205    pub quantization: Option<String>,
206}
207
208/// Health checker - monitors node health across federation
209pub trait HealthCheckerTrait: Send + Sync {
210    /// Check health of a specific node
211    fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>>;
212
213    /// Get cached health status (non-blocking)
214    fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth>;
215
216    /// Start background health monitoring
217    fn start_monitoring(&self, interval: Duration) -> BoxFuture<'_, ()>;
218
219    /// Stop health monitoring
220    fn stop_monitoring(&self) -> BoxFuture<'_, ()>;
221}
222
223/// Node health information
224#[derive(Debug, Clone)]
225pub struct NodeHealth {
226    pub node_id: NodeId,
227    pub status: HealthState,
228    pub latency_p50: Duration,
229    pub latency_p99: Duration,
230    pub throughput: u32,
231    pub gpu_utilization: Option<f32>,
232    pub queue_depth: u32,
233    pub last_check: std::time::Instant,
234}
235
236/// Health state enum
237#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum HealthState {
239    /// Node is healthy and accepting requests
240    Healthy,
241    /// Node is degraded but functional
242    Degraded,
243    /// Node is unhealthy, avoid routing
244    Unhealthy,
245    /// Node health is unknown
246    Unknown,
247}
248
249/// Router - selects the best node for a request
250pub trait RouterTrait: Send + Sync {
251    /// Route a request to the best available node
252    fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>>;
253
254    /// Get all possible routes for a request (for debugging/transparency)
255    fn get_candidates(
256        &self,
257        request: &InferenceRequest,
258    ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>>;
259}
260
261/// Selected route target
262#[derive(Debug, Clone)]
263pub struct RouteTarget {
264    pub node_id: NodeId,
265    pub region_id: RegionId,
266    pub endpoint: String,
267    pub estimated_latency: Duration,
268    pub score: f64,
269}
270
271/// Route candidate with scoring details
272#[derive(Debug, Clone)]
273pub struct RouteCandidate {
274    pub target: RouteTarget,
275    pub scores: RouteScores,
276    pub eligible: bool,
277    pub rejection_reason: Option<String>,
278}
279
280/// Breakdown of routing scores
281#[derive(Debug, Clone)]
282pub struct RouteScores {
283    pub latency_score: f64,
284    pub throughput_score: f64,
285    pub cost_score: f64,
286    pub locality_score: f64,
287    pub health_score: f64,
288    pub total: f64,
289}
290
291impl Default for RouteScores {
292    fn default() -> Self {
293        Self {
294            latency_score: 0.5,
295            throughput_score: 0.5,
296            cost_score: 0.5,
297            locality_score: 0.5,
298            health_score: 1.0,
299            total: 0.5,
300        }
301    }
302}
303
304/// Gateway - the main entry point for federation requests
305pub trait GatewayTrait: Send + Sync {
306    /// Execute an inference request through the federation
307    fn infer(
308        &self,
309        request: InferenceRequest,
310    ) -> BoxFuture<'_, FederationResult<InferenceResponse>>;
311
312    /// Execute with streaming response
313    fn infer_stream(
314        &self,
315        request: InferenceRequest,
316    ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>>;
317
318    /// Get gateway statistics
319    fn stats(&self) -> GatewayStats;
320}
321
322/// Streaming token interface
323pub trait TokenStream: Send {
324    /// Get next token (None = stream complete)
325    fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>>;
326
327    /// Cancel the stream
328    fn cancel(&mut self) -> BoxFuture<'_, ()>;
329}
330
331/// Gateway statistics
332#[derive(Debug, Clone, Default)]
333pub struct GatewayStats {
334    pub total_requests: u64,
335    pub successful_requests: u64,
336    pub failed_requests: u64,
337    pub total_tokens: u64,
338    pub avg_latency: Duration,
339    pub active_streams: u32,
340}
341
342// ============================================================================
343// Middleware Traits (Tower-style composability)
344// ============================================================================
345
346/// Middleware that can wrap a gateway
347pub trait GatewayMiddleware: Send + Sync {
348    /// Process request before routing
349    fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()>;
350
351    /// Process response after inference
352    fn after_infer(
353        &self,
354        request: &InferenceRequest,
355        response: &mut InferenceResponse,
356    ) -> FederationResult<()>;
357
358    /// Handle errors
359    fn on_error(&self, request: &InferenceRequest, error: &FederationError);
360}
361
362/// Circuit breaker for fault tolerance
363pub trait CircuitBreakerTrait: Send + Sync {
364    /// Check if circuit is open (should skip this node)
365    fn is_open(&self, node_id: &NodeId) -> bool;
366
367    /// Record a success
368    fn record_success(&self, node_id: &NodeId);
369
370    /// Record a failure
371    fn record_failure(&self, node_id: &NodeId);
372
373    /// Get circuit state
374    fn state(&self, node_id: &NodeId) -> CircuitState;
375}
376
377/// Circuit breaker state
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum CircuitState {
380    /// Normal operation
381    Closed,
382    /// Failing, allowing probe requests
383    HalfOpen,
384    /// Failing, blocking all requests
385    Open,
386}
387
388// ============================================================================
389// Policy Traits
390// ============================================================================
391
392/// Routing policy that influences node selection
393pub trait RoutingPolicyTrait: Send + Sync {
394    /// Score a candidate node (higher = better)
395    fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64;
396
397    /// Check if a candidate is eligible
398    fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool;
399
400    /// Get policy name for logging
401    fn name(&self) -> &'static str;
402}
403
404/// Load balancing strategy
405#[derive(Debug, Clone, Copy, Default)]
406pub enum LoadBalanceStrategy {
407    /// Round-robin across healthy nodes
408    RoundRobin,
409    /// Route to least loaded node
410    LeastConnections,
411    /// Route based on latency
412    #[default]
413    LeastLatency,
414    /// Weighted random
415    WeightedRandom,
416    /// Consistent hashing (sticky sessions)
417    ConsistentHash,
418}
419
420// ============================================================================
421// Builder Pattern for Configuration
422// ============================================================================
423
424/// Builder for creating federation gateways
425#[derive(Default)]
426pub struct FederationBuilder {
427    pub catalog: Option<Box<dyn ModelCatalogTrait>>,
428    pub health_checker: Option<Box<dyn HealthCheckerTrait>>,
429    pub router: Option<Box<dyn RouterTrait>>,
430    pub policies: Vec<Box<dyn RoutingPolicyTrait>>,
431    pub middlewares: Vec<Box<dyn GatewayMiddleware>>,
432    pub load_balance: LoadBalanceStrategy,
433}
434
435impl FederationBuilder {
436    pub fn new() -> Self {
437        Self {
438            load_balance: LoadBalanceStrategy::LeastLatency,
439            ..Default::default()
440        }
441    }
442
443    #[must_use]
444    pub fn with_catalog(mut self, catalog: impl ModelCatalogTrait + 'static) -> Self {
445        self.catalog = Some(Box::new(catalog));
446        self
447    }
448
449    #[must_use]
450    pub fn with_health_checker(mut self, checker: impl HealthCheckerTrait + 'static) -> Self {
451        self.health_checker = Some(Box::new(checker));
452        self
453    }
454
455    #[must_use]
456    pub fn with_router(mut self, router: impl RouterTrait + 'static) -> Self {
457        self.router = Some(Box::new(router));
458        self
459    }
460
461    #[must_use]
462    pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
463        self.policies.push(Box::new(policy));
464        self
465    }
466
467    #[must_use]
468    pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
469        self.middlewares.push(Box::new(middleware));
470        self
471    }
472
473    #[must_use]
474    pub fn with_load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
475        self.load_balance = strategy;
476        self
477    }
478}
479
480// ============================================================================
481// Tests
482// ============================================================================
483
484#[cfg(test)]
485#[path = "traits_tests.rs"]
486mod tests;