language_barrier_core/provider/
gemini.rs

1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, ModelInfo};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10/// Configuration for the Gemini provider
11#[derive(Debug, Clone)]
12pub struct GeminiConfig {
13    /// API key for authentication
14    pub api_key: String,
15    /// Base URL for the API
16    pub base_url: String,
17}
18
19impl Default for GeminiConfig {
20    fn default() -> Self {
21        Self {
22            api_key: env::var("GEMINI_API_KEY").unwrap_or_default(),
23            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
24        }
25    }
26}
27
28/// Implementation of the Gemini provider
29#[derive(Debug, Clone)]
30pub struct GeminiProvider {
31    /// Configuration for the provider
32    config: GeminiConfig,
33}
34
35impl GeminiProvider {
36    /// Creates a new GeminiProvider with default configuration
37    ///
38    /// This method will use the GEMINI_API_KEY environment variable for authentication.
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use language_barrier_core::provider::gemini::GeminiProvider;
44    ///
45    /// let provider = GeminiProvider::new();
46    /// ```
47    #[instrument(level = "debug")]
48    pub fn new() -> Self {
49        info!("Creating new GeminiProvider with default configuration");
50        let config = GeminiConfig::default();
51        debug!("API key set: {}", !config.api_key.is_empty());
52        debug!("Base URL: {}", config.base_url);
53
54        Self { config }
55    }
56
57    /// Creates a new GeminiProvider with custom configuration
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use language_barrier_core::provider::gemini::{GeminiProvider, GeminiConfig};
63    ///
64    /// let config = GeminiConfig {
65    ///     api_key: "your-api-key".to_string(),
66    ///     base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
67    /// };
68    ///
69    /// let provider = GeminiProvider::with_config(config);
70    /// ```
71    #[instrument(skip(config), level = "debug")]
72    pub fn with_config(config: GeminiConfig) -> Self {
73        info!("Creating new GeminiProvider with custom configuration");
74        debug!("API key set: {}", !config.api_key.is_empty());
75        debug!("Base URL: {}", config.base_url);
76
77        Self { config }
78    }
79}
80
81impl Default for GeminiProvider {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl<M: ModelInfo + GeminiModelInfo> HTTPProvider<M> for GeminiProvider {
88    fn accept(&self, chat: Chat<M>) -> Result<Request> {
89        info!("Creating request for Gemini model: {:?}", chat.model);
90        debug!("Messages in chat history: {}", chat.history.len());
91
92        let model_id = chat.model.gemini_model_id();
93        let url_str = format!(
94            "{}/models/{}:generateContent?key={}",
95            self.config.base_url, model_id, self.config.api_key
96        );
97
98        debug!("Parsing URL: {}", url_str);
99        let url = match Url::parse(&url_str) {
100            Ok(url) => {
101                debug!("URL parsed successfully: {}", url);
102                url
103            }
104            Err(e) => {
105                error!("Failed to parse URL '{}': {}", url_str, e);
106                return Err(e.into());
107            }
108        };
109
110        let mut request = Request::new(Method::POST, url);
111        debug!("Created request: {} {}", request.method(), request.url());
112
113        // Set headers
114        debug!("Setting request headers");
115        let content_type_header = match "application/json".parse() {
116            Ok(header) => header,
117            Err(e) => {
118                error!("Failed to set content type: {}", e);
119                return Err(Error::Other("Failed to set content type".into()));
120            }
121        };
122
123        request
124            .headers_mut()
125            .insert("Content-Type", content_type_header);
126
127        trace!("Request headers set: {:#?}", request.headers());
128
129        // Create the request payload
130        debug!("Creating request payload");
131        let payload = match self.create_request_payload(&chat) {
132            Ok(payload) => {
133                debug!("Request payload created successfully");
134                trace!("Number of contents: {}", payload.contents.len());
135                trace!(
136                    "System instruction present: {}",
137                    payload.system_instruction.is_some()
138                );
139                trace!(
140                    "Generation config present: {}",
141                    payload.generation_config.is_some()
142                );
143                payload
144            }
145            Err(e) => {
146                error!("Failed to create request payload: {}", e);
147                return Err(e);
148            }
149        };
150
151        // Set the request body
152        debug!("Serializing request payload");
153        let body_bytes = match serde_json::to_vec(&payload) {
154            Ok(bytes) => {
155                debug!("Payload serialized successfully ({} bytes)", bytes.len());
156                bytes
157            }
158            Err(e) => {
159                error!("Failed to serialize payload: {}", e);
160                return Err(Error::Serialization(e));
161            }
162        };
163
164        *request.body_mut() = Some(body_bytes.into());
165        info!("Request created successfully");
166
167        Ok(request)
168    }
169
170    fn parse(&self, raw_response_text: String) -> Result<Message> {
171        info!("Parsing response from Gemini API");
172        trace!("Raw response: {}", raw_response_text);
173
174        // First try to parse as an error response
175        if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&raw_response_text)
176        {
177            if let Some(error) = error_response.error {
178                error!("Gemini API returned an error: {}", error.message);
179                return Err(Error::ProviderUnavailable(error.message));
180            }
181        }
182
183        // If not an error, parse as a successful response
184        debug!("Deserializing response JSON");
185        let gemini_response = match serde_json::from_str::<GeminiResponse>(&raw_response_text) {
186            Ok(response) => {
187                debug!("Response deserialized successfully");
188                if !response.candidates.is_empty() {
189                    debug!(
190                        "Content parts: {}",
191                        response.candidates[0].content.parts.len()
192                    );
193                }
194                response
195            }
196            Err(e) => {
197                error!("Failed to deserialize response: {}", e);
198                error!("Raw response: {}", raw_response_text);
199                return Err(Error::Serialization(e));
200            }
201        };
202
203        // Convert to our message format
204        debug!("Converting Gemini response to Message");
205        let message = Message::from(&gemini_response);
206
207        info!("Response parsed successfully");
208        trace!("Response message processed");
209
210        Ok(message)
211    }
212}
213
214// Trait to get Gemini-specific model IDs
215pub trait GeminiModelInfo {
216    fn gemini_model_id(&self) -> String;
217}
218
219impl GeminiProvider {
220    /// Creates a request payload from a Chat object
221    ///
222    /// This method converts the Chat's messages and settings into a Gemini-specific
223    /// format for the API request.
224    #[instrument(skip(self, chat), level = "debug")]
225    fn create_request_payload<M: ModelInfo + GeminiModelInfo>(
226        &self,
227        chat: &Chat<M>,
228    ) -> Result<GeminiRequest> {
229        info!("Creating request payload for chat with Gemini model");
230        debug!("System prompt length: {}", chat.system_prompt.len());
231        debug!("Messages in history: {}", chat.history.len());
232        debug!("Max output tokens: {}", chat.max_output_tokens);
233
234        // Convert system prompt if present
235        let system_instruction = if !chat.system_prompt.is_empty() {
236            debug!("Including system prompt in request");
237            trace!("System prompt: {}", chat.system_prompt);
238            Some(GeminiContent {
239                parts: vec![GeminiPart::text(chat.system_prompt.clone())],
240                role: None,
241            })
242        } else {
243            debug!("No system prompt provided");
244            None
245        };
246
247        // Convert messages to contents
248        debug!("Converting messages to Gemini format");
249        let mut contents: Vec<GeminiContent> = Vec::new();
250        let mut current_role_str: Option<&'static str> = None;
251        let mut current_parts: Vec<GeminiPart> = Vec::new();
252
253        for msg in &chat.history {
254            // Get the current role string
255            let msg_role_str = msg.role_str();
256
257            // If role changes, finish the current content and start a new one
258            if current_role_str.is_some()
259                && current_role_str != Some(msg_role_str)
260                && !current_parts.is_empty()
261            {
262                let role = match current_role_str {
263                    Some("user") => Some("user".to_string()),
264                    Some("assistant") => Some("model".to_string()),
265                    _ => None,
266                };
267
268                contents.push(GeminiContent {
269                    parts: std::mem::take(&mut current_parts),
270                    role,
271                });
272            }
273
274            current_role_str = Some(msg_role_str);
275
276            // Convert message content to parts based on the message variant
277            match msg {
278                Message::System { content, .. } => {
279                    current_parts.push(GeminiPart::text(content.clone()));
280                }
281                Message::User { content, .. } => match content {
282                    Content::Text(text) => {
283                        current_parts.push(GeminiPart::text(text.clone()));
284                    }
285                    Content::Parts(parts) => {
286                        for part in parts {
287                            match part {
288                                ContentPart::Text { text } => {
289                                    current_parts.push(GeminiPart::text(text.clone()));
290                                }
291                                ContentPart::ImageUrl { image_url } => {
292                                    current_parts.push(GeminiPart::inline_data(
293                                        image_url.url.clone(),
294                                        "image/jpeg".to_string(),
295                                    ));
296                                }
297                            }
298                        }
299                    }
300                },
301                Message::Assistant { content, .. } => {
302                    if let Some(content_data) = content {
303                        match content_data {
304                            Content::Text(text) => {
305                                current_parts.push(GeminiPart::text(text.clone()));
306                            }
307                            Content::Parts(parts) => {
308                                for part in parts {
309                                    match part {
310                                        ContentPart::Text { text } => {
311                                            current_parts.push(GeminiPart::text(text.clone()));
312                                        }
313                                        ContentPart::ImageUrl { image_url } => {
314                                            current_parts.push(GeminiPart::inline_data(
315                                                image_url.url.clone(),
316                                                "image/jpeg".to_string(),
317                                            ));
318                                        }
319                                    }
320                                }
321                            }
322                        }
323                    }
324                }
325                Message::Tool {
326                    tool_call_id,
327                    content,
328                    ..
329                } => {
330                    // For Gemini, include both the tool call ID and the content
331                    current_parts.push(GeminiPart::text(format!(
332                        "Tool result for call {}: {}",
333                        tool_call_id, content
334                    )));
335                }
336            }
337        }
338
339        // Add any remaining parts
340        if !current_parts.is_empty() {
341            let role = match current_role_str {
342                Some("user") => Some("user".to_string()),
343                Some("assistant") => Some("model".to_string()),
344                _ => None,
345            };
346
347            contents.push(GeminiContent {
348                parts: current_parts,
349                role,
350            });
351        }
352
353        debug!("Converted {} contents for the request", contents.len());
354
355        // Create generation config
356        let generation_config = Some(GeminiGenerationConfig {
357            max_output_tokens: Some(chat.max_output_tokens),
358            temperature: None,
359            top_p: None,
360            top_k: None,
361            stop_sequences: None,
362        });
363
364        // Convert tool descriptions if a tool registry is provided
365        let tools = chat.tools.as_ref().map(|tools| {
366            vec![GeminiTool {
367                function_declarations: tools
368                    .iter()
369                    .map(GeminiFunctionDeclaration::from)
370                    .collect(),
371            }]
372        });
373
374        // Create the request
375        debug!("Creating GeminiRequest");
376        let request = GeminiRequest {
377            contents,
378            system_instruction,
379            generation_config,
380            tools,
381        };
382
383        info!("Request payload created successfully");
384        Ok(request)
385    }
386}
387
388/// Represents a content part in Gemini API format
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub(crate) struct GeminiPart {
391    /// The text content (optional)
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub text: Option<String>,
394
395    /// The inline data (optional)
396    #[serde(skip_serializing_if = "Option::is_none")]
397    pub inline_data: Option<GeminiInlineData>,
398
399    /// The function call (optional)
400    #[serde(skip_serializing_if = "Option::is_none", rename = "functionCall")]
401    pub function_call: Option<GeminiFunctionCall>,
402}
403
404/// Represents a function call in the Gemini API format
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub(crate) struct GeminiFunctionCall {
407    /// The name of the function
408    pub name: String,
409    /// The arguments as a JSON Value
410    pub args: serde_json::Value,
411}
412
413impl GeminiPart {
414    /// Create a new text part
415    fn text(text: String) -> Self {
416        GeminiPart {
417            text: Some(text),
418            inline_data: None,
419            function_call: None,
420        }
421    }
422
423    /// Create a new inline data part
424    fn inline_data(data: String, mime_type: String) -> Self {
425        GeminiPart {
426            text: None,
427            inline_data: Some(GeminiInlineData { data, mime_type }),
428            function_call: None,
429        }
430    }
431}
432
433/// Represents inline data in Gemini API format
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub(crate) struct GeminiInlineData {
436    /// The data (base64 encoded)
437    pub data: String,
438    /// The MIME type
439    pub mime_type: String,
440}
441
442/// Represents a content object in Gemini API format
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub(crate) struct GeminiContent {
445    /// The parts of the content
446    pub parts: Vec<GeminiPart>,
447    /// The role of the content (user, model, etc.)
448    #[serde(skip_serializing_if = "Option::is_none")]
449    pub role: Option<String>,
450}
451
452/// Represents a generation config in Gemini API format
453#[derive(Debug, Clone, Serialize, Deserialize)]
454pub(crate) struct GeminiGenerationConfig {
455    /// The maximum number of tokens to generate
456    #[serde(skip_serializing_if = "Option::is_none")]
457    pub max_output_tokens: Option<usize>,
458    /// The temperature (randomness) of the generation
459    #[serde(skip_serializing_if = "Option::is_none")]
460    pub temperature: Option<f32>,
461    /// The top-p sampling parameter
462    #[serde(skip_serializing_if = "Option::is_none")]
463    pub top_p: Option<f32>,
464    /// The top-k sampling parameter
465    #[serde(skip_serializing_if = "Option::is_none")]
466    pub top_k: Option<u32>,
467    /// Sequences that will stop generation
468    #[serde(skip_serializing_if = "Option::is_none")]
469    pub stop_sequences: Option<Vec<String>>,
470}
471
472/// Represents a function declaration in the Gemini API format
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub(crate) struct GeminiFunctionDeclaration {
475    /// The name of the function
476    pub name: String,
477    /// The description of the function
478    pub description: String,
479    /// The parameters schema
480    pub parameters: serde_json::Value,
481}
482
483impl From<&LlmToolInfo> for GeminiFunctionDeclaration {
484    fn from(value: &LlmToolInfo) -> Self {
485        GeminiFunctionDeclaration {
486            name: value.name.clone(),
487            description: value.description.clone(),
488            parameters: value.parameters.clone(),
489        }
490    }
491}
492
493/// Represents a function in the Gemini API format (tools are called functions in Gemini)
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub(crate) struct GeminiFunction {
496    /// The name of the function
497    pub name: String,
498    /// The description of the function
499    pub description: String,
500    /// The parameters definition
501    pub parameters: serde_json::Value,
502}
503
504/// Represents a tool in the Gemini API format
505#[derive(Debug, Clone, Serialize, Deserialize)]
506pub(crate) struct GeminiTool {
507    /// The function declaration
508    #[serde(rename = "functionDeclarations")]
509    pub function_declarations: Vec<GeminiFunctionDeclaration>,
510}
511
512/// Represents a request to the Gemini API
513#[derive(Debug, Serialize, Deserialize)]
514pub(crate) struct GeminiRequest {
515    /// The contents to send to the model
516    pub contents: Vec<GeminiContent>,
517    /// The system instruction (optional)
518    #[serde(skip_serializing_if = "Option::is_none")]
519    pub system_instruction: Option<GeminiContent>,
520    /// The generation config (optional)
521    #[serde(skip_serializing_if = "Option::is_none")]
522    pub generation_config: Option<GeminiGenerationConfig>,
523    /// The tools (functions) available to the model
524    #[serde(skip_serializing_if = "Option::is_none")]
525    pub tools: Option<Vec<GeminiTool>>,
526}
527
528/// Represents a response from the Gemini API
529#[derive(Debug, Serialize, Deserialize)]
530pub(crate) struct GeminiResponse {
531    /// The candidates (typically one)
532    pub candidates: Vec<GeminiCandidate>,
533    /// Usage information (may not be present in all responses)
534    #[serde(rename = "usageMetadata", skip_serializing_if = "Option::is_none")]
535    pub usage_metadata: Option<GeminiUsageMetadata>,
536    /// The model version
537    #[serde(rename = "modelVersion", skip_serializing_if = "Option::is_none")]
538    pub model_version: Option<String>,
539}
540
541/// Represents a candidate in a Gemini response
542#[derive(Debug, Serialize, Deserialize)]
543pub(crate) struct GeminiCandidate {
544    /// The content of the candidate
545    pub content: GeminiContent,
546    /// The finish reason (using camelCase as in the API)
547    #[serde(skip_serializing_if = "Option::is_none", rename = "finishReason")]
548    pub finish_reason: Option<String>,
549    /// The index of the candidate (optional)
550    #[serde(skip_serializing_if = "Option::is_none")]
551    pub index: Option<i32>,
552    /// The average log probability (optional)
553    #[serde(skip_serializing_if = "Option::is_none", rename = "avgLogprobs")]
554    pub avg_logprobs: Option<f64>,
555}
556
557/// Represents token details for a specific modality
558#[derive(Debug, Serialize, Deserialize)]
559pub(crate) struct GeminiTokenDetails {
560    /// The modality (TEXT, IMAGE, etc.)
561    pub modality: String,
562    /// The token count for this modality
563    #[serde(rename = "tokenCount")]
564    pub token_count: u32,
565}
566
567/// Represents usage metadata in a Gemini response
568#[derive(Debug, Serialize, Deserialize)]
569pub(crate) struct GeminiUsageMetadata {
570    /// Token count in the prompt
571    #[serde(rename = "promptTokenCount")]
572    pub prompt_token_count: u32,
573    /// Token count in the response
574    #[serde(rename = "candidatesTokenCount", default)]
575    pub candidates_token_count: u32,
576    /// Total token count
577    #[serde(rename = "totalTokenCount", default)]
578    pub total_token_count: u32,
579    /// Detailed token breakdown for the prompt
580    #[serde(
581        rename = "promptTokensDetails",
582        skip_serializing_if = "Option::is_none"
583    )]
584    pub prompt_tokens_details: Option<Vec<GeminiTokenDetails>>,
585    /// Detailed token breakdown for the candidates
586    #[serde(
587        rename = "candidatesTokensDetails",
588        skip_serializing_if = "Option::is_none"
589    )]
590    pub candidates_tokens_details: Option<Vec<GeminiTokenDetails>>,
591}
592
593/// Represents an error response from the Gemini API
594#[derive(Debug, Serialize, Deserialize)]
595pub(crate) struct GeminiErrorResponse {
596    /// The error details
597    pub error: Option<GeminiError>,
598}
599
600/// Represents an error from the Gemini API
601#[derive(Debug, Serialize, Deserialize)]
602pub(crate) struct GeminiError {
603    /// The error code
604    pub code: i32,
605    /// The error message
606    pub message: String,
607    /// The error status
608    pub status: String,
609}
610
611/// Convert from Gemini's response to our message format
612impl From<&GeminiResponse> for Message {
613    fn from(response: &GeminiResponse) -> Self {
614        // Check if we have candidates
615        if response.candidates.is_empty() {
616            return Message::assistant("No response generated");
617        }
618
619        // Get the first candidate
620        let candidate = &response.candidates[0];
621
622        // Extract text content and tool calls separately
623        let mut text_content_parts = Vec::new();
624        let mut tool_calls = Vec::new();
625        let mut tool_call_id_counter = 0;
626
627        // Process each part of the response
628        for part in &candidate.content.parts {
629            // Handle function calls
630            if let Some(function_call) = &part.function_call {
631                tool_call_id_counter += 1;
632                let tool_id = format!("gemini_call_{}", tool_call_id_counter);
633
634                let args_str =
635                    serde_json::to_string(&function_call.args).unwrap_or_else(|_| "{}".to_string());
636
637                let tool_call = crate::message::ToolCall {
638                    id: tool_id,
639                    tool_type: "function".to_string(),
640                    function: crate::message::Function {
641                        name: function_call.name.clone(),
642                        arguments: args_str,
643                    },
644                };
645
646                tool_calls.push(tool_call);
647            }
648
649            // Handle text content
650            if let Some(text) = &part.text {
651                text_content_parts.push(ContentPart::text(text.clone()));
652            } else if let Some(inline_data) = &part.inline_data {
653                // Just convert to text representation for now
654                text_content_parts.push(ContentPart::text(format!(
655                    "[Image: {} ({})]",
656                    inline_data.data, inline_data.mime_type
657                )));
658            }
659        }
660
661        // Create the content
662        let content = if text_content_parts.len() == 1 {
663            // If there's only one text part, use simple Text content
664            match &text_content_parts[0] {
665                ContentPart::Text { text } => Some(Content::Text(text.clone())),
666                _ => Some(Content::Parts(text_content_parts)),
667            }
668        } else if !text_content_parts.is_empty() {
669            // Multiple content parts
670            Some(Content::Parts(text_content_parts))
671        } else {
672            // No text content, may have only function calls
673            None
674        };
675
676        // Create a new assistant message with appropriate content and tool calls
677        let mut msg = if !tool_calls.is_empty() {
678            // If we have tool calls
679            Message::Assistant {
680                content,
681                tool_calls,
682                metadata: Default::default(),
683            }
684        } else if let Some(Content::Text(text)) = content {
685            // Simple text response
686            Message::assistant(text)
687        } else {
688            // Other content types (multipart or none)
689            Message::Assistant {
690                content,
691                tool_calls: Vec::new(),
692                metadata: Default::default(),
693            }
694        };
695
696        // Add usage info if available
697        if let Some(usage) = &response.usage_metadata {
698            msg = msg.with_metadata(
699                "prompt_tokens",
700                serde_json::Value::Number(usage.prompt_token_count.into()),
701            );
702            msg = msg.with_metadata(
703                "completion_tokens",
704                serde_json::Value::Number(usage.candidates_token_count.into()),
705            );
706            msg = msg.with_metadata(
707                "total_tokens",
708                serde_json::Value::Number(usage.total_token_count.into()),
709            );
710        }
711
712        msg
713    }
714}
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719
720    // Tests will be implemented as we get more information about the API
721    #[test]
722    fn test_gemini_part_serialization() {
723        let text_part = GeminiPart::text("Hello, world!".to_string());
724        let serialized = serde_json::to_string(&text_part).unwrap();
725        let expected = r#"{"text":"Hello, world!"}"#;
726        assert_eq!(serialized, expected);
727
728        let inline_data_part =
729            GeminiPart::inline_data("base64data".to_string(), "image/jpeg".to_string());
730        let serialized = serde_json::to_string(&inline_data_part).unwrap();
731        let expected = r#"{"inline_data":{"data":"base64data","mime_type":"image/jpeg"}}"#;
732        assert_eq!(serialized, expected);
733    }
734
735    #[test]
736    fn test_error_response_parsing() {
737        let error_json = r#"{
738            "error": {
739                "code": 400,
740                "message": "Invalid JSON payload received.",
741                "status": "INVALID_ARGUMENT"
742            }
743        }"#;
744
745        let error_response: GeminiErrorResponse = serde_json::from_str(error_json).unwrap();
746        assert!(error_response.error.is_some());
747        let error = error_response.error.unwrap();
748        assert_eq!(error.code, 400);
749        assert_eq!(error.status, "INVALID_ARGUMENT");
750    }
751}