use async_trait::async_trait;
use gemini_rust::{
Content, FunctionCall as GeminiFunctionCall, FunctionCallingMode, FunctionDeclaration,
FunctionResponse as GeminiFunctionResponse, Gemini, GenerationResponse,
Message as GeminiMessage, Part, Role as GeminiRole, Tool as GeminiTool,
client::Model as GeminiModel,
};
use serde_json::Value;
use super::super::tools::ToolDefinition;
use super::schema;
use super::{
Client, ClientError, ClientOptions, ClientOutput, ClientResponse, LlmUrl, Message, Provider,
Role, TokenUsage, ToolCall, ToolChoice, parse_json_output, validate_tools,
};
fn build_client(url: &LlmUrl) -> Result<Gemini, ClientError> {
let api_key = if let Some(key) = &url.api_key {
key.clone()
} else {
std::env::var("GEMINI_API_KEY")
.map_err(|_| ClientError::Llm("GEMINI_API_KEY is not set".into()))?
};
let model = GeminiModel::Custom(url.model.clone());
Gemini::with_model(&api_key, model).map_err(|e| ClientError::Llm(e.to_string()))
}
struct GeminiClient {
client: Gemini,
options: ClientOptions,
}
fn build_gemini_messages(history: &[Message]) -> Vec<GeminiMessage> {
let mut msgs = Vec::new();
let mut i = 0;
while i < history.len() {
match &history[i].role {
Role::System => {
i += 1;
}
Role::User => {
msgs.push(GeminiMessage::user(history[i].content.clone()));
i += 1;
}
Role::Assistant => {
msgs.push(GeminiMessage::model(&history[i].content));
i += 1;
}
Role::AssistantToolCalls { calls } => {
msgs.push(tool_calls_to_message(calls));
i += 1;
}
Role::Tool { .. } => {
let (msg, consumed) = tool_responses_to_message(history, i);
msgs.push(msg);
i += consumed;
}
}
}
msgs
}
fn build_tools_spec(tools: &[ToolDefinition]) -> Result<Option<GeminiTool>, ClientError> {
if tools.is_empty() {
return Ok(None);
}
let fns: Vec<FunctionDeclaration> = tools
.iter()
.map(build_fn_decl)
.collect::<Result<Vec<_>, _>>()?;
if fns.is_empty() {
Ok(None)
} else {
Ok(Some(GeminiTool::with_functions(fns)))
}
}
fn tool_calls_to_message(calls: &[ToolCall]) -> GeminiMessage {
let parts: Vec<Part> = calls
.iter()
.map(|c| {
let thought_sig = c
.thought_signatures
.as_ref()
.and_then(|v| v.first())
.cloned();
Part::FunctionCall {
function_call: GeminiFunctionCall::new(&c.name, c.args.clone()),
thought_signature: thought_sig,
}
})
.collect();
GeminiMessage {
content: Content {
parts: Some(parts),
role: Some(GeminiRole::Model),
},
role: GeminiRole::Model,
}
}
fn tool_responses_to_message(history: &[Message], start: usize) -> (GeminiMessage, usize) {
let mut parts = Vec::new();
let mut i = start;
while i < history.len() {
let Role::Tool { call_id } = &history[i].role else {
break;
};
let name = resolve_call_name(history, call_id);
let val: Value = serde_json::from_str(&history[i].content)
.unwrap_or_else(|_| Value::String(history[i].content.clone()));
parts.push(Part::FunctionResponse {
function_response: GeminiFunctionResponse::new(name, val),
});
i += 1;
}
let msg = GeminiMessage {
content: Content {
parts: Some(parts),
role: Some(GeminiRole::User),
},
role: GeminiRole::User,
};
(msg, i - start)
}
fn resolve_call_name<'a>(history: &'a [Message], call_id: &'a str) -> &'a str {
for msg in history.iter().rev() {
if let Role::AssistantToolCalls { calls } = &msg.role {
for c in calls {
if c.id == call_id {
return &c.name;
}
}
}
}
tracing::error!(
call_id,
"could not resolve tool call name from history; using call_id as fallback"
);
call_id
}
fn build_fn_decl(tool: &ToolDefinition) -> Result<FunctionDeclaration, ClientError> {
let sanitized = schema::sanitize_strict(tool.parameters.clone());
let json = serde_json::json!({
"name": tool.name,
"description": tool.description,
"parameters": sanitized,
});
serde_json::from_value(json).map_err(ClientError::Serialize)
}
fn map_response(
response: GenerationResponse,
tools_enabled: bool,
) -> Result<ClientResponse, ClientError> {
let usage = response.usage_metadata.as_ref().map(|usage| TokenUsage {
input: usage.prompt_token_count.map(|v| v as u32),
output: usage.candidates_token_count.map(|v| v as u32),
});
let provider_model = response.model_version.clone();
let raw_metadata = Some(serde_json::json!({
"response_id": response.response_id.clone(),
}));
let fcs = response.function_calls_with_thoughts();
if !fcs.is_empty() {
let thought_text = response.text();
let thought = if thought_text.is_empty() {
None
} else {
Some(thought_text)
};
let calls: Vec<ToolCall> = fcs
.iter()
.enumerate()
.map(|(idx, (fc, sig))| ToolCall {
id: format!("{}_{}", fc.name, idx),
name: fc.name.clone(),
args: fc.args.clone(),
thought_signatures: sig.map(|s| vec![s.to_string()]),
})
.collect();
return Ok(ClientResponse::new(
Provider::Gemini,
ClientOutput::ToolCalls { thought, calls },
)
.with_usage(usage)
.with_provider_model(provider_model)
.with_raw_metadata(raw_metadata));
}
if tools_enabled {
let text = response.text();
let content = if text.is_empty() { None } else { Some(text) };
tracing::warn!(model_output = ?content, "LLM response contained no tool calls");
return Err(ClientError::MissingToolCalls(content));
}
let text = response.text();
if text.is_empty() {
return Err(ClientError::EmptyResponse);
}
Ok(ClientResponse::new(
Provider::Gemini,
ClientOutput::Output(parse_json_output(&text)?),
)
.with_usage(usage)
.with_provider_model(provider_model)
.with_raw_metadata(raw_metadata))
}
impl GeminiClient {
async fn call_api(
&self,
messages: Vec<GeminiMessage>,
tools_enabled: bool,
response_schema: Option<Value>,
) -> Result<GenerationResponse, ClientError> {
let client = &self.client;
let thinking_budget = if self.options.thinking { i32::MAX } else { 0 };
let mut builder = client
.generate_content()
.with_thinking_budget(thinking_budget);
if let Some(p) = &self.options.preamble {
builder = builder.with_system_prompt(p.clone());
}
builder = builder.with_messages(messages);
if tools_enabled {
if let Some(tool_spec) = build_tools_spec(&self.options.tools)? {
let mode = match self.options.tool_choice {
ToolChoice::Required => FunctionCallingMode::Any,
_ => FunctionCallingMode::Auto,
};
builder = builder
.with_tool(tool_spec)
.with_function_calling_mode(mode);
}
} else if let Some(schema) = response_schema {
builder = builder
.with_response_mime_type("application/json")
.with_response_schema(schema);
}
builder
.execute()
.await
.map_err(|e| ClientError::Llm(e.to_string()))
}
}
#[async_trait]
impl Client for GeminiClient {
async fn execute(&self, messages: &[Message]) -> Result<ClientResponse, ClientError> {
if messages.is_empty() {
return Err(ClientError::Validation("messages must not be empty".into()));
}
if matches!(
messages.last().map(|m| &m.role),
Some(Role::AssistantToolCalls { .. })
) {
return Err(ClientError::Validation(
"history ends with assistant tool calls without tool results".into(),
));
}
let tools_enabled =
!self.options.tools.is_empty() && self.options.tool_choice != ToolChoice::Disabled;
validate_tools(Provider::Gemini, &self.options.tools)?;
let response_schema = if !tools_enabled {
self.options
.output_schema
.as_ref()
.map(|s| schema::sanitize_strict(s.clone()))
} else {
None
};
let gemini_messages = build_gemini_messages(messages);
let response = self
.call_api(gemini_messages, tools_enabled, response_schema)
.await?;
map_response(response, tools_enabled)
}
}
pub fn new_client(url: &LlmUrl, options: ClientOptions) -> Result<Box<dyn Client>, ClientError> {
let client = build_client(url)?;
Ok(Box::new(GeminiClient { client, options }))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_call(id: &str, name: &str) -> ToolCall {
ToolCall {
id: id.into(),
name: name.into(),
args: json!({}),
thought_signatures: None,
}
}
#[test]
fn build_messages_user_only() {
let history = vec![Message::user(r#"{"text":"hi"}"#)];
let msgs = build_gemini_messages(&history);
assert_eq!(msgs.len(), 1);
}
#[test]
fn build_messages_preamble_is_separate() {
let history = vec![Message::user(r#"{"text":"hi"}"#)];
let msgs = build_gemini_messages(&history);
assert_eq!(msgs.len(), 1);
}
#[test]
fn build_messages_history_in_order() {
let history = vec![
Message::user("prev question"),
Message::assistant("prev answer"),
Message::user("next question"),
];
let msgs = build_gemini_messages(&history);
assert_eq!(msgs.len(), 3);
let debug = format!("{msgs:?}");
assert!(debug.contains("prev question"));
assert!(debug.contains("prev answer"));
}
#[test]
fn build_messages_tool_role_included() {
let history = vec![
Message {
role: Role::AssistantToolCalls {
calls: vec![make_call("call-42", "read_file")],
},
content: String::new(),
usage: None,
},
Message {
role: Role::Tool {
call_id: "call-42".into(),
},
content: r#"{"temp":22}"#.into(),
usage: None,
},
];
let msgs = build_gemini_messages(&history);
assert_eq!(msgs.len(), 2);
let debug = format!("{msgs:?}");
assert!(debug.contains("read_file"));
}
#[test]
fn build_messages_continue_after_tool_result() {
let history = vec![
Message::user(r#"{"goal":"ship","known_context":[]}"#),
Message {
role: Role::AssistantToolCalls {
calls: vec![make_call("c1", "project_outline")],
},
content: String::new(),
usage: None,
},
Message {
role: Role::Tool {
call_id: "c1".into(),
},
content: r#"{"files":[]}"#.into(),
usage: None,
},
];
let msgs = build_gemini_messages(&history);
assert_eq!(msgs.len(), 3);
}
}