use super::config::RoutingStrategy;
use super::deployment::{Deployment, DeploymentId};
use super::error::RouterError;
use super::strategy_impl;
use super::unified::Router;
use crate::core::types::model::ProviderCapability;
use std::sync::atomic::Ordering::Relaxed;
impl Router {
pub fn select_deployment(&self, model_name: &str) -> Result<DeploymentId, RouterError> {
self.select_deployment_matching(model_name, |_| true)
}
pub fn select_deployment_for_capability(
&self,
model_name: &str,
capability: &ProviderCapability,
) -> Result<DeploymentId, RouterError> {
self.select_deployment_matching(model_name, |deployment| {
deployment
.provider
.capabilities()
.iter()
.any(|cap| cap == capability)
})
}
fn select_deployment_matching<F>(
&self,
model_name: &str,
is_candidate: F,
) -> Result<DeploymentId, RouterError>
where
F: Fn(&Deployment) -> bool,
{
let alias_guard = self.maybe_model_alias(model_name);
let resolved_name = alias_guard
.as_ref()
.map(|alias| alias.value().as_str())
.unwrap_or(model_name);
let deployment_ids_ref = self
.model_index
.get(resolved_name)
.ok_or_else(|| RouterError::ModelNotFound(model_name.to_string()))?;
if deployment_ids_ref.is_empty() {
return Err(RouterError::ModelNotFound(model_name.to_string()));
}
let total_deployments = deployment_ids_ref.len();
let mut routing_contexts = Vec::with_capacity(total_deployments);
for id in deployment_ids_ref.iter() {
let Some(deployment) = self.deployments.get(id.as_str()) else {
continue;
};
if !is_candidate(&deployment) {
tracing::trace!(
deployment_id = id.as_str(),
model = %resolved_name,
reason = "capability_mismatch",
"deployment excluded from routing candidates"
);
continue;
}
if deployment.is_in_cooldown() {
tracing::trace!(
deployment_id = id.as_str(),
model = %resolved_name,
reason = "in_cooldown",
"deployment excluded from routing candidates"
);
continue;
}
if !deployment.is_healthy() {
tracing::trace!(
deployment_id = id.as_str(),
model = %resolved_name,
reason = "unhealthy",
"deployment excluded from routing candidates"
);
continue;
}
let active_requests = deployment.state.active_requests.load(Relaxed);
if let Some(limit) = deployment.config.max_parallel_requests
&& active_requests >= limit
{
tracing::trace!(
deployment_id = id.as_str(),
model = %resolved_name,
reason = "parallel_limit_reached",
"deployment excluded from routing candidates"
);
continue;
}
let rpm_current = deployment.state.rpm_current.load(Relaxed);
if let Some(limit) = deployment.config.rpm_limit
&& rpm_current >= limit
{
tracing::trace!(
deployment_id = id.as_str(),
model = %resolved_name,
reason = "rate_limited",
"deployment excluded from routing candidates"
);
continue;
}
let tpm_current = deployment.state.tpm_current.load(Relaxed);
if let Some(limit) = deployment.config.tpm_limit
&& tpm_current >= limit
{
tracing::trace!(
deployment_id = id.as_str(),
model = %resolved_name,
reason = "rate_limited",
"deployment excluded from routing candidates"
);
continue;
}
routing_contexts.push(strategy_impl::RoutingContext {
deployment_id: id,
weight: deployment.config.weight,
priority: deployment.config.priority,
active_requests,
tpm_current,
tpm_limit: deployment.config.tpm_limit,
rpm_current,
rpm_limit: deployment.config.rpm_limit,
avg_latency_us: deployment.state.avg_latency_us.load(Relaxed),
});
}
if routing_contexts.is_empty() {
tracing::warn!(
model = %model_name,
total_deployments = total_deployments,
"no available deployments after filtering"
);
return Err(RouterError::NoAvailableDeployment(model_name.to_string()));
}
let selected_id = match self.config.routing_strategy {
RoutingStrategy::SimpleShuffle => {
strategy_impl::weighted_random_from_context(&routing_contexts)
}
RoutingStrategy::LeastBusy => strategy_impl::least_busy_from_context(&routing_contexts),
RoutingStrategy::UsageBased => {
strategy_impl::lowest_usage_from_context(&routing_contexts)
}
RoutingStrategy::LatencyBased => {
strategy_impl::lowest_latency_from_context(&routing_contexts)
}
RoutingStrategy::PriorityBased => {
strategy_impl::lowest_priority_from_context(&routing_contexts)
}
RoutingStrategy::RateLimitAware => {
strategy_impl::rate_limit_aware_from_context(&routing_contexts)
}
RoutingStrategy::RoundRobin => strategy_impl::round_robin_from_context(
resolved_name,
&routing_contexts,
&self.round_robin_counters,
),
}
.ok_or_else(|| RouterError::NoAvailableDeployment(model_name.to_string()))?
.clone();
if let Some(deployment) = self.deployments.get(&selected_id) {
deployment.state.active_requests.fetch_add(1, Relaxed);
}
self.provider_selected_count.fetch_add(1, Relaxed);
self.strategy_used_count.fetch_add(1, Relaxed);
tracing::debug!(
model = %model_name,
strategy = ?self.config.routing_strategy,
candidate_count = routing_contexts.len(),
selected_id = %selected_id,
"deployment selected for routing"
);
Ok(selected_id)
}
pub fn release_deployment(&self, deployment_id: &str) {
if let Some(deployment) = self.deployments.get(deployment_id) {
let _ = deployment
.state
.active_requests
.fetch_update(Relaxed, Relaxed, |v| Some(v.saturating_sub(1)));
}
}
}