language_barrier_core/provider/
gemini.rs

1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, Gemini, LlmToolInfo};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10/// Configuration for the Gemini provider
11#[derive(Debug, Clone)]
12pub struct GeminiConfig {
13    /// API key for authentication
14    pub api_key: String,
15    /// Base URL for the API
16    pub base_url: String,
17}
18
19impl Default for GeminiConfig {
20    fn default() -> Self {
21        Self {
22            api_key: env::var("GEMINI_API_KEY").unwrap_or_default(),
23            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
24        }
25    }
26}
27
28/// Implementation of the Gemini provider
29#[derive(Debug, Clone)]
30pub struct GeminiProvider {
31    /// Configuration for the provider
32    config: GeminiConfig,
33}
34
35impl GeminiProvider {
36    /// Creates a new GeminiProvider with default configuration
37    ///
38    /// This method will use the GEMINI_API_KEY environment variable for authentication.
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use language_barrier_core::provider::gemini::GeminiProvider;
44    ///
45    /// let provider = GeminiProvider::new();
46    /// ```
47    #[instrument(level = "debug")]
48    pub fn new() -> Self {
49        info!("Creating new GeminiProvider with default configuration");
50        let config = GeminiConfig::default();
51        debug!("API key set: {}", !config.api_key.is_empty());
52        debug!("Base URL: {}", config.base_url);
53
54        Self { config }
55    }
56
57    /// Creates a new GeminiProvider with custom configuration
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use language_barrier_core::provider::gemini::{GeminiProvider, GeminiConfig};
63    ///
64    /// let config = GeminiConfig {
65    ///     api_key: "your-api-key".to_string(),
66    ///     base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
67    /// };
68    ///
69    /// let provider = GeminiProvider::with_config(config);
70    /// ```
71    #[instrument(skip(config), level = "debug")]
72    pub fn with_config(config: GeminiConfig) -> Self {
73        info!("Creating new GeminiProvider with custom configuration");
74        debug!("API key set: {}", !config.api_key.is_empty());
75        debug!("Base URL: {}", config.base_url);
76
77        Self { config }
78    }
79}
80
81impl Default for GeminiProvider {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl HTTPProvider<Gemini> for GeminiProvider {
88    fn accept(&self, model: Gemini, chat: &Chat) -> Result<Request> {
89        info!("Creating request for Gemini model: {:?}", model);
90        debug!("Messages in chat history: {}", chat.history.len());
91
92        let model_id = model.gemini_model_id();
93        let url_str = format!(
94            "{}/models/{}:generateContent?key={}",
95            self.config.base_url, model_id, self.config.api_key
96        );
97
98        debug!("Parsing URL: {}", url_str);
99        let url = match Url::parse(&url_str) {
100            Ok(url) => {
101                debug!("URL parsed successfully: {}", url);
102                url
103            }
104            Err(e) => {
105                error!("Failed to parse URL '{}': {}", url_str, e);
106                return Err(e.into());
107            }
108        };
109
110        let mut request = Request::new(Method::POST, url);
111        debug!("Created request: {} {}", request.method(), request.url());
112
113        // Set headers
114        debug!("Setting request headers");
115        let content_type_header = match "application/json".parse() {
116            Ok(header) => header,
117            Err(e) => {
118                error!("Failed to set content type: {}", e);
119                return Err(Error::Other("Failed to set content type".into()));
120            }
121        };
122
123        request
124            .headers_mut()
125            .insert("Content-Type", content_type_header);
126
127        trace!("Request headers set: {:#?}", request.headers());
128
129        // Create the request payload
130        debug!("Creating request payload");
131        let payload = match self.create_request_payload(model, chat) {
132            Ok(payload) => {
133                debug!("Request payload created successfully");
134                trace!("Number of contents: {}", payload.contents.len());
135                trace!(
136                    "System instruction present: {}",
137                    payload.system_instruction.is_some()
138                );
139                trace!(
140                    "Generation config present: {}",
141                    payload.generation_config.is_some()
142                );
143                payload
144            }
145            Err(e) => {
146                error!("Failed to create request payload: {}", e);
147                return Err(e);
148            }
149        };
150
151        // Set the request body
152        debug!("Serializing request payload");
153        let body_bytes = match serde_json::to_vec(&payload) {
154            Ok(bytes) => {
155                debug!("Payload serialized successfully ({} bytes)", bytes.len());
156                bytes
157            }
158            Err(e) => {
159                error!("Failed to serialize payload: {}", e);
160                return Err(Error::Serialization(e));
161            }
162        };
163
164        *request.body_mut() = Some(body_bytes.into());
165        info!("Request created successfully");
166
167        Ok(request)
168    }
169
170    fn parse(&self, raw_response_text: String) -> Result<Message> {
171        info!("Parsing response from Gemini API");
172        trace!("Raw response: {}", raw_response_text);
173
174        // First try to parse as an error response
175        if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&raw_response_text)
176        {
177            if let Some(error) = error_response.error {
178                error!("Gemini API returned an error: {}", error.message);
179                return Err(Error::ProviderUnavailable(error.message));
180            }
181        }
182
183        // If not an error, parse as a successful response
184        debug!("Deserializing response JSON");
185        let gemini_response = match serde_json::from_str::<GeminiResponse>(&raw_response_text) {
186            Ok(response) => {
187                debug!("Response deserialized successfully");
188                if !response.candidates.is_empty() {
189                    debug!(
190                        "Content parts: {}",
191                        response.candidates[0].content.parts.len()
192                    );
193                }
194                response
195            }
196            Err(e) => {
197                error!("Failed to deserialize response: {}", e);
198                error!("Raw response: {}", raw_response_text);
199                return Err(Error::Serialization(e));
200            }
201        };
202
203        // Convert to our message format
204        debug!("Converting Gemini response to Message");
205        let message = Message::from(&gemini_response);
206
207        info!("Response parsed successfully");
208        trace!("Response message processed");
209
210        Ok(message)
211    }
212}
213
214// Trait to get Gemini-specific model IDs
215pub trait GeminiModelInfo {
216    fn gemini_model_id(&self) -> String;
217}
218
219impl GeminiProvider {
220    /// Creates a request payload from a Chat object
221    ///
222    /// This method converts the Chat's messages and settings into a Gemini-specific
223    /// format for the API request.
224    #[instrument(skip(self, chat), level = "debug")]
225    fn create_request_payload(&self, model: Gemini, chat: &Chat) -> Result<GeminiRequest> {
226        info!("Creating request payload for chat with Gemini model");
227        debug!("System prompt length: {}", chat.system_prompt.len());
228        debug!("Messages in history: {}", chat.history.len());
229        debug!("Max output tokens: {}", chat.max_output_tokens);
230
231        // Convert system prompt if present
232        let system_instruction = if !chat.system_prompt.is_empty() {
233            debug!("Including system prompt in request");
234            trace!("System prompt: {}", chat.system_prompt);
235            Some(GeminiContent {
236                parts: vec![GeminiPart::text(chat.system_prompt.clone())],
237                role: None,
238            })
239        } else {
240            debug!("No system prompt provided");
241            None
242        };
243
244        // Convert messages to contents
245        debug!("Converting messages to Gemini format");
246        let mut contents: Vec<GeminiContent> = Vec::new();
247        let mut current_role_str: Option<&'static str> = None;
248        let mut current_parts: Vec<GeminiPart> = Vec::new();
249
250        for msg in &chat.history {
251            // Get the current role string
252            let msg_role_str = msg.role_str();
253
254            // If role changes, finish the current content and start a new one
255            if current_role_str.is_some()
256                && current_role_str != Some(msg_role_str)
257                && !current_parts.is_empty()
258            {
259                let role = match current_role_str {
260                    Some("user") => Some("user".to_string()),
261                    Some("assistant") => Some("model".to_string()),
262                    _ => None,
263                };
264
265                contents.push(GeminiContent {
266                    parts: std::mem::take(&mut current_parts),
267                    role,
268                });
269            }
270
271            current_role_str = Some(msg_role_str);
272
273            // Convert message content to parts based on the message variant
274            match msg {
275                Message::System { content, .. } => {
276                    current_parts.push(GeminiPart::text(content.clone()));
277                }
278                Message::User { content, .. } => match content {
279                    Content::Text(text) => {
280                        current_parts.push(GeminiPart::text(text.clone()));
281                    }
282                    Content::Parts(parts) => {
283                        for part in parts {
284                            match part {
285                                ContentPart::Text { text } => {
286                                    current_parts.push(GeminiPart::text(text.clone()));
287                                }
288                                ContentPart::ImageUrl { image_url } => {
289                                    current_parts.push(GeminiPart::inline_data(
290                                        image_url.url.clone(),
291                                        "image/jpeg".to_string(),
292                                    ));
293                                }
294                            }
295                        }
296                    }
297                },
298                Message::Assistant { content, .. } => {
299                    if let Some(content_data) = content {
300                        match content_data {
301                            Content::Text(text) => {
302                                current_parts.push(GeminiPart::text(text.clone()));
303                            }
304                            Content::Parts(parts) => {
305                                for part in parts {
306                                    match part {
307                                        ContentPart::Text { text } => {
308                                            current_parts.push(GeminiPart::text(text.clone()));
309                                        }
310                                        ContentPart::ImageUrl { image_url } => {
311                                            current_parts.push(GeminiPart::inline_data(
312                                                image_url.url.clone(),
313                                                "image/jpeg".to_string(),
314                                            ));
315                                        }
316                                    }
317                                }
318                            }
319                        }
320                    }
321                }
322                Message::Tool {
323                    tool_call_id,
324                    content,
325                    ..
326                } => {
327                    // For Gemini, include both the tool call ID and the content
328                    current_parts.push(GeminiPart::text(format!(
329                        "Tool result for call {}: {}",
330                        tool_call_id, content
331                    )));
332                }
333            }
334        }
335
336        // Add any remaining parts
337        if !current_parts.is_empty() {
338            let role = match current_role_str {
339                Some("user") => Some("user".to_string()),
340                Some("assistant") => Some("model".to_string()),
341                _ => None,
342            };
343
344            contents.push(GeminiContent {
345                parts: current_parts,
346                role,
347            });
348        }
349
350        debug!("Converted {} contents for the request", contents.len());
351
352        // Create generation config
353        let generation_config = Some(GeminiGenerationConfig {
354            max_output_tokens: Some(chat.max_output_tokens),
355            temperature: None,
356            top_p: None,
357            top_k: None,
358            stop_sequences: None,
359        });
360
361        // Convert tool descriptions if a tool registry is provided
362        let tools = chat.tools.as_ref().map(|tools| {
363            vec![GeminiTool {
364                function_declarations: tools.iter().map(GeminiFunctionDeclaration::from).collect(),
365            }]
366        });
367
368        // Note: For Gemini, tool_choice is handled through the API's behavior
369        // We don't modify the tools list based on the choice
370
371        // Create the tool_config setting based on Google's specific format
372        let tool_config = if let Some(choice) = &chat.tool_choice {
373            match choice {
374                crate::tool::ToolChoice::Auto => Some(GeminiToolConfig {
375                    function_calling_config: GeminiFunctionCallingConfig {
376                        mode: "auto".to_string(),
377                        allowed_function_names: None,
378                    },
379                }),
380                crate::tool::ToolChoice::Any => Some(GeminiToolConfig {
381                    function_calling_config: GeminiFunctionCallingConfig {
382                        mode: "any".to_string(),
383                        allowed_function_names: None,
384                    },
385                }),
386                crate::tool::ToolChoice::None => Some(GeminiToolConfig {
387                    function_calling_config: GeminiFunctionCallingConfig {
388                        mode: "none".to_string(),
389                        allowed_function_names: None,
390                    },
391                }),
392                crate::tool::ToolChoice::Specific(name) => Some(GeminiToolConfig {
393                    function_calling_config: GeminiFunctionCallingConfig {
394                        mode: "auto".to_string(), // Use mode auto with specific allowed function
395                        allowed_function_names: Some(vec![name.clone()]),
396                    },
397                }),
398            }
399        } else if tools.is_some() {
400            // Default to auto if tools are present but no choice specified
401            Some(GeminiToolConfig {
402                function_calling_config: GeminiFunctionCallingConfig {
403                    mode: "auto".to_string(),
404                    allowed_function_names: None,
405                },
406            })
407        } else {
408            None
409        };
410
411        // Create the request
412        debug!("Creating GeminiRequest");
413        let request = GeminiRequest {
414            contents,
415            system_instruction,
416            generation_config,
417            tools,
418            tool_config,
419        };
420
421        info!("Request payload created successfully");
422        Ok(request)
423    }
424}
425
426/// Represents a content part in Gemini API format
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub(crate) struct GeminiPart {
429    /// The text content (optional)
430    #[serde(skip_serializing_if = "Option::is_none")]
431    pub text: Option<String>,
432
433    /// The inline data (optional)
434    #[serde(skip_serializing_if = "Option::is_none")]
435    pub inline_data: Option<GeminiInlineData>,
436
437    /// The function call (optional)
438    #[serde(skip_serializing_if = "Option::is_none", rename = "functionCall")]
439    pub function_call: Option<GeminiFunctionCall>,
440}
441
442/// Represents a function call in the Gemini API format
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub(crate) struct GeminiFunctionCall {
445    /// The name of the function
446    pub name: String,
447    /// The arguments as a JSON Value
448    pub args: serde_json::Value,
449}
450
451impl GeminiPart {
452    /// Create a new text part
453    fn text(text: String) -> Self {
454        GeminiPart {
455            text: Some(text),
456            inline_data: None,
457            function_call: None,
458        }
459    }
460
461    /// Create a new inline data part
462    fn inline_data(data: String, mime_type: String) -> Self {
463        GeminiPart {
464            text: None,
465            inline_data: Some(GeminiInlineData { data, mime_type }),
466            function_call: None,
467        }
468    }
469}
470
471/// Represents inline data in Gemini API format
472#[derive(Debug, Clone, Serialize, Deserialize)]
473pub(crate) struct GeminiInlineData {
474    /// The data (base64 encoded)
475    pub data: String,
476    /// The MIME type
477    pub mime_type: String,
478}
479
480/// Represents a content object in Gemini API format
481#[derive(Debug, Clone, Serialize, Deserialize)]
482pub(crate) struct GeminiContent {
483    /// The parts of the content
484    pub parts: Vec<GeminiPart>,
485    /// The role of the content (user, model, etc.)
486    #[serde(skip_serializing_if = "Option::is_none")]
487    pub role: Option<String>,
488}
489
490/// Represents a generation config in Gemini API format
491#[derive(Debug, Clone, Serialize, Deserialize)]
492pub(crate) struct GeminiGenerationConfig {
493    /// The maximum number of tokens to generate
494    #[serde(skip_serializing_if = "Option::is_none")]
495    pub max_output_tokens: Option<usize>,
496    /// The temperature (randomness) of the generation
497    #[serde(skip_serializing_if = "Option::is_none")]
498    pub temperature: Option<f32>,
499    /// The top-p sampling parameter
500    #[serde(skip_serializing_if = "Option::is_none")]
501    pub top_p: Option<f32>,
502    /// The top-k sampling parameter
503    #[serde(skip_serializing_if = "Option::is_none")]
504    pub top_k: Option<u32>,
505    /// Sequences that will stop generation
506    #[serde(skip_serializing_if = "Option::is_none")]
507    pub stop_sequences: Option<Vec<String>>,
508}
509
510/// Represents a function declaration in the Gemini API format
511#[derive(Debug, Clone, Serialize, Deserialize)]
512pub(crate) struct GeminiFunctionDeclaration {
513    /// The name of the function
514    pub name: String,
515    /// The description of the function
516    pub description: String,
517    /// The parameters schema
518    pub parameters: serde_json::Value,
519}
520
521impl From<&LlmToolInfo> for GeminiFunctionDeclaration {
522    fn from(value: &LlmToolInfo) -> Self {
523        GeminiFunctionDeclaration {
524            name: value.name.clone(),
525            description: value.description.clone(),
526            parameters: value.parameters.clone(),
527        }
528    }
529}
530
531/// Represents a function in the Gemini API format (tools are called functions in Gemini)
532#[derive(Debug, Clone, Serialize, Deserialize)]
533pub(crate) struct GeminiFunction {
534    /// The name of the function
535    pub name: String,
536    /// The description of the function
537    pub description: String,
538    /// The parameters definition
539    pub parameters: serde_json::Value,
540}
541
542/// Represents a tool in the Gemini API format
543#[derive(Debug, Clone, Serialize, Deserialize)]
544pub(crate) struct GeminiTool {
545    /// The function declaration
546    #[serde(rename = "functionDeclarations")]
547    pub function_declarations: Vec<GeminiFunctionDeclaration>,
548}
549
550// Gemini uses GeminiToolConfig instead of a direct tool_choice field
551
552/// Tool config for Gemini API
553#[derive(Debug, Clone, Serialize, Deserialize)]
554pub(crate) struct GeminiToolConfig {
555    /// Function calling configuration
556    #[serde(rename = "function_calling_config")]
557    pub function_calling_config: GeminiFunctionCallingConfig,
558}
559
560/// Function calling config for Gemini API
561#[derive(Debug, Clone, Serialize, Deserialize)]
562pub(crate) struct GeminiFunctionCallingConfig {
563    /// The mode (auto, any, none)
564    pub mode: String,
565    /// List of specific function names that are allowed (optional)
566    #[serde(skip_serializing_if = "Option::is_none")]
567    pub allowed_function_names: Option<Vec<String>>,
568}
569
570/// Represents a request to the Gemini API
571#[derive(Debug, Serialize, Deserialize)]
572pub(crate) struct GeminiRequest {
573    /// The contents to send to the model
574    pub contents: Vec<GeminiContent>,
575    /// The system instruction (optional)
576    #[serde(skip_serializing_if = "Option::is_none")]
577    pub system_instruction: Option<GeminiContent>,
578    /// The generation config (optional)
579    #[serde(skip_serializing_if = "Option::is_none")]
580    pub generation_config: Option<GeminiGenerationConfig>,
581    /// The tools (functions) available to the model
582    #[serde(skip_serializing_if = "Option::is_none")]
583    pub tools: Option<Vec<GeminiTool>>,
584    /// Tool configuration
585    #[serde(skip_serializing_if = "Option::is_none")]
586    pub tool_config: Option<GeminiToolConfig>,
587}
588
589/// Represents a response from the Gemini API
590#[derive(Debug, Serialize, Deserialize)]
591pub(crate) struct GeminiResponse {
592    /// The candidates (typically one)
593    pub candidates: Vec<GeminiCandidate>,
594    /// Usage information (may not be present in all responses)
595    #[serde(rename = "usageMetadata", skip_serializing_if = "Option::is_none")]
596    pub usage_metadata: Option<GeminiUsageMetadata>,
597    /// The model version
598    #[serde(rename = "modelVersion", skip_serializing_if = "Option::is_none")]
599    pub model_version: Option<String>,
600}
601
602/// Represents a candidate in a Gemini response
603#[derive(Debug, Serialize, Deserialize)]
604pub(crate) struct GeminiCandidate {
605    /// The content of the candidate
606    pub content: GeminiContent,
607    /// The finish reason (using camelCase as in the API)
608    #[serde(skip_serializing_if = "Option::is_none", rename = "finishReason")]
609    pub finish_reason: Option<String>,
610    /// The index of the candidate (optional)
611    #[serde(skip_serializing_if = "Option::is_none")]
612    pub index: Option<i32>,
613    /// The average log probability (optional)
614    #[serde(skip_serializing_if = "Option::is_none", rename = "avgLogprobs")]
615    pub avg_logprobs: Option<f64>,
616}
617
618/// Represents token details for a specific modality
619#[derive(Debug, Serialize, Deserialize)]
620pub(crate) struct GeminiTokenDetails {
621    /// The modality (TEXT, IMAGE, etc.)
622    pub modality: String,
623    /// The token count for this modality
624    #[serde(rename = "tokenCount")]
625    pub token_count: u32,
626}
627
628/// Represents usage metadata in a Gemini response
629#[derive(Debug, Serialize, Deserialize)]
630pub(crate) struct GeminiUsageMetadata {
631    /// Token count in the prompt
632    #[serde(rename = "promptTokenCount")]
633    pub prompt_token_count: u32,
634    /// Token count in the response
635    #[serde(rename = "candidatesTokenCount", default)]
636    pub candidates_token_count: u32,
637    /// Total token count
638    #[serde(rename = "totalTokenCount", default)]
639    pub total_token_count: u32,
640    /// Detailed token breakdown for the prompt
641    #[serde(
642        rename = "promptTokensDetails",
643        skip_serializing_if = "Option::is_none"
644    )]
645    pub prompt_tokens_details: Option<Vec<GeminiTokenDetails>>,
646    /// Detailed token breakdown for the candidates
647    #[serde(
648        rename = "candidatesTokensDetails",
649        skip_serializing_if = "Option::is_none"
650    )]
651    pub candidates_tokens_details: Option<Vec<GeminiTokenDetails>>,
652}
653
654/// Represents an error response from the Gemini API
655#[derive(Debug, Serialize, Deserialize)]
656pub(crate) struct GeminiErrorResponse {
657    /// The error details
658    pub error: Option<GeminiError>,
659}
660
661/// Represents an error from the Gemini API
662#[derive(Debug, Serialize, Deserialize)]
663pub(crate) struct GeminiError {
664    /// The error code
665    pub code: i32,
666    /// The error message
667    pub message: String,
668    /// The error status
669    pub status: String,
670}
671
672/// Convert from Gemini's response to our message format
673impl From<&GeminiResponse> for Message {
674    fn from(response: &GeminiResponse) -> Self {
675        // Check if we have candidates
676        if response.candidates.is_empty() {
677            return Message::assistant("No response generated");
678        }
679
680        // Get the first candidate
681        let candidate = &response.candidates[0];
682
683        // Extract text content and tool calls separately
684        let mut text_content_parts = Vec::new();
685        let mut tool_calls = Vec::new();
686        let mut tool_call_id_counter = 0;
687
688        // Process each part of the response
689        for part in &candidate.content.parts {
690            // Handle function calls
691            if let Some(function_call) = &part.function_call {
692                tool_call_id_counter += 1;
693                let tool_id = format!("gemini_call_{}", tool_call_id_counter);
694
695                let args_str =
696                    serde_json::to_string(&function_call.args).unwrap_or_else(|_| "{}".to_string());
697
698                let tool_call = crate::message::ToolCall {
699                    id: tool_id,
700                    tool_type: "function".to_string(),
701                    function: crate::message::Function {
702                        name: function_call.name.clone(),
703                        arguments: args_str,
704                    },
705                };
706
707                tool_calls.push(tool_call);
708            }
709
710            // Handle text content
711            if let Some(text) = &part.text {
712                text_content_parts.push(ContentPart::text(text.clone()));
713            } else if let Some(inline_data) = &part.inline_data {
714                // Just convert to text representation for now
715                text_content_parts.push(ContentPart::text(format!(
716                    "[Image: {} ({})]",
717                    inline_data.data, inline_data.mime_type
718                )));
719            }
720        }
721
722        // Create the content
723        let content = if text_content_parts.len() == 1 {
724            // If there's only one text part, use simple Text content
725            match &text_content_parts[0] {
726                ContentPart::Text { text } => Some(Content::Text(text.clone())),
727                _ => Some(Content::Parts(text_content_parts)),
728            }
729        } else if !text_content_parts.is_empty() {
730            // Multiple content parts
731            Some(Content::Parts(text_content_parts))
732        } else {
733            // No text content, may have only function calls
734            None
735        };
736
737        // Create a new assistant message with appropriate content and tool calls
738        let mut msg = if !tool_calls.is_empty() {
739            // If we have tool calls
740            Message::Assistant {
741                content,
742                tool_calls,
743                metadata: Default::default(),
744            }
745        } else if let Some(Content::Text(text)) = content {
746            // Simple text response
747            Message::assistant(text)
748        } else {
749            // Other content types (multipart or none)
750            Message::Assistant {
751                content,
752                tool_calls: Vec::new(),
753                metadata: Default::default(),
754            }
755        };
756
757        // Add usage info if available
758        if let Some(usage) = &response.usage_metadata {
759            msg = msg.with_metadata(
760                "prompt_tokens",
761                serde_json::Value::Number(usage.prompt_token_count.into()),
762            );
763            msg = msg.with_metadata(
764                "completion_tokens",
765                serde_json::Value::Number(usage.candidates_token_count.into()),
766            );
767            msg = msg.with_metadata(
768                "total_tokens",
769                serde_json::Value::Number(usage.total_token_count.into()),
770            );
771        }
772
773        msg
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use super::*;
780
781    // Tests will be implemented as we get more information about the API
782    #[test]
783    fn test_gemini_part_serialization() {
784        let text_part = GeminiPart::text("Hello, world!".to_string());
785        let serialized = serde_json::to_string(&text_part).unwrap();
786        let expected = r#"{"text":"Hello, world!"}"#;
787        assert_eq!(serialized, expected);
788
789        let inline_data_part =
790            GeminiPart::inline_data("base64data".to_string(), "image/jpeg".to_string());
791        let serialized = serde_json::to_string(&inline_data_part).unwrap();
792        let expected = r#"{"inline_data":{"data":"base64data","mime_type":"image/jpeg"}}"#;
793        assert_eq!(serialized, expected);
794    }
795
796    #[test]
797    fn test_error_response_parsing() {
798        let error_json = r#"{
799            "error": {
800                "code": 400,
801                "message": "Invalid JSON payload received.",
802                "status": "INVALID_ARGUMENT"
803            }
804        }"#;
805
806        let error_response: GeminiErrorResponse = serde_json::from_str(error_json).unwrap();
807        assert!(error_response.error.is_some());
808        let error = error_response.error.unwrap();
809        assert_eq!(error.code, 400);
810        assert_eq!(error.status, "INVALID_ARGUMENT");
811    }
812}