litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! Unified LLM client implementation

use std::sync::Arc;
use crate::sdk::{config::ClientConfig, types::*, errors::*, providers::ProviderRegistry};
use crate::core::function_calling::FunctionDefinition;

/// Unified LLM client
#[derive(Debug, Clone)]
pub struct LLMClient {
    registry: Arc<ProviderRegistry>,
    config: ClientConfig,
}

impl LLMClient {
    /// Create a new LLM client
    pub fn new(config: ClientConfig) -> Result<Self> {
        let registry = Arc::new(ProviderRegistry::new(&config)?);
        
        Ok(Self {
            registry,
            config,
        })
    }

    /// Send a chat message using the default provider
    pub async fn chat(&self, messages: Vec<Message>) -> Result<ChatResponse> {
        self.chat_with_provider(None, messages).await
    }

    /// Send a chat message using a specific provider
    pub async fn chat_with_provider(&self, provider_id: Option<&str>, messages: Vec<Message>) -> Result<ChatResponse> {
        let provider = self.registry.get_provider(provider_id)?;
        
        let request = ChatRequest {
            model: self.get_default_model(provider_id)?,
            messages,
            options: ChatOptions::default(),
        };

        // Use existing provider implementation from the main codebase
        let context = crate::core::models::RequestContext::new(
            uuid::Uuid::new_v4().to_string(),
            None,
        );

        provider.chat_completion(request.into(), context).await.map(|resp| resp.into())
    }

    /// Get streaming chat response
    pub async fn chat_stream(&self, messages: Vec<Message>) -> Result<impl tokio_stream::Stream<Item = Result<ChatChunk>>> {
        self.chat_stream_with_provider(None, messages).await
    }

    /// Get streaming chat response with specific provider
    pub async fn chat_stream_with_provider(&self, provider_id: Option<&str>, messages: Vec<Message>) -> Result<impl tokio_stream::Stream<Item = Result<ChatChunk>>> {
        let provider = self.registry.get_provider(provider_id)?;
        
        let request = ChatRequest {
            model: self.get_default_model(provider_id)?,
            messages,
            options: ChatOptions {
                stream: true,
                ..Default::default()
            },
        };

        let context = crate::core::models::RequestContext::new(
            uuid::Uuid::new_v4().to_string(),
            None,
        );

        let stream = provider.chat_completion_stream(request.into(), context).await?;
        
        // Convert the stream to our SDK format
        Ok(tokio_stream::StreamExt::map(stream, |chunk| {
            chunk.map(|s| ChatChunk {
                id: uuid::Uuid::new_v4().to_string(),
                model: "unknown".to_string(),
                choices: vec![ChunkChoice {
                    index: 0,
                    delta: MessageDelta {
                        role: None,
                        content: Some(s),
                        tool_calls: None,
                    },
                    finish_reason: None,
                }],
            }).map_err(SDKError::from)
        }))
    }

    /// List available providers
    pub fn list_providers(&self) -> Vec<String> {
        self.registry.list_providers()
    }

    /// Get provider health status
    pub async fn health_check(&self, provider_id: Option<&str>) -> Result<bool> {
        let provider = self.registry.get_provider(provider_id)?;
        provider.health_check().await.map(|_| true).or(Ok(false))
    }

    fn get_default_model(&self, provider_id: Option<&str>) -> Result<String> {
        let provider_config = if let Some(id) = provider_id {
            self.config.providers.iter()
                .find(|p| p.id == id)
                .ok_or_else(|| SDKError::ProviderNotFound(id.to_string()))?
        } else {
            self.config.providers.first()
                .ok_or(SDKError::NoDefaultProvider)?
        };

        Ok(provider_config.models.first()
            .cloned()
            .unwrap_or_else(|| "gpt-3.5-turbo".to_string()))
    }
}

// Conversion implementations to bridge SDK types with core types
impl From<ChatRequest> for crate::core::models::openai::ChatCompletionRequest {
    fn from(req: ChatRequest) -> Self {
        Self {
            model: req.model,
            messages: req.messages.into_iter().map(|m| m.into()).collect(),
            max_tokens: req.options.max_tokens,
            temperature: req.options.temperature,
            top_p: req.options.top_p,
            n: Some(1),
            stream: Some(req.options.stream),
            stop: req.options.stop.map(|stops| {
                if stops.len() == 1 {
                    crate::core::models::openai::Stop::String(stops[0].clone())
                } else {
                    crate::core::models::openai::Stop::Array(stops)
                }
            }),
            presence_penalty: req.options.presence_penalty,
            frequency_penalty: req.options.frequency_penalty,
            logit_bias: None,
            user: None,
            function_call: None,
            functions: None,
            tools: req.options.tools.map(|tools| tools.into_iter().map(|t| t.into()).collect()),
            tool_choice: req.options.tool_choice.map(|tc| tc.into()),
            response_format: None,
            seed: None,
            logprobs: None,
            top_logprobs: None,
            parallel_tool_calls: None,
        }
    }
}

impl From<Message> for crate::core::models::openai::ChatMessage {
    fn from(msg: Message) -> Self {
        Self {
            role: match msg.role {
                Role::System => crate::core::models::openai::MessageRole::System,
                Role::User => crate::core::models::openai::MessageRole::User,
                Role::Assistant => crate::core::models::openai::MessageRole::Assistant,
                Role::Tool => crate::core::models::openai::MessageRole::Tool,
            },
            content: msg.content.map(|c| match c {
                Content::Text(text) => crate::core::models::openai::MessageContent::Text(text),
                Content::Multimodal(parts) => {
                    crate::core::models::openai::MessageContent::Parts(
                        parts.into_iter().map(|p| p.into()).collect()
                    )
                }
            }),
            name: msg.name,
            function_call: None,
            tool_calls: msg.tool_calls.map(|calls| calls.into_iter().map(|c| c.into()).collect()),
            tool_call_id: None,
            audio: None,
        }
    }
}

impl From<ContentPart> for crate::core::models::openai::ContentPart {
    fn from(part: ContentPart) -> Self {
        match part {
            ContentPart::Text { text } => Self::Text { text },
            ContentPart::Image { image_url } => Self::ImageUrl { 
                image_url: crate::core::models::openai::ImageUrl {
                    url: image_url.url,
                    detail: image_url.detail,
                }
            },
            ContentPart::Audio { audio: _ } => {
                // Audio not supported in core types yet, fallback to text
                Self::Text { text: "[Audio content]".to_string() }
            }
        }
    }
}

impl From<Tool> for crate::core::models::openai::Tool {
    fn from(tool: Tool) -> Self {
        Self {
            r#type: tool.tool_type,
            function: crate::core::models::openai::FunctionDefinition {
                name: tool.function.name,
                description: tool.function.description,
                parameters: Some(tool.function.parameters),
            },
        }
    }
}

impl From<ToolChoice> for crate::core::models::openai::ToolChoice {
    fn from(choice: ToolChoice) -> Self {
        match choice {
            ToolChoice::None => Self::None,
            ToolChoice::Auto => Self::Auto,
            ToolChoice::Required => Self::Required,
            ToolChoice::Function { name } => Self::Function { 
                r#type: "function".to_string(),
                function: crate::core::models::openai::FunctionCall { name }
            },
        }
    }
}

impl From<ToolCall> for crate::core::models::openai::ToolCall {
    fn from(call: ToolCall) -> Self {
        Self {
            id: call.id,
            r#type: call.tool_type,
            function: crate::core::models::openai::FunctionCall {
                name: call.function.name,
                arguments: call.function.arguments,
            },
        }
    }
}

impl From<crate::core::models::openai::ChatCompletionResponse> for ChatResponse {
    fn from(resp: crate::core::models::openai::ChatCompletionResponse) -> Self {
        Self {
            id: resp.id,
            model: resp.model,
            choices: resp.choices.into_iter().map(|c| c.into()).collect(),
            usage: resp.usage.map(|u| u.into()).unwrap_or_default(),
            created: resp.created,
        }
    }
}

impl From<crate::core::models::openai::ChatChoice> for ChatChoice {
    fn from(choice: crate::core::models::openai::ChatChoice) -> Self {
        Self {
            index: choice.index,
            message: choice.message.into(),
            finish_reason: choice.finish_reason,
        }
    }
}

impl From<crate::core::models::openai::ChatMessage> for Message {
    fn from(msg: crate::core::models::openai::ChatMessage) -> Self {
        Self {
            role: match msg.role {
                crate::core::models::openai::MessageRole::System => Role::System,
                crate::core::models::openai::MessageRole::User => Role::User,
                crate::core::models::openai::MessageRole::Assistant => Role::Assistant,
                crate::core::models::openai::MessageRole::Tool => Role::Tool,
                _ => Role::User, // fallback
            },
            content: msg.content.map(|c| match c {
                crate::core::models::openai::MessageContent::Text(text) => Content::Text(text),
                crate::core::models::openai::MessageContent::Parts(parts) => {
                    Content::Multimodal(parts.into_iter().map(|p| p.into()).collect())
                }
            }),
            name: msg.name,
            tool_calls: msg.tool_calls.map(|calls| calls.into_iter().map(|c| c.into()).collect()),
        }
    }
}

impl From<crate::core::models::openai::ContentPart> for ContentPart {
    fn from(part: crate::core::models::openai::ContentPart) -> Self {
        match part {
            crate::core::models::openai::ContentPart::Text { text } => Self::Text { text },
            crate::core::models::openai::ContentPart::ImageUrl { image_url } => Self::Image {
                image_url: ImageUrl {
                    url: image_url.url,
                    detail: image_url.detail,
                }
            },
            crate::core::models::openai::ContentPart::Audio { .. } => {
                // Convert audio to text representation for now
                Self::Text { text: "[Audio content]".to_string() }
            }
        }
    }
}

impl From<crate::core::models::openai::ToolCall> for ToolCall {
    fn from(call: crate::core::models::openai::ToolCall) -> Self {
        Self {
            id: call.id,
            tool_type: call.r#type,
            function: Function {
                name: call.function.name,
                description: None,
                parameters: serde_json::Value::Null,
                arguments: call.function.arguments,
            },
        }
    }
}

impl From<crate::core::models::openai::Usage> for Usage {
    fn from(usage: crate::core::models::openai::Usage) -> Self {
        Self {
            prompt_tokens: usage.prompt_tokens,
            completion_tokens: usage.completion_tokens,
            total_tokens: usage.total_tokens,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::sdk::config::{ConfigBuilder, ProviderType};

    #[tokio::test]
    async fn test_client_creation() {
        let config = ConfigBuilder::new()
            .add_provider(crate::sdk::config::ProviderConfig {
                id: "test".to_string(),
                provider_type: ProviderType::OpenAI,
                name: "Test Provider".to_string(),
                api_key: "test-key".to_string(),
                base_url: None,
                models: vec!["gpt-3.5-turbo".to_string()],
                enabled: true,
                weight: 1.0,
                rate_limit_rpm: Some(1000),
                rate_limit_tpm: Some(10000),
                settings: std::collections::HashMap::new(),
            })
            .build();

        let _client = LLMClient::new(config);
        // Note: This test will pass even with invalid credentials as we're only testing construction
    }
}