1use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16pub struct ModelId(pub String);
17
18impl std::fmt::Display for ModelId {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{}", self.0)
21 }
22}
23
24#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26pub struct RegionId(pub String);
27
28impl std::fmt::Display for RegionId {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "{}", self.0)
31 }
32}
33
34#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct NodeId(pub String);
37
38impl std::fmt::Display for NodeId {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{}", self.0)
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum Capability {
47 Transcribe,
49 Synthesize,
51 Generate,
53 Code,
55 Embed,
57 ImageGen,
59 Custom(String),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum PrivacyLevel {
66 Public = 0,
68 Internal = 1,
70 Confidential = 2,
72 Restricted = 3,
74}
75
76#[derive(Debug, Clone)]
78pub struct QoSRequirements {
79 pub max_latency: Option<Duration>,
81 pub min_throughput: Option<u32>,
83 pub privacy: PrivacyLevel,
85 pub prefer_gpu: bool,
87 pub cost_tolerance: u8,
89}
90
91impl Default for QoSRequirements {
92 fn default() -> Self {
93 Self {
94 max_latency: None,
95 min_throughput: None,
96 privacy: PrivacyLevel::Internal,
97 prefer_gpu: true,
98 cost_tolerance: 50,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct InferenceRequest {
106 pub capability: Capability,
108 pub input: Vec<u8>,
110 pub qos: QoSRequirements,
112 pub request_id: String,
114 pub tenant_id: Option<String>,
116}
117
118#[derive(Debug)]
120pub struct InferenceResponse {
121 pub output: Vec<u8>,
123 pub served_by: NodeId,
125 pub latency: Duration,
127 pub tokens: Option<u32>,
129}
130
131#[derive(Debug, thiserror::Error)]
133pub enum FederationError {
134 #[error("No nodes available for capability: {0:?}")]
135 NoCapacity(Capability),
136
137 #[error("All nodes unhealthy for capability: {0:?}")]
138 AllNodesUnhealthy(Capability),
139
140 #[error("QoS requirements cannot be met: {0}")]
141 QoSViolation(String),
142
143 #[error("Privacy policy violation: {0}")]
144 PrivacyViolation(String),
145
146 #[error("Node unreachable: {0}")]
147 NodeUnreachable(NodeId),
148
149 #[error("Timeout after {0:?}")]
150 Timeout(Duration),
151
152 #[error("Circuit breaker open for node: {0}")]
153 CircuitOpen(NodeId),
154
155 #[error("Internal error: {0}")]
156 Internal(String),
157}
158
159pub type FederationResult<T> = Result<T, FederationError>;
165
166pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
168
169pub trait ModelCatalogTrait: Send + Sync {
171 fn register(
173 &self,
174 model_id: ModelId,
175 node_id: NodeId,
176 region_id: RegionId,
177 capabilities: Vec<Capability>,
178 ) -> BoxFuture<'_, FederationResult<()>>;
179
180 fn deregister(&self, model_id: ModelId, node_id: NodeId)
182 -> BoxFuture<'_, FederationResult<()>>;
183
184 fn find_by_capability(
186 &self,
187 capability: &Capability,
188 ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>>;
189
190 fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>>;
192
193 fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>>;
195}
196
197#[derive(Debug, Clone)]
199pub struct ModelMetadata {
200 pub model_id: ModelId,
201 pub name: String,
202 pub version: String,
203 pub capabilities: Vec<Capability>,
204 pub parameters: u64,
205 pub quantization: Option<String>,
206}
207
208pub trait HealthCheckerTrait: Send + Sync {
210 fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>>;
212
213 fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth>;
215
216 fn start_monitoring(&self, interval: Duration) -> BoxFuture<'_, ()>;
218
219 fn stop_monitoring(&self) -> BoxFuture<'_, ()>;
221}
222
223#[derive(Debug, Clone)]
225pub struct NodeHealth {
226 pub node_id: NodeId,
227 pub status: HealthState,
228 pub latency_p50: Duration,
229 pub latency_p99: Duration,
230 pub throughput: u32,
231 pub gpu_utilization: Option<f32>,
232 pub queue_depth: u32,
233 pub last_check: std::time::Instant,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum HealthState {
239 Healthy,
241 Degraded,
243 Unhealthy,
245 Unknown,
247}
248
249pub trait RouterTrait: Send + Sync {
251 fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>>;
253
254 fn get_candidates(
256 &self,
257 request: &InferenceRequest,
258 ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>>;
259}
260
261#[derive(Debug, Clone)]
263pub struct RouteTarget {
264 pub node_id: NodeId,
265 pub region_id: RegionId,
266 pub endpoint: String,
267 pub estimated_latency: Duration,
268 pub score: f64,
269}
270
271#[derive(Debug, Clone)]
273pub struct RouteCandidate {
274 pub target: RouteTarget,
275 pub scores: RouteScores,
276 pub eligible: bool,
277 pub rejection_reason: Option<String>,
278}
279
280#[derive(Debug, Clone)]
282pub struct RouteScores {
283 pub latency_score: f64,
284 pub throughput_score: f64,
285 pub cost_score: f64,
286 pub locality_score: f64,
287 pub health_score: f64,
288 pub total: f64,
289}
290
291pub trait GatewayTrait: Send + Sync {
293 fn infer(
295 &self,
296 request: InferenceRequest,
297 ) -> BoxFuture<'_, FederationResult<InferenceResponse>>;
298
299 fn infer_stream(
301 &self,
302 request: InferenceRequest,
303 ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>>;
304
305 fn stats(&self) -> GatewayStats;
307}
308
309pub trait TokenStream: Send {
311 fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>>;
313
314 fn cancel(&mut self) -> BoxFuture<'_, ()>;
316}
317
318#[derive(Debug, Clone, Default)]
320pub struct GatewayStats {
321 pub total_requests: u64,
322 pub successful_requests: u64,
323 pub failed_requests: u64,
324 pub total_tokens: u64,
325 pub avg_latency: Duration,
326 pub active_streams: u32,
327}
328
329pub trait GatewayMiddleware: Send + Sync {
335 fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()>;
337
338 fn after_infer(
340 &self,
341 request: &InferenceRequest,
342 response: &mut InferenceResponse,
343 ) -> FederationResult<()>;
344
345 fn on_error(&self, request: &InferenceRequest, error: &FederationError);
347}
348
349pub trait CircuitBreakerTrait: Send + Sync {
351 fn is_open(&self, node_id: &NodeId) -> bool;
353
354 fn record_success(&self, node_id: &NodeId);
356
357 fn record_failure(&self, node_id: &NodeId);
359
360 fn state(&self, node_id: &NodeId) -> CircuitState;
362}
363
364#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum CircuitState {
367 Closed,
369 HalfOpen,
371 Open,
373}
374
375pub trait RoutingPolicyTrait: Send + Sync {
381 fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64;
383
384 fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool;
386
387 fn name(&self) -> &'static str;
389}
390
391#[derive(Debug, Clone, Copy, Default)]
393pub enum LoadBalanceStrategy {
394 RoundRobin,
396 LeastConnections,
398 #[default]
400 LeastLatency,
401 WeightedRandom,
403 ConsistentHash,
405}
406
407#[derive(Default)]
413pub struct FederationBuilder {
414 pub catalog: Option<Box<dyn ModelCatalogTrait>>,
415 pub health_checker: Option<Box<dyn HealthCheckerTrait>>,
416 pub router: Option<Box<dyn RouterTrait>>,
417 pub policies: Vec<Box<dyn RoutingPolicyTrait>>,
418 pub middlewares: Vec<Box<dyn GatewayMiddleware>>,
419 pub load_balance: LoadBalanceStrategy,
420}
421
422impl FederationBuilder {
423 pub fn new() -> Self {
424 Self {
425 load_balance: LoadBalanceStrategy::LeastLatency,
426 ..Default::default()
427 }
428 }
429
430 #[must_use]
431 pub fn with_catalog(mut self, catalog: impl ModelCatalogTrait + 'static) -> Self {
432 self.catalog = Some(Box::new(catalog));
433 self
434 }
435
436 #[must_use]
437 pub fn with_health_checker(mut self, checker: impl HealthCheckerTrait + 'static) -> Self {
438 self.health_checker = Some(Box::new(checker));
439 self
440 }
441
442 #[must_use]
443 pub fn with_router(mut self, router: impl RouterTrait + 'static) -> Self {
444 self.router = Some(Box::new(router));
445 self
446 }
447
448 #[must_use]
449 pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
450 self.policies.push(Box::new(policy));
451 self
452 }
453
454 #[must_use]
455 pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
456 self.middlewares.push(Box::new(middleware));
457 self
458 }
459
460 #[must_use]
461 pub fn with_load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
462 self.load_balance = strategy;
463 self
464 }
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_qos_default() {
477 let qos = QoSRequirements::default();
478 assert_eq!(qos.privacy, PrivacyLevel::Internal);
479 assert!(qos.prefer_gpu);
480 assert_eq!(qos.cost_tolerance, 50);
481 }
482
483 #[test]
484 fn test_privacy_ordering() {
485 assert!(PrivacyLevel::Public < PrivacyLevel::Internal);
486 assert!(PrivacyLevel::Internal < PrivacyLevel::Confidential);
487 assert!(PrivacyLevel::Confidential < PrivacyLevel::Restricted);
488 }
489
490 #[test]
491 fn test_health_state() {
492 let healthy = HealthState::Healthy;
493 let degraded = HealthState::Degraded;
494 assert_ne!(healthy, degraded);
495 }
496
497 #[test]
498 fn test_circuit_state() {
499 let closed = CircuitState::Closed;
500 let open = CircuitState::Open;
501 assert_ne!(closed, open);
502 }
503
504 #[test]
505 fn test_federation_builder() {
506 let builder =
507 FederationBuilder::new().with_load_balance(LoadBalanceStrategy::LeastConnections);
508
509 assert!(matches!(
510 builder.load_balance,
511 LoadBalanceStrategy::LeastConnections
512 ));
513 }
514
515 #[test]
516 fn test_model_id_equality() {
517 let id1 = ModelId("whisper-v3".to_string());
518 let id2 = ModelId("whisper-v3".to_string());
519 let id3 = ModelId("llama-7b".to_string());
520
521 assert_eq!(id1, id2);
522 assert_ne!(id1, id3);
523 }
524
525 #[test]
526 fn test_capability_variants() {
527 let cap1 = Capability::Transcribe;
528 let cap2 = Capability::Custom("sentiment".to_string());
529
530 assert_ne!(cap1, cap2);
531 assert_eq!(cap1, Capability::Transcribe);
532 }
533}