use super::types::{
GeminiContent, GeminiFunctionCall, GeminiFunctionDeclaration, GeminiFunctionResponse,
GeminiGenerationConfig, GeminiInlineData, GeminiPart, GeminiRequest, GeminiResponse,
GeminiSystemInstruction, GeminiThinkingConfig, GeminiTool,
};
use crate::error::{Error, Result};
use crate::types::{
ContentPart, FinishReason, FinishReasonKind, GenerateRequest, GenerateResponse,
InputTokenDetails, Message, OutputTokenDetails, ProviderOptions, ResponseContent, Role, Usage,
};
use serde_json::json;
use std::collections::HashMap;
pub fn to_gemini_request(req: &GenerateRequest) -> Result<GeminiRequest> {
use serde_json::json;
let (system_instruction, contents) = convert_messages(&req.messages)?;
let google_opts = if let Some(ProviderOptions::Google(opts)) = &req.provider_options {
Some(opts)
} else {
None
};
let generation_config = Some(GeminiGenerationConfig {
temperature: req.options.temperature,
top_p: req.options.top_p,
top_k: None, max_output_tokens: req.options.max_tokens,
stop_sequences: req.options.stop_sequences.clone(),
response_mime_type: None,
candidate_count: None,
seed: None,
presence_penalty: None,
frequency_penalty: None,
response_logprobs: None,
logprobs: None,
enable_enhanced_civic_answers: None,
thinking_config: google_opts.and_then(|opts| {
opts.thinking_budget.map(|budget| GeminiThinkingConfig {
include_thoughts: Some(true),
thinking_budget: Some(budget),
})
}),
speech_config: None,
media_resolution: None,
response_modalities: None,
});
let tools = req.options.tools.as_ref().map(|tools| {
vec![GeminiTool {
function_declarations: tools
.iter()
.map(|tool| GeminiFunctionDeclaration {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters_json_schema: Some(tool.function.parameters.clone()),
})
.collect::<Vec<_>>(),
}]
});
let tool_config = req.options.tool_choice.as_ref().map(|choice| {
let mode = match choice {
crate::types::ToolChoice::Auto => "AUTO",
crate::types::ToolChoice::None => "NONE",
crate::types::ToolChoice::Required { .. } => "ANY",
};
json!({
"function_calling_config": {
"mode": mode
}
})
});
let cached_content = google_opts.and_then(|opts| opts.cached_content.clone());
Ok(GeminiRequest {
contents,
generation_config,
safety_settings: None, tools,
system_instruction,
tool_config,
cached_content,
})
}
fn convert_messages(
messages: &[Message],
) -> Result<(Option<GeminiSystemInstruction>, Vec<GeminiContent>)> {
let mut result = Vec::new();
let mut system_parts = Vec::new();
let tool_call_names = build_tool_call_name_map(messages);
for msg in messages {
if msg.role == Role::System {
let content = to_gemini_content(msg, &tool_call_names)?;
system_parts.extend(content.parts);
}
}
let system_instruction = if system_parts.is_empty() {
None
} else {
Some(GeminiSystemInstruction {
parts: system_parts,
})
};
for msg in messages {
if msg.role == Role::System {
continue; }
let content = to_gemini_content(msg, &tool_call_names)?;
result.push(content);
}
Ok((system_instruction, result))
}
fn build_tool_call_name_map(messages: &[Message]) -> HashMap<String, String> {
let mut map = HashMap::new();
for msg in messages {
for part in msg.parts() {
if let ContentPart::ToolCall { id, name, .. } = part {
map.insert(id.clone(), name.clone());
}
}
}
map
}
fn to_gemini_content(
msg: &Message,
tool_call_names: &HashMap<String, String>,
) -> Result<GeminiContent> {
let role = match msg.role {
Role::User | Role::System => "user",
Role::Assistant => "model", Role::Tool => "function",
};
let content_parts = msg.parts();
let parts: Vec<GeminiPart> = content_parts
.iter()
.map(|part| match part {
ContentPart::Text { text, .. } => GeminiPart {
text: Some(text.clone()),
inline_data: None,
function_call: None,
function_response: None,
thought_signature: None,
},
ContentPart::Image { url, .. } => {
match parse_image_data(url) {
Ok(inline_data) => GeminiPart {
text: None,
inline_data: Some(inline_data),
function_call: None,
function_response: None,
thought_signature: None,
},
Err(_) => GeminiPart {
text: Some(format!("[Image: {}]", url)),
inline_data: None,
function_call: None,
function_response: None,
thought_signature: None,
},
}
}
ContentPart::ToolCall {
id,
name,
arguments,
metadata,
..
} => {
let thought_signature = metadata.as_ref().and_then(|m| {
m.get("thought_signature")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
});
GeminiPart {
text: None,
inline_data: None,
function_call: Some(GeminiFunctionCall {
id: Some(id.clone()),
name: name.clone(),
args: arguments.clone(),
}),
function_response: None,
thought_signature,
}
}
ContentPart::ToolResult {
tool_call_id,
content,
..
} => {
let name = tool_call_names
.get(tool_call_id)
.cloned()
.unwrap_or_else(|| "unknown".to_string());
let response = if content.is_object() {
content.clone()
} else {
json!({ "result": content })
};
GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: Some(GeminiFunctionResponse {
id: tool_call_id.clone(),
name,
response,
}),
thought_signature: None,
}
}
})
.collect();
Ok(GeminiContent {
role: role.to_string(),
parts,
})
}
fn parse_image_data(url: &str) -> Result<GeminiInlineData> {
if url.starts_with("data:") {
let parts: Vec<&str> = url.splitn(2, ',').collect();
if parts.len() != 2 {
return Err(Error::invalid_response("Invalid data URL format"));
}
let mime_type = parts[0]
.strip_prefix("data:")
.and_then(|s| s.strip_suffix(";base64"))
.ok_or_else(|| Error::invalid_response("Invalid data URL media type"))?;
Ok(GeminiInlineData {
mime_type: mime_type.to_string(),
data: parts[1].to_string(),
})
} else {
Err(Error::invalid_response(
"Gemini requires base64-encoded images, not URLs",
))
}
}
pub fn from_gemini_response(resp: GeminiResponse) -> Result<GenerateResponse> {
use crate::types::ToolCall;
let candidate = resp.candidates.as_ref().and_then(|c| c.first());
let mut content: Vec<ResponseContent> = Vec::new();
let mut has_tool_calls = false;
if let Some(candidate) = candidate
&& let Some(c) = &candidate.content
{
for part in &c.parts {
if let Some(text) = &part.text {
content.push(ResponseContent::Text { text: text.clone() });
}
if let Some(function_call) = &part.function_call {
has_tool_calls = true;
let metadata = part
.thought_signature
.as_ref()
.map(|sig| json!({ "thought_signature": sig }));
content.push(ResponseContent::ToolCall(ToolCall {
id: function_call
.id
.clone()
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4())),
name: function_call.name.clone(),
arguments: function_call.args.clone(),
metadata,
}));
}
}
}
if content.is_empty() {
content.push(ResponseContent::Text {
text: String::new(),
});
}
let usage = resp
.usage_metadata
.map(|u| {
let prompt_tokens = u.prompt_token_count.unwrap_or(0);
let completion_tokens = u.candidates_token_count.unwrap_or(0);
let cached_tokens = u.cached_content_token_count.unwrap_or(0);
let reasoning_tokens = u.thoughts_token_count;
Usage::with_details(
InputTokenDetails {
total: Some(prompt_tokens),
no_cache: Some(prompt_tokens.saturating_sub(cached_tokens)),
cache_read: if cached_tokens > 0 {
Some(cached_tokens)
} else {
None
},
cache_write: None, },
OutputTokenDetails {
total: Some(completion_tokens),
text: reasoning_tokens.map(|r| completion_tokens.saturating_sub(r)),
reasoning: reasoning_tokens,
},
Some(serde_json::to_value(&u).unwrap_or_default()),
)
})
.unwrap_or_default();
let finish_reason = if has_tool_calls {
FinishReason::with_raw(FinishReasonKind::ToolCalls, "TOOL_CALLS")
} else {
candidate
.and_then(|c| c.finish_reason.as_deref())
.map(parse_finish_reason)
.unwrap_or_else(FinishReason::other)
};
Ok(GenerateResponse {
content,
usage,
finish_reason,
metadata: Some(json!({
"model_version": resp.model_version,
"response_id": resp.response_id,
})),
warnings: None, })
}
pub(super) fn parse_finish_reason(reason: &str) -> FinishReason {
match reason {
"STOP" => FinishReason::with_raw(FinishReasonKind::Stop, "STOP"),
"MAX_TOKENS" => FinishReason::with_raw(FinishReasonKind::Length, "MAX_TOKENS"),
"SAFETY" => FinishReason::with_raw(FinishReasonKind::ContentFilter, "SAFETY"),
"RECITATION" => FinishReason::with_raw(FinishReasonKind::ContentFilter, "RECITATION"),
"OTHER" => FinishReason::with_raw(FinishReasonKind::Other, "OTHER"),
raw => FinishReason::with_raw(FinishReasonKind::Other, raw),
}
}
#[cfg(test)]
mod tests {
use super::super::types::GeminiCandidate;
use super::*;
#[test]
fn test_to_gemini_content_tool_result() {
let mut tool_call_names = HashMap::new();
tool_call_names.insert("call_123".to_string(), "get_weather".to_string());
let msg = Message {
role: Role::Tool,
content: crate::types::MessageContent::Parts(vec![ContentPart::ToolResult {
tool_call_id: "call_123".to_string(),
content: serde_json::json!({"temp": 22}),
provider_options: None,
}]),
name: None,
provider_options: None,
};
let result = to_gemini_content(&msg, &tool_call_names).unwrap();
assert_eq!(result.role, "function");
assert_eq!(result.parts.len(), 1);
let part = &result.parts[0];
assert!(part.function_response.is_some());
let resp = part.function_response.as_ref().unwrap();
assert_eq!(resp.id, "call_123");
assert_eq!(resp.name, "get_weather");
assert_eq!(resp.response["temp"], 22);
}
#[test]
fn test_to_gemini_content_tool_result_string_wrapped() {
let mut tool_call_names = HashMap::new();
tool_call_names.insert("call_456".to_string(), "read_file".to_string());
let msg = Message {
role: Role::Tool,
content: crate::types::MessageContent::Parts(vec![ContentPart::ToolResult {
tool_call_id: "call_456".to_string(),
content: serde_json::json!("File: README.md\n 1: # Hello"),
provider_options: None,
}]),
name: None,
provider_options: None,
};
let result = to_gemini_content(&msg, &tool_call_names).unwrap();
let part = &result.parts[0];
let resp = part.function_response.as_ref().unwrap();
assert_eq!(resp.id, "call_456");
assert_eq!(resp.name, "read_file");
assert!(
resp.response.is_object(),
"response must be a JSON object for Gemini, got: {:?}",
resp.response
);
assert_eq!(resp.response["result"], "File: README.md\n 1: # Hello");
}
#[test]
fn test_convert_messages_resolves_tool_result_names() {
let messages = vec![
Message::new(Role::User, "List files"),
Message {
role: Role::Assistant,
content: crate::types::MessageContent::Parts(vec![ContentPart::ToolCall {
id: "call_abc".to_string(),
name: "run_command".to_string(),
arguments: serde_json::json!({"command": "ls"}),
metadata: None,
provider_options: None,
}]),
name: None,
provider_options: None,
},
Message {
role: Role::Tool,
content: crate::types::MessageContent::Parts(vec![ContentPart::ToolResult {
tool_call_id: "call_abc".to_string(),
content: serde_json::Value::String("README.md\nsrc/\nCargo.toml".to_string()),
provider_options: None,
}]),
name: None,
provider_options: None,
},
];
let (_system, contents) = convert_messages(&messages).unwrap();
assert_eq!(contents.len(), 3);
let tool_result_content = &contents[2];
assert_eq!(tool_result_content.role, "function");
let resp = tool_result_content.parts[0]
.function_response
.as_ref()
.expect("should have function_response");
assert_eq!(
resp.name, "run_command",
"name should be resolved from the preceding tool call, not 'unknown'"
);
assert_eq!(resp.id, "call_abc");
}
#[test]
fn test_convert_messages_tool_result_fallback_when_no_matching_call() {
let messages = vec![Message {
role: Role::Tool,
content: crate::types::MessageContent::Parts(vec![ContentPart::ToolResult {
tool_call_id: "call_orphan".to_string(),
content: serde_json::Value::String("some result".to_string()),
provider_options: None,
}]),
name: None,
provider_options: None,
}];
let (_system, contents) = convert_messages(&messages).unwrap();
let resp = contents[0].parts[0]
.function_response
.as_ref()
.expect("should have function_response");
assert_eq!(resp.name, "unknown");
}
#[test]
fn test_from_gemini_response_tool_call() {
let resp = GeminiResponse {
candidates: Some(vec![GeminiCandidate {
content: Some(GeminiContent {
role: "model".to_string(),
parts: vec![GeminiPart {
text: None,
inline_data: None,
function_call: Some(GeminiFunctionCall {
id: Some("call_123".to_string()),
name: "get_weather".to_string(),
args: serde_json::json!({"location": "London"}),
}),
function_response: None,
thought_signature: None,
}],
}),
finish_reason: Some("STOP".to_string()),
safety_ratings: None,
}]),
usage_metadata: None,
model_version: None,
response_id: None,
};
let result = from_gemini_response(resp).unwrap();
assert_eq!(result.content.len(), 1);
if let ResponseContent::ToolCall(call) = &result.content[0] {
assert_eq!(call.id, "call_123");
assert_eq!(call.name, "get_weather");
assert_eq!(call.arguments["location"], "London");
} else {
panic!("Expected ToolCall");
}
}
#[test]
fn test_convert_messages_system_instruction() {
let messages = vec![
Message::new(Role::System, "You are a helpful assistant."),
Message::new(Role::User, "Hello!"),
];
let (system_instruction, contents) = convert_messages(&messages).unwrap();
assert!(system_instruction.is_some());
let si = system_instruction.unwrap();
assert_eq!(si.parts.len(), 1);
assert_eq!(
si.parts[0].text,
Some("You are a helpful assistant.".to_string())
);
assert_eq!(contents.len(), 1);
assert_eq!(contents[0].role, "user");
assert_eq!(contents[0].parts[0].text, Some("Hello!".to_string()));
}
}