pub mod health;
pub mod load_balancer;
pub mod metrics;
pub mod strategy;
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::core::providers::Provider;
use crate::core::router::health::RouterHealthStatus;
use crate::core::router::metrics::RouterMetricsSnapshot;
use crate::storage::StorageLayer;
use crate::utils::async_utils::{ConcurrentRunner, RetryPolicy, default_retry_policy};
use crate::utils::error::Result;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};
pub use health::HealthChecker;
pub use load_balancer::LoadBalancer;
pub use metrics::RouterMetrics;
pub use strategy::RoutingStrategy;
#[derive(Clone)]
pub struct Router {
providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>,
configs: Arc<Vec<ProviderConfig>>,
storage: Arc<StorageLayer>,
strategy: RoutingStrategy,
health_checker: Arc<HealthChecker>,
load_balancer: Arc<LoadBalancer>,
metrics: Arc<RouterMetrics>,
concurrent_runner: ConcurrentRunner,
retry_policy: RetryPolicy,
}
impl Router {
pub async fn new(
configs: Vec<ProviderConfig>,
storage: Arc<StorageLayer>,
strategy: RoutingStrategy,
) -> Result<Self> {
info!("Initializing router with {} providers", configs.len());
let providers = Arc::new(RwLock::new(HashMap::new()));
let health_checker = Arc::new(HealthChecker::new(providers.clone()).await?);
let load_balancer = Arc::new(LoadBalancer::new(strategy.clone()).await?);
let metrics = Arc::new(RouterMetrics::new().await?);
let concurrent_runner =
ConcurrentRunner::new(10).with_timeout(std::time::Duration::from_secs(30));
let retry_policy = default_retry_policy();
Ok(Self {
providers,
configs: Arc::new(configs),
storage,
strategy,
health_checker,
load_balancer,
metrics,
concurrent_runner,
retry_policy,
})
}
pub async fn route_chat_completion(
&self,
request: ChatCompletionRequest,
context: RequestContext,
) -> Result<ChatCompletionResponse> {
let start_time = Instant::now();
let result = self
.retry_policy
.execute(|| {
let request = request.clone();
let context = context.clone();
let load_balancer = self.load_balancer.clone();
async move {
let provider = load_balancer
.select_provider(&request.model, &context)
.await?;
provider.chat_completion(request, context).await
}
})
.await;
let duration = start_time.elapsed();
let provider_name = if let Ok(provider) = self
.load_balancer
.select_provider(&request.model, &context)
.await
{
provider.name().to_string()
} else {
"unknown".to_string()
};
let metrics = self.metrics.clone();
let model = request.model.clone();
let success = result.is_ok();
tokio::spawn(async move {
let _ = metrics
.record_request(&provider_name, &model, duration, success)
.await;
});
result
}
pub async fn route_completion(
&self,
request: CompletionRequest,
context: RequestContext,
) -> Result<CompletionResponse> {
let start_time = Instant::now();
let provider = self
.load_balancer
.select_provider(&request.model, &context)
.await?;
let result = provider.completion(request.clone(), context.clone()).await;
let duration = start_time.elapsed();
self.metrics
.record_request(&provider.name(), &request.model, duration, result.is_ok())
.await;
result
}
pub async fn route_embedding(
&self,
request: EmbeddingRequest,
context: RequestContext,
) -> Result<EmbeddingResponse> {
let start_time = Instant::now();
let provider = self
.load_balancer
.select_provider(&request.model, &context)
.await?;
let result = provider.embedding(request.clone(), context.clone()).await;
let duration = start_time.elapsed();
self.metrics
.record_request(&provider.name(), &request.model, duration, result.is_ok())
.await;
result
}
pub async fn health_status(&self) -> Result<RouterHealthStatus> {
self.health_checker.get_status().await
}
pub async fn get_metrics(&self) -> Result<RouterMetricsSnapshot> {
self.metrics.get_snapshot().await
}
pub async fn route_chat_completion_stream(
&self,
request: crate::core::models::openai::ChatCompletionRequest,
context: crate::core::models::RequestContext,
) -> Result<impl futures::Stream<Item = Result<String>>> {
let model = &request.model;
let provider = self.load_balancer.select_provider(&model, &context).await?;
provider.chat_completion_stream(request, context).await
}
pub async fn route_image_generation(
&self,
request: crate::core::models::openai::ImageGenerationRequest,
context: crate::core::models::RequestContext,
) -> Result<crate::core::models::openai::ImageGenerationResponse> {
let default_model = "dall-e-3".to_string();
let model = request.model.as_ref().unwrap_or(&default_model);
let provider = self.load_balancer.select_provider(&model, &context).await?;
provider.image_generation(request, context).await
}
pub async fn list_models(&self) -> Result<Vec<crate::core::models::openai::Model>> {
let mut all_models = Vec::new();
let providers = self.providers.read().await;
for provider in providers.values() {
match provider.list_models().await {
Ok(models) => all_models.extend(models),
Err(e) => warn!(
"Failed to get models from provider {}: {}",
provider.name(),
e
),
}
}
Ok(all_models)
}
pub async fn get_model(
&self,
model_id: &str,
) -> Result<Option<crate::core::models::openai::Model>> {
let providers = self.providers.read().await;
for provider in providers.values() {
match provider.get_model(model_id).await {
Ok(Some(model)) => return Ok(Some(model)),
Ok(None) => continue,
Err(e) => warn!(
"Failed to get model {} from provider {}: {}",
model_id,
provider.name(),
e
),
}
}
Ok(None)
}
pub async fn add_provider(&self, config: ProviderConfig) -> Result<()> {
info!("Adding provider: {}", config.name);
let provider = crate::core::providers::create_provider(config.clone()).await?;
let mut providers = self.providers.write().await;
providers.insert(config.name.clone(), provider.clone());
self.health_checker.add_provider(&config.name).await?;
self.load_balancer
.add_provider(&config.name, provider)
.await?;
info!("Provider {} added successfully", config.name);
Ok(())
}
pub async fn remove_provider(&self, name: &str) -> Result<()> {
info!("Removing provider: {}", name);
let mut providers = self.providers.write().await;
providers.remove(name);
self.health_checker.remove_provider(name).await?;
self.load_balancer.remove_provider(name).await?;
info!("Provider {} removed successfully", name);
Ok(())
}
pub async fn list_providers(&self) -> Result<Vec<String>> {
let providers = self.providers.read().await;
Ok(providers.keys().cloned().collect())
}
}