use super::{
FunctionDefinition, LlmProvider, LlmResponse, Message, Role, ToolCall, ToolDefinition, Usage,
};
use anyhow::{bail, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
pub const GEMINI_API_BASE: &str = "gemini";
pub struct GeminiProvider {
api_key: String,
pub model: String,
client: Client,
pub stream_print: AtomicBool,
}
impl GeminiProvider {
pub fn new(api_key: String, model: String) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(300))
.build()
.expect("Failed to create HTTP client");
GeminiProvider {
api_key,
model,
client,
stream_print: AtomicBool::new(true),
}
}
pub fn set_stream_print(&self, enabled: bool) {
self.stream_print.store(enabled, Ordering::Relaxed);
}
fn build_request_body(&self, messages: &[Message], tools: &[ToolDefinition]) -> Value {
let mut system_text: Option<String> = None;
let mut gemini_messages: Vec<Value> = Vec::new();
for msg in messages {
match msg.role {
Role::System => {
let text = msg.text_content().unwrap_or_default();
match system_text.as_mut() {
Some(existing) => {
existing.push('\n');
existing.push_str(&text);
}
None => system_text = Some(text),
}
}
Role::Tool => {
let name = msg.tool_call_id.as_deref().unwrap_or("unknown");
let result_text = msg.text_content().unwrap_or_default();
gemini_messages.push(json!({
"role": "user",
"parts": [{
"function_response": {
"name": name,
"response": {
"content": result_text
}
}
}]
}));
}
Role::User | Role::Assistant => {
let parts = build_gemini_content_parts(&msg.content, msg.tool_calls.as_deref());
let role_str = match msg.role {
Role::User => "user",
Role::Assistant => "model",
_ => unreachable!(),
};
gemini_messages.push(json!({
"role": role_str,
"parts": parts,
}));
}
}
}
let function_declarations: Vec<Value> = tools
.iter()
.map(|t| build_gemini_tool(&t.function))
.collect();
let mut body = json!({
"contents": gemini_messages,
});
if let Some(sys) = system_text {
body["system_instruction"] = json!({
"parts": [{ "text": sys }]
});
}
if !function_declarations.is_empty() {
body["tools"] = json!([{
"function_declarations": function_declarations
}]);
}
body
}
async fn try_once(
&self,
messages: &[Message],
tools: &[ToolDefinition],
) -> Result<LlmResponse> {
let body = self.build_request_body(messages, tools);
let url = format!(
"{}/{}:streamGenerateContent?key={}&alt=sse",
GEMINI_API_BASE_URL, self.model, self.api_key
);
let response = self
.client
.post(&url)
.header("content-type", "application/json")
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let err_body = response.text().await.unwrap_or_default();
bail!("Gemini API error {}: {}", status.as_u16(), err_body);
}
let mut text_content = String::new();
let mut function_calls: Vec<(String, String)> = Vec::new();
let mut prompt_tokens: u32 = 0;
let mut completion_tokens: u32 = 0;
let mut byte_stream = response.bytes_stream();
let mut line_buf = String::new();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = chunk_result?;
let chunk_str = String::from_utf8_lossy(&chunk);
line_buf.push_str(&chunk_str);
while let Some(newline_pos) = line_buf.find('\n') {
let line = line_buf[..newline_pos].trim_end_matches('\r').to_string();
line_buf = line_buf[newline_pos + 1..].to_string();
if let Some(json_str) = line.strip_prefix("data: ") {
let json_str = json_str.trim();
if json_str == "[DONE]" {
continue;
}
if let Ok(chunk_json) = serde_json::from_str::<GeminiResponse>(json_str) {
for candidate in &chunk_json.candidates {
if let Some(content) = &candidate.content {
for part in &content.parts {
if let Some(text) = &part.text {
if self.stream_print.load(Ordering::Relaxed) {
print!("{}", text);
std::io::stdout().flush().ok();
}
text_content.push_str(text);
}
if let Some(fc) = &part.function_call {
let args_json = fc
.args
.as_ref()
.map(|a| serde_json::to_string(a).unwrap_or_default())
.unwrap_or_else(|| "{}".to_string());
function_calls.push((fc.name.clone(), args_json));
}
}
}
}
if let Some(usage) = &chunk_json.usage_metadata {
if let Some(n) = usage.prompt_token_count {
prompt_tokens = n;
}
if let Some(n) = usage.candidates_token_count {
completion_tokens = n;
}
}
}
}
}
}
let tool_calls: Vec<ToolCall> = function_calls
.into_iter()
.enumerate()
.map(|(i, (name, args))| ToolCall {
id: format!("gemini_call_{}_{}", i, name),
call_type: "function".to_string(),
function: super::FunctionCall {
name,
arguments: args,
},
})
.collect();
Ok(LlmResponse {
content: if text_content.is_empty() {
None
} else {
Some(text_content)
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
usage: Some(Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}),
})
}
}
#[async_trait]
impl LlmProvider for GeminiProvider {
async fn chat_completion(
&self,
messages: &[Message],
tools: &[ToolDefinition],
) -> Result<LlmResponse> {
self.try_once(messages, tools).await
}
fn set_stream_print(&self, enabled: bool) {
GeminiProvider::set_stream_print(self, enabled);
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse {
#[serde(default)]
candidates: Vec<GeminiCandidate>,
#[serde(rename = "usageMetadata")]
usage_metadata: Option<GeminiUsageMetadata>,
}
#[derive(Debug, Deserialize)]
struct GeminiCandidate {
content: Option<GeminiContent>,
}
#[derive(Debug, Deserialize)]
struct GeminiContent {
#[serde(default)]
parts: Vec<GeminiPart>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiPart {
text: Option<String>,
function_call: Option<GeminiFunctionCall>,
}
#[derive(Debug, Deserialize)]
struct GeminiFunctionCall {
name: String,
args: Option<Value>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiUsageMetadata {
prompt_token_count: Option<u32>,
candidates_token_count: Option<u32>,
}
fn infer_mime_from_url(url: &str) -> &'static str {
let lower = url.to_lowercase();
if lower.ends_with(".png") {
"image/png"
} else if lower.ends_with(".gif") {
"image/gif"
} else if lower.ends_with(".webp") {
"image/webp"
} else {
"image/jpeg"
}
}
fn build_gemini_content_parts(
content: &[super::ContentPart],
tool_calls: Option<&[super::ToolCall]>,
) -> Vec<Value> {
let mut parts: Vec<Value> = Vec::new();
for cp in content {
match cp {
super::ContentPart::Text { text } => {
if !text.is_empty() {
parts.push(json!({ "text": text }));
}
}
super::ContentPart::ImageUrl { image_url } => {
if image_url.url.starts_with("data:") {
if let Some((header, data)) = image_url.url.split_once(',') {
let mime = header
.strip_prefix("data:")
.and_then(|s| s.strip_suffix(";base64"))
.unwrap_or("image/jpeg");
parts.push(json!({
"inline_data": {
"mime_type": mime,
"data": data,
}
}));
}
} else {
let mime = infer_mime_from_url(&image_url.url);
parts.push(json!({
"file_data": {
"mime_type": mime,
"file_uri": image_url.url,
}
}));
}
}
_ => {}
}
}
if let Some(calls) = tool_calls {
for tc in calls {
let args: Value = serde_json::from_str(&tc.function.arguments)
.unwrap_or(Value::Object(serde_json::Map::new()));
parts.push(json!({
"functionCall": {
"name": tc.function.name,
"args": args
}
}));
}
}
if parts.is_empty() {
parts.push(json!({ "text": "" }));
}
parts
}
fn build_gemini_tool(func: &FunctionDefinition) -> Value {
json!({
"name": func.name,
"description": func.description,
"parameters": func.parameters,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::{
ContentPart, FunctionCall, FunctionDefinition, ImageUrl, Message, Role, ToolCall,
ToolDefinition,
};
#[test]
fn test_new_stores_fields() {
let p = GeminiProvider::new("key-abc".to_string(), "gemini-2.0-flash".to_string());
assert_eq!(p.api_key, "key-abc");
assert_eq!(p.model, "gemini-2.0-flash");
assert!(p.stream_print.load(Ordering::Relaxed));
}
#[test]
fn test_set_stream_print() {
let p = GeminiProvider::new("k".to_string(), "m".to_string());
assert!(p.stream_print.load(Ordering::Relaxed)); p.set_stream_print(false);
assert!(!p.stream_print.load(Ordering::Relaxed));
p.set_stream_print(true);
assert!(p.stream_print.load(Ordering::Relaxed));
}
#[test]
fn test_build_request_body_simple_message() {
let p = GeminiProvider::new("k".to_string(), "gemini-2.0-flash".to_string());
let messages = vec![Message::user("Hello!")];
let body = p.build_request_body(&messages, &[]);
assert!(body["contents"].is_array());
let contents = body["contents"].as_array().unwrap();
assert_eq!(contents.len(), 1);
assert_eq!(contents[0]["role"], "user");
let parts = contents[0]["parts"].as_array().unwrap();
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["text"], "Hello!");
}
#[test]
fn test_build_request_body_system_prompt() {
let p = GeminiProvider::new("k".to_string(), "m".to_string());
let messages = vec![
Message::system("You are a helpful assistant."),
Message::user("Hi"),
];
let body = p.build_request_body(&messages, &[]);
assert!(!body["system_instruction"].is_null());
assert_eq!(
body["system_instruction"]["parts"][0]["text"],
"You are a helpful assistant."
);
let contents = body["contents"].as_array().unwrap();
assert_eq!(contents.len(), 1);
assert_eq!(contents[0]["role"], "user");
}
#[test]
fn test_build_request_body_assistant_role_is_model() {
let p = GeminiProvider::new("k".to_string(), "m".to_string());
let messages = vec![
Message::user("What is 2+2?"),
Message::assistant(Some("4".to_string()), None),
];
let body = p.build_request_body(&messages, &[]);
let contents = body["contents"].as_array().unwrap();
assert_eq!(contents.len(), 2);
assert_eq!(contents[0]["role"], "user");
assert_eq!(contents[1]["role"], "model"); }
#[test]
fn test_build_request_body_tool_result() {
let p = GeminiProvider::new("k".to_string(), "m".to_string());
let tool_msg = Message {
role: Role::Tool,
content: vec![ContentPart::text("File created successfully")],
tool_calls: None,
tool_call_id: Some("file_write".to_string()),
name: None,
};
let body = p.build_request_body(&[tool_msg], &[]);
let contents = body["contents"].as_array().unwrap();
assert_eq!(contents.len(), 1);
assert_eq!(contents[0]["role"], "user");
let parts = contents[0]["parts"].as_array().unwrap();
let fr = &parts[0]["function_response"];
assert_eq!(fr["name"], "file_write");
assert_eq!(fr["response"]["content"], "File created successfully");
}
#[test]
fn test_build_request_body_with_tools() {
let p = GeminiProvider::new("k".to_string(), "m".to_string());
let tools = vec![ToolDefinition {
def_type: "function".to_string(),
function: FunctionDefinition {
name: "bash".to_string(),
description: "Run a shell command".to_string(),
parameters: serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}}),
},
}];
let body = p.build_request_body(&[Message::user("run ls")], &tools);
let tools_json = body["tools"].as_array().unwrap();
assert_eq!(tools_json.len(), 1);
let decls = tools_json[0]["function_declarations"].as_array().unwrap();
assert_eq!(decls.len(), 1);
assert_eq!(decls[0]["name"], "bash");
assert_eq!(decls[0]["description"], "Run a shell command");
}
#[test]
fn test_build_request_body_no_tools_when_empty() {
let p = GeminiProvider::new("k".to_string(), "m".to_string());
let body = p.build_request_body(&[Message::user("hello")], &[]);
assert!(body.get("tools").is_none() || body["tools"].is_null());
}
#[test]
fn test_build_gemini_content_parts_text() {
let parts = build_gemini_content_parts(&[ContentPart::text("hello")], None);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["text"], "hello");
}
#[test]
fn test_build_gemini_content_parts_tool_calls() {
let tool_calls = vec![ToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "file_read".to_string(),
arguments: r#"{"path": "src/main.rs"}"#.to_string(),
},
}];
let parts = build_gemini_content_parts(&[], Some(&tool_calls));
let fc_part = parts.iter().find(|p| !p["functionCall"].is_null());
assert!(fc_part.is_some());
let fc = &fc_part.unwrap()["functionCall"];
assert_eq!(fc["name"], "file_read");
assert_eq!(fc["args"]["path"], "src/main.rs");
}
#[test]
fn test_build_gemini_content_parts_empty_gets_placeholder() {
let parts = build_gemini_content_parts(&[], None);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["text"], "");
}
#[test]
fn test_build_gemini_tool() {
let func = FunctionDefinition {
name: "grep_search".to_string(),
description: "Search file contents".to_string(),
parameters: serde_json::json!({"type": "object"}),
};
let tool_json = build_gemini_tool(&func);
assert_eq!(tool_json["name"], "grep_search");
assert_eq!(tool_json["description"], "Search file contents");
assert!(!tool_json["parameters"].is_null());
}
#[test]
fn test_gemini_api_base_sentinel() {
assert_eq!(GEMINI_API_BASE, "gemini");
}
#[test]
fn test_is_not_copilot() {
let p = GeminiProvider::new("k".to_string(), "m".to_string());
assert_ne!(GEMINI_API_BASE, "copilot");
assert!(p.stream_print.load(Ordering::Relaxed));
}
#[test]
fn test_gemini_response_deserializes_text_chunk() {
let json_str = r#"{
"candidates": [{
"content": {
"role": "model",
"parts": [{ "text": "Hello " }]
}
}],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 5,
"totalTokenCount": 15
}
}"#;
let resp: GeminiResponse = serde_json::from_str(json_str).unwrap();
assert_eq!(resp.candidates.len(), 1);
let content = resp.candidates[0].content.as_ref().unwrap();
assert_eq!(content.parts.len(), 1);
assert_eq!(content.parts[0].text.as_deref(), Some("Hello "));
let usage = resp.usage_metadata.as_ref().unwrap();
assert_eq!(usage.prompt_token_count, Some(10));
assert_eq!(usage.candidates_token_count, Some(5));
}
#[test]
fn test_gemini_response_deserializes_function_call() {
let json_str = r#"{
"candidates": [{
"content": {
"role": "model",
"parts": [{
"functionCall": {
"name": "bash",
"args": { "command": "ls -la" }
}
}]
}
}]
}"#;
let resp: GeminiResponse = serde_json::from_str(json_str).unwrap();
let content = resp.candidates[0].content.as_ref().unwrap();
let fc = content.parts[0].function_call.as_ref().unwrap();
assert_eq!(fc.name, "bash");
assert_eq!(fc.args.as_ref().unwrap()["command"], "ls -la");
}
#[test]
fn test_gemini_usage_metadata_partial() {
let json_str = r#"{ "promptTokenCount": 42 }"#;
let usage: GeminiUsageMetadata = serde_json::from_str(json_str).unwrap();
assert_eq!(usage.prompt_token_count, Some(42));
assert_eq!(usage.candidates_token_count, None);
}
#[test]
fn test_gemini_response_empty_candidates() {
let json_str = r#"{ "candidates": [] }"#;
let resp: GeminiResponse = serde_json::from_str(json_str).unwrap();
assert!(resp.candidates.is_empty());
}
#[test]
fn test_build_gemini_content_image_base64() {
let parts = build_gemini_content_parts(
&[ContentPart::ImageUrl {
image_url: ImageUrl {
url: "data:image/png;base64,iVBORw0KGgo=".to_string(),
detail: None,
},
}],
None,
);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["inline_data"]["mime_type"], "image/png");
assert_eq!(parts[0]["inline_data"]["data"], "iVBORw0KGgo=");
}
#[test]
fn test_build_gemini_content_image_url() {
let parts = build_gemini_content_parts(
&[ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/photo.jpg".to_string(),
detail: None,
},
}],
None,
);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["file_data"]["mime_type"], "image/jpeg");
assert_eq!(
parts[0]["file_data"]["file_uri"],
"https://example.com/photo.jpg"
);
}
#[test]
fn test_infer_mime_from_url() {
assert_eq!(infer_mime_from_url("photo.png"), "image/png");
assert_eq!(infer_mime_from_url("anim.gif"), "image/gif");
assert_eq!(infer_mime_from_url("img.webp"), "image/webp");
assert_eq!(infer_mime_from_url("pic.jpg"), "image/jpeg");
assert_eq!(
infer_mime_from_url("https://cdn.example.com/img.PNG"),
"image/png"
);
assert_eq!(infer_mime_from_url("no-extension"), "image/jpeg");
}
}