use super::config::RouterConfig;
use super::deployment::{Deployment, DeploymentId};
use super::error::CooldownReason;
use super::execution::infer_cooldown_reason;
use super::fallback::{FallbackConfig, FallbackType};
use crate::core::providers::unified_provider::ProviderError;
use dashmap::DashMap;
use dashmap::mapref::one::Ref;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering::Relaxed};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RoutingMetrics {
pub provider_selected: u64,
pub strategy_used: u64,
pub fallback_triggered: u64,
}
#[derive(Debug)]
pub struct Router {
pub(crate) deployments: DashMap<DeploymentId, Deployment>,
pub(crate) model_index: DashMap<String, Vec<DeploymentId>>,
pub(crate) model_aliases: DashMap<String, String>,
pub(crate) config: RouterConfig,
pub(crate) fallback_config: FallbackConfig,
pub(crate) round_robin_counters: DashMap<String, AtomicUsize>,
pub(crate) provider_selected_count: AtomicU64,
pub(crate) strategy_used_count: AtomicU64,
pub(crate) fallback_triggered_count: AtomicU64,
}
impl Router {
pub fn new(config: RouterConfig) -> Self {
Self {
deployments: DashMap::new(),
model_index: DashMap::new(),
model_aliases: DashMap::new(),
config,
fallback_config: FallbackConfig::default(),
round_robin_counters: DashMap::new(),
provider_selected_count: AtomicU64::new(0),
strategy_used_count: AtomicU64::new(0),
fallback_triggered_count: AtomicU64::new(0),
}
}
pub fn with_fallback_config(mut self, config: FallbackConfig) -> Self {
self.fallback_config = config;
self
}
pub fn set_fallback_config(&mut self, config: FallbackConfig) {
self.fallback_config = config;
}
pub fn config(&self) -> &RouterConfig {
&self.config
}
pub fn routing_metrics(&self) -> RoutingMetrics {
RoutingMetrics {
provider_selected: self.provider_selected_count.load(Relaxed),
strategy_used: self.strategy_used_count.load(Relaxed),
fallback_triggered: self.fallback_triggered_count.load(Relaxed),
}
}
pub fn add_deployment(&self, deployment: Deployment) {
let model_name = deployment.model_name.clone();
let deployment_id = deployment.id.clone();
self.deployments.insert(deployment_id.clone(), deployment);
self.model_index
.entry(model_name)
.or_default()
.push(deployment_id);
}
pub fn remove_deployment(&self, id: &str) -> Option<Deployment> {
let removed = self.deployments.remove(id).map(|(_, v)| v);
if let Some(ref deployment) = removed
&& let Some(mut entry) = self.model_index.get_mut(&deployment.model_name)
{
entry.retain(|did| did != id);
}
removed
}
pub fn get_deployment(&self, id: &str) -> Option<Ref<'_, DeploymentId, Deployment>> {
self.deployments.get(id)
}
pub fn set_model_list(&self, deployments: Vec<Deployment>) {
let new_deployments: DashMap<DeploymentId, Deployment> = DashMap::new();
let new_index: DashMap<String, Vec<DeploymentId>> = DashMap::new();
for deployment in deployments {
let model_name = deployment.model_name.clone();
let id = deployment.id.clone();
new_deployments.insert(id.clone(), deployment);
new_index.entry(model_name).or_default().push(id);
}
self.deployments
.retain(|k, _| new_deployments.contains_key(k));
for (k, v) in new_deployments {
self.deployments.insert(k, v);
}
self.model_index.retain(|k, _| new_index.contains_key(k));
for (k, v) in new_index {
self.model_index.insert(k, v);
}
}
pub fn add_model_alias(
&self,
alias: &str,
model_name: &str,
) -> Result<(), super::error::RouterError> {
if alias == model_name {
return Err(super::error::RouterError::AliasCycle(format!(
"'{alias}' -> '{model_name}' would create a cycle"
)));
}
let mut current = model_name.to_string();
let mut visited = std::collections::HashSet::new();
visited.insert(alias.to_string());
while let Some(next) = self.model_aliases.get(¤t) {
let next_val = next.value().clone();
if !visited.insert(next_val.clone()) {
return Err(super::error::RouterError::AliasCycle(format!(
"'{alias}' -> '{model_name}' would create a cycle"
)));
}
current = next_val;
}
self.model_aliases
.insert(alias.to_string(), model_name.to_string());
Ok(())
}
pub fn resolve_model_name(&self, name: &str) -> String {
self.model_aliases
.get(name)
.map(|v| v.clone())
.unwrap_or_else(|| name.to_string())
}
pub fn get_deployments_for_model(&self, model_name: &str) -> Vec<DeploymentId> {
let resolved_name = self.resolve_model_name(model_name);
self.model_index
.get(&resolved_name)
.map(|v| v.clone())
.unwrap_or_default()
}
pub fn get_healthy_deployments(&self, model_name: &str) -> Vec<DeploymentId> {
let resolved_name = self.resolve_model_name(model_name);
let Some(deployment_ids) = self.model_index.get(&resolved_name) else {
return Vec::new();
};
deployment_ids
.iter()
.filter(|id| {
if let Some(deployment) = self.deployments.get(id.as_str()) {
deployment.is_healthy() && !deployment.is_in_cooldown()
} else {
false
}
})
.cloned()
.collect()
}
pub fn list_models(&self) -> Vec<String> {
self.model_index
.iter()
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_deployments(&self) -> Vec<DeploymentId> {
self.deployments
.iter()
.map(|entry| entry.key().clone())
.collect()
}
pub fn record_success(&self, deployment_id: &str, tokens: u64, latency_us: u64) {
if let Some(deployment) = self.deployments.get(deployment_id) {
deployment.record_success(tokens, latency_us);
let current_health = deployment.state.health.load(Relaxed);
if current_health == super::deployment::HealthStatus::Degraded as u8 {
let consec = deployment.state.consecutive_successes.load(Relaxed);
if consec >= self.config.success_threshold {
deployment
.state
.health
.store(super::deployment::HealthStatus::Healthy as u8, Relaxed);
}
}
}
}
pub fn record_failure(&self, deployment_id: &str) {
if let Some(deployment) = self.deployments.get(deployment_id) {
deployment.record_failure();
let fails = deployment.state.fails_this_minute.load(Relaxed);
let successes_this_minute = deployment.state.rpm_current.load(Relaxed);
let total_this_minute = successes_this_minute + fails as u64;
if fails >= self.config.allowed_fails
&& total_this_minute >= self.config.min_requests as u64
{
tracing::info!(
deployment_id = %deployment_id,
model = %deployment.model_name,
reason = "consecutive_failures",
cooldown_secs = self.config.cooldown_time_secs,
fails_this_minute = fails,
"deployment entering cooldown"
);
deployment.enter_cooldown(self.config.cooldown_time_secs);
}
}
}
pub fn record_failure_with_reason(&self, deployment_id: &str, reason: CooldownReason) {
if let Some(d) = self.deployments.get(deployment_id) {
d.record_failure();
let should_cooldown = match reason {
CooldownReason::RateLimit
| CooldownReason::AuthError
| CooldownReason::NotFound
| CooldownReason::Timeout
| CooldownReason::Manual => true,
CooldownReason::ConsecutiveFailures => {
let fails = d.state.fails_this_minute.load(Relaxed);
let successes_this_minute = d.state.rpm_current.load(Relaxed);
let total_this_minute = successes_this_minute + fails as u64;
fails >= self.config.allowed_fails
&& total_this_minute >= self.config.min_requests as u64
}
CooldownReason::HighFailureRate => {
let total = d.state.total_requests.load(Relaxed);
let fails = d.state.fail_requests.load(Relaxed);
total >= self.config.min_requests as u64 && (fails * 100 / total) > 50
}
};
if should_cooldown {
tracing::info!(
deployment_id = %deployment_id,
model = %d.model_name,
reason = ?reason,
cooldown_secs = self.config.cooldown_time_secs,
"deployment entering cooldown"
);
d.enter_cooldown(self.config.cooldown_time_secs);
}
}
}
pub fn infer_fallback_type(error: &ProviderError) -> FallbackType {
super::execution::infer_fallback_type(error)
}
pub fn get_fallbacks(&self, model_name: &str, fallback_type: FallbackType) -> Vec<String> {
let resolved_name = self.resolve_model_name(model_name);
let mut fallbacks = self
.fallback_config
.get_fallbacks_for_type(&resolved_name, fallback_type);
if fallbacks.is_empty() && fallback_type != FallbackType::General {
fallbacks = self
.fallback_config
.get_fallbacks_for_type(&resolved_name, FallbackType::General);
}
fallbacks
}
pub fn get_models_with_fallbacks(
&self,
model_name: &str,
fallback_type: FallbackType,
) -> Vec<String> {
let mut models = vec![self.resolve_model_name(model_name)];
models.extend(self.get_fallbacks(model_name, fallback_type));
models
}
pub fn infer_cooldown_reason(error: &ProviderError) -> CooldownReason {
infer_cooldown_reason(error)
}
pub fn reset_minute_counters(&self) {
for entry in self.deployments.iter() {
entry.value().state.reset_minute();
}
}
pub fn start_minute_reset_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
self.reset_minute_counters();
}
})
}
}
impl Default for Router {
fn default() -> Self {
Self::new(RouterConfig::default())
}
}