use super::catalog::ModelCatalog;
use super::health::{CircuitBreaker, HealthChecker};
use super::policy::CompositePolicy;
use super::traits::*;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RouteDecision {
pub target: RouteTarget,
pub alternatives: Vec<RouteTarget>,
pub reasoning: String,
}
#[derive(Debug, Clone)]
pub struct RouterConfig {
pub max_candidates: usize,
pub min_score: f64,
pub strategy: LoadBalanceStrategy,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
max_candidates: 10,
min_score: 0.1,
strategy: LoadBalanceStrategy::LeastLatency,
}
}
}
pub struct Router {
config: RouterConfig,
catalog: Arc<ModelCatalog>,
health: Arc<HealthChecker>,
circuit_breaker: Arc<CircuitBreaker>,
policy: CompositePolicy,
}
impl Router {
pub fn new(
config: RouterConfig,
catalog: Arc<ModelCatalog>,
health: Arc<HealthChecker>,
circuit_breaker: Arc<CircuitBreaker>,
) -> Self {
Self {
config,
catalog,
health,
circuit_breaker,
policy: CompositePolicy::enterprise_default(),
}
}
#[must_use]
pub fn with_policy(mut self, policy: CompositePolicy) -> Self {
self.policy = policy;
self
}
fn build_candidates(&self, capability: &Capability) -> Vec<RouteCandidate> {
let mut candidates = Vec::new();
let entries = self.catalog.all_entries();
for entry in entries {
let has_capability = entry.metadata.capabilities.iter().any(|c| c == capability);
if !has_capability {
continue;
}
for deployment in &entry.deployments {
if self.circuit_breaker.is_open(&deployment.node_id) {
continue;
}
let health = self
.health
.get_cached_health(&deployment.node_id)
.unwrap_or_else(|| NodeHealth {
node_id: deployment.node_id.clone(),
status: HealthState::Unknown,
latency_p50: Duration::from_secs(1),
latency_p99: Duration::from_secs(5),
throughput: 0,
gpu_utilization: None,
queue_depth: 0,
last_check: std::time::Instant::now(),
});
if health.status == HealthState::Unhealthy {
continue;
}
let target = RouteTarget {
node_id: deployment.node_id.clone(),
region_id: deployment.region_id.clone(),
endpoint: deployment.endpoint.clone(),
estimated_latency: health.latency_p50,
score: 0.0, };
let health_score = match health.status {
HealthState::Healthy => 1.0,
HealthState::Degraded => 0.5,
HealthState::Unknown => 0.3,
HealthState::Unhealthy => 0.0,
};
let scores = RouteScores {
latency_score: 1.0 - (health.latency_p50.as_millis() as f64 / 5000.0).min(1.0),
throughput_score: (health.throughput as f64 / 1000.0).min(1.0),
cost_score: 0.5, locality_score: 0.5, health_score,
total: 0.0,
};
candidates.push(RouteCandidate {
target,
scores,
eligible: true,
rejection_reason: None,
});
}
}
candidates
}
fn rank_candidates(&self, candidates: &mut [RouteCandidate], request: &InferenceRequest) {
for candidate in candidates.iter_mut() {
if !self.policy.is_eligible(candidate, request) {
candidate.eligible = false;
candidate.rejection_reason = Some("Policy rejected".to_string());
continue;
}
let score = self.policy.score(candidate, request);
candidate.target.score = score;
candidate.scores.total = score;
}
candidates.sort_by(|a, b| {
b.scores
.total
.partial_cmp(&a.scores.total)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn select_best(&self, candidates: &[RouteCandidate]) -> Option<RouteCandidate> {
let eligible: Vec<_> = candidates
.iter()
.filter(|c| c.eligible && c.scores.total >= self.config.min_score)
.take(self.config.max_candidates)
.collect();
if eligible.is_empty() {
return None;
}
match self.config.strategy {
LoadBalanceStrategy::LeastLatency => {
eligible.first().map(|c| (*c).clone())
}
LoadBalanceStrategy::LeastConnections => {
eligible.first().map(|c| (*c).clone())
}
LoadBalanceStrategy::RoundRobin => {
eligible.first().map(|c| (*c).clone())
}
LoadBalanceStrategy::WeightedRandom => {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let total_weight: f64 = eligible.iter().map(|c| c.scores.total).sum();
if total_weight <= 0.0 {
return eligible.first().map(|c| (*c).clone());
}
let mut hasher = DefaultHasher::new();
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
.hash(&mut hasher);
let random = (hasher.finish() as f64) / (u64::MAX as f64);
let target = random * total_weight;
let mut cumulative = 0.0;
for candidate in &eligible {
cumulative += candidate.scores.total;
if cumulative >= target {
return Some((*candidate).clone());
}
}
eligible.last().map(|c| (*c).clone())
}
LoadBalanceStrategy::ConsistentHash => {
eligible.first().map(|c| (*c).clone())
}
}
}
}
impl RouterTrait for Router {
fn route(&self, request: &InferenceRequest) -> BoxFuture<'_, FederationResult<RouteTarget>> {
let request = request.clone();
Box::pin(async move {
let mut candidates = self.build_candidates(&request.capability);
if candidates.is_empty() {
return Err(FederationError::NoCapacity(request.capability.clone()));
}
self.rank_candidates(&mut candidates, &request);
self.select_best(&candidates)
.map(|c| c.target)
.ok_or_else(|| FederationError::AllNodesUnhealthy(request.capability.clone()))
})
}
fn get_candidates(
&self,
request: &InferenceRequest,
) -> BoxFuture<'_, FederationResult<Vec<RouteCandidate>>> {
let request = request.clone();
Box::pin(async move {
let mut candidates = self.build_candidates(&request.capability);
self.rank_candidates(&mut candidates, &request);
Ok(candidates)
})
}
}
pub struct RouterBuilder {
config: RouterConfig,
catalog: Option<Arc<ModelCatalog>>,
health: Option<Arc<HealthChecker>>,
circuit_breaker: Option<Arc<CircuitBreaker>>,
policy: Option<CompositePolicy>,
}
impl RouterBuilder {
pub fn new() -> Self {
Self {
config: RouterConfig::default(),
catalog: None,
health: None,
circuit_breaker: None,
policy: None,
}
}
#[must_use]
pub fn config(mut self, config: RouterConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
self.catalog = Some(catalog);
self
}
#[must_use]
pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
self.health = Some(health);
self
}
#[must_use]
pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
self.circuit_breaker = Some(cb);
self
}
#[must_use]
pub fn policy(mut self, policy: CompositePolicy) -> Self {
self.policy = Some(policy);
self
}
pub fn build(self) -> Router {
let catalog = self
.catalog
.unwrap_or_else(|| Arc::new(ModelCatalog::new()));
let health = self
.health
.unwrap_or_else(|| Arc::new(HealthChecker::default()));
let circuit_breaker = self
.circuit_breaker
.unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
let router = Router::new(self.config, catalog, health, circuit_breaker);
if let Some(policy) = self.policy {
router.with_policy(policy)
} else {
router
}
}
}
impl Default for RouterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[path = "routing_tests.rs"]
mod tests;