1use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Hash, Eq, PartialEq)]
16pub struct ModelId(pub String);
17
18impl std::fmt::Display for ModelId {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{}", self.0)
21 }
22}
23
24#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26pub struct RegionId(pub String);
27
28impl std::fmt::Display for RegionId {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "{}", self.0)
31 }
32}
33
34#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct NodeId(pub String);
37
38impl std::fmt::Display for NodeId {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{}", self.0)
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum Capability {
47 Transcribe,
49 Synthesize,
51 Generate,
53 Code,
55 Embed,
57 ImageGen,
59 Custom(String),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum PrivacyLevel {
66 Public = 0,
68 Internal = 1,
70 Confidential = 2,
72 Restricted = 3,
74}
75
76#[derive(Debug, Clone)]
78pub struct QoSRequirements {
79 pub max_latency: Option<Duration>,
81 pub min_throughput: Option<u32>,
83 pub privacy: PrivacyLevel,
85 pub prefer_gpu: bool,
87 pub cost_tolerance: u8,
89}
90
91impl Default for QoSRequirements {
92 fn default() -> Self {
93 Self {
94 max_latency: None,
95 min_throughput: None,
96 privacy: PrivacyLevel::Internal,
97 prefer_gpu: true,
98 cost_tolerance: 50,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct InferenceRequest {
106 pub capability: Capability,
108 pub input: Vec<u8>,
110 pub qos: QoSRequirements,
112 pub request_id: String,
114 pub tenant_id: Option<String>,
116}
117
118#[derive(Debug)]
120pub struct InferenceResponse {
121 pub output: Vec<u8>,
123 pub served_by: NodeId,
125 pub latency: Duration,
127 pub tokens: Option<u32>,
129}
130
131#[derive(Debug, thiserror::Error)]
133pub enum FederationError {
134 #[error("No nodes available for capability: {0:?}")]
135 NoCapacity(Capability),
136
137 #[error("All nodes unhealthy for capability: {0:?}")]
138 AllNodesUnhealthy(Capability),
139
140 #[error("QoS requirements cannot be met: {0}")]
141 QoSViolation(String),
142
143 #[error("Privacy policy violation: {0}")]
144 PrivacyViolation(String),
145
146 #[error("Node unreachable: {0}")]
147 NodeUnreachable(NodeId),
148
149 #[error("Timeout after {0:?}")]
150 Timeout(Duration),
151
152 #[error("Circuit breaker open for node: {0}")]
153 CircuitOpen(NodeId),
154
155 #[error("Internal error: {0}")]
156 Internal(String),
157}
158
159pub type FederationResult<T> = Result<T, FederationError>;
165
166pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
168
169pub trait ModelCatalogTrait: Send + Sync {
171 fn register(
173 &self,
174 model_id: ModelId,
175 node_id: NodeId,
176 region_id: RegionId,
177 capabilities: Vec<Capability>,
178 ) -> BoxFuture<'_, FederationResult<()>>;
179
180 fn deregister(&self, model_id: ModelId, node_id: NodeId)
182 -> BoxFuture<'_, FederationResult<()>>;
183
184 fn find_by_capability(
186 &self,
187 capability: &Capability,
188 ) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>>;
189
190 fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>>;
192
193 fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>>;
195}
196
197#[derive(Debug, Clone)]
199pub struct ModelMetadata {
200 pub model_id: ModelId,
201 pub name: String,
202 pub version: String,
203 pub capabilities: Vec<Capability>,
204 pub parameters: u64,
205 pub quantization: Option<String>,
206}
207
208pub trait HealthCheckerTrait: Send + Sync {
210 fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>>;
212
213 fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth>;
215
216 fn start_monitoring(&self, interval: Duration) -> BoxFuture<'_, ()>;
218
219 fn stop_monitoring(&self) -> BoxFuture<'_, ()>;
221}
222
223#[derive(Debug, Clone)]
225pub struct NodeHealth {
226 pub node_id: NodeId,
227 pub status: HealthState,
228 pub latency_p50: Duration,
229 pub latency_p99: Duration,
230 pub throughput: u32,
231 pub gpu_utilization: Option<f32>,
232 pub queue_depth: u32,
233 pub last_check: std::time::Instant,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum HealthState {
239 Healthy,
241 Degraded,
243 Unhealthy,
245 Unknown,
247}
248
249pub trait RouterTrait: Send + Sync {
251 fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>>;
253
254 fn get_candidates(
256 &self,
257 request: &InferenceRequest,
258 ) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>>;
259}
260
261#[derive(Debug, Clone)]
263pub struct RouteTarget {
264 pub node_id: NodeId,
265 pub region_id: RegionId,
266 pub endpoint: String,
267 pub estimated_latency: Duration,
268 pub score: f64,
269}
270
271#[derive(Debug, Clone)]
273pub struct RouteCandidate {
274 pub target: RouteTarget,
275 pub scores: RouteScores,
276 pub eligible: bool,
277 pub rejection_reason: Option<String>,
278}
279
280#[derive(Debug, Clone)]
282pub struct RouteScores {
283 pub latency_score: f64,
284 pub throughput_score: f64,
285 pub cost_score: f64,
286 pub locality_score: f64,
287 pub health_score: f64,
288 pub total: f64,
289}
290
291impl Default for RouteScores {
292 fn default() -> Self {
293 Self {
294 latency_score: 0.5,
295 throughput_score: 0.5,
296 cost_score: 0.5,
297 locality_score: 0.5,
298 health_score: 1.0,
299 total: 0.5,
300 }
301 }
302}
303
304pub trait GatewayTrait: Send + Sync {
306 fn infer(
308 &self,
309 request: InferenceRequest,
310 ) -> BoxFuture<'_, FederationResult<InferenceResponse>>;
311
312 fn infer_stream(
314 &self,
315 request: InferenceRequest,
316 ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>>;
317
318 fn stats(&self) -> GatewayStats;
320}
321
322pub trait TokenStream: Send {
324 fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>>;
326
327 fn cancel(&mut self) -> BoxFuture<'_, ()>;
329}
330
331#[derive(Debug, Clone, Default)]
333pub struct GatewayStats {
334 pub total_requests: u64,
335 pub successful_requests: u64,
336 pub failed_requests: u64,
337 pub total_tokens: u64,
338 pub avg_latency: Duration,
339 pub active_streams: u32,
340}
341
342pub trait GatewayMiddleware: Send + Sync {
348 fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()>;
350
351 fn after_infer(
353 &self,
354 request: &InferenceRequest,
355 response: &mut InferenceResponse,
356 ) -> FederationResult<()>;
357
358 fn on_error(&self, request: &InferenceRequest, error: &FederationError);
360}
361
362pub trait CircuitBreakerTrait: Send + Sync {
364 fn is_open(&self, node_id: &NodeId) -> bool;
366
367 fn record_success(&self, node_id: &NodeId);
369
370 fn record_failure(&self, node_id: &NodeId);
372
373 fn state(&self, node_id: &NodeId) -> CircuitState;
375}
376
377#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum CircuitState {
380 Closed,
382 HalfOpen,
384 Open,
386}
387
388pub trait RoutingPolicyTrait: Send + Sync {
394 fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64;
396
397 fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool;
399
400 fn name(&self) -> &'static str;
402}
403
404#[derive(Debug, Clone, Copy, Default)]
406pub enum LoadBalanceStrategy {
407 RoundRobin,
409 LeastConnections,
411 #[default]
413 LeastLatency,
414 WeightedRandom,
416 ConsistentHash,
418}
419
420#[derive(Default)]
426pub struct FederationBuilder {
427 pub catalog: Option<Box<dyn ModelCatalogTrait>>,
428 pub health_checker: Option<Box<dyn HealthCheckerTrait>>,
429 pub router: Option<Box<dyn RouterTrait>>,
430 pub policies: Vec<Box<dyn RoutingPolicyTrait>>,
431 pub middlewares: Vec<Box<dyn GatewayMiddleware>>,
432 pub load_balance: LoadBalanceStrategy,
433}
434
435impl FederationBuilder {
436 pub fn new() -> Self {
437 Self {
438 load_balance: LoadBalanceStrategy::LeastLatency,
439 ..Default::default()
440 }
441 }
442
443 #[must_use]
444 pub fn with_catalog(mut self, catalog: impl ModelCatalogTrait + 'static) -> Self {
445 self.catalog = Some(Box::new(catalog));
446 self
447 }
448
449 #[must_use]
450 pub fn with_health_checker(mut self, checker: impl HealthCheckerTrait + 'static) -> Self {
451 self.health_checker = Some(Box::new(checker));
452 self
453 }
454
455 #[must_use]
456 pub fn with_router(mut self, router: impl RouterTrait + 'static) -> Self {
457 self.router = Some(Box::new(router));
458 self
459 }
460
461 #[must_use]
462 pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
463 self.policies.push(Box::new(policy));
464 self
465 }
466
467 #[must_use]
468 pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
469 self.middlewares.push(Box::new(middleware));
470 self
471 }
472
473 #[must_use]
474 pub fn with_load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
475 self.load_balance = strategy;
476 self
477 }
478}
479
480#[cfg(test)]
485#[path = "traits_tests.rs"]
486mod tests;