litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! Google AI provider implementation
//!
//! This module provides Google AI (Gemini) API integration.

use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};

/// Google AI provider implementation
#[derive(Debug, Clone)]
pub struct GoogleProvider {
    /// Base provider functionality
    base: BaseProvider,
    /// Model pricing cache
    pricing_cache: HashMap<String, ModelPricing>,
}

impl GoogleProvider {
    /// Create a new Google provider
    pub async fn new(config: &ProviderConfig) -> Result<Self> {
        let base = BaseProvider::new(config)?;

        // Set default base URL if not provided
        let base_url = config
            .base_url
            .clone()
            .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1".to_string());

        let provider = Self {
            base: BaseProvider { base_url, ..base },
            pricing_cache: Self::initialize_pricing_cache(),
        };

        info!("Google provider '{}' initialized successfully", config.name);
        Ok(provider)
    }

    /// Initialize pricing cache with known Google model prices
    fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
        let mut cache = HashMap::new();

        // Gemini models
        cache.insert(
            "gemini-pro".to_string(),
            ModelPricing {
                model: "gemini-pro".to_string(),
                input_cost_per_1k: 0.00025,
                output_cost_per_1k: 0.0005,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache.insert(
            "gemini-ultra".to_string(),
            ModelPricing {
                model: "gemini-ultra".to_string(),
                input_cost_per_1k: 0.00125,
                output_cost_per_1k: 0.00375,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache
    }

    /// Convert OpenAI messages to Google format
    fn convert_messages_to_google(&self, messages: &[ChatMessage]) -> serde_json::Value {
        let mut google_messages = Vec::new();
        let mut system_message = None;

        for message in messages {
            match message.role {
                MessageRole::System => {
                    if let Some(MessageContent::Text(text)) = &message.content {
                        system_message = Some(text.clone());
                    }
                }
                MessageRole::User => {
                    google_messages.push(json!({
                        "role": "user",
                        "parts": [
                            {
                                "text": self.extract_text_content(message.content.as_ref())
                            }
                        ]
                    }));
                }
                MessageRole::Assistant => {
                    google_messages.push(json!({
                        "role": "model",
                        "parts": [
                            {
                                "text": self.extract_text_content(message.content.as_ref())
                            }
                        ]
                    }));
                }
                _ => {}
            }
        }

        // Add system message as a user message with a special prefix
        if let Some(system) = system_message {
            google_messages.insert(
                0,
                json!({
                    "role": "user",
                    "parts": [
                        {
                            "text": format!("System: {}", system)
                        }
                    ]
                }),
            );
        }

        json!({ "contents": google_messages })
    }

    /// Extract text content from MessageContent
    fn extract_text_content(&self, content: Option<&MessageContent>) -> String {
        match content {
            Some(MessageContent::Text(text)) => text.clone(),
            Some(MessageContent::Parts(parts)) => parts
                .iter()
                .filter_map(|part| match part {
                    ContentPart::Text { text } => Some(text.clone()),
                    _ => None,
                })
                .collect::<Vec<String>>()
                .join("\n"),
            None => String::new(),
        }
    }

    /// Convert Google response to OpenAI format
    fn convert_google_response_to_openai(
        &self,
        google_response: serde_json::Value,
        model: &str,
    ) -> Result<ChatCompletionResponse> {
        let content = google_response
            .get("candidates")
            .and_then(|c| c.as_array())
            .and_then(|arr| arr.first())
            .and_then(|candidate| candidate.get("content"))
            .and_then(|content| content.get("parts"))
            .and_then(|parts| parts.as_array())
            .and_then(|arr| arr.first())
            .and_then(|part| part.get("text"))
            .and_then(|text| text.as_str())
            .unwrap_or("")
            .to_string();

        let usage = Usage {
            prompt_tokens: 0, // Google doesn't provide token counts
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };

        Ok(ChatCompletionResponse {
            id: format!("chatcmpl-google-{}", uuid::Uuid::new_v4()),
            object: "chat.completion".to_string(),
            created: chrono::Utc::now().timestamp() as u64,
            model: model.to_string(),
            choices: vec![ChatChoice {
                index: 0,
                message: ChatMessage {
                    role: MessageRole::Assistant,
                    content: Some(MessageContent::Text(content)),
                    name: None,
                    function_call: None,
                    tool_calls: None,
                    tool_call_id: None,
                    audio: None,
                },
                finish_reason: Some("stop".to_string()),
                logprobs: None,
            }],
            usage: Some(usage),
            system_fingerprint: None,
        })
    }
}

#[async_trait]
impl Provider for GoogleProvider {
    fn name(&self) -> &str {
        &self.base.name
    }

    fn provider_type(&self) -> ProviderType {
        ProviderType::Google
    }

    async fn supports_model(&self, model: &str) -> bool {
        self.base.is_model_supported(model) || model.starts_with("gemini-")
    }

    async fn supports_images(&self) -> bool {
        true // Gemini supports images
    }

    async fn supports_embeddings(&self) -> bool {
        true
    }

    async fn supports_streaming(&self) -> bool {
        true
    }

    async fn list_models(&self) -> Result<Vec<Model>> {
        // Return known models
        let known_models = vec!["gemini-pro", "gemini-ultra", "gemini-pro-vision"];

        let models = known_models
            .into_iter()
            .map(|model| Model {
                id: model.to_string(),
                object: "model".to_string(),
                created: chrono::Utc::now().timestamp() as u64,
                owned_by: "google".to_string(),
            })
            .collect();

        Ok(models)
    }

    async fn health_check(&self) -> Result<()> {
        debug!("Performing Google health check");
        // TODO: Implement Google-specific health check
        Ok(())
    }

    async fn chat_completion(
        &self,
        request: ChatCompletionRequest,
        _context: RequestContext,
    ) -> Result<ChatCompletionResponse> {
        debug!("Google chat completion for model: {}", request.model);

        // Extract model name without provider prefix
        let model_name = request
            .model
            .split('/')
            .next_back()
            .unwrap_or(&request.model);

        // Google API format: {base_url}/models/{model}:generateContent?key={api_key}
        let endpoint = format!(
            "models/{}:generateContent?key={}",
            model_name, self.base.api_key
        );

        let mut body = self.convert_messages_to_google(&request.messages);

        // Add generation parameters
        let mut generation_config = json!({});

        if let Some(max_tokens) = request.max_tokens {
            generation_config["maxOutputTokens"] = json!(max_tokens);
        }

        if let Some(temperature) = request.temperature {
            generation_config["temperature"] = json!(temperature);
        }

        if let Some(top_p) = request.top_p {
            generation_config["topP"] = json!(top_p);
        }

        if !generation_config.as_object().unwrap().is_empty() {
            body["generationConfig"] = generation_config;
        }

        let response = self
            .base
            .make_request(reqwest::Method::POST, &endpoint, Some(body))
            .await?;
        let google_response: serde_json::Value = self.base.parse_json_response(response).await?;

        self.convert_google_response_to_openai(google_response, &request.model)
    }

    async fn completion(
        &self,
        _request: CompletionRequest,
        _context: RequestContext,
    ) -> Result<CompletionResponse> {
        Err(ProviderError::InvalidRequest(
            "Google does not support legacy completion endpoint".to_string(),
        )
        .into())
    }

    async fn embedding(
        &self,
        _request: EmbeddingRequest,
        _context: RequestContext,
    ) -> Result<EmbeddingResponse> {
        Err(
            ProviderError::InvalidRequest("Google embedding not implemented yet".to_string())
                .into(),
        )
    }

    async fn image_generation(
        &self,
        _request: ImageGenerationRequest,
        _context: RequestContext,
    ) -> Result<ImageGenerationResponse> {
        Err(ProviderError::InvalidRequest(
            "Google image generation not implemented yet".to_string(),
        )
        .into())
    }

    async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
        if let Some(pricing) = self.pricing_cache.get(model) {
            Ok(pricing.clone())
        } else {
            // Return default pricing for unknown models
            Ok(ModelPricing {
                model: model.to_string(),
                input_cost_per_1k: 0.0005, // Default rate
                output_cost_per_1k: 0.0015,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            })
        }
    }

    async fn calculate_cost(
        &self,
        model: &str,
        input_tokens: u32,
        output_tokens: u32,
    ) -> Result<f64> {
        let pricing = self.get_model_pricing(model).await?;

        let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
        let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;

        Ok(input_cost + output_cost)
    }
}