litellm-rs 0.4.16

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
Documentation
//! Provider handle for routing system
//!
//! Provides a type-erased wrapper for LLMProvider instances used by the router

use crate::core::types::health::HealthStatus;
use crate::core::types::{chat::ChatRequest, context::RequestContext, responses::ChatResponse};
use crate::utils::error::gateway_error::GatewayError;

use super::llm_provider::trait_definition::LLMProvider;

/// Provider handle for routing system
///
/// This struct wraps a concrete provider implementation and provides
/// a uniform interface for the routing system. It uses type erasure
/// to allow heterogeneous provider collections.
///
/// # Design Principles
/// - Type erasure for flexible provider management
/// - Weight-based routing support
/// - Enable/disable functionality for health management
/// - Simplified interface for routing decisions
///
/// # Example
///
/// The `ProviderHandle` wraps any type implementing `LLMProvider` with type erasure
/// for flexible provider management in routing systems. See the routing module
/// for usage examples.
pub struct ProviderHandle {
    name: String,
    _provider: std::sync::Arc<dyn std::any::Any + Send + Sync>,
    weight: f64,
    enabled: bool,
}

impl ProviderHandle {
    /// Create a new provider handle
    ///
    /// # Parameters
    /// * `provider` - The provider instance to wrap
    /// * `weight` - Routing weight (higher values = more traffic)
    ///
    /// # Returns
    /// A new ProviderHandle with the provider enabled by default
    ///
    /// # Type Parameters
    /// * `P` - Any type implementing LLMProvider
    pub fn new<P>(provider: P, weight: f64) -> Self
    where
        P: LLMProvider + Send + Sync + 'static,
    {
        Self {
            name: provider.name().to_string(),
            _provider: std::sync::Arc::new(provider)
                as std::sync::Arc<dyn std::any::Any + Send + Sync>,
            weight,
            enabled: true,
        }
    }

    /// Get provider name
    ///
    /// # Returns
    /// The provider's identifier string
    pub fn name(&self) -> &str {
        &self.name
    }

    /// Get routing weight
    ///
    /// # Returns
    /// The weight used for weighted routing strategies
    pub fn weight(&self) -> f64 {
        self.weight
    }

    /// Check if provider is enabled
    ///
    /// # Returns
    /// `true` if the provider can receive traffic, `false` otherwise
    pub fn is_enabled(&self) -> bool {
        self.enabled
    }

    /// Set enabled state
    ///
    /// # Parameters
    /// * `enabled` - Whether to enable or disable this provider
    ///
    /// # Use Cases
    /// - Disable unhealthy providers automatically
    /// - Manual traffic control
    /// - Gradual rollout/rollback
    pub fn set_enabled(&mut self, enabled: bool) {
        self.enabled = enabled;
    }

    /// Execute chat completion request
    ///
    /// # Parameters
    /// * `request` - Chat completion request
    /// * `context` - Request context with metadata
    ///
    /// # Returns
    /// Chat completion response
    ///
    /// # Note
    /// This is a simplified implementation. In a real system, you would need
    /// to properly downcast the provider and call its chat_completion method.
    pub async fn chat_completion(
        &self,
        _request: ChatRequest,
        _context: RequestContext,
    ) -> Result<ChatResponse, GatewayError> {
        // This is a simplified implementation - in a real system,
        // you'd need to properly downcast and handle the provider
        Err(GatewayError::Internal(
            "Provider chat_completion not implemented".to_string(),
        ))
    }

    /// Check if model is supported
    ///
    /// # Parameters
    /// * `model` - Model name to check
    ///
    /// # Returns
    /// `true` if the model is supported
    ///
    /// # Note
    /// Simplified implementation - returns true for all models
    pub fn supports_model(&self, _model: &str) -> bool {
        // Simplified implementation
        true
    }

    /// Check if tools are supported
    ///
    /// # Returns
    /// `true` if tool calling is supported
    ///
    /// # Note
    /// Simplified implementation - returns true
    pub fn supports_tools(&self) -> bool {
        // Simplified implementation
        true
    }

    /// Check provider health status
    ///
    /// # Returns
    /// Health status of the provider
    ///
    /// # Note
    /// Simplified implementation - always returns Healthy
    pub async fn health_check(&self) -> HealthStatus {
        // Simplified implementation
        HealthStatus::Healthy
    }

    /// Calculate request cost
    ///
    /// # Parameters
    /// * `model` - Model name used
    /// * `input` - Number of input tokens
    /// * `output` - Number of output tokens
    ///
    /// # Returns
    /// Estimated cost in USD
    ///
    /// # Note
    /// Simplified implementation - returns 0.0
    pub async fn calculate_cost(
        &self,
        _model: &str,
        _input: u32,
        _output: u32,
    ) -> Result<f64, GatewayError> {
        // Simplified implementation
        Ok(0.0)
    }

    /// Get average response latency
    ///
    /// # Returns
    /// Average latency for this provider
    ///
    /// # Note
    /// Simplified implementation - returns 100ms
    pub async fn get_average_latency(&self) -> Result<std::time::Duration, GatewayError> {
        // Simplified implementation
        Ok(std::time::Duration::from_millis(100))
    }

    /// Get success rate
    ///
    /// # Returns
    /// Success rate between 0.0 and 1.0
    ///
    /// # Note
    /// Simplified implementation - returns 1.0 (100%)
    pub async fn get_success_rate(&self) -> Result<f32, GatewayError> {
        // Simplified implementation
        Ok(1.0)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // Test directly on the struct methods that don't require provider creation
    // Since ProviderHandle::new requires a concrete LLMProvider implementation,
    // we test the simpler functions that can be tested in isolation

    #[test]
    fn test_provider_handle_struct_exists() {
        // Test that ProviderHandle struct is properly defined
        // and can be referenced (compilation test)
        let _type_check: fn() -> bool = || {
            // Just a compilation test to ensure the type exists
            true
        };
        assert!(_type_check());
    }

    #[tokio::test]
    async fn test_health_status_healthy() {
        // Test HealthStatus enum
        let status = HealthStatus::Healthy;
        assert!(matches!(status, HealthStatus::Healthy));
    }

    #[tokio::test]
    async fn test_health_status_unhealthy() {
        let status = HealthStatus::Unhealthy;
        assert!(matches!(status, HealthStatus::Unhealthy));
    }

    #[tokio::test]
    async fn test_health_status_degraded() {
        let status = HealthStatus::Degraded;
        assert!(matches!(status, HealthStatus::Degraded));
    }

    #[test]
    fn test_request_context_default() {
        let context = RequestContext::default();
        // Verify default context is created with a request_id
        assert!(!context.request_id.is_empty());
    }

    #[test]
    fn test_chat_request_default() {
        let request = ChatRequest {
            model: "test-model".to_string(),
            messages: vec![],
            ..Default::default()
        };

        assert_eq!(request.model, "test-model");
        assert!(request.messages.is_empty());
    }

    // Note: Full ProviderHandle tests would require a mock LLMProvider implementation
    // which is complex due to the trait bounds. These basic tests verify the
    // foundational types work correctly.
}