use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct ModelId(pub String);
impl std::fmt::Display for ModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct RegionId(pub String);
impl std::fmt::Display for RegionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct NodeId(pub String);
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Capability {
Transcribe,
Synthesize,
Generate,
Code,
Embed,
ImageGen,
Custom(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PrivacyLevel {
Public = 0,
Internal = 1,
Confidential = 2,
Restricted = 3,
}
#[derive(Debug, Clone)]
pub struct QoSRequirements {
pub max_latency: Option<Duration>,
pub min_throughput: Option<u32>,
pub privacy: PrivacyLevel,
pub prefer_gpu: bool,
pub cost_tolerance: u8,
}
impl Default for QoSRequirements {
fn default() -> Self {
Self {
max_latency: None,
min_throughput: None,
privacy: PrivacyLevel::Internal,
prefer_gpu: true,
cost_tolerance: 50,
}
}
}
#[derive(Debug, Clone)]
pub struct InferenceRequest {
pub capability: Capability,
pub input: Vec<u8>,
pub qos: QoSRequirements,
pub request_id: String,
pub tenant_id: Option<String>,
}
#[derive(Debug)]
pub struct InferenceResponse {
pub output: Vec<u8>,
pub served_by: NodeId,
pub latency: Duration,
pub tokens: Option<u32>,
}
#[derive(Debug, thiserror::Error)]
pub enum FederationError {
#[error("No nodes available for capability: {0:?}")]
NoCapacity(Capability),
#[error("All nodes unhealthy for capability: {0:?}")]
AllNodesUnhealthy(Capability),
#[error("QoS requirements cannot be met: {0}")]
QoSViolation(String),
#[error("Privacy policy violation: {0}")]
PrivacyViolation(String),
#[error("Node unreachable: {0}")]
NodeUnreachable(NodeId),
#[error("Timeout after {0:?}")]
Timeout(Duration),
#[error("Circuit breaker open for node: {0}")]
CircuitOpen(NodeId),
#[error("Internal error: {0}")]
Internal(String),
}
pub type FederationResult<T> = Result<T, FederationError>;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub trait ModelCatalogTrait: Send + Sync {
fn register(
&self,
model_id: ModelId,
node_id: NodeId,
region_id: RegionId,
capabilities: Vec<Capability>,
) -> BoxFuture<'_, FederationResult<()>>;
fn deregister(&self, model_id: ModelId, node_id: NodeId)
-> BoxFuture<'_, FederationResult<()>>;
fn find_by_capability(
&self,
capability: &Capability,
) -> BoxFuture<'_, FederationResult<Vec<(NodeId, RegionId)>>>;
fn list_all(&self) -> BoxFuture<'_, FederationResult<Vec<ModelId>>>;
fn get_metadata(&self, model_id: &ModelId) -> BoxFuture<'_, FederationResult<ModelMetadata>>;
}
#[derive(Debug, Clone)]
pub struct ModelMetadata {
pub model_id: ModelId,
pub name: String,
pub version: String,
pub capabilities: Vec<Capability>,
pub parameters: u64,
pub quantization: Option<String>,
}
pub trait HealthCheckerTrait: Send + Sync {
fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>>;
fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth>;
fn start_monitoring(&self, interval: Duration) -> BoxFuture<'_, ()>;
fn stop_monitoring(&self) -> BoxFuture<'_, ()>;
}
#[derive(Debug, Clone)]
pub struct NodeHealth {
pub node_id: NodeId,
pub status: HealthState,
pub latency_p50: Duration,
pub latency_p99: Duration,
pub throughput: u32,
pub gpu_utilization: Option<f32>,
pub queue_depth: u32,
pub last_check: std::time::Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HealthState {
Healthy,
Degraded,
Unhealthy,
Unknown,
}
pub trait RouterTrait: Send + Sync {
fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>>;
fn get_candidates(
&self,
request: &InferenceRequest,
) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>>;
}
#[derive(Debug, Clone)]
pub struct RouteTarget {
pub node_id: NodeId,
pub region_id: RegionId,
pub endpoint: String,
pub estimated_latency: Duration,
pub score: f64,
}
#[derive(Debug, Clone)]
pub struct RouteCandidate {
pub target: RouteTarget,
pub scores: RouteScores,
pub eligible: bool,
pub rejection_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct RouteScores {
pub latency_score: f64,
pub throughput_score: f64,
pub cost_score: f64,
pub locality_score: f64,
pub health_score: f64,
pub total: f64,
}
impl Default for RouteScores {
fn default() -> Self {
Self {
latency_score: 0.5,
throughput_score: 0.5,
cost_score: 0.5,
locality_score: 0.5,
health_score: 1.0,
total: 0.5,
}
}
}
pub trait GatewayTrait: Send + Sync {
fn infer(
&self,
request: InferenceRequest,
) -> BoxFuture<'_, FederationResult<InferenceResponse>>;
fn infer_stream(
&self,
request: InferenceRequest,
) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>>;
fn stats(&self) -> GatewayStats;
}
pub trait TokenStream: Send {
fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>>;
fn cancel(&mut self) -> BoxFuture<'_, ()>;
}
#[derive(Debug, Clone, Default)]
pub struct GatewayStats {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub total_tokens: u64,
pub avg_latency: Duration,
pub active_streams: u32,
}
pub trait GatewayMiddleware: Send + Sync {
fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()>;
fn after_infer(
&self,
request: &InferenceRequest,
response: &mut InferenceResponse,
) -> FederationResult<()>;
fn on_error(&self, request: &InferenceRequest, error: &FederationError);
}
pub trait CircuitBreakerTrait: Send + Sync {
fn is_open(&self, node_id: &NodeId) -> bool;
fn record_success(&self, node_id: &NodeId);
fn record_failure(&self, node_id: &NodeId);
fn state(&self, node_id: &NodeId) -> CircuitState;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
HalfOpen,
Open,
}
pub trait RoutingPolicyTrait: Send + Sync {
fn score(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> f64;
fn is_eligible(&self, candidate: &RouteCandidate, request: &InferenceRequest) -> bool;
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone, Copy, Default)]
pub enum LoadBalanceStrategy {
RoundRobin,
LeastConnections,
#[default]
LeastLatency,
WeightedRandom,
ConsistentHash,
}
#[derive(Default)]
pub struct FederationBuilder {
pub catalog: Option<Box<dyn ModelCatalogTrait>>,
pub health_checker: Option<Box<dyn HealthCheckerTrait>>,
pub router: Option<Box<dyn RouterTrait>>,
pub policies: Vec<Box<dyn RoutingPolicyTrait>>,
pub middlewares: Vec<Box<dyn GatewayMiddleware>>,
pub load_balance: LoadBalanceStrategy,
}
impl FederationBuilder {
pub fn new() -> Self {
Self {
load_balance: LoadBalanceStrategy::LeastLatency,
..Default::default()
}
}
#[must_use]
pub fn with_catalog(mut self, catalog: impl ModelCatalogTrait + 'static) -> Self {
self.catalog = Some(Box::new(catalog));
self
}
#[must_use]
pub fn with_health_checker(mut self, checker: impl HealthCheckerTrait + 'static) -> Self {
self.health_checker = Some(Box::new(checker));
self
}
#[must_use]
pub fn with_router(mut self, router: impl RouterTrait + 'static) -> Self {
self.router = Some(Box::new(router));
self
}
#[must_use]
pub fn with_policy(mut self, policy: impl RoutingPolicyTrait + 'static) -> Self {
self.policies.push(Box::new(policy));
self
}
#[must_use]
pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
self.middlewares.push(Box::new(middleware));
self
}
#[must_use]
pub fn with_load_balance(mut self, strategy: LoadBalanceStrategy) -> Self {
self.load_balance = strategy;
self
}
}
#[cfg(test)]
#[path = "traits_tests.rs"]
mod tests;