llm/backends/
google.rs

1// This file will be completely replaced to fix the syntax errors
2
3//! Google Gemini API client implementation for chat and completion functionality.
4//!
5//! This module provides integration with Google's Gemini models through their API.
6//! It implements chat, completion and embedding capabilities via the Gemini API.
7//!
8//! # Features
9//! - Chat conversations with system prompts and message history
10//! - Text completion requests
11//! - Configuration options for temperature, tokens, top_p, top_k etc.
12//! - Streaming support
13//!
14//! # Example
15//! ```no_run
16//! use llm::backends::google::Google;
17//! use llm::chat::{ChatMessage, ChatRole, ChatProvider};
18//!
19//! #[tokio::main]
20//! async fn main() {
21//! let client = Google::new(
22//!     "your-api-key",
23//!     None, // Use default model
24//!     Some(1000), // Max tokens
25//!     Some(0.7), // Temperature
26//!     None, // Default timeout
27//!     None, // No system prompt
28//!     None, // Default top_p
29//!     None, // Default top_k
30//!     None, // No JSON schema
31//!     None, // No tools
32//! );
33//!
34//! let messages = vec![
35//!     ChatMessage::user().content("Hello!").build()
36//! ];
37//!
38//! let response = client.chat(&messages).await.unwrap();
39//! println!("{response}");
40//! }
41//! ```
42
43use std::sync::Arc;
44
45use crate::{
46    builder::LLMBackend,
47    chat::{
48        ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, StructuredOutputFormat,
49        Tool, Usage,
50    },
51    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
52    embedding::EmbeddingProvider,
53    error::LLMError,
54    models::{ModelListRawEntry, ModelListRequest, ModelListResponse, ModelsProvider},
55    stt::SpeechToTextProvider,
56    tts::TextToSpeechProvider,
57    FunctionCall, LLMProvider, ToolCall,
58};
59use async_trait::async_trait;
60use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
61use chrono::{DateTime, Utc};
62use futures::{stream::Stream, StreamExt};
63use reqwest::Client;
64use serde::{Deserialize, Serialize};
65use serde_json::Value;
66
67/// Configuration for the Google Gemini client.
68#[derive(Debug)]
69pub struct GoogleConfig {
70    /// API key for authentication with Google.
71    pub api_key: String,
72    /// Model identifier (e.g., "gemini-pro").
73    pub model: String,
74    /// Maximum tokens to generate in responses.
75    pub max_tokens: Option<u32>,
76    /// Sampling temperature for response randomness.
77    pub temperature: Option<f32>,
78    /// System prompt to guide model behavior.
79    pub system: Option<String>,
80    /// Request timeout in seconds.
81    pub timeout_seconds: Option<u64>,
82    /// Top-p (nucleus) sampling parameter.
83    pub top_p: Option<f32>,
84    /// Top-k sampling parameter.
85    pub top_k: Option<u32>,
86    /// JSON schema for structured output.
87    pub json_schema: Option<StructuredOutputFormat>,
88    /// Available tools for the model to use.
89    pub tools: Option<Vec<Tool>>,
90}
91
92/// Client for interacting with Google's Gemini API.
93///
94/// This struct holds the configuration and state needed to make requests to the Gemini API.
95/// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits.
96///
97/// The client uses `Arc` internally for configuration, making cloning cheap
98/// (only an atomic reference count increment).
99#[derive(Debug, Clone)]
100pub struct Google {
101    /// Shared configuration wrapped in Arc for cheap cloning.
102    pub config: Arc<GoogleConfig>,
103    /// HTTP client for making requests.
104    pub client: Client,
105}
106
107/// Request body for chat completions
108#[derive(Serialize)]
109struct GoogleChatRequest<'a> {
110    /// List of conversation messages
111    contents: Vec<GoogleChatContent<'a>>,
112    /// Optional generation parameters
113    #[serde(skip_serializing_if = "Option::is_none", rename = "generationConfig")]
114    generation_config: Option<GoogleGenerationConfig>,
115    /// Tools that the model can use
116    #[serde(skip_serializing_if = "Option::is_none")]
117    tools: Option<Vec<GoogleTool>>,
118}
119
120/// Individual message in a chat conversation
121#[derive(Serialize)]
122struct GoogleChatContent<'a> {
123    /// Role of the message sender ("user", "model", or "system")
124    role: &'a str,
125    /// Content parts of the message
126    parts: Vec<GoogleContentPart<'a>>,
127}
128
129/// Text content within a chat message
130#[derive(Serialize)]
131#[serde(rename_all = "camelCase")]
132enum GoogleContentPart<'a> {
133    /// The actual text content
134    #[serde(rename = "text")]
135    Text(&'a str),
136    InlineData(GoogleInlineData),
137    FunctionCall(GoogleFunctionCall),
138    #[serde(rename = "functionResponse")]
139    FunctionResponse(GoogleFunctionResponse),
140}
141
142#[derive(Serialize)]
143struct GoogleInlineData {
144    mime_type: String,
145    data: String,
146}
147
148/// Configuration parameters for text generation
149#[derive(Serialize)]
150struct GoogleGenerationConfig {
151    /// Maximum tokens to generate
152    #[serde(skip_serializing_if = "Option::is_none", rename = "maxOutputTokens")]
153    max_output_tokens: Option<u32>,
154    /// Sampling temperature
155    #[serde(skip_serializing_if = "Option::is_none")]
156    temperature: Option<f32>,
157    /// Top-p sampling parameter
158    #[serde(skip_serializing_if = "Option::is_none", rename = "topP")]
159    top_p: Option<f32>,
160    /// Top-k sampling parameter
161    #[serde(skip_serializing_if = "Option::is_none", rename = "topK")]
162    top_k: Option<u32>,
163    /// The MIME type of the response
164    #[serde(skip_serializing_if = "Option::is_none")]
165    response_mime_type: Option<GoogleResponseMimeType>,
166    /// A schema for structured output
167    #[serde(skip_serializing_if = "Option::is_none")]
168    response_schema: Option<Value>,
169}
170
171/// Response from the chat completion API
172#[derive(Deserialize, Debug)]
173struct GoogleChatResponse {
174    /// Generated completion candidates
175    candidates: Vec<GoogleCandidate>,
176    /// Usage metadata containing token counts
177    #[serde(rename = "usageMetadata")]
178    usage_metadata: Option<GoogleUsageMetadata>,
179}
180
181/// Usage metadata for token counts
182#[derive(Deserialize, Debug)]
183struct GoogleUsageMetadata {
184    /// Number of tokens in the prompt
185    #[serde(rename = "promptTokenCount")]
186    prompt_token_count: Option<u32>,
187    /// Number of tokens in the completion
188    #[serde(rename = "candidatesTokenCount")]
189    candidates_token_count: Option<u32>,
190    /// Total number of tokens used
191    #[serde(rename = "totalTokenCount")]
192    total_token_count: Option<u32>,
193}
194
195/// Response from the streaming chat completion API
196#[derive(Deserialize, Debug)]
197struct GoogleStreamResponse {
198    /// Generated completion candidates
199    candidates: Option<Vec<GoogleCandidate>>,
200    /// Usage metadata containing token counts (usually not present in streaming, but may be in final chunk)
201    #[serde(rename = "usageMetadata")]
202    usage_metadata: Option<GoogleUsageMetadata>,
203}
204
205impl std::fmt::Display for GoogleChatResponse {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        match (self.text(), self.tool_calls()) {
208            (Some(text), Some(tool_calls)) => {
209                for call in tool_calls {
210                    write!(f, "{call}")?;
211                }
212                write!(f, "{text}")
213            }
214            (Some(text), None) => write!(f, "{text}"),
215            (None, Some(tool_calls)) => {
216                for call in tool_calls {
217                    write!(f, "{call}")?;
218                }
219                Ok(())
220            }
221            (None, None) => write!(f, ""),
222        }
223    }
224}
225
226/// Individual completion candidate
227#[derive(Deserialize, Debug)]
228struct GoogleCandidate {
229    /// Content of the candidate response
230    content: GoogleResponseContent,
231}
232
233/// Content block within a response
234#[derive(Deserialize, Debug)]
235struct GoogleResponseContent {
236    /// Parts making up the content (might be absent when only function calls are present)
237    #[serde(default)]
238    parts: Vec<GoogleResponsePart>,
239    /// Function calls if any are used - can be a single object or array
240    #[serde(skip_serializing_if = "Option::is_none")]
241    function_call: Option<GoogleFunctionCall>,
242    /// Function calls as array (newer format in some responses)
243    #[serde(skip_serializing_if = "Option::is_none")]
244    function_calls: Option<Vec<GoogleFunctionCall>>,
245}
246
247impl ChatResponse for GoogleChatResponse {
248    fn text(&self) -> Option<String> {
249        self.candidates
250            .first()
251            .map(|c| c.content.parts.iter().map(|p| p.text.clone()).collect())
252    }
253
254    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
255        self.candidates.first().and_then(|c| {
256            // First check for function calls at the part level (new API format)
257            let part_function_calls: Vec<ToolCall> = c
258                .content
259                .parts
260                .iter()
261                .filter_map(|part| {
262                    part.function_call.as_ref().map(|f| ToolCall {
263                        id: format!("call_{}", f.name),
264                        call_type: "function".to_string(),
265                        function: FunctionCall {
266                            name: f.name.clone(),
267                            arguments: serde_json::to_string(&f.args).unwrap_or_default(),
268                        },
269                    })
270                })
271                .collect();
272
273            if !part_function_calls.is_empty() {
274                return Some(part_function_calls);
275            }
276
277            // Otherwise check for function_calls/function_call at the content level (older format)
278            if let Some(fc) = &c.content.function_calls {
279                // Process array of function calls
280                Some(
281                    fc.iter()
282                        .map(|f| ToolCall {
283                            id: format!("call_{}", f.name),
284                            call_type: "function".to_string(),
285                            function: FunctionCall {
286                                name: f.name.clone(),
287                                arguments: serde_json::to_string(&f.args).unwrap_or_default(),
288                            },
289                        })
290                        .collect(),
291                )
292            } else {
293                c.content.function_call.as_ref().map(|f| {
294                    vec![ToolCall {
295                        id: format!("call_{}", f.name),
296                        call_type: "function".to_string(),
297                        function: FunctionCall {
298                            name: f.name.clone(),
299                            arguments: serde_json::to_string(&f.args).unwrap_or_default(),
300                        },
301                    }]
302                })
303            }
304        })
305    }
306
307    fn usage(&self) -> Option<Usage> {
308        self.usage_metadata.as_ref().and_then(|metadata| {
309            match (metadata.prompt_token_count, metadata.candidates_token_count) {
310                (Some(prompt_tokens), Some(completion_tokens)) => Some(Usage {
311                    prompt_tokens,
312                    completion_tokens,
313                    total_tokens: metadata
314                        .total_token_count
315                        .unwrap_or(prompt_tokens + completion_tokens),
316                    completion_tokens_details: None,
317                    prompt_tokens_details: None,
318                }),
319                _ => None,
320            }
321        })
322    }
323}
324
325/// Individual part of response content
326#[derive(Deserialize, Debug)]
327struct GoogleResponsePart {
328    /// Text content of this part (may be absent if functionCall is present)
329    #[serde(default)]
330    text: String,
331    /// Function call contained in this part
332    #[serde(rename = "functionCall")]
333    function_call: Option<GoogleFunctionCall>,
334}
335
336/// MIME type of the response
337#[derive(Deserialize, Debug, Serialize)]
338enum GoogleResponseMimeType {
339    /// Plain text response
340    #[serde(rename = "text/plain")]
341    PlainText,
342    /// JSON response
343    #[serde(rename = "application/json")]
344    Json,
345    /// ENUM as a string response in the response candidates.
346    #[serde(rename = "text/x.enum")]
347    Enum,
348}
349
350/// Google's function calling tool definition
351#[derive(Serialize, Debug)]
352struct GoogleTool {
353    /// The function declarations array
354    #[serde(rename = "functionDeclarations")]
355    function_declarations: Vec<GoogleFunctionDeclaration>,
356}
357
358/// Google function declaration, similar to OpenAI's function definition
359#[derive(Serialize, Debug)]
360struct GoogleFunctionDeclaration {
361    /// Name of the function
362    name: String,
363    /// Description of what the function does
364    description: String,
365    /// Parameters for the function
366    parameters: GoogleFunctionParameters,
367}
368
369impl From<&crate::chat::Tool> for GoogleFunctionDeclaration {
370    fn from(tool: &crate::chat::Tool) -> Self {
371        let properties_value = tool
372            .function
373            .parameters
374            .get("properties")
375            .cloned()
376            .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
377
378        GoogleFunctionDeclaration {
379            name: tool.function.name.clone(),
380            description: tool.function.description.clone(),
381            parameters: GoogleFunctionParameters {
382                schema_type: "object".to_string(),
383                properties: properties_value,
384                required: tool
385                    .function
386                    .parameters
387                    .get("required")
388                    .and_then(|v| v.as_array())
389                    .map(|arr| {
390                        arr.iter()
391                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
392                            .collect::<Vec<String>>()
393                    })
394                    .unwrap_or_default(),
395            },
396        }
397    }
398}
399
400/// Google function parameters schema
401#[derive(Serialize, Debug)]
402struct GoogleFunctionParameters {
403    /// The type of parameters object (usually "object")
404    #[serde(rename = "type")]
405    schema_type: String,
406    /// Map of parameter names to their properties
407    properties: Value,
408    /// List of required parameter names
409    required: Vec<String>,
410}
411
412/// Google function call object in response
413#[derive(Deserialize, Debug, Serialize)]
414struct GoogleFunctionCall {
415    /// Name of the function to call
416    name: String,
417    /// Arguments for the function call as structured JSON
418    #[serde(default)]
419    args: Value,
420}
421
422/// Google function response wrapper for function results
423///
424/// Format follows Google's Gemini API specification for function calling results:
425/// https://ai.google.dev/docs/function_calling
426///
427/// The expected format is:
428/// {
429///   "role": "function",
430///   "parts": [{
431///     "functionResponse": {
432///       "name": "function_name",
433///       "response": {
434///         "name": "function_name",
435///         "content": { ... } // JSON content returned by the function
436///       }
437///     }
438///   }]
439/// }
440#[derive(Deserialize, Debug, Serialize)]
441struct GoogleFunctionResponse {
442    /// Name of the function that was called
443    name: String,
444    /// Response from the function as structured JSON
445    response: GoogleFunctionResponseContent,
446}
447
448#[derive(Deserialize, Debug, Serialize)]
449struct GoogleFunctionResponseContent {
450    /// Name of the function that was called
451    name: String,
452    /// Content of the function response
453    content: Value,
454}
455
456/// Request body for embedding content
457#[derive(Serialize)]
458struct GoogleEmbeddingRequest<'a> {
459    model: &'a str,
460    content: GoogleEmbeddingContent<'a>,
461}
462
463#[derive(Serialize)]
464struct GoogleEmbeddingContent<'a> {
465    parts: Vec<GoogleContentPart<'a>>,
466}
467
468/// Response from the embedding API
469#[derive(Deserialize)]
470struct GoogleEmbeddingResponse {
471    embedding: GoogleEmbedding,
472}
473
474#[derive(Deserialize)]
475struct GoogleEmbedding {
476    values: Vec<f32>,
477}
478
479impl Google {
480    /// Creates a new Google Gemini client with the specified configuration.
481    ///
482    /// # Arguments
483    ///
484    /// * `api_key` - Google API key for authentication
485    /// * `model` - Model identifier (defaults to "gemini-1.5-flash")
486    /// * `max_tokens` - Maximum tokens in response
487    /// * `temperature` - Sampling temperature between 0.0 and 1.0
488    /// * `timeout_seconds` - Request timeout in seconds
489    /// * `system` - System prompt to set context
490    /// * `top_p` - Top-p sampling parameter
491    /// * `top_k` - Top-k sampling parameter
492    /// * `json_schema` - JSON schema for structured output
493    /// * `tools` - Function tools that the model can use
494    ///
495    /// # Returns
496    ///
497    /// A new `Google` client instance
498    #[allow(clippy::too_many_arguments)]
499    pub fn new(
500        api_key: impl Into<String>,
501        model: Option<String>,
502        max_tokens: Option<u32>,
503        temperature: Option<f32>,
504        timeout_seconds: Option<u64>,
505        system: Option<String>,
506        top_p: Option<f32>,
507        top_k: Option<u32>,
508        json_schema: Option<StructuredOutputFormat>,
509        tools: Option<Vec<Tool>>,
510    ) -> Self {
511        let mut builder = Client::builder();
512        if let Some(sec) = timeout_seconds {
513            builder = builder.timeout(std::time::Duration::from_secs(sec));
514        }
515        Self::with_client(
516            builder.build().expect("Failed to build reqwest Client"),
517            api_key,
518            model,
519            max_tokens,
520            temperature,
521            timeout_seconds,
522            system,
523            top_p,
524            top_k,
525            json_schema,
526            tools,
527        )
528    }
529
530    /// Creates a new Google Gemini client with a custom HTTP client.
531    #[allow(clippy::too_many_arguments)]
532    pub fn with_client(
533        client: Client,
534        api_key: impl Into<String>,
535        model: Option<String>,
536        max_tokens: Option<u32>,
537        temperature: Option<f32>,
538        timeout_seconds: Option<u64>,
539        system: Option<String>,
540        top_p: Option<f32>,
541        top_k: Option<u32>,
542        json_schema: Option<StructuredOutputFormat>,
543        tools: Option<Vec<Tool>>,
544    ) -> Self {
545        Self {
546            config: Arc::new(GoogleConfig {
547                api_key: api_key.into(),
548                model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()),
549                max_tokens,
550                temperature,
551                system,
552                timeout_seconds,
553                top_p,
554                top_k,
555                json_schema,
556                tools,
557            }),
558            client,
559        }
560    }
561
562    pub fn api_key(&self) -> &str {
563        &self.config.api_key
564    }
565
566    pub fn model(&self) -> &str {
567        &self.config.model
568    }
569
570    pub fn max_tokens(&self) -> Option<u32> {
571        self.config.max_tokens
572    }
573
574    pub fn temperature(&self) -> Option<f32> {
575        self.config.temperature
576    }
577
578    pub fn timeout_seconds(&self) -> Option<u64> {
579        self.config.timeout_seconds
580    }
581
582    pub fn system(&self) -> Option<&str> {
583        self.config.system.as_deref()
584    }
585
586    pub fn top_p(&self) -> Option<f32> {
587        self.config.top_p
588    }
589
590    pub fn top_k(&self) -> Option<u32> {
591        self.config.top_k
592    }
593
594    pub fn json_schema(&self) -> Option<&StructuredOutputFormat> {
595        self.config.json_schema.as_ref()
596    }
597
598    pub fn tools(&self) -> Option<&[Tool]> {
599        self.config.tools.as_deref()
600    }
601
602    pub fn client(&self) -> &Client {
603        &self.client
604    }
605}
606
607#[async_trait]
608impl ChatProvider for Google {
609    /// Sends a chat request to Google's Gemini API.
610    ///
611    /// # Arguments
612    ///
613    /// * `messages` - Slice of chat messages representing the conversation
614    ///
615    /// # Returns
616    ///
617    /// The model's response text or an error
618    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
619        if self.config.api_key.is_empty() {
620            return Err(LLMError::AuthError("Missing Google API key".to_string()));
621        }
622
623        let mut chat_contents = Vec::with_capacity(messages.len());
624
625        // Add system message if present
626        if let Some(system) = &self.config.system {
627            chat_contents.push(GoogleChatContent {
628                role: "user",
629                parts: vec![GoogleContentPart::Text(system)],
630            });
631        }
632
633        // Add conversation messages in pairs to maintain context
634        for msg in messages {
635            // For tool results, we need to use "function" role
636            let role = match &msg.message_type {
637                MessageType::ToolResult(_) => "function",
638                _ => match msg.role {
639                    ChatRole::User => "user",
640                    ChatRole::Assistant => "model",
641                },
642            };
643
644            chat_contents.push(GoogleChatContent {
645                role,
646                parts: match &msg.message_type {
647                    MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
648                    MessageType::Image((image_mime, raw_bytes)) => {
649                        vec![GoogleContentPart::InlineData(GoogleInlineData {
650                            mime_type: image_mime.mime_type().to_string(),
651                            data: BASE64.encode(raw_bytes),
652                        })]
653                    }
654                    MessageType::ImageURL(_) => unimplemented!(),
655                    MessageType::Pdf(raw_bytes) => {
656                        vec![GoogleContentPart::InlineData(GoogleInlineData {
657                            mime_type: "application/pdf".to_string(),
658                            data: BASE64.encode(raw_bytes),
659                        })]
660                    }
661                    MessageType::ToolUse(calls) => calls
662                        .iter()
663                        .map(|call| {
664                            GoogleContentPart::FunctionCall(GoogleFunctionCall {
665                                name: call.function.name.clone(),
666                                args: serde_json::from_str(&call.function.arguments)
667                                    .unwrap_or(serde_json::Value::Null),
668                            })
669                        })
670                        .collect(),
671                    MessageType::ToolResult(result) => result
672                        .iter()
673                        .map(|result| {
674                            let parsed_args =
675                                serde_json::from_str::<Value>(&result.function.arguments)
676                                    .unwrap_or(serde_json::Value::Null);
677
678                            GoogleContentPart::FunctionResponse(GoogleFunctionResponse {
679                                name: result.function.name.clone(),
680                                response: GoogleFunctionResponseContent {
681                                    name: result.function.name.clone(),
682                                    content: parsed_args,
683                                },
684                            })
685                        })
686                        .collect(),
687                },
688            });
689        }
690
691        // Remove generation_config if empty to avoid validation errors
692        let generation_config = if self.config.max_tokens.is_none()
693            && self.config.temperature.is_none()
694            && self.config.top_p.is_none()
695            && self.config.top_k.is_none()
696            && self.config.json_schema.is_none()
697        {
698            None
699        } else {
700            // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON
701            // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly.
702            let (response_mime_type, response_schema) =
703                if let Some(json_schema) = &self.config.json_schema {
704                    if let Some(schema) = &json_schema.schema {
705                        // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it
706                        let mut schema = schema.clone();
707                        if let Some(obj) = schema.as_object_mut() {
708                            obj.remove("additionalProperties");
709                        }
710                        (Some(GoogleResponseMimeType::Json), Some(schema))
711                    } else {
712                        (None, None)
713                    }
714                } else {
715                    (None, None)
716                };
717            Some(GoogleGenerationConfig {
718                max_output_tokens: self.config.max_tokens,
719                temperature: self.config.temperature,
720                top_p: self.config.top_p,
721                top_k: self.config.top_k,
722                response_mime_type,
723                response_schema,
724            })
725        };
726
727        let req_body = GoogleChatRequest {
728            contents: chat_contents,
729            generation_config,
730            tools: None,
731        };
732        if log::log_enabled!(log::Level::Trace) {
733            if let Ok(json) = serde_json::to_string(&req_body) {
734                log::trace!("Google Gemini request payload: {}", json);
735            }
736        }
737
738        let url = format!(
739            "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}",
740            model = self.config.model,
741            key = self.config.api_key
742        );
743
744        let mut request = self.client.post(&url).json(&req_body);
745        if let Some(timeout) = self.config.timeout_seconds {
746            request = request.timeout(std::time::Duration::from_secs(timeout));
747        }
748
749        let resp = request.send().await?;
750        log::debug!("Google Gemini HTTP status: {}", resp.status());
751        let resp = resp.error_for_status()?;
752
753        // Get the raw response text for debugging
754        let resp_text = resp.text().await?;
755
756        // Try to parse the response
757        let json_resp: Result<GoogleChatResponse, serde_json::Error> =
758            serde_json::from_str(&resp_text);
759
760        match json_resp {
761            Ok(response) => Ok(Box::new(response)),
762            Err(e) => {
763                // Return a more descriptive error with the raw response
764                Err(LLMError::ResponseFormatError {
765                    message: format!("Failed to decode Google API response: {e}"),
766                    raw_response: resp_text,
767                })
768            }
769        }
770    }
771
772    /// Sends a chat request to Google's Gemini API with tools.
773    ///
774    /// # Arguments
775    ///
776    /// * `messages` - The conversation history as a slice of chat messages
777    /// * `tools` - Optional slice of tools to use in the chat
778    ///
779    /// # Returns
780    ///
781    /// The provider's response text or an error
782    async fn chat_with_tools(
783        &self,
784        messages: &[ChatMessage],
785        tools: Option<&[Tool]>,
786    ) -> Result<Box<dyn ChatResponse>, LLMError> {
787        if self.config.api_key.is_empty() {
788            return Err(LLMError::AuthError("Missing Google API key".to_string()));
789        }
790
791        let mut chat_contents = Vec::with_capacity(messages.len());
792
793        // Add system message if present
794        if let Some(system) = &self.config.system {
795            chat_contents.push(GoogleChatContent {
796                role: "user",
797                parts: vec![GoogleContentPart::Text(system)],
798            });
799        }
800
801        // Add conversation messages in pairs to maintain context
802        for msg in messages {
803            // For tool results, we need to use "function" role
804            let role = match &msg.message_type {
805                MessageType::ToolResult(_) => "function",
806                _ => match msg.role {
807                    ChatRole::User => "user",
808                    ChatRole::Assistant => "model",
809                },
810            };
811
812            chat_contents.push(GoogleChatContent {
813                role,
814                parts: match &msg.message_type {
815                    MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
816                    MessageType::Image((image_mime, raw_bytes)) => {
817                        vec![GoogleContentPart::InlineData(GoogleInlineData {
818                            mime_type: image_mime.mime_type().to_string(),
819                            data: BASE64.encode(raw_bytes),
820                        })]
821                    }
822                    MessageType::ImageURL(_) => unimplemented!(),
823                    MessageType::Pdf(raw_bytes) => {
824                        vec![GoogleContentPart::InlineData(GoogleInlineData {
825                            mime_type: "application/pdf".to_string(),
826                            data: BASE64.encode(raw_bytes),
827                        })]
828                    }
829                    MessageType::ToolUse(calls) => calls
830                        .iter()
831                        .map(|call| {
832                            GoogleContentPart::FunctionCall(GoogleFunctionCall {
833                                name: call.function.name.clone(),
834                                args: serde_json::from_str(&call.function.arguments)
835                                    .unwrap_or(serde_json::Value::Null),
836                            })
837                        })
838                        .collect(),
839                    MessageType::ToolResult(result) => result
840                        .iter()
841                        .map(|result| {
842                            let parsed_args =
843                                serde_json::from_str::<Value>(&result.function.arguments)
844                                    .unwrap_or(serde_json::Value::Null);
845
846                            GoogleContentPart::FunctionResponse(GoogleFunctionResponse {
847                                name: result.function.name.clone(),
848                                response: GoogleFunctionResponseContent {
849                                    name: result.function.name.clone(),
850                                    content: parsed_args,
851                                },
852                            })
853                        })
854                        .collect(),
855                },
856            });
857        }
858
859        // Convert tools to Google's format if provided
860        let google_tools = tools.map(|t| {
861            vec![GoogleTool {
862                function_declarations: t.iter().map(GoogleFunctionDeclaration::from).collect(),
863            }]
864        });
865
866        // Build generation config
867        let generation_config = {
868            // If json_schema and json_schema.schema are not None, use json_schema.schema as the response schema and set response_mime_type to JSON
869            // Google's API doesn't need the schema to have a "name" field, so we can just use the schema directly.
870            let (response_mime_type, response_schema) =
871                if let Some(json_schema) = &self.config.json_schema {
872                    if let Some(schema) = &json_schema.schema {
873                        // If the schema has an "additionalProperties" field (as required by OpenAI), remove it as Google's API doesn't support it
874                        let mut schema = schema.clone();
875
876                        if let Some(obj) = schema.as_object_mut() {
877                            obj.remove("additionalProperties");
878                        }
879
880                        (Some(GoogleResponseMimeType::Json), Some(schema))
881                    } else {
882                        (None, None)
883                    }
884                } else {
885                    (None, None)
886                };
887
888            Some(GoogleGenerationConfig {
889                max_output_tokens: self.config.max_tokens,
890                temperature: self.config.temperature,
891                top_p: self.config.top_p,
892                top_k: self.config.top_k,
893                response_mime_type,
894                response_schema,
895            })
896        };
897
898        let req_body = GoogleChatRequest {
899            contents: chat_contents,
900            generation_config,
901            tools: google_tools,
902        };
903
904        if log::log_enabled!(log::Level::Trace) {
905            if let Ok(json) = serde_json::to_string(&req_body) {
906                log::trace!("Google Gemini request payload (tool): {}", json);
907            }
908        }
909
910        let url = format!(
911            "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}",
912            model = self.config.model,
913            key = self.config.api_key
914
915        );
916
917        let mut request = self.client.post(&url).json(&req_body);
918
919        if let Some(timeout) = self.config.timeout_seconds {
920            request = request.timeout(std::time::Duration::from_secs(timeout));
921        }
922
923        let resp = request.send().await?;
924
925        log::debug!("Google Gemini HTTP status (tool): {}", resp.status());
926
927        let resp = resp.error_for_status()?;
928
929        // Get the raw response text for debugging
930        let resp_text = resp.text().await?;
931
932        // Try to parse the response
933        let json_resp: Result<GoogleChatResponse, serde_json::Error> =
934            serde_json::from_str(&resp_text);
935
936        match json_resp {
937            Ok(response) => Ok(Box::new(response)),
938            Err(e) => {
939                // Return a more descriptive error with the raw response
940                Err(LLMError::ResponseFormatError {
941                    message: format!("Failed to decode Google API response: {e}"),
942                    raw_response: resp_text,
943                })
944            }
945        }
946    }
947
948    /// Sends a streaming chat request to Google's Gemini API.
949    ///
950    /// # Arguments
951    ///
952    /// * `messages` - Slice of chat messages representing the conversation
953    ///
954    /// # Returns
955    ///
956    /// A stream of text tokens or an error
957    async fn chat_stream(
958        &self,
959        messages: &[ChatMessage],
960    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
961    {
962        let struct_stream = self.chat_stream_struct(messages).await?;
963        let content_stream = struct_stream.filter_map(|result| async move {
964            match result {
965                Ok(stream_response) => {
966                    if let Some(choice) = stream_response.choices.first() {
967                        if let Some(content) = &choice.delta.content {
968                            if !content.is_empty() {
969                                return Some(Ok(content.clone()));
970                            }
971                        }
972                    }
973                    None
974                }
975                Err(e) => Some(Err(e)),
976            }
977        });
978        Ok(Box::pin(content_stream))
979    }
980
981    /// Sends a streaming chat request to Google's Gemini API with structured responses.
982    ///
983    /// # Arguments
984    ///
985    /// * `messages` - Slice of chat messages representing the conversation
986    ///
987    /// # Returns
988    ///
989    /// A stream of structured response objects or an error
990    async fn chat_stream_struct(
991        &self,
992        messages: &[ChatMessage],
993    ) -> Result<
994        std::pin::Pin<Box<dyn Stream<Item = Result<crate::chat::StreamResponse, LLMError>> + Send>>,
995        LLMError,
996    > {
997        if self.config.api_key.is_empty() {
998            return Err(LLMError::AuthError("Missing Google API key".to_string()));
999        }
1000        let mut chat_contents = Vec::with_capacity(messages.len());
1001        if let Some(system) = &self.config.system {
1002            chat_contents.push(GoogleChatContent {
1003                role: "user",
1004                parts: vec![GoogleContentPart::Text(system)],
1005            });
1006        }
1007        for msg in messages {
1008            let role = match msg.role {
1009                ChatRole::User => "user",
1010                ChatRole::Assistant => "model",
1011            };
1012            chat_contents.push(GoogleChatContent {
1013                role,
1014                parts: match &msg.message_type {
1015                    MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
1016                    MessageType::Image((image_mime, raw_bytes)) => {
1017                        vec![GoogleContentPart::InlineData(GoogleInlineData {
1018                            mime_type: image_mime.mime_type().to_string(),
1019                            data: BASE64.encode(raw_bytes),
1020                        })]
1021                    }
1022                    MessageType::Pdf(raw_bytes) => {
1023                        vec![GoogleContentPart::InlineData(GoogleInlineData {
1024                            mime_type: "application/pdf".to_string(),
1025                            data: BASE64.encode(raw_bytes),
1026                        })]
1027                    }
1028                    _ => vec![GoogleContentPart::Text(&msg.content)],
1029                },
1030            });
1031        }
1032        let generation_config = if self.config.max_tokens.is_none()
1033            && self.config.temperature.is_none()
1034            && self.config.top_p.is_none()
1035            && self.config.top_k.is_none()
1036        {
1037            None
1038        } else {
1039            Some(GoogleGenerationConfig {
1040                max_output_tokens: self.config.max_tokens,
1041                temperature: self.config.temperature,
1042                top_p: self.config.top_p,
1043                top_k: self.config.top_k,
1044                response_mime_type: None,
1045                response_schema: None,
1046            })
1047        };
1048
1049        let req_body = GoogleChatRequest {
1050            contents: chat_contents,
1051            generation_config,
1052            tools: None,
1053        };
1054        let url = format!(
1055            "https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse&key={key}",
1056            model = self.config.model,
1057            key = self.config.api_key
1058        );
1059
1060        let mut request = self.client.post(&url).json(&req_body);
1061        if let Some(timeout) = self.config.timeout_seconds {
1062            request = request.timeout(std::time::Duration::from_secs(timeout));
1063        }
1064        let response = request.send().await?;
1065        if !response.status().is_success() {
1066            let status = response.status();
1067            let error_text = response.text().await?;
1068            return Err(LLMError::ResponseFormatError {
1069                message: format!("Google API returned error status: {status}"),
1070                raw_response: error_text,
1071            });
1072        }
1073        Ok(create_google_sse_stream(response))
1074    }
1075}
1076
1077#[async_trait]
1078impl CompletionProvider for Google {
1079    /// Performs a completion request using the chat endpoint.
1080    ///
1081    /// # Arguments
1082    ///
1083    /// * `req` - Completion request parameters
1084    ///
1085    /// # Returns
1086    ///
1087    /// The completion response or an error
1088    async fn complete(&self, req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
1089        let chat_message = ChatMessage::user().content(req.prompt.clone()).build();
1090        if let Some(text) = self.chat(&[chat_message]).await?.text() {
1091            Ok(CompletionResponse { text })
1092        } else {
1093            Err(LLMError::ProviderError(
1094                "No answer returned by Google".to_string(),
1095            ))
1096        }
1097    }
1098}
1099
1100#[async_trait]
1101impl EmbeddingProvider for Google {
1102    async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
1103        if self.config.api_key.is_empty() {
1104            return Err(LLMError::AuthError("Missing Google API key".to_string()));
1105        }
1106
1107        let mut embeddings = Vec::new();
1108
1109        // Process each text separately as Gemini API accepts one text at a time
1110        for text in texts {
1111            let req_body = GoogleEmbeddingRequest {
1112                model: "models/text-embedding-004",
1113                content: GoogleEmbeddingContent {
1114                    parts: vec![GoogleContentPart::Text(&text)],
1115                },
1116            };
1117
1118            let url = format!(
1119                "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={}",
1120                self.config.api_key
1121            );
1122
1123            let resp = self
1124                .client
1125                .post(&url)
1126                .json(&req_body)
1127                .send()
1128                .await?
1129                .error_for_status()?;
1130
1131            let embedding_resp: GoogleEmbeddingResponse = resp.json().await?;
1132            embeddings.push(embedding_resp.embedding.values);
1133        }
1134        Ok(embeddings)
1135    }
1136}
1137
1138#[async_trait]
1139impl SpeechToTextProvider for Google {
1140    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
1141        Err(LLMError::ProviderError(
1142            "Google does not implement speech to text endpoint yet.".into(),
1143        ))
1144    }
1145}
1146
1147impl LLMProvider for Google {
1148    fn tools(&self) -> Option<&[Tool]> {
1149        self.config.tools.as_deref()
1150    }
1151}
1152
1153/// Creates a structured SSE stream for Google's streaming API responses.
1154///
1155/// # Arguments
1156///
1157/// * `response` - The HTTP response containing the SSE stream
1158///
1159/// # Returns
1160///
1161/// A stream of `StreamResponse` objects
1162fn create_google_sse_stream(
1163    response: reqwest::Response,
1164) -> std::pin::Pin<Box<dyn Stream<Item = Result<crate::chat::StreamResponse, LLMError>> + Send>> {
1165    let stream = response
1166        .bytes_stream()
1167        .map(move |chunk| match chunk {
1168            Ok(bytes) => {
1169                let text = String::from_utf8_lossy(&bytes);
1170                parse_google_sse_chunk(&text)
1171            }
1172            Err(e) => Err(LLMError::HttpError(e.to_string())),
1173        })
1174        .filter_map(|result| async move {
1175            match result {
1176                Ok(Some(response)) => Some(Ok(response)),
1177                Ok(None) => None,
1178                Err(e) => Some(Err(e)),
1179            }
1180        });
1181    Box::pin(stream)
1182}
1183
1184/// Parses a Google SSE chunk and converts it to StreamResponse format.
1185///
1186/// # Arguments
1187///
1188/// * `chunk` - The raw SSE chunk text
1189///
1190/// # Returns
1191///
1192/// * `Ok(Some(StreamResponse))` - Structured response if content found
1193/// * `Ok(None)` - If chunk should be skipped (e.g., ping, done signal)
1194/// * `Err(LLMError)` - If parsing fails
1195fn parse_google_sse_chunk(chunk: &str) -> Result<Option<crate::chat::StreamResponse>, LLMError> {
1196    for line in chunk.lines() {
1197        let line = line.trim();
1198        if let Some(data) = line.strip_prefix("data: ") {
1199            match serde_json::from_str::<GoogleStreamResponse>(data) {
1200                Ok(response) => {
1201                    let mut content = None;
1202                    let mut usage = None;
1203                    // Check for content chunks first
1204                    if let Some(candidates) = &response.candidates {
1205                        if let Some(candidate) = candidates.first() {
1206                            if let Some(part) = candidate.content.parts.first() {
1207                                if !part.text.is_empty() {
1208                                    content = Some(part.text.clone());
1209                                }
1210                            }
1211                        }
1212                    }
1213                    // Check for usage metadata
1214                    if let Some(usage_metadata) = &response.usage_metadata {
1215                        if let (Some(prompt_tokens), Some(completion_tokens)) = (
1216                            usage_metadata.prompt_token_count,
1217                            usage_metadata.candidates_token_count,
1218                        ) {
1219                            usage = Some(Usage {
1220                                prompt_tokens,
1221                                completion_tokens,
1222                                total_tokens: usage_metadata
1223                                    .total_token_count
1224                                    .unwrap_or(prompt_tokens + completion_tokens),
1225                                completion_tokens_details: None,
1226                                prompt_tokens_details: None,
1227                            });
1228                        }
1229                    }
1230                    // Return response if we have either content or usage
1231                    if content.is_some() || usage.is_some() {
1232                        return Ok(Some(crate::chat::StreamResponse {
1233                            choices: vec![crate::chat::StreamChoice {
1234                                delta: crate::chat::StreamDelta {
1235                                    content,
1236                                    tool_calls: None,
1237                                },
1238                            }],
1239                            usage,
1240                        }));
1241                    }
1242                    return Ok(None);
1243                }
1244                Err(_) => continue,
1245            }
1246        }
1247    }
1248    Ok(None)
1249}
1250
1251#[async_trait]
1252impl TextToSpeechProvider for Google {}
1253
1254#[derive(Clone, Debug, Deserialize)]
1255pub struct GoogleModelEntry {
1256    pub name: String,
1257    pub version: String,
1258    pub display_name: String,
1259    pub description: String,
1260    pub input_token_limit: Option<u32>,
1261    pub output_token_limit: Option<u32>,
1262    pub supported_generation_methods: Vec<String>,
1263    pub temperature: Option<f32>,
1264    pub top_p: Option<f32>,
1265    pub top_k: Option<u32>,
1266    #[serde(flatten)]
1267    pub extra: Value,
1268}
1269
1270impl ModelListRawEntry for GoogleModelEntry {
1271    fn get_id(&self) -> String {
1272        self.name.clone()
1273    }
1274
1275    fn get_created_at(&self) -> DateTime<Utc> {
1276        // Google doesn't provide creation dates in their models API
1277        DateTime::<Utc>::UNIX_EPOCH
1278    }
1279
1280    fn get_raw(&self) -> Value {
1281        self.extra.clone()
1282    }
1283}
1284
1285#[derive(Clone, Debug, Deserialize)]
1286pub struct GoogleModelListResponse {
1287    pub models: Vec<GoogleModelEntry>,
1288}
1289
1290impl ModelListResponse for GoogleModelListResponse {
1291    fn get_models(&self) -> Vec<String> {
1292        self.models.iter().map(|m| m.name.clone()).collect()
1293    }
1294
1295    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
1296        self.models
1297            .iter()
1298            .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
1299            .collect()
1300    }
1301
1302    fn get_backend(&self) -> LLMBackend {
1303        LLMBackend::Google
1304    }
1305}
1306
1307#[async_trait]
1308impl ModelsProvider for Google {
1309    async fn list_models(
1310        &self,
1311        _request: Option<&ModelListRequest>,
1312    ) -> Result<Box<dyn ModelListResponse>, LLMError> {
1313        if self.config.api_key.is_empty() {
1314            return Err(LLMError::AuthError("Missing Google API key".to_string()));
1315        }
1316
1317        let url = format!(
1318            "https://generativelanguage.googleapis.com/v1beta/models?key={}",
1319            self.config.api_key
1320        );
1321
1322        let resp = self.client.get(&url).send().await?.error_for_status()?;
1323
1324        let result: GoogleModelListResponse = resp.json().await?;
1325        Ok(Box::new(result))
1326    }
1327}