1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, ModelInfo};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10#[derive(Debug, Clone)]
12pub struct GeminiConfig {
13 pub api_key: String,
15 pub base_url: String,
17}
18
19impl Default for GeminiConfig {
20 fn default() -> Self {
21 Self {
22 api_key: env::var("GEMINI_API_KEY").unwrap_or_default(),
23 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct GeminiProvider {
31 config: GeminiConfig,
33}
34
35impl GeminiProvider {
36 #[instrument(level = "debug")]
48 pub fn new() -> Self {
49 info!("Creating new GeminiProvider with default configuration");
50 let config = GeminiConfig::default();
51 debug!("API key set: {}", !config.api_key.is_empty());
52 debug!("Base URL: {}", config.base_url);
53
54 Self { config }
55 }
56
57 #[instrument(skip(config), level = "debug")]
72 pub fn with_config(config: GeminiConfig) -> Self {
73 info!("Creating new GeminiProvider with custom configuration");
74 debug!("API key set: {}", !config.api_key.is_empty());
75 debug!("Base URL: {}", config.base_url);
76
77 Self { config }
78 }
79}
80
81impl Default for GeminiProvider {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl<M: ModelInfo + GeminiModelInfo> HTTPProvider<M> for GeminiProvider {
88 fn accept(&self, chat: Chat<M>) -> Result<Request> {
89 info!("Creating request for Gemini model: {:?}", chat.model);
90 debug!("Messages in chat history: {}", chat.history.len());
91
92 let model_id = chat.model.gemini_model_id();
93 let url_str = format!(
94 "{}/models/{}:generateContent?key={}",
95 self.config.base_url, model_id, self.config.api_key
96 );
97
98 debug!("Parsing URL: {}", url_str);
99 let url = match Url::parse(&url_str) {
100 Ok(url) => {
101 debug!("URL parsed successfully: {}", url);
102 url
103 }
104 Err(e) => {
105 error!("Failed to parse URL '{}': {}", url_str, e);
106 return Err(e.into());
107 }
108 };
109
110 let mut request = Request::new(Method::POST, url);
111 debug!("Created request: {} {}", request.method(), request.url());
112
113 debug!("Setting request headers");
115 let content_type_header = match "application/json".parse() {
116 Ok(header) => header,
117 Err(e) => {
118 error!("Failed to set content type: {}", e);
119 return Err(Error::Other("Failed to set content type".into()));
120 }
121 };
122
123 request
124 .headers_mut()
125 .insert("Content-Type", content_type_header);
126
127 trace!("Request headers set: {:#?}", request.headers());
128
129 debug!("Creating request payload");
131 let payload = match self.create_request_payload(&chat) {
132 Ok(payload) => {
133 debug!("Request payload created successfully");
134 trace!("Number of contents: {}", payload.contents.len());
135 trace!(
136 "System instruction present: {}",
137 payload.system_instruction.is_some()
138 );
139 trace!(
140 "Generation config present: {}",
141 payload.generation_config.is_some()
142 );
143 payload
144 }
145 Err(e) => {
146 error!("Failed to create request payload: {}", e);
147 return Err(e);
148 }
149 };
150
151 debug!("Serializing request payload");
153 let body_bytes = match serde_json::to_vec(&payload) {
154 Ok(bytes) => {
155 debug!("Payload serialized successfully ({} bytes)", bytes.len());
156 bytes
157 }
158 Err(e) => {
159 error!("Failed to serialize payload: {}", e);
160 return Err(Error::Serialization(e));
161 }
162 };
163
164 *request.body_mut() = Some(body_bytes.into());
165 info!("Request created successfully");
166
167 Ok(request)
168 }
169
170 fn parse(&self, raw_response_text: String) -> Result<Message> {
171 info!("Parsing response from Gemini API");
172 trace!("Raw response: {}", raw_response_text);
173
174 if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&raw_response_text)
176 {
177 if let Some(error) = error_response.error {
178 error!("Gemini API returned an error: {}", error.message);
179 return Err(Error::ProviderUnavailable(error.message));
180 }
181 }
182
183 debug!("Deserializing response JSON");
185 let gemini_response = match serde_json::from_str::<GeminiResponse>(&raw_response_text) {
186 Ok(response) => {
187 debug!("Response deserialized successfully");
188 if !response.candidates.is_empty() {
189 debug!(
190 "Content parts: {}",
191 response.candidates[0].content.parts.len()
192 );
193 }
194 response
195 }
196 Err(e) => {
197 error!("Failed to deserialize response: {}", e);
198 error!("Raw response: {}", raw_response_text);
199 return Err(Error::Serialization(e));
200 }
201 };
202
203 debug!("Converting Gemini response to Message");
205 let message = Message::from(&gemini_response);
206
207 info!("Response parsed successfully");
208 trace!("Response message processed");
209
210 Ok(message)
211 }
212}
213
214pub trait GeminiModelInfo {
216 fn gemini_model_id(&self) -> String;
217}
218
219impl GeminiProvider {
220 #[instrument(skip(self, chat), level = "debug")]
225 fn create_request_payload<M: ModelInfo + GeminiModelInfo>(
226 &self,
227 chat: &Chat<M>,
228 ) -> Result<GeminiRequest> {
229 info!("Creating request payload for chat with Gemini model");
230 debug!("System prompt length: {}", chat.system_prompt.len());
231 debug!("Messages in history: {}", chat.history.len());
232 debug!("Max output tokens: {}", chat.max_output_tokens);
233
234 let system_instruction = if !chat.system_prompt.is_empty() {
236 debug!("Including system prompt in request");
237 trace!("System prompt: {}", chat.system_prompt);
238 Some(GeminiContent {
239 parts: vec![GeminiPart::text(chat.system_prompt.clone())],
240 role: None,
241 })
242 } else {
243 debug!("No system prompt provided");
244 None
245 };
246
247 debug!("Converting messages to Gemini format");
249 let mut contents: Vec<GeminiContent> = Vec::new();
250 let mut current_role_str: Option<&'static str> = None;
251 let mut current_parts: Vec<GeminiPart> = Vec::new();
252
253 for msg in &chat.history {
254 let msg_role_str = msg.role_str();
256
257 if current_role_str.is_some()
259 && current_role_str != Some(msg_role_str)
260 && !current_parts.is_empty()
261 {
262 let role = match current_role_str {
263 Some("user") => Some("user".to_string()),
264 Some("assistant") => Some("model".to_string()),
265 _ => None,
266 };
267
268 contents.push(GeminiContent {
269 parts: std::mem::take(&mut current_parts),
270 role,
271 });
272 }
273
274 current_role_str = Some(msg_role_str);
275
276 match msg {
278 Message::System { content, .. } => {
279 current_parts.push(GeminiPart::text(content.clone()));
280 }
281 Message::User { content, .. } => match content {
282 Content::Text(text) => {
283 current_parts.push(GeminiPart::text(text.clone()));
284 }
285 Content::Parts(parts) => {
286 for part in parts {
287 match part {
288 ContentPart::Text { text } => {
289 current_parts.push(GeminiPart::text(text.clone()));
290 }
291 ContentPart::ImageUrl { image_url } => {
292 current_parts.push(GeminiPart::inline_data(
293 image_url.url.clone(),
294 "image/jpeg".to_string(),
295 ));
296 }
297 }
298 }
299 }
300 },
301 Message::Assistant { content, .. } => {
302 if let Some(content_data) = content {
303 match content_data {
304 Content::Text(text) => {
305 current_parts.push(GeminiPart::text(text.clone()));
306 }
307 Content::Parts(parts) => {
308 for part in parts {
309 match part {
310 ContentPart::Text { text } => {
311 current_parts.push(GeminiPart::text(text.clone()));
312 }
313 ContentPart::ImageUrl { image_url } => {
314 current_parts.push(GeminiPart::inline_data(
315 image_url.url.clone(),
316 "image/jpeg".to_string(),
317 ));
318 }
319 }
320 }
321 }
322 }
323 }
324 }
325 Message::Tool {
326 tool_call_id,
327 content,
328 ..
329 } => {
330 current_parts.push(GeminiPart::text(format!(
332 "Tool result for call {}: {}",
333 tool_call_id, content
334 )));
335 }
336 }
337 }
338
339 if !current_parts.is_empty() {
341 let role = match current_role_str {
342 Some("user") => Some("user".to_string()),
343 Some("assistant") => Some("model".to_string()),
344 _ => None,
345 };
346
347 contents.push(GeminiContent {
348 parts: current_parts,
349 role,
350 });
351 }
352
353 debug!("Converted {} contents for the request", contents.len());
354
355 let generation_config = Some(GeminiGenerationConfig {
357 max_output_tokens: Some(chat.max_output_tokens),
358 temperature: None,
359 top_p: None,
360 top_k: None,
361 stop_sequences: None,
362 });
363
364 let tools = chat.tools.as_ref().map(|tools| {
366 vec![GeminiTool {
367 function_declarations: tools
368 .iter()
369 .map(GeminiFunctionDeclaration::from)
370 .collect(),
371 }]
372 });
373
374 debug!("Creating GeminiRequest");
376 let request = GeminiRequest {
377 contents,
378 system_instruction,
379 generation_config,
380 tools,
381 };
382
383 info!("Request payload created successfully");
384 Ok(request)
385 }
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize)]
390pub(crate) struct GeminiPart {
391 #[serde(skip_serializing_if = "Option::is_none")]
393 pub text: Option<String>,
394
395 #[serde(skip_serializing_if = "Option::is_none")]
397 pub inline_data: Option<GeminiInlineData>,
398
399 #[serde(skip_serializing_if = "Option::is_none", rename = "functionCall")]
401 pub function_call: Option<GeminiFunctionCall>,
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub(crate) struct GeminiFunctionCall {
407 pub name: String,
409 pub args: serde_json::Value,
411}
412
413impl GeminiPart {
414 fn text(text: String) -> Self {
416 GeminiPart {
417 text: Some(text),
418 inline_data: None,
419 function_call: None,
420 }
421 }
422
423 fn inline_data(data: String, mime_type: String) -> Self {
425 GeminiPart {
426 text: None,
427 inline_data: Some(GeminiInlineData { data, mime_type }),
428 function_call: None,
429 }
430 }
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize)]
435pub(crate) struct GeminiInlineData {
436 pub data: String,
438 pub mime_type: String,
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub(crate) struct GeminiContent {
445 pub parts: Vec<GeminiPart>,
447 #[serde(skip_serializing_if = "Option::is_none")]
449 pub role: Option<String>,
450}
451
452#[derive(Debug, Clone, Serialize, Deserialize)]
454pub(crate) struct GeminiGenerationConfig {
455 #[serde(skip_serializing_if = "Option::is_none")]
457 pub max_output_tokens: Option<usize>,
458 #[serde(skip_serializing_if = "Option::is_none")]
460 pub temperature: Option<f32>,
461 #[serde(skip_serializing_if = "Option::is_none")]
463 pub top_p: Option<f32>,
464 #[serde(skip_serializing_if = "Option::is_none")]
466 pub top_k: Option<u32>,
467 #[serde(skip_serializing_if = "Option::is_none")]
469 pub stop_sequences: Option<Vec<String>>,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
474pub(crate) struct GeminiFunctionDeclaration {
475 pub name: String,
477 pub description: String,
479 pub parameters: serde_json::Value,
481}
482
483impl From<&LlmToolInfo> for GeminiFunctionDeclaration {
484 fn from(value: &LlmToolInfo) -> Self {
485 GeminiFunctionDeclaration {
486 name: value.name.clone(),
487 description: value.description.clone(),
488 parameters: value.parameters.clone(),
489 }
490 }
491}
492
493#[derive(Debug, Clone, Serialize, Deserialize)]
495pub(crate) struct GeminiFunction {
496 pub name: String,
498 pub description: String,
500 pub parameters: serde_json::Value,
502}
503
504#[derive(Debug, Clone, Serialize, Deserialize)]
506pub(crate) struct GeminiTool {
507 #[serde(rename = "functionDeclarations")]
509 pub function_declarations: Vec<GeminiFunctionDeclaration>,
510}
511
512#[derive(Debug, Serialize, Deserialize)]
514pub(crate) struct GeminiRequest {
515 pub contents: Vec<GeminiContent>,
517 #[serde(skip_serializing_if = "Option::is_none")]
519 pub system_instruction: Option<GeminiContent>,
520 #[serde(skip_serializing_if = "Option::is_none")]
522 pub generation_config: Option<GeminiGenerationConfig>,
523 #[serde(skip_serializing_if = "Option::is_none")]
525 pub tools: Option<Vec<GeminiTool>>,
526}
527
528#[derive(Debug, Serialize, Deserialize)]
530pub(crate) struct GeminiResponse {
531 pub candidates: Vec<GeminiCandidate>,
533 #[serde(rename = "usageMetadata", skip_serializing_if = "Option::is_none")]
535 pub usage_metadata: Option<GeminiUsageMetadata>,
536 #[serde(rename = "modelVersion", skip_serializing_if = "Option::is_none")]
538 pub model_version: Option<String>,
539}
540
541#[derive(Debug, Serialize, Deserialize)]
543pub(crate) struct GeminiCandidate {
544 pub content: GeminiContent,
546 #[serde(skip_serializing_if = "Option::is_none", rename = "finishReason")]
548 pub finish_reason: Option<String>,
549 #[serde(skip_serializing_if = "Option::is_none")]
551 pub index: Option<i32>,
552 #[serde(skip_serializing_if = "Option::is_none", rename = "avgLogprobs")]
554 pub avg_logprobs: Option<f64>,
555}
556
557#[derive(Debug, Serialize, Deserialize)]
559pub(crate) struct GeminiTokenDetails {
560 pub modality: String,
562 #[serde(rename = "tokenCount")]
564 pub token_count: u32,
565}
566
567#[derive(Debug, Serialize, Deserialize)]
569pub(crate) struct GeminiUsageMetadata {
570 #[serde(rename = "promptTokenCount")]
572 pub prompt_token_count: u32,
573 #[serde(rename = "candidatesTokenCount", default)]
575 pub candidates_token_count: u32,
576 #[serde(rename = "totalTokenCount", default)]
578 pub total_token_count: u32,
579 #[serde(
581 rename = "promptTokensDetails",
582 skip_serializing_if = "Option::is_none"
583 )]
584 pub prompt_tokens_details: Option<Vec<GeminiTokenDetails>>,
585 #[serde(
587 rename = "candidatesTokensDetails",
588 skip_serializing_if = "Option::is_none"
589 )]
590 pub candidates_tokens_details: Option<Vec<GeminiTokenDetails>>,
591}
592
593#[derive(Debug, Serialize, Deserialize)]
595pub(crate) struct GeminiErrorResponse {
596 pub error: Option<GeminiError>,
598}
599
600#[derive(Debug, Serialize, Deserialize)]
602pub(crate) struct GeminiError {
603 pub code: i32,
605 pub message: String,
607 pub status: String,
609}
610
611impl From<&GeminiResponse> for Message {
613 fn from(response: &GeminiResponse) -> Self {
614 if response.candidates.is_empty() {
616 return Message::assistant("No response generated");
617 }
618
619 let candidate = &response.candidates[0];
621
622 let mut text_content_parts = Vec::new();
624 let mut tool_calls = Vec::new();
625 let mut tool_call_id_counter = 0;
626
627 for part in &candidate.content.parts {
629 if let Some(function_call) = &part.function_call {
631 tool_call_id_counter += 1;
632 let tool_id = format!("gemini_call_{}", tool_call_id_counter);
633
634 let args_str =
635 serde_json::to_string(&function_call.args).unwrap_or_else(|_| "{}".to_string());
636
637 let tool_call = crate::message::ToolCall {
638 id: tool_id,
639 tool_type: "function".to_string(),
640 function: crate::message::Function {
641 name: function_call.name.clone(),
642 arguments: args_str,
643 },
644 };
645
646 tool_calls.push(tool_call);
647 }
648
649 if let Some(text) = &part.text {
651 text_content_parts.push(ContentPart::text(text.clone()));
652 } else if let Some(inline_data) = &part.inline_data {
653 text_content_parts.push(ContentPart::text(format!(
655 "[Image: {} ({})]",
656 inline_data.data, inline_data.mime_type
657 )));
658 }
659 }
660
661 let content = if text_content_parts.len() == 1 {
663 match &text_content_parts[0] {
665 ContentPart::Text { text } => Some(Content::Text(text.clone())),
666 _ => Some(Content::Parts(text_content_parts)),
667 }
668 } else if !text_content_parts.is_empty() {
669 Some(Content::Parts(text_content_parts))
671 } else {
672 None
674 };
675
676 let mut msg = if !tool_calls.is_empty() {
678 Message::Assistant {
680 content,
681 tool_calls,
682 metadata: Default::default(),
683 }
684 } else if let Some(Content::Text(text)) = content {
685 Message::assistant(text)
687 } else {
688 Message::Assistant {
690 content,
691 tool_calls: Vec::new(),
692 metadata: Default::default(),
693 }
694 };
695
696 if let Some(usage) = &response.usage_metadata {
698 msg = msg.with_metadata(
699 "prompt_tokens",
700 serde_json::Value::Number(usage.prompt_token_count.into()),
701 );
702 msg = msg.with_metadata(
703 "completion_tokens",
704 serde_json::Value::Number(usage.candidates_token_count.into()),
705 );
706 msg = msg.with_metadata(
707 "total_tokens",
708 serde_json::Value::Number(usage.total_token_count.into()),
709 );
710 }
711
712 msg
713 }
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719
720 #[test]
722 fn test_gemini_part_serialization() {
723 let text_part = GeminiPart::text("Hello, world!".to_string());
724 let serialized = serde_json::to_string(&text_part).unwrap();
725 let expected = r#"{"text":"Hello, world!"}"#;
726 assert_eq!(serialized, expected);
727
728 let inline_data_part =
729 GeminiPart::inline_data("base64data".to_string(), "image/jpeg".to_string());
730 let serialized = serde_json::to_string(&inline_data_part).unwrap();
731 let expected = r#"{"inline_data":{"data":"base64data","mime_type":"image/jpeg"}}"#;
732 assert_eq!(serialized, expected);
733 }
734
735 #[test]
736 fn test_error_response_parsing() {
737 let error_json = r#"{
738 "error": {
739 "code": 400,
740 "message": "Invalid JSON payload received.",
741 "status": "INVALID_ARGUMENT"
742 }
743 }"#;
744
745 let error_response: GeminiErrorResponse = serde_json::from_str(error_json).unwrap();
746 assert!(error_response.error.is_some());
747 let error = error_response.error.unwrap();
748 assert_eq!(error.code, 400);
749 assert_eq!(error.status, "INVALID_ARGUMENT");
750 }
751}