1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, Gemini, LlmToolInfo};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10#[derive(Debug, Clone)]
12pub struct GeminiConfig {
13 pub api_key: String,
15 pub base_url: String,
17}
18
19impl Default for GeminiConfig {
20 fn default() -> Self {
21 Self {
22 api_key: env::var("GEMINI_API_KEY").unwrap_or_default(),
23 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct GeminiProvider {
31 config: GeminiConfig,
33}
34
35impl GeminiProvider {
36 #[instrument(level = "debug")]
48 pub fn new() -> Self {
49 info!("Creating new GeminiProvider with default configuration");
50 let config = GeminiConfig::default();
51 debug!("API key set: {}", !config.api_key.is_empty());
52 debug!("Base URL: {}", config.base_url);
53
54 Self { config }
55 }
56
57 #[instrument(skip(config), level = "debug")]
72 pub fn with_config(config: GeminiConfig) -> Self {
73 info!("Creating new GeminiProvider with custom configuration");
74 debug!("API key set: {}", !config.api_key.is_empty());
75 debug!("Base URL: {}", config.base_url);
76
77 Self { config }
78 }
79}
80
81impl Default for GeminiProvider {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl HTTPProvider<Gemini> for GeminiProvider {
88 fn accept(&self, model: Gemini, chat: &Chat) -> Result<Request> {
89 info!("Creating request for Gemini model: {:?}", model);
90 debug!("Messages in chat history: {}", chat.history.len());
91
92 let model_id = model.gemini_model_id();
93 let url_str = format!(
94 "{}/models/{}:generateContent?key={}",
95 self.config.base_url, model_id, self.config.api_key
96 );
97
98 debug!("Parsing URL: {}", url_str);
99 let url = match Url::parse(&url_str) {
100 Ok(url) => {
101 debug!("URL parsed successfully: {}", url);
102 url
103 }
104 Err(e) => {
105 error!("Failed to parse URL '{}': {}", url_str, e);
106 return Err(e.into());
107 }
108 };
109
110 let mut request = Request::new(Method::POST, url);
111 debug!("Created request: {} {}", request.method(), request.url());
112
113 debug!("Setting request headers");
115 let content_type_header = match "application/json".parse() {
116 Ok(header) => header,
117 Err(e) => {
118 error!("Failed to set content type: {}", e);
119 return Err(Error::Other("Failed to set content type".into()));
120 }
121 };
122
123 request
124 .headers_mut()
125 .insert("Content-Type", content_type_header);
126
127 trace!("Request headers set: {:#?}", request.headers());
128
129 debug!("Creating request payload");
131 let payload = match self.create_request_payload(model, chat) {
132 Ok(payload) => {
133 debug!("Request payload created successfully");
134 trace!("Number of contents: {}", payload.contents.len());
135 trace!(
136 "System instruction present: {}",
137 payload.system_instruction.is_some()
138 );
139 trace!(
140 "Generation config present: {}",
141 payload.generation_config.is_some()
142 );
143 payload
144 }
145 Err(e) => {
146 error!("Failed to create request payload: {}", e);
147 return Err(e);
148 }
149 };
150
151 debug!("Serializing request payload");
153 let body_bytes = match serde_json::to_vec(&payload) {
154 Ok(bytes) => {
155 debug!("Payload serialized successfully ({} bytes)", bytes.len());
156 bytes
157 }
158 Err(e) => {
159 error!("Failed to serialize payload: {}", e);
160 return Err(Error::Serialization(e));
161 }
162 };
163
164 *request.body_mut() = Some(body_bytes.into());
165 info!("Request created successfully");
166
167 Ok(request)
168 }
169
170 fn parse(&self, raw_response_text: String) -> Result<Message> {
171 info!("Parsing response from Gemini API");
172 trace!("Raw response: {}", raw_response_text);
173
174 if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&raw_response_text)
176 {
177 if let Some(error) = error_response.error {
178 error!("Gemini API returned an error: {}", error.message);
179 return Err(Error::ProviderUnavailable(error.message));
180 }
181 }
182
183 debug!("Deserializing response JSON");
185 let gemini_response = match serde_json::from_str::<GeminiResponse>(&raw_response_text) {
186 Ok(response) => {
187 debug!("Response deserialized successfully");
188 if !response.candidates.is_empty() {
189 debug!(
190 "Content parts: {}",
191 response.candidates[0].content.parts.len()
192 );
193 }
194 response
195 }
196 Err(e) => {
197 error!("Failed to deserialize response: {}", e);
198 error!("Raw response: {}", raw_response_text);
199 return Err(Error::Serialization(e));
200 }
201 };
202
203 debug!("Converting Gemini response to Message");
205 let message = Message::from(&gemini_response);
206
207 info!("Response parsed successfully");
208 trace!("Response message processed");
209
210 Ok(message)
211 }
212}
213
214pub trait GeminiModelInfo {
216 fn gemini_model_id(&self) -> String;
217}
218
219impl GeminiProvider {
220 #[instrument(skip(self, chat), level = "debug")]
225 fn create_request_payload(&self, model: Gemini, chat: &Chat) -> Result<GeminiRequest> {
226 info!("Creating request payload for chat with Gemini model");
227 debug!("System prompt length: {}", chat.system_prompt.len());
228 debug!("Messages in history: {}", chat.history.len());
229 debug!("Max output tokens: {}", chat.max_output_tokens);
230
231 let system_instruction = if !chat.system_prompt.is_empty() {
233 debug!("Including system prompt in request");
234 trace!("System prompt: {}", chat.system_prompt);
235 Some(GeminiContent {
236 parts: vec![GeminiPart::text(chat.system_prompt.clone())],
237 role: None,
238 })
239 } else {
240 debug!("No system prompt provided");
241 None
242 };
243
244 debug!("Converting messages to Gemini format");
246 let mut contents: Vec<GeminiContent> = Vec::new();
247 let mut current_role_str: Option<&'static str> = None;
248 let mut current_parts: Vec<GeminiPart> = Vec::new();
249
250 for msg in &chat.history {
251 let msg_role_str = msg.role_str();
253
254 if current_role_str.is_some()
256 && current_role_str != Some(msg_role_str)
257 && !current_parts.is_empty()
258 {
259 let role = match current_role_str {
260 Some("user") => Some("user".to_string()),
261 Some("assistant") => Some("model".to_string()),
262 _ => None,
263 };
264
265 contents.push(GeminiContent {
266 parts: std::mem::take(&mut current_parts),
267 role,
268 });
269 }
270
271 current_role_str = Some(msg_role_str);
272
273 match msg {
275 Message::System { content, .. } => {
276 current_parts.push(GeminiPart::text(content.clone()));
277 }
278 Message::User { content, .. } => match content {
279 Content::Text(text) => {
280 current_parts.push(GeminiPart::text(text.clone()));
281 }
282 Content::Parts(parts) => {
283 for part in parts {
284 match part {
285 ContentPart::Text { text } => {
286 current_parts.push(GeminiPart::text(text.clone()));
287 }
288 ContentPart::ImageUrl { image_url } => {
289 current_parts.push(GeminiPart::inline_data(
290 image_url.url.clone(),
291 "image/jpeg".to_string(),
292 ));
293 }
294 }
295 }
296 }
297 },
298 Message::Assistant { content, .. } => {
299 if let Some(content_data) = content {
300 match content_data {
301 Content::Text(text) => {
302 current_parts.push(GeminiPart::text(text.clone()));
303 }
304 Content::Parts(parts) => {
305 for part in parts {
306 match part {
307 ContentPart::Text { text } => {
308 current_parts.push(GeminiPart::text(text.clone()));
309 }
310 ContentPart::ImageUrl { image_url } => {
311 current_parts.push(GeminiPart::inline_data(
312 image_url.url.clone(),
313 "image/jpeg".to_string(),
314 ));
315 }
316 }
317 }
318 }
319 }
320 }
321 }
322 Message::Tool {
323 tool_call_id,
324 content,
325 ..
326 } => {
327 current_parts.push(GeminiPart::text(format!(
329 "Tool result for call {}: {}",
330 tool_call_id, content
331 )));
332 }
333 }
334 }
335
336 if !current_parts.is_empty() {
338 let role = match current_role_str {
339 Some("user") => Some("user".to_string()),
340 Some("assistant") => Some("model".to_string()),
341 _ => None,
342 };
343
344 contents.push(GeminiContent {
345 parts: current_parts,
346 role,
347 });
348 }
349
350 debug!("Converted {} contents for the request", contents.len());
351
352 let generation_config = Some(GeminiGenerationConfig {
354 max_output_tokens: Some(chat.max_output_tokens),
355 temperature: None,
356 top_p: None,
357 top_k: None,
358 stop_sequences: None,
359 });
360
361 let tools = chat.tools.as_ref().map(|tools| {
363 vec![GeminiTool {
364 function_declarations: tools.iter().map(GeminiFunctionDeclaration::from).collect(),
365 }]
366 });
367
368 let tool_config = if let Some(choice) = &chat.tool_choice {
373 match choice {
374 crate::tool::ToolChoice::Auto => Some(GeminiToolConfig {
375 function_calling_config: GeminiFunctionCallingConfig {
376 mode: "auto".to_string(),
377 allowed_function_names: None,
378 },
379 }),
380 crate::tool::ToolChoice::Any => Some(GeminiToolConfig {
381 function_calling_config: GeminiFunctionCallingConfig {
382 mode: "any".to_string(),
383 allowed_function_names: None,
384 },
385 }),
386 crate::tool::ToolChoice::None => Some(GeminiToolConfig {
387 function_calling_config: GeminiFunctionCallingConfig {
388 mode: "none".to_string(),
389 allowed_function_names: None,
390 },
391 }),
392 crate::tool::ToolChoice::Specific(name) => Some(GeminiToolConfig {
393 function_calling_config: GeminiFunctionCallingConfig {
394 mode: "auto".to_string(), allowed_function_names: Some(vec![name.clone()]),
396 },
397 }),
398 }
399 } else if tools.is_some() {
400 Some(GeminiToolConfig {
402 function_calling_config: GeminiFunctionCallingConfig {
403 mode: "auto".to_string(),
404 allowed_function_names: None,
405 },
406 })
407 } else {
408 None
409 };
410
411 debug!("Creating GeminiRequest");
413 let request = GeminiRequest {
414 contents,
415 system_instruction,
416 generation_config,
417 tools,
418 tool_config,
419 };
420
421 info!("Request payload created successfully");
422 Ok(request)
423 }
424}
425
426#[derive(Debug, Clone, Serialize, Deserialize)]
428pub(crate) struct GeminiPart {
429 #[serde(skip_serializing_if = "Option::is_none")]
431 pub text: Option<String>,
432
433 #[serde(skip_serializing_if = "Option::is_none")]
435 pub inline_data: Option<GeminiInlineData>,
436
437 #[serde(skip_serializing_if = "Option::is_none", rename = "functionCall")]
439 pub function_call: Option<GeminiFunctionCall>,
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub(crate) struct GeminiFunctionCall {
445 pub name: String,
447 pub args: serde_json::Value,
449}
450
451impl GeminiPart {
452 fn text(text: String) -> Self {
454 GeminiPart {
455 text: Some(text),
456 inline_data: None,
457 function_call: None,
458 }
459 }
460
461 fn inline_data(data: String, mime_type: String) -> Self {
463 GeminiPart {
464 text: None,
465 inline_data: Some(GeminiInlineData { data, mime_type }),
466 function_call: None,
467 }
468 }
469}
470
471#[derive(Debug, Clone, Serialize, Deserialize)]
473pub(crate) struct GeminiInlineData {
474 pub data: String,
476 pub mime_type: String,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
482pub(crate) struct GeminiContent {
483 pub parts: Vec<GeminiPart>,
485 #[serde(skip_serializing_if = "Option::is_none")]
487 pub role: Option<String>,
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize)]
492pub(crate) struct GeminiGenerationConfig {
493 #[serde(skip_serializing_if = "Option::is_none")]
495 pub max_output_tokens: Option<usize>,
496 #[serde(skip_serializing_if = "Option::is_none")]
498 pub temperature: Option<f32>,
499 #[serde(skip_serializing_if = "Option::is_none")]
501 pub top_p: Option<f32>,
502 #[serde(skip_serializing_if = "Option::is_none")]
504 pub top_k: Option<u32>,
505 #[serde(skip_serializing_if = "Option::is_none")]
507 pub stop_sequences: Option<Vec<String>>,
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
512pub(crate) struct GeminiFunctionDeclaration {
513 pub name: String,
515 pub description: String,
517 pub parameters: serde_json::Value,
519}
520
521impl From<&LlmToolInfo> for GeminiFunctionDeclaration {
522 fn from(value: &LlmToolInfo) -> Self {
523 GeminiFunctionDeclaration {
524 name: value.name.clone(),
525 description: value.description.clone(),
526 parameters: value.parameters.clone(),
527 }
528 }
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize)]
533pub(crate) struct GeminiFunction {
534 pub name: String,
536 pub description: String,
538 pub parameters: serde_json::Value,
540}
541
542#[derive(Debug, Clone, Serialize, Deserialize)]
544pub(crate) struct GeminiTool {
545 #[serde(rename = "functionDeclarations")]
547 pub function_declarations: Vec<GeminiFunctionDeclaration>,
548}
549
550#[derive(Debug, Clone, Serialize, Deserialize)]
554pub(crate) struct GeminiToolConfig {
555 #[serde(rename = "function_calling_config")]
557 pub function_calling_config: GeminiFunctionCallingConfig,
558}
559
560#[derive(Debug, Clone, Serialize, Deserialize)]
562pub(crate) struct GeminiFunctionCallingConfig {
563 pub mode: String,
565 #[serde(skip_serializing_if = "Option::is_none")]
567 pub allowed_function_names: Option<Vec<String>>,
568}
569
570#[derive(Debug, Serialize, Deserialize)]
572pub(crate) struct GeminiRequest {
573 pub contents: Vec<GeminiContent>,
575 #[serde(skip_serializing_if = "Option::is_none")]
577 pub system_instruction: Option<GeminiContent>,
578 #[serde(skip_serializing_if = "Option::is_none")]
580 pub generation_config: Option<GeminiGenerationConfig>,
581 #[serde(skip_serializing_if = "Option::is_none")]
583 pub tools: Option<Vec<GeminiTool>>,
584 #[serde(skip_serializing_if = "Option::is_none")]
586 pub tool_config: Option<GeminiToolConfig>,
587}
588
589#[derive(Debug, Serialize, Deserialize)]
591pub(crate) struct GeminiResponse {
592 pub candidates: Vec<GeminiCandidate>,
594 #[serde(rename = "usageMetadata", skip_serializing_if = "Option::is_none")]
596 pub usage_metadata: Option<GeminiUsageMetadata>,
597 #[serde(rename = "modelVersion", skip_serializing_if = "Option::is_none")]
599 pub model_version: Option<String>,
600}
601
602#[derive(Debug, Serialize, Deserialize)]
604pub(crate) struct GeminiCandidate {
605 pub content: GeminiContent,
607 #[serde(skip_serializing_if = "Option::is_none", rename = "finishReason")]
609 pub finish_reason: Option<String>,
610 #[serde(skip_serializing_if = "Option::is_none")]
612 pub index: Option<i32>,
613 #[serde(skip_serializing_if = "Option::is_none", rename = "avgLogprobs")]
615 pub avg_logprobs: Option<f64>,
616}
617
618#[derive(Debug, Serialize, Deserialize)]
620pub(crate) struct GeminiTokenDetails {
621 pub modality: String,
623 #[serde(rename = "tokenCount")]
625 pub token_count: u32,
626}
627
628#[derive(Debug, Serialize, Deserialize)]
630pub(crate) struct GeminiUsageMetadata {
631 #[serde(rename = "promptTokenCount")]
633 pub prompt_token_count: u32,
634 #[serde(rename = "candidatesTokenCount", default)]
636 pub candidates_token_count: u32,
637 #[serde(rename = "totalTokenCount", default)]
639 pub total_token_count: u32,
640 #[serde(
642 rename = "promptTokensDetails",
643 skip_serializing_if = "Option::is_none"
644 )]
645 pub prompt_tokens_details: Option<Vec<GeminiTokenDetails>>,
646 #[serde(
648 rename = "candidatesTokensDetails",
649 skip_serializing_if = "Option::is_none"
650 )]
651 pub candidates_tokens_details: Option<Vec<GeminiTokenDetails>>,
652}
653
654#[derive(Debug, Serialize, Deserialize)]
656pub(crate) struct GeminiErrorResponse {
657 pub error: Option<GeminiError>,
659}
660
661#[derive(Debug, Serialize, Deserialize)]
663pub(crate) struct GeminiError {
664 pub code: i32,
666 pub message: String,
668 pub status: String,
670}
671
672impl From<&GeminiResponse> for Message {
674 fn from(response: &GeminiResponse) -> Self {
675 if response.candidates.is_empty() {
677 return Message::assistant("No response generated");
678 }
679
680 let candidate = &response.candidates[0];
682
683 let mut text_content_parts = Vec::new();
685 let mut tool_calls = Vec::new();
686 let mut tool_call_id_counter = 0;
687
688 for part in &candidate.content.parts {
690 if let Some(function_call) = &part.function_call {
692 tool_call_id_counter += 1;
693 let tool_id = format!("gemini_call_{}", tool_call_id_counter);
694
695 let args_str =
696 serde_json::to_string(&function_call.args).unwrap_or_else(|_| "{}".to_string());
697
698 let tool_call = crate::message::ToolCall {
699 id: tool_id,
700 tool_type: "function".to_string(),
701 function: crate::message::Function {
702 name: function_call.name.clone(),
703 arguments: args_str,
704 },
705 };
706
707 tool_calls.push(tool_call);
708 }
709
710 if let Some(text) = &part.text {
712 text_content_parts.push(ContentPart::text(text.clone()));
713 } else if let Some(inline_data) = &part.inline_data {
714 text_content_parts.push(ContentPart::text(format!(
716 "[Image: {} ({})]",
717 inline_data.data, inline_data.mime_type
718 )));
719 }
720 }
721
722 let content = if text_content_parts.len() == 1 {
724 match &text_content_parts[0] {
726 ContentPart::Text { text } => Some(Content::Text(text.clone())),
727 _ => Some(Content::Parts(text_content_parts)),
728 }
729 } else if !text_content_parts.is_empty() {
730 Some(Content::Parts(text_content_parts))
732 } else {
733 None
735 };
736
737 let mut msg = if !tool_calls.is_empty() {
739 Message::Assistant {
741 content,
742 tool_calls,
743 metadata: Default::default(),
744 }
745 } else if let Some(Content::Text(text)) = content {
746 Message::assistant(text)
748 } else {
749 Message::Assistant {
751 content,
752 tool_calls: Vec::new(),
753 metadata: Default::default(),
754 }
755 };
756
757 if let Some(usage) = &response.usage_metadata {
759 msg = msg.with_metadata(
760 "prompt_tokens",
761 serde_json::Value::Number(usage.prompt_token_count.into()),
762 );
763 msg = msg.with_metadata(
764 "completion_tokens",
765 serde_json::Value::Number(usage.candidates_token_count.into()),
766 );
767 msg = msg.with_metadata(
768 "total_tokens",
769 serde_json::Value::Number(usage.total_token_count.into()),
770 );
771 }
772
773 msg
774 }
775}
776
777#[cfg(test)]
778mod tests {
779 use super::*;
780
781 #[test]
783 fn test_gemini_part_serialization() {
784 let text_part = GeminiPart::text("Hello, world!".to_string());
785 let serialized = serde_json::to_string(&text_part).unwrap();
786 let expected = r#"{"text":"Hello, world!"}"#;
787 assert_eq!(serialized, expected);
788
789 let inline_data_part =
790 GeminiPart::inline_data("base64data".to_string(), "image/jpeg".to_string());
791 let serialized = serde_json::to_string(&inline_data_part).unwrap();
792 let expected = r#"{"inline_data":{"data":"base64data","mime_type":"image/jpeg"}}"#;
793 assert_eq!(serialized, expected);
794 }
795
796 #[test]
797 fn test_error_response_parsing() {
798 let error_json = r#"{
799 "error": {
800 "code": 400,
801 "message": "Invalid JSON payload received.",
802 "status": "INVALID_ARGUMENT"
803 }
804 }"#;
805
806 let error_response: GeminiErrorResponse = serde_json::from_str(error_json).unwrap();
807 assert!(error_response.error.is_some());
808 let error = error_response.error.unwrap();
809 assert_eq!(error.code, 400);
810 assert_eq!(error.status, "INVALID_ARGUMENT");
811 }
812}