use crate::core::models::RequestContext;
use crate::core::providers::Provider;
use crate::core::router::health::HealthChecker;
use crate::core::router::strategy::{RoutingStrategy, StrategyExecutor};
use crate::utils::error::{GatewayError, Result};
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info};
pub struct LoadBalancer {
providers: Arc<DashMap<String, Arc<dyn Provider>>>,
strategy: Arc<StrategyExecutor>,
health_checker: Option<Arc<HealthChecker>>,
model_support_cache: Arc<DashMap<String, Arc<Vec<String>>>>,
}
impl LoadBalancer {
pub async fn new(strategy: RoutingStrategy) -> Result<Self> {
info!("Creating load balancer with strategy: {:?}", strategy);
let strategy_executor = Arc::new(StrategyExecutor::new(strategy).await?);
Ok(Self {
providers: Arc::new(DashMap::new()),
strategy: strategy_executor,
health_checker: None,
model_support_cache: Arc::new(DashMap::new()),
})
}
pub async fn set_health_checker(&mut self, health_checker: Arc<HealthChecker>) {
self.health_checker = Some(health_checker);
info!("Health checker attached to load balancer");
}
pub async fn select_provider(
&self,
model: &str,
context: &RequestContext,
) -> Result<Arc<dyn Provider>> {
let supporting_providers = self.get_supporting_providers(model).await?;
if supporting_providers.is_empty() {
return Err(GatewayError::NoProvidersForModel(model.to_string()));
}
let healthy_providers = if let Some(health_checker) = &self.health_checker {
let healthy_list = health_checker.get_healthy_providers().await?;
supporting_providers
.into_iter()
.filter(|p| healthy_list.contains(p))
.collect()
} else {
supporting_providers
};
if healthy_providers.is_empty() {
return Err(GatewayError::NoHealthyProviders(
"No healthy providers available".to_string(),
));
}
let selected_name = self
.strategy
.select_provider(&healthy_providers, model, context)
.await?;
let provider = self
.providers
.get(&selected_name)
.ok_or_else(|| GatewayError::ProviderNotFound(selected_name.clone()))?
.clone();
debug!("Selected provider {} for model {}", selected_name, model);
Ok(provider)
}
async fn get_supporting_providers(&self, model: &str) -> Result<Vec<String>> {
if let Some(cached_providers) = self.model_support_cache.get(model) {
debug!(
"Found cached providers for model {}: {:?}",
model, cached_providers
);
return Ok((**cached_providers).clone());
}
let mut supporting_providers = Vec::new();
for entry in self.providers.iter() {
let (name, provider) = entry.pair();
if provider.supports_model(model).await {
supporting_providers.push(name.clone());
}
}
let cached_result = Arc::new(supporting_providers.clone());
self.model_support_cache
.insert(model.to_string(), cached_result);
debug!(
"Providers supporting model {}: {:?}",
model, supporting_providers
);
Ok(supporting_providers)
}
pub async fn add_provider(&self, name: &str, provider: Arc<dyn Provider>) -> Result<()> {
self.providers.insert(name.to_string(), provider);
self.model_support_cache.clear();
info!("Added provider {} to load balancer", name);
Ok(())
}
pub async fn remove_provider(&self, name: &str) -> Result<()> {
self.providers.remove(name);
self.model_support_cache.retain(|_, providers| {
let mut updated_providers = (**providers).clone();
updated_providers.retain(|p| p != name);
if updated_providers.len() != providers.len() {
false } else {
true }
});
info!("Removed provider {} from load balancer", name);
Ok(())
}
pub async fn update_provider_weight(&self, provider: &str, weight: f64) -> Result<()> {
self.strategy.update_weight(provider, weight).await
}
pub async fn update_provider_latency(&self, provider: &str, latency: f64) -> Result<()> {
self.strategy.update_latency(provider, latency).await
}
pub async fn update_provider_cost(&self, provider: &str, model: &str, cost: f64) -> Result<()> {
self.strategy.update_cost(provider, model, cost).await
}
pub async fn update_provider_priority(&self, provider: &str, priority: u32) -> Result<()> {
self.strategy.update_priority(provider, priority).await
}
pub async fn get_stats(&self) -> Result<LoadBalancerStats> {
let provider_count = self.providers.len();
let healthy_count = if let Some(health_checker) = &self.health_checker {
health_checker.get_healthy_providers().await?.len()
} else {
provider_count
};
let cached_models = self.model_support_cache.len();
Ok(LoadBalancerStats {
total_providers: provider_count,
healthy_providers: healthy_count,
cached_models,
})
}
pub async fn clear_cache(&self) -> Result<()> {
self.model_support_cache.clear();
info!("Cleared model support cache");
Ok(())
}
pub async fn get_model_cache(&self) -> Result<HashMap<String, Vec<String>>> {
let mut result = HashMap::new();
for entry in self.model_support_cache.iter() {
let (key, value) = entry.pair();
result.insert(key.clone(), (**value).clone());
}
Ok(result)
}
pub async fn preload_cache(&self, models: &[String]) -> Result<()> {
info!("Preloading model support cache for {} models", models.len());
for model in models {
self.get_supporting_providers(model).await?;
}
info!("Model support cache preloaded successfully");
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct LoadBalancerStats {
pub total_providers: usize,
pub healthy_providers: usize,
pub cached_models: usize,
}