litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! Core router for AI provider management and request routing
//!
//! This module provides intelligent routing, load balancing, and failover
//! across multiple AI providers.

pub mod health;
pub mod load_balancer;
pub mod metrics;
pub mod strategy;

use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::core::providers::Provider;
use crate::core::router::health::RouterHealthStatus;
use crate::core::router::metrics::RouterMetricsSnapshot;
use crate::storage::StorageLayer;
use crate::utils::async_utils::{ConcurrentRunner, RetryPolicy, default_retry_policy};
use crate::utils::error::Result;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};

pub use health::HealthChecker;
pub use load_balancer::LoadBalancer;
pub use metrics::RouterMetrics;
pub use strategy::RoutingStrategy;

/// Core router for managing AI providers and routing requests
#[derive(Clone)]
pub struct Router {
    /// Available providers
    providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>,
    /// Provider configurations
    configs: Arc<Vec<ProviderConfig>>,
    /// Storage layer for metrics and caching
    storage: Arc<StorageLayer>,
    /// Routing strategy
    strategy: RoutingStrategy,
    /// Health checker
    health_checker: Arc<HealthChecker>,
    /// Load balancer
    load_balancer: Arc<LoadBalancer>,
    /// Router metrics
    metrics: Arc<RouterMetrics>,
    /// Concurrent runner for parallel operations
    concurrent_runner: ConcurrentRunner,
    /// Retry policy for failed requests
    retry_policy: RetryPolicy,
}

impl Router {
    /// Create a new router
    pub async fn new(
        configs: Vec<ProviderConfig>,
        storage: Arc<StorageLayer>,
        strategy: RoutingStrategy,
    ) -> Result<Self> {
        info!("Initializing router with {} providers", configs.len());

        let providers = Arc::new(RwLock::new(HashMap::new()));
        let health_checker = Arc::new(HealthChecker::new(providers.clone()).await?);
        let load_balancer = Arc::new(LoadBalancer::new(strategy.clone()).await?);
        let metrics = Arc::new(RouterMetrics::new().await?);

        // Create concurrent runner for parallel operations
        let concurrent_runner =
            ConcurrentRunner::new(10).with_timeout(std::time::Duration::from_secs(30));

        // Create retry policy for failed requests using default settings
        let retry_policy = default_retry_policy();

        Ok(Self {
            providers,
            configs: Arc::new(configs),
            storage,
            strategy,
            health_checker,
            load_balancer,
            metrics,
            concurrent_runner,
            retry_policy,
        })
    }

    /// Route a chat completion request with retry logic
    pub async fn route_chat_completion(
        &self,
        request: ChatCompletionRequest,
        context: RequestContext,
    ) -> Result<ChatCompletionResponse> {
        let start_time = Instant::now();

        // Execute request with retry policy
        let result = self
            .retry_policy
            .execute(|| {
                let request = request.clone();
                let context = context.clone();
                let load_balancer = self.load_balancer.clone();

                async move {
                    // Select provider using load balancer
                    let provider = load_balancer
                        .select_provider(&request.model, &context)
                        .await?;

                    // Execute request
                    provider.chat_completion(request, context).await
                }
            })
            .await;

        // Record metrics (run concurrently with result processing)
        let duration = start_time.elapsed();
        let provider_name = if let Ok(provider) = self
            .load_balancer
            .select_provider(&request.model, &context)
            .await
        {
            provider.name().to_string()
        } else {
            "unknown".to_string()
        };

        // Record metrics asynchronously
        let metrics = self.metrics.clone();
        let model = request.model.clone();
        let success = result.is_ok();
        tokio::spawn(async move {
            let _ = metrics
                .record_request(&provider_name, &model, duration, success)
                .await;
        });

        result
    }

    /// Route a completion request
    pub async fn route_completion(
        &self,
        request: CompletionRequest,
        context: RequestContext,
    ) -> Result<CompletionResponse> {
        let start_time = Instant::now();

        // Select provider using load balancer
        let provider = self
            .load_balancer
            .select_provider(&request.model, &context)
            .await?;

        // Execute request
        let result = provider.completion(request.clone(), context.clone()).await;

        // Record metrics
        let duration = start_time.elapsed();
        self.metrics
            .record_request(&provider.name(), &request.model, duration, result.is_ok())
            .await;

        result
    }

    /// Route an embedding request
    pub async fn route_embedding(
        &self,
        request: EmbeddingRequest,
        context: RequestContext,
    ) -> Result<EmbeddingResponse> {
        let start_time = Instant::now();

        // Select provider using load balancer
        let provider = self
            .load_balancer
            .select_provider(&request.model, &context)
            .await?;

        // Execute request
        let result = provider.embedding(request.clone(), context.clone()).await;

        // Record metrics
        let duration = start_time.elapsed();
        self.metrics
            .record_request(&provider.name(), &request.model, duration, result.is_ok())
            .await;

        result
    }

    /// Get router health status
    pub async fn health_status(&self) -> Result<RouterHealthStatus> {
        self.health_checker.get_status().await
    }

    /// Get router metrics
    pub async fn get_metrics(&self) -> Result<RouterMetricsSnapshot> {
        self.metrics.get_snapshot().await
    }

    /// Route chat completion request with streaming
    pub async fn route_chat_completion_stream(
        &self,
        request: crate::core::models::openai::ChatCompletionRequest,
        context: crate::core::models::RequestContext,
    ) -> Result<impl futures::Stream<Item = Result<String>>> {
        // Find the best provider for this request
        let model = &request.model;
        let provider = self.load_balancer.select_provider(&model, &context).await?;

        // Route to the selected provider for streaming
        provider.chat_completion_stream(request, context).await
    }

    /// Route image generation request
    pub async fn route_image_generation(
        &self,
        request: crate::core::models::openai::ImageGenerationRequest,
        context: crate::core::models::RequestContext,
    ) -> Result<crate::core::models::openai::ImageGenerationResponse> {
        // Find the best provider for this request
        let default_model = "dall-e-3".to_string();
        let model = request.model.as_ref().unwrap_or(&default_model);
        let provider = self.load_balancer.select_provider(&model, &context).await?;

        // Route to the selected provider
        provider.image_generation(request, context).await
    }

    /// List all available models
    pub async fn list_models(&self) -> Result<Vec<crate::core::models::openai::Model>> {
        let mut all_models = Vec::new();

        // Collect models from all providers
        let providers = self.providers.read().await;
        for provider in providers.values() {
            match provider.list_models().await {
                Ok(models) => all_models.extend(models),
                Err(e) => warn!(
                    "Failed to get models from provider {}: {}",
                    provider.name(),
                    e
                ),
            }
        }

        Ok(all_models)
    }

    /// Get specific model information
    pub async fn get_model(
        &self,
        model_id: &str,
    ) -> Result<Option<crate::core::models::openai::Model>> {
        // Try to find the model in any provider
        let providers = self.providers.read().await;
        for provider in providers.values() {
            match provider.get_model(model_id).await {
                Ok(Some(model)) => return Ok(Some(model)),
                Ok(None) => continue,
                Err(e) => warn!(
                    "Failed to get model {} from provider {}: {}",
                    model_id,
                    provider.name(),
                    e
                ),
            }
        }

        Ok(None)
    }

    /// Add a new provider
    pub async fn add_provider(&self, config: ProviderConfig) -> Result<()> {
        info!("Adding provider: {}", config.name);

        // Create provider instance
        let provider = crate::core::providers::create_provider(config.clone()).await?;

        // Add to providers map
        let mut providers = self.providers.write().await;
        providers.insert(config.name.clone(), provider.clone());

        // Update health checker
        self.health_checker.add_provider(&config.name).await?;

        // Update load balancer
        self.load_balancer
            .add_provider(&config.name, provider)
            .await?;

        info!("Provider {} added successfully", config.name);
        Ok(())
    }

    /// Remove a provider
    pub async fn remove_provider(&self, name: &str) -> Result<()> {
        info!("Removing provider: {}", name);

        // Remove from providers map
        let mut providers = self.providers.write().await;
        providers.remove(name);

        // Update health checker
        self.health_checker.remove_provider(name).await?;

        // Update load balancer
        self.load_balancer.remove_provider(name).await?;

        info!("Provider {} removed successfully", name);
        Ok(())
    }

    /// List all providers
    pub async fn list_providers(&self) -> Result<Vec<String>> {
        let providers = self.providers.read().await;
        Ok(providers.keys().cloned().collect())
    }
}