use std::time::Duration;
use reqwest::{Client, ClientBuilder, Response};
use serde_json::{Value, json};
use tokio::time::timeout;
use crate::core::providers::base::{
HeaderPair, apply_headers, header, header_owned, header_static,
};
use crate::core::providers::unified_provider::ProviderError;
use crate::core::types::{
chat::ChatMessage,
chat::ChatRequest,
content::ContentPart,
message::MessageContent,
message::MessageRole,
responses::{ChatChoice, ChatResponse, Usage},
tools::{FunctionCall, ToolCall},
};
use super::config::GeminiConfig;
use super::error::{
GeminiErrorMapper, gemini_multimodal_error, gemini_network_error, gemini_parse_error,
};
#[derive(Debug, Clone)]
pub struct GeminiClient {
config: GeminiConfig,
http_client: Client,
}
impl GeminiClient {
pub fn new(config: GeminiConfig) -> Result<Self, ProviderError> {
let mut builder = ClientBuilder::new()
.timeout(Duration::from_secs(config.request_timeout))
.connect_timeout(Duration::from_secs(config.connect_timeout));
if let Some(proxy_url) = &config.proxy_url {
let proxy = reqwest::Proxy::all(proxy_url)
.map_err(|e| gemini_network_error(format!("Invalid proxy URL: {}", e)))?;
builder = builder.proxy(proxy);
}
let http_client = builder
.build()
.map_err(|e| gemini_network_error(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
config,
http_client,
})
}
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
let gemini_request = self.transform_chat_request(&request)?;
let endpoint = "generateContent";
let response = self
.send_request(&request.model, endpoint, gemini_request)
.await?;
self.transform_chat_response(response, &request)
}
pub async fn chat_stream(
&self,
request: ChatRequest,
) -> Result<reqwest::Response, ProviderError> {
let gemini_request = self.transform_chat_request(&request)?;
let endpoint = "streamGenerateContent";
self.send_stream_request(&request.model, endpoint, gemini_request)
.await
}
async fn send_request(
&self,
model: &str,
operation: &str,
body: Value,
) -> Result<Value, ProviderError> {
let url = self.config.get_endpoint(model, operation);
let headers = self.get_request_headers();
if self.config.debug {
tracing::debug!("Gemini request URL: {}", url);
tracing::debug!(
"Gemini request body: {}",
serde_json::to_string_pretty(&body).unwrap_or_default()
);
}
let response = timeout(
Duration::from_secs(self.config.request_timeout),
apply_headers(self.http_client.post(&url).json(&body), headers).send(),
)
.await
.map_err(|_| gemini_network_error("Request timeout"))?
.map_err(|e| gemini_network_error(format!("Network error: {}", e)))?;
self.handle_response(response).await
}
async fn send_stream_request(
&self,
model: &str,
operation: &str,
body: Value,
) -> Result<Response, ProviderError> {
let url = self.config.get_endpoint(model, operation);
let headers = self.get_request_headers();
if self.config.debug {
tracing::debug!("Gemini stream request URL: {}", url);
tracing::debug!(
"Gemini stream request body: {}",
serde_json::to_string_pretty(&body).unwrap_or_default()
);
}
let response = timeout(
Duration::from_secs(self.config.request_timeout),
apply_headers(self.http_client.post(&url).json(&body), headers).send(),
)
.await
.map_err(|_| gemini_network_error("Request timeout"))?
.map_err(|e| gemini_network_error(format!("Network error: {}", e)))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.map_err(|e| {
gemini_network_error(format!("Failed to read error response: {}", e))
})?;
return Err(GeminiErrorMapper::from_http_status(
status.as_u16(),
&error_text,
));
}
Ok(response)
}
fn get_request_headers(&self) -> Vec<HeaderPair> {
let mut headers = Vec::with_capacity(4);
headers.push(header_static("Content-Type", "application/json"));
if self.config.use_vertex_ai
&& let Some(api_key) = &self.config.api_key
{
headers.push(header("Authorization", format!("Bearer {}", api_key)));
}
for (key, value) in &self.config.custom_headers {
headers.push(header_owned(key.clone(), value.clone()));
}
headers
}
async fn handle_response(&self, response: Response) -> Result<Value, ProviderError> {
let status = response.status();
let response_text = response
.text()
.await
.map_err(|e| gemini_network_error(format!("Failed to read response: {}", e)))?;
if self.config.debug {
tracing::debug!("Gemini response status: {}", status);
tracing::debug!("Gemini response body: {}", response_text);
}
if !status.is_success() {
return Err(GeminiErrorMapper::from_http_status(
status.as_u16(),
&response_text,
));
}
let json_response: Value = serde_json::from_str(&response_text)
.map_err(|e| gemini_parse_error(format!("Failed to parse response JSON: {}", e)))?;
if json_response.get("error").is_some() {
return Err(GeminiErrorMapper::from_api_response(&json_response));
}
Ok(json_response)
}
pub fn transform_chat_request(&self, request: &ChatRequest) -> Result<Value, ProviderError> {
let mut contents = Vec::new();
let mut system_parts: Vec<Value> = Vec::new();
for message in &request.messages {
if message.role == MessageRole::System {
if let Some(text) = message.content.as_ref() {
system_parts.push(json!({"text": text.to_string()}));
}
continue;
}
let content = self.transform_message_content(message)?;
let role = match message.role {
MessageRole::System | MessageRole::Developer => {
continue;
}
MessageRole::User => "user",
MessageRole::Assistant => "model",
MessageRole::Tool => "function", MessageRole::Function => "function", };
contents.push(json!({
"role": role,
"parts": content
}));
}
let mut gemini_request = json!({
"contents": contents
});
if !system_parts.is_empty() {
gemini_request["systemInstruction"] = json!({"parts": system_parts});
}
let mut generation_config = json!({});
if let Some(max_tokens) = request.max_tokens {
generation_config["maxOutputTokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
generation_config["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
generation_config["topP"] = json!(top_p);
}
if let Some(stop) = &request.stop {
let stop_sequences = stop.clone();
if !stop_sequences.is_empty() {
generation_config["stopSequences"] = json!(stop_sequences);
}
}
if generation_config
.as_object()
.is_some_and(|obj| !obj.is_empty())
{
gemini_request["generationConfig"] = generation_config;
}
if let Some(safety_settings) = &self.config.safety_settings {
let gemini_safety: Vec<Value> = safety_settings
.iter()
.map(|setting| {
json!({
"category": setting.category,
"threshold": setting.threshold
})
})
.collect();
gemini_request["safetySettings"] = json!(gemini_safety);
}
Ok(gemini_request)
}
fn transform_message_content(
&self,
message: &ChatMessage,
) -> Result<Vec<Value>, ProviderError> {
let mut parts = Vec::new();
match &message.content {
Some(MessageContent::Text(text)) => {
parts.push(json!({
"text": text
}));
}
Some(MessageContent::Parts(content_parts)) => {
for part in content_parts {
match part {
ContentPart::Text { text } => {
parts.push(json!({
"text": text
}));
}
ContentPart::ImageUrl { image_url } => {
if image_url.url.starts_with("data:") {
if let Some((mime_type, data)) =
self.parse_data_url(&image_url.url)?
{
parts.push(json!({
"inlineData": {
"mimeType": mime_type,
"data": data
}
}));
}
} else {
return Err(gemini_multimodal_error(
"External image URLs not supported directly. Please convert to base64 data URL",
));
}
}
ContentPart::Audio { .. } => {
return Err(gemini_multimodal_error(
"Audio content not yet implemented",
));
}
ContentPart::Image { source, .. } => {
parts.push(json!({
"inlineData": {
"mimeType": source.media_type,
"data": source.data
}
}));
}
ContentPart::Document { .. } => {
return Err(gemini_multimodal_error(
"Document content not yet supported in Gemini",
));
}
ContentPart::ToolResult { .. } => {
return Err(gemini_multimodal_error(
"Tool result content should be handled separately",
));
}
ContentPart::ToolUse { .. } => {
return Err(gemini_multimodal_error(
"Tool use content should be handled separately",
));
}
}
}
}
None => {
if let Some(content) = &message.content {
parts.push(json!({
"text": content
}));
}
}
}
if parts.is_empty() {
parts.push(json!({
"text": ""
}));
}
Ok(parts)
}
fn parse_data_url(&self, data_url: &str) -> Result<Option<(String, String)>, ProviderError> {
if !data_url.starts_with("data:") {
return Ok(None);
}
let parts: Vec<&str> = data_url.splitn(2, ',').collect();
if parts.len() != 2 {
return Err(gemini_parse_error("Invalid data URL format"));
}
let header = parts[0];
let data = parts[1];
let mime_parts: Vec<&str> = header.split(';').collect();
let mime_type = mime_parts[0]
.strip_prefix("data:")
.unwrap_or("application/octet-stream");
Ok(Some((mime_type.to_string(), data.to_string())))
}
pub fn transform_chat_response(
&self,
response: Value,
request: &ChatRequest,
) -> Result<ChatResponse, ProviderError> {
let candidates = response
.get("candidates")
.and_then(|c| c.as_array())
.ok_or_else(|| gemini_parse_error("No candidates in response"))?;
let mut choices = Vec::new();
for (index, candidate) in candidates.iter().enumerate() {
let content = candidate
.get("content")
.and_then(|c| c.get("parts"))
.and_then(|p| p.as_array())
.ok_or_else(|| gemini_parse_error("Invalid candidate content structure"))?;
let mut text_parts = Vec::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
for (part_index, part) in content.iter().enumerate() {
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
text_parts.push(text);
}
if let Some(fc) = part.get("functionCall") {
let name = fc
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let args = fc
.get("args")
.map(|a| a.to_string())
.unwrap_or_else(|| "{}".to_string());
tool_calls.push(ToolCall {
id: format!("call_{}_{}", index, part_index),
tool_type: "function".to_string(),
function: FunctionCall {
name,
arguments: args,
},
});
}
}
let message_content = text_parts.join("");
let finish_reason = candidate
.get("finishReason")
.and_then(|r| r.as_str())
.map(|r| match r {
"STOP" => "stop",
"MAX_TOKENS" => "length",
"SAFETY" => "content_filter",
"RECITATION" => "content_filter",
_ => "stop",
})
.unwrap_or("stop");
let msg_content = if message_content.is_empty() && !tool_calls.is_empty() {
None
} else {
Some(MessageContent::Text(message_content))
};
choices.push(ChatChoice {
index: index as u32,
message: crate::core::types::chat::ChatMessage {
role: MessageRole::Assistant,
content: msg_content,
thinking: None,
name: None,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
function_call: None,
},
finish_reason: Some(match finish_reason {
"stop" => crate::core::types::responses::FinishReason::Stop,
"length" => crate::core::types::responses::FinishReason::Length,
"content_filter" => crate::core::types::responses::FinishReason::ContentFilter,
_ => crate::core::types::responses::FinishReason::Stop,
}),
logprobs: None,
});
}
let usage = response.get("usageMetadata").map(|usage_metadata| Usage {
prompt_tokens: usage_metadata
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
completion_tokens: usage_metadata
.get("candidatesTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
total_tokens: usage_metadata
.get("totalTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
prompt_tokens_details: None,
completion_tokens_details: None,
thinking_usage: None,
});
let now = std::time::SystemTime::now();
let nanos = now
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let secs = now
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
Ok(ChatResponse {
id: format!("gemini-{}", nanos),
object: "chat.completion".to_string(),
created: secs,
model: request.model.clone(),
choices,
usage,
system_fingerprint: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let config = GeminiConfig::new_google_ai("test-key");
let client = GeminiClient::new(config);
assert!(client.is_ok());
}
#[test]
fn test_data_url_parsing() {
let config = GeminiConfig::new_google_ai("test-key");
let client = GeminiClient::new(config).unwrap();
let data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==";
let result = client.parse_data_url(data_url).unwrap();
assert!(result.is_some());
let (mime_type, _data) = result.unwrap();
assert_eq!(mime_type, "image/png");
}
#[test]
fn test_message_transformation() {
let config = GeminiConfig::new_google_ai("test-key");
let client = GeminiClient::new(config).unwrap();
let message = ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello, world!".to_string())),
thinking: None,
name: None,
tool_calls: None,
tool_call_id: None,
function_call: None,
};
let parts = client.transform_message_content(&message).unwrap();
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["text"], "Hello, world!");
}
#[test]
fn test_multimodal_message() {
let config = GeminiConfig::new_google_ai("test-key");
let client = GeminiClient::new(config).unwrap();
let message = ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Parts(vec![
ContentPart::Text {
text: "What's in this image?".to_string(),
},
ContentPart::Image {
source: crate::core::types::content::ImageSource {
data: "test".to_string(),
media_type: "image/png".to_string(),
},
image_url: None,
detail: None,
},
])),
thinking: None,
name: None,
tool_calls: None,
tool_call_id: None,
function_call: None,
};
let parts = client.transform_message_content(&message).unwrap();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0]["text"], "What's in this image?");
assert!(parts[1].get("inlineData").is_some());
}
}