use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
use agents_core::tools::ToolSchema;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Clone)]
pub struct GeminiConfig {
pub api_key: String,
pub model: String,
pub api_url: Option<String>,
pub custom_headers: Vec<(String, String)>,
}
impl GeminiConfig {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
api_url: None,
custom_headers: Vec::new(),
}
}
pub fn with_custom_headers(mut self, headers: Vec<(String, String)>) -> Self {
self.custom_headers = headers;
self
}
}
pub struct GeminiChatModel {
client: Client,
config: GeminiConfig,
}
impl GeminiChatModel {
pub fn new(config: GeminiConfig) -> anyhow::Result<Self> {
Ok(Self {
client: Client::builder()
.user_agent("rust-deep-agents-sdk/0.1")
.build()?,
config,
})
}
}
#[derive(Serialize)]
struct GeminiRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiToolDeclaration>>,
}
#[derive(Clone, Serialize)]
struct GeminiToolDeclaration {
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Clone, Serialize)]
struct GeminiFunctionDeclaration {
name: String,
description: String,
parameters: Value,
}
#[derive(Serialize)]
struct GeminiContent {
role: String,
parts: Vec<GeminiPart>,
}
#[derive(Serialize)]
struct GeminiPart {
text: String,
}
#[derive(Deserialize)]
struct GeminiResponse {
candidates: Vec<GeminiCandidate>,
}
#[derive(Deserialize)]
struct GeminiCandidate {
content: Option<GeminiContentResponse>,
}
#[derive(Deserialize)]
struct GeminiContentResponse {
parts: Vec<GeminiPartResponse>,
}
#[derive(Deserialize)]
struct GeminiPartResponse {
text: Option<String>,
#[serde(rename = "functionCall")]
function_call: Option<GeminiFunctionCall>,
}
#[derive(Deserialize)]
struct GeminiFunctionCall {
name: String,
args: Value,
}
fn to_gemini_contents(request: &LlmRequest) -> (Vec<GeminiContent>, Option<GeminiContent>) {
let mut contents = Vec::new();
for message in &request.messages {
let role = match message.role {
MessageRole::User => "user",
MessageRole::Agent => "model",
MessageRole::Tool => "user",
MessageRole::System => "user",
};
let text = match &message.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Json(value) => value.to_string(),
};
contents.push(GeminiContent {
role: role.into(),
parts: vec![GeminiPart { text }],
});
}
let system_instruction = if request.system_prompt.trim().is_empty() {
None
} else {
Some(GeminiContent {
role: "system".into(),
parts: vec![GeminiPart {
text: request.system_prompt.clone(),
}],
})
};
(contents, system_instruction)
}
fn to_gemini_tools(tools: &[ToolSchema]) -> Option<Vec<GeminiToolDeclaration>> {
if tools.is_empty() {
return None;
}
Some(vec![GeminiToolDeclaration {
function_declarations: tools
.iter()
.map(|tool| GeminiFunctionDeclaration {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: serde_json::to_value(&tool.parameters)
.unwrap_or_else(|_| serde_json::json!({})),
})
.collect(),
}])
}
#[async_trait]
impl LanguageModel for GeminiChatModel {
async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
let (contents, system_instruction) = to_gemini_contents(&request);
let tools = to_gemini_tools(&request.tools);
tracing::debug!(
"Gemini request: model={}, contents={}, tools={}",
self.config.model,
contents.len(),
tools
.as_ref()
.map(|t| t
.iter()
.map(|td| td.function_declarations.len())
.sum::<usize>())
.unwrap_or(0)
);
let body = GeminiRequest {
contents,
system_instruction,
tools,
};
let base_url = self
.config
.api_url
.clone()
.unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".into());
let url = format!(
"{}/models/{}:generateContent?key={}",
base_url, self.config.model, self.config.api_key
);
let mut request = self.client.post(&url);
for (key, value) in &self.config.custom_headers {
request = request.header(key, value);
}
let response = request.json(&body).send().await?.error_for_status()?;
let data: GeminiResponse = response.json().await?;
let function_calls: Vec<_> = data
.candidates
.iter()
.filter_map(|candidate| candidate.content.as_ref())
.flat_map(|content| &content.parts)
.filter_map(|part| part.function_call.as_ref())
.collect();
if !function_calls.is_empty() {
let tool_calls: Vec<_> = function_calls
.iter()
.map(|fc| {
serde_json::json!({
"name": fc.name,
"args": fc.args
})
})
.collect();
tracing::debug!(
"Gemini response contains {} function calls",
tool_calls.len()
);
return Ok(LlmResponse {
message: AgentMessage {
role: MessageRole::Agent,
content: MessageContent::Json(serde_json::json!({
"tool_calls": tool_calls
})),
metadata: None,
},
});
}
let text = data
.candidates
.into_iter()
.filter_map(|candidate| candidate.content)
.flat_map(|content| content.parts)
.find_map(|part| part.text)
.unwrap_or_default();
Ok(LlmResponse {
message: AgentMessage {
role: MessageRole::Agent,
content: MessageContent::Text(text),
metadata: None,
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gemini_conversion_handles_system_prompt() {
let request = LlmRequest::new(
"You are concise",
vec![AgentMessage {
role: MessageRole::User,
content: MessageContent::Text("Hello".into()),
metadata: None,
}],
);
let (contents, system) = to_gemini_contents(&request);
assert_eq!(contents.len(), 1);
assert_eq!(contents[0].role, "user");
assert!(system.is_some());
assert_eq!(system.unwrap().parts[0].text, "You are concise");
}
#[test]
fn gemini_config_new_initializes_empty_custom_headers() {
let config = GeminiConfig::new("test-key", "gemini-pro");
assert_eq!(config.api_key, "test-key");
assert_eq!(config.model, "gemini-pro");
assert!(config.custom_headers.is_empty());
assert!(config.api_url.is_none());
}
#[test]
fn gemini_config_with_custom_headers_sets_headers() {
let headers = vec![
("X-Custom-Header".to_string(), "value1".to_string()),
("X-Another-Header".to_string(), "value2".to_string()),
];
let config =
GeminiConfig::new("test-key", "gemini-pro").with_custom_headers(headers.clone());
assert_eq!(config.custom_headers.len(), 2);
assert_eq!(config.custom_headers[0].0, "X-Custom-Header");
assert_eq!(config.custom_headers[0].1, "value1");
assert_eq!(config.custom_headers[1].0, "X-Another-Header");
assert_eq!(config.custom_headers[1].1, "value2");
}
}