use std::time::Duration;
use async_trait::async_trait;
use futures::StreamExt;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{Value, json};
use crate::constants::MAX_RESPONSE_CHARS;
use crate::models::ModelCapabilities;
use crate::models::config::ModelConfig;
use crate::models::error::{BackendError, ModelError, Result};
use crate::models::reasoning::{
ReasoningCapability, ReasoningChunk, ReasoningLevel, nearest_effort,
};
use crate::models::stream::{StreamCallback, StreamEvent};
use crate::models::tool_call::{FunctionCall, ToolCall};
use crate::models::traits::Model;
use crate::models::types::{ChatMessage, MessageRole, ModelResponse, TokenUsage};
use crate::utils::drain_sse_events;
const TRUNCATION_MARKER: &str = "\n\n[TRUNCATED: response exceeded size limit]";
fn push_capped(buf: &mut String, chunk: &str, truncated: &mut bool, cap: usize) {
if *truncated {
return;
}
buf.push_str(chunk);
if buf.len() > cap {
let end = buf.floor_char_boundary(cap);
buf.truncate(end);
buf.push_str(TRUNCATION_MARKER);
*truncated = true;
}
}
fn thinking_budget_for(level: ReasoningLevel) -> i32 {
match level {
ReasoningLevel::None => 0,
ReasoningLevel::Minimal => 512,
ReasoningLevel::Low => 2048,
ReasoningLevel::Medium => 8192,
ReasoningLevel::High => 24576,
ReasoningLevel::Max | ReasoningLevel::XHigh => -1,
}
}
fn thinking_level_for(level: ReasoningLevel) -> &'static str {
match level {
ReasoningLevel::None | ReasoningLevel::Minimal => "minimal",
ReasoningLevel::Low => "low",
ReasoningLevel::Medium => "medium",
ReasoningLevel::High | ReasoningLevel::Max | ReasoningLevel::XHigh => "high",
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum GeminiThinkingDispatch {
Disabled,
Level,
Budget { min: i32, can_disable: bool },
}
fn gemini_thinking_dispatch(model: &str) -> GeminiThinkingDispatch {
let m = model.to_lowercase();
if m.starts_with("gemini-3") {
return GeminiThinkingDispatch::Level;
}
if m.starts_with("gemini-2.5-pro") {
return GeminiThinkingDispatch::Budget {
min: 128,
can_disable: false,
};
}
if m.starts_with("gemini-2.5-flash-lite") {
return GeminiThinkingDispatch::Budget {
min: 512,
can_disable: true,
};
}
if m.starts_with("gemini-2.5-flash") {
return GeminiThinkingDispatch::Budget {
min: 0,
can_disable: true,
};
}
GeminiThinkingDispatch::Disabled
}
fn to_gemini_tools(openai_tools: &[&Value]) -> Vec<Value> {
let declarations: Vec<Value> = openai_tools
.iter()
.filter_map(|tool| {
let function = tool.get("function")?;
let name = function.get("name")?.as_str()?;
let description = function
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("");
let parameters = function.get("parameters").cloned().unwrap_or(json!({
"type": "object",
"properties": {}
}));
Some(json!({
"name": name,
"description": description,
"parameters": parameters,
}))
})
.collect();
if declarations.is_empty() {
Vec::new()
} else {
vec![json!({"functionDeclarations": declarations})]
}
}
fn convert_messages(messages: &[ChatMessage]) -> (Option<Value>, Vec<Value>) {
let mut system: Option<Value> = None;
let mut out: Vec<Value> = Vec::new();
let mut i = 0;
while i < messages.len() {
let msg = &messages[i];
match msg.role {
MessageRole::System => {
if system.is_none() && !msg.content.is_empty() {
system = Some(json!({
"parts": [{"text": msg.content}],
}));
}
i += 1;
},
MessageRole::User => {
let mut parts: Vec<Value> = Vec::new();
if !msg.content.is_empty() {
parts.push(json!({"text": msg.content}));
}
if let Some(ref images) = msg.images {
for data in images {
parts.push(json!({
"inlineData": {
"mimeType": "image/png",
"data": data,
}
}));
}
}
if parts.is_empty() {
parts.push(json!({"text": ""}));
}
out.push(json!({"role": "user", "parts": parts}));
i += 1;
},
MessageRole::Assistant => {
let mut parts: Vec<Value> = Vec::new();
if let Some(ref thinking) = msg.thinking
&& !thinking.is_empty()
{
parts.push(json!({
"text": thinking,
"thought": true,
}));
}
if !msg.content.is_empty() {
parts.push(json!({"text": msg.content}));
}
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
parts.push(json!({
"functionCall": {
"name": tc.function.name,
"args": tc.function.arguments,
}
}));
}
}
if parts.is_empty() {
i += 1;
continue;
}
out.push(json!({"role": "model", "parts": parts}));
i += 1;
},
MessageRole::Tool => {
let mut parts: Vec<Value> = Vec::new();
while i < messages.len() && messages[i].role == MessageRole::Tool {
let t = &messages[i];
let name = t
.tool_name
.clone()
.unwrap_or_else(|| "unknown_tool".to_string());
parts.push(json!({
"functionResponse": {
"name": name,
"response": {"result": t.content},
}
}));
i += 1;
}
out.push(json!({"role": "user", "parts": parts}));
},
}
}
(system, out)
}
pub struct GeminiAdapter {
client: Client,
api_key: String,
base_url: String,
model_name: String,
capabilities: ModelCapabilities,
}
impl GeminiAdapter {
pub fn new(api_key: String, model_name: String, base_url: String) -> Result<Self> {
let client = Client::builder()
.pool_max_idle_per_host(10)
.pool_idle_timeout(Duration::from_secs(90))
.tcp_keepalive(Duration::from_secs(60))
.connect_timeout(Duration::from_secs(10))
.build()
.map_err(|e| {
ModelError::Backend(BackendError::ConnectionFailed {
backend: "gemini".to_string(),
url: base_url.clone(),
reason: e.to_string(),
})
})?;
let capabilities = ModelCapabilities {
supports_tools: true,
supports_vision: true,
supports_reasoning: ReasoningCapability::Levels(vec![
ReasoningLevel::None,
ReasoningLevel::Minimal,
ReasoningLevel::Low,
ReasoningLevel::Medium,
ReasoningLevel::High,
ReasoningLevel::Max,
ReasoningLevel::XHigh,
]),
max_context_tokens: None,
};
Ok(Self {
client,
api_key,
base_url,
model_name,
capabilities,
})
}
fn build_request_body(&self, messages: &[ChatMessage], config: &ModelConfig) -> Value {
let (system_from_msgs, gemini_contents) = convert_messages(messages);
let system = match (config.combined_system_prompt(), system_from_msgs) {
(Some(s), _) if !s.is_empty() => Some(json!({
"parts": [{"text": s}],
})),
(_, Some(v)) => Some(v),
_ => None,
};
let mut body = json!({
"contents": gemini_contents,
});
if let Some(s) = system {
body["systemInstruction"] = s;
}
let mut gen_config = json!({});
gen_config["temperature"] = json!(config.temperature.clamp(0.0, 2.0));
if config.max_tokens > 0 {
gen_config["maxOutputTokens"] = json!(config.max_tokens);
}
let effective_reasoning = match &self.capabilities.supports_reasoning {
ReasoningCapability::Levels(supported) => {
nearest_effort(config.reasoning, supported).unwrap_or(ReasoningLevel::None)
},
_ => config.reasoning,
};
match gemini_thinking_dispatch(&self.model_name) {
GeminiThinkingDispatch::Disabled => {
},
GeminiThinkingDispatch::Level => {
let level_str = thinking_level_for(effective_reasoning);
gen_config["thinkingConfig"] = json!({
"thinkingLevel": level_str,
"includeThoughts": effective_reasoning != ReasoningLevel::None,
});
},
GeminiThinkingDispatch::Budget { min, can_disable } => {
let raw = thinking_budget_for(effective_reasoning);
let budget = if effective_reasoning == ReasoningLevel::None {
if can_disable { 0 } else { min }
} else if raw < 0 {
-1
} else {
raw.max(min)
};
gen_config["thinkingConfig"] = json!({
"thinkingBudget": budget,
"includeThoughts": budget != 0,
});
},
}
body["generationConfig"] = gen_config;
let no_cloud_key = crate::ollama::get_cloud_api_key().is_none();
let filtered: Vec<&Value> = config
.tools
.iter()
.filter(|t| {
let name = t
.pointer("/function/name")
.and_then(|n| n.as_str())
.unwrap_or("");
!(no_cloud_key && (name == "web_search" || name == "web_fetch"))
})
.collect();
let gemini_tools = to_gemini_tools(&filtered);
if !gemini_tools.is_empty() {
body["tools"] = json!(gemini_tools);
}
body
}
async fn send_chat(&self, body: &Value, stream: bool) -> Result<reqwest::Response> {
let method = if stream {
"streamGenerateContent?alt=sse"
} else {
"generateContent"
};
let url = format!(
"{}/models/{}:{}",
self.base_url.trim_end_matches('/'),
self.model_name,
method
);
crate::effect::retry_transient_http(|| async {
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.header("content-type", "application/json")
.json(body)
.send()
.await
.map_err(|e| {
ModelError::Backend(BackendError::ConnectionFailed {
backend: "gemini".to_string(),
url: url.clone(),
reason: e.to_string(),
})
})
})
.await
}
async fn decode_non_streaming(&self, response: reqwest::Response) -> Result<ModelResponse> {
if !response.status().is_success() {
return Err(http_error_from_response(response).await);
}
let json: GeminiResponse = response.json().await.map_err(|e| ModelError::ParseError {
message: format!("Failed to parse Gemini response: {}", e),
raw: None,
})?;
let mut text_acc = String::new();
let mut thinking_acc = String::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
if let Some(candidate) = json.candidates.into_iter().next() {
for part in candidate.content.parts {
if let Some(text) = part.text {
if part.thought.unwrap_or(false) {
thinking_acc.push_str(&text);
} else {
text_acc.push_str(&text);
}
} else if let Some(fc) = part.function_call {
let id = format!("call_{}", tool_calls.len());
tool_calls.push(ToolCall {
id: Some(id),
function: FunctionCall {
name: fc.name,
arguments: fc.args,
},
});
}
}
}
let prompt_tokens = json.usage_metadata.prompt_token_count.unwrap_or(0);
let completion_tokens = json.usage_metadata.candidates_token_count.unwrap_or(0);
let reasoning_tokens = json.usage_metadata.thoughts_token_count.unwrap_or(0);
let usage = TokenUsage::provider(
prompt_tokens,
completion_tokens,
json.usage_metadata.total_token_count.unwrap_or_else(|| {
prompt_tokens
.saturating_add(completion_tokens)
.saturating_add(reasoning_tokens)
}),
)
.with_cached_input(json.usage_metadata.cached_content_token_count.unwrap_or(0))
.with_reasoning_output(reasoning_tokens);
Ok(ModelResponse {
content: text_acc,
usage: Some(usage),
model_name: self.model_name.clone(),
thinking: if thinking_acc.is_empty() {
None
} else {
Some(thinking_acc)
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
thinking_signature: None,
})
}
async fn handle_stream(
&self,
response: reqwest::Response,
callback: StreamCallback,
hide_reasoning_trace: bool,
) -> Result<ModelResponse> {
if !response.status().is_success() {
return Err(http_error_from_response(response).await);
}
let mut stream = response.bytes_stream();
let mut buf: Vec<u8> = Vec::new();
let mut state = StreamState::default();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| ModelError::StreamError(e.to_string()))?;
buf.extend_from_slice(&chunk);
for payload in drain_sse_events(&mut buf) {
process_chunk_payload(&payload, &mut state, &callback, hide_reasoning_trace)?;
}
}
let total_tokens = if state.total_tokens > 0 {
state.total_tokens
} else {
state
.prompt_tokens
.saturating_add(state.completion_tokens)
.saturating_add(state.reasoning_output_tokens)
};
Ok(ModelResponse {
content: state.text_acc,
usage: Some(
TokenUsage::provider(state.prompt_tokens, state.completion_tokens, total_tokens)
.with_cached_input(state.cached_input_tokens)
.with_reasoning_output(state.reasoning_output_tokens),
),
model_name: self.model_name.clone(),
thinking: if state.thinking_acc.is_empty() {
None
} else {
Some(state.thinking_acc)
},
tool_calls: if state.tool_calls_done.is_empty() {
None
} else {
Some(state.tool_calls_done)
},
thinking_signature: None,
})
}
}
#[derive(Debug, Default)]
struct StreamState {
text_acc: String,
thinking_acc: String,
tool_calls_done: Vec<ToolCall>,
truncated: bool,
prompt_tokens: usize,
completion_tokens: usize,
cached_input_tokens: usize,
reasoning_output_tokens: usize,
total_tokens: usize,
}
fn process_chunk_payload(
payload: &str,
state: &mut StreamState,
callback: &StreamCallback,
hide_reasoning_trace: bool,
) -> Result<()> {
let parsed: Value = serde_json::from_str(payload).map_err(|e| ModelError::ParseError {
message: format!("Failed to parse Gemini stream chunk: {}", e),
raw: Some(payload.to_string()),
})?;
if let Some(err) = parsed.get("error") {
let code = err
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("UNKNOWN");
let msg = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("Gemini stream error");
return Err(ModelError::Backend(BackendError::ProviderError {
provider: "gemini".to_string(),
code: Some(code.to_string()),
message: msg.to_string(),
}));
}
if let Some(usage) = parsed.get("usageMetadata") {
if let Some(p) = usage.get("promptTokenCount").and_then(|v| v.as_u64()) {
state.prompt_tokens = p as usize;
}
if let Some(c) = usage.get("candidatesTokenCount").and_then(|v| v.as_u64()) {
state.completion_tokens = c as usize;
}
if let Some(t) = usage.get("totalTokenCount").and_then(|v| v.as_u64()) {
state.total_tokens = t as usize;
}
if let Some(cached) = usage
.get("cachedContentTokenCount")
.and_then(|v| v.as_u64())
{
state.cached_input_tokens = cached as usize;
}
if let Some(thoughts) = usage.get("thoughtsTokenCount").and_then(|v| v.as_u64()) {
state.reasoning_output_tokens = thoughts as usize;
}
}
let Some(parts_arr) = parsed
.pointer("/candidates/0/content/parts")
.and_then(|v| v.as_array())
else {
return Ok(());
};
for part in parts_arr {
if let Some(fc) = part.get("functionCall") {
let name = fc
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let args = fc.get("args").cloned().unwrap_or_else(|| json!({}));
if name.is_empty() {
continue;
}
let id = format!("call_{}", state.tool_calls_done.len());
let tc = ToolCall {
id: Some(id),
function: FunctionCall {
name,
arguments: args,
},
};
callback(StreamEvent::ToolCall(tc.clone()));
state.tool_calls_done.push(tc);
continue;
}
let Some(text) = part.get("text").and_then(|v| v.as_str()) else {
continue;
};
if text.is_empty() || state.truncated {
continue;
}
let is_thought = part
.get("thought")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if is_thought {
if !hide_reasoning_trace {
callback(StreamEvent::Reasoning(ReasoningChunk {
text: text.to_string(),
signature: None,
}));
}
push_capped(
&mut state.thinking_acc,
text,
&mut state.truncated,
MAX_RESPONSE_CHARS,
);
} else {
callback(StreamEvent::Text(text.to_string()));
push_capped(
&mut state.text_acc,
text,
&mut state.truncated,
MAX_RESPONSE_CHARS,
);
}
}
Ok(())
}
#[async_trait]
impl Model for GeminiAdapter {
fn name(&self) -> &str {
&self.model_name
}
fn capabilities(&self) -> &ModelCapabilities {
&self.capabilities
}
async fn list_models(&self) -> Result<Vec<String>> {
Err(ModelError::Unsupported {
feature: "list_models (gemini)".to_string(),
})
}
async fn chat(
&self,
messages: &[ChatMessage],
config: &ModelConfig,
callback: Option<StreamCallback>,
) -> Result<ModelResponse> {
let body = self.build_request_body(messages, config);
let stream = callback.is_some();
let response = self.send_chat(&body, stream).await?;
if let Some(cb) = callback {
self.handle_stream(response, cb, config.hide_reasoning_trace)
.await
} else {
self.decode_non_streaming(response).await
}
}
}
#[derive(Debug, Deserialize)]
struct GeminiResponse {
#[serde(default)]
candidates: Vec<Candidate>,
#[serde(default, rename = "usageMetadata")]
usage_metadata: UsageMetadata,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: CandidateContent,
}
#[derive(Debug, Deserialize)]
struct CandidateContent {
#[serde(default)]
parts: Vec<ResponsePart>,
}
#[derive(Debug, Deserialize)]
struct ResponsePart {
#[serde(default)]
text: Option<String>,
#[serde(default)]
thought: Option<bool>,
#[serde(default, rename = "functionCall")]
function_call: Option<FunctionCallOut>,
}
#[derive(Debug, Deserialize)]
struct FunctionCallOut {
name: String,
#[serde(default)]
args: Value,
}
#[derive(Debug, Default, Deserialize)]
struct UsageMetadata {
#[serde(default, rename = "promptTokenCount")]
prompt_token_count: Option<usize>,
#[serde(default, rename = "candidatesTokenCount")]
candidates_token_count: Option<usize>,
#[serde(default, rename = "cachedContentTokenCount")]
cached_content_token_count: Option<usize>,
#[serde(default, rename = "thoughtsTokenCount")]
thoughts_token_count: Option<usize>,
#[serde(default, rename = "totalTokenCount")]
total_token_count: Option<usize>,
}
async fn http_error_from_response(response: reqwest::Response) -> ModelError {
let status = response.status().as_u16();
let body = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
if let Ok(parsed) = serde_json::from_str::<Value>(&body)
&& let Some(err) = parsed.get("error")
{
let code = err.get("status").and_then(|v| v.as_str()).map(String::from);
let msg = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or(&body)
.to_string();
let suffix = if code.as_deref() == Some("PERMISSION_DENIED") {
" (check that GOOGLE_API_KEY is valid and the Generative Language API is enabled)"
} else if code.as_deref() == Some("INVALID_ARGUMENT")
&& msg.to_lowercase().contains("thinkingbudget")
{
" (thinkingBudget out of range for this model — file an issue at github.com/noahsabaj/mermaid)"
} else {
""
};
return ModelError::Backend(BackendError::ProviderError {
provider: "gemini".to_string(),
code,
message: format!("{}{}", msg, suffix),
});
}
ModelError::Backend(BackendError::HttpError {
status,
message: body,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::tool_call::{FunctionCall, ToolCall};
fn test_adapter() -> GeminiAdapter {
GeminiAdapter::new(
"test-key".to_string(),
"gemini-3-pro".to_string(),
"https://generativelanguage.googleapis.com/v1beta".to_string(),
)
.expect("adapter constructs")
}
#[test]
fn thinking_budget_per_level() {
assert_eq!(thinking_budget_for(ReasoningLevel::None), 0);
assert_eq!(thinking_budget_for(ReasoningLevel::Minimal), 512);
assert_eq!(thinking_budget_for(ReasoningLevel::Low), 2048);
assert_eq!(thinking_budget_for(ReasoningLevel::Medium), 8192);
assert_eq!(thinking_budget_for(ReasoningLevel::High), 24576);
assert_eq!(thinking_budget_for(ReasoningLevel::Max), -1);
}
#[test]
fn tool_translation_groups_into_function_declarations() {
let openai = [
json!({
"type": "function",
"function": {
"name": "read_file",
"description": "Read a file",
"parameters": {"type": "object", "properties": {"path": {"type": "string"}}}
}
}),
json!({
"type": "function",
"function": {
"name": "write_file",
"description": "Write a file",
"parameters": {"type": "object", "properties": {}}
}
}),
];
let refs: Vec<&Value> = openai.iter().collect();
let gemini = to_gemini_tools(&refs);
assert_eq!(gemini.len(), 1);
let decls = gemini[0]["functionDeclarations"].as_array().unwrap();
assert_eq!(decls.len(), 2);
assert_eq!(decls[0]["name"], "read_file");
assert_eq!(decls[0]["description"], "Read a file");
assert!(decls[0].get("parameters").is_some());
assert!(decls[0].get("function").is_none());
assert!(decls[0].get("type").is_none());
}
#[test]
fn tool_translation_handles_missing_description() {
let openai = [json!({
"type": "function",
"function": {
"name": "no_desc",
"parameters": {"type": "object", "properties": {}}
}
})];
let refs: Vec<&Value> = openai.iter().collect();
let gemini = to_gemini_tools(&refs);
let decls = gemini[0]["functionDeclarations"].as_array().unwrap();
assert_eq!(decls[0]["description"], "");
}
#[test]
fn tool_translation_empty_returns_empty() {
let gemini = to_gemini_tools(&[]);
assert!(gemini.is_empty());
}
#[test]
fn convert_messages_extracts_system_first() {
let messages = vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("hi"),
ChatMessage::system("ignored second system"),
];
let (system, contents) = convert_messages(&messages);
let sys = system.expect("system extracted");
assert_eq!(sys["parts"][0]["text"], "You are helpful.");
assert_eq!(contents.len(), 1);
assert_eq!(contents[0]["role"], "user");
}
#[test]
fn convert_messages_renames_assistant_to_model() {
let messages = vec![ChatMessage::user("hi"), ChatMessage::assistant("hello")];
let (_system, contents) = convert_messages(&messages);
assert_eq!(contents[0]["role"], "user");
assert_eq!(contents[1]["role"], "model");
assert_eq!(contents[1]["parts"][0]["text"], "hello");
}
#[test]
fn convert_messages_merges_consecutive_tool_messages() {
let messages = vec![
ChatMessage::user("read two files"),
ChatMessage::tool("call_0", "read_file", "contents of A"),
ChatMessage::tool("call_1", "read_file", "contents of B"),
ChatMessage::user("now compare them"),
];
let (_, contents) = convert_messages(&messages);
assert_eq!(contents.len(), 3);
assert_eq!(contents[1]["role"], "user");
let parts = contents[1]["parts"].as_array().unwrap();
assert_eq!(parts.len(), 2, "two functionResponse parts");
assert_eq!(parts[0]["functionResponse"]["name"], "read_file");
assert_eq!(
parts[0]["functionResponse"]["response"]["result"],
"contents of A"
);
assert_eq!(
parts[1]["functionResponse"]["response"]["result"],
"contents of B"
);
}
#[test]
fn convert_messages_emits_function_call_part_for_assistant_tool_call() {
let mut msg = ChatMessage::assistant("");
msg.tool_calls = Some(vec![ToolCall {
id: Some("call_0".to_string()),
function: FunctionCall {
name: "read_file".to_string(),
arguments: json!({"path": "Cargo.toml"}),
},
}]);
let messages = vec![ChatMessage::user("read it"), msg];
let (_, contents) = convert_messages(&messages);
assert_eq!(contents[1]["role"], "model");
let parts = contents[1]["parts"].as_array().unwrap();
assert_eq!(parts.len(), 1);
let fc = &parts[0]["functionCall"];
assert_eq!(fc["name"], "read_file");
assert_eq!(fc["args"]["path"], "Cargo.toml");
}
#[test]
fn convert_messages_emits_inline_data_for_user_images() {
let msg = ChatMessage::user("look at this").with_images(vec!["base64data".to_string()]);
let (_, contents) = convert_messages(&[msg]);
let parts = contents[0]["parts"].as_array().unwrap();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0]["text"], "look at this");
assert_eq!(parts[1]["inlineData"]["mimeType"], "image/png");
assert_eq!(parts[1]["inlineData"]["data"], "base64data");
}
#[test]
fn convert_messages_emits_thought_part_for_assistant_thinking() {
let mut msg = ChatMessage::assistant("the answer is 42");
msg.thinking = Some("step 1: think hard".to_string());
let messages = vec![ChatMessage::user("compute"), msg];
let (_, contents) = convert_messages(&messages);
let parts = contents[1]["parts"].as_array().unwrap();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0]["text"], "step 1: think hard");
assert_eq!(parts[0]["thought"], true);
assert_eq!(parts[1]["text"], "the answer is 42");
assert!(parts[1].get("thought").is_none());
}
#[test]
fn capabilities_advertise_full_reasoning_levels_and_vision() {
let adapter = test_adapter();
let caps = adapter.capabilities();
assert!(caps.supports_tools);
assert!(caps.supports_vision);
match &caps.supports_reasoning {
ReasoningCapability::Levels(levels) => {
assert!(levels.contains(&ReasoningLevel::None));
assert!(levels.contains(&ReasoningLevel::Minimal));
assert!(levels.contains(&ReasoningLevel::Max));
},
other => panic!("expected Levels, got {:?}", other),
}
}
#[test]
fn name_returns_model_id() {
let adapter = test_adapter();
assert_eq!(adapter.name(), "gemini-3-pro");
}
#[test]
fn build_request_body_includes_required_fields() {
let adapter = test_adapter();
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig::default();
let body = adapter.build_request_body(&messages, &config);
assert!(body.get("model").is_none());
assert!(body["contents"].is_array());
let contents = body["contents"].as_array().unwrap();
assert_eq!(contents[0]["role"], "user");
assert_eq!(contents[0]["parts"][0]["text"], "hi");
assert!(body["generationConfig"].is_object());
}
#[test]
fn build_request_body_wraps_system_in_content_object() {
let adapter = test_adapter();
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
system_prompt: Some("You are helpful.".to_string()),
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let sys = &body["systemInstruction"];
assert!(sys.is_object());
assert_eq!(sys["parts"][0]["text"], "You are helpful.");
}
#[test]
fn build_request_body_concats_dynamic_suffix_to_system_instruction() {
let adapter = test_adapter();
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
system_prompt: Some("You are Mermaid.".to_string()),
dynamic_system_suffix: Some("Project rule: always snake_case.".to_string()),
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let text = body["systemInstruction"]["parts"][0]["text"]
.as_str()
.expect("systemInstruction text");
assert!(text.contains("You are Mermaid."));
assert!(text.contains("Project rule: always snake_case."));
assert!(text.contains("---"));
}
#[test]
fn build_request_body_thinking_level_for_medium_on_gemini_3() {
let adapter = test_adapter(); let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
reasoning: ReasoningLevel::Medium,
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let tc = &body["generationConfig"]["thinkingConfig"];
assert_eq!(tc["thinkingLevel"], "medium");
assert_eq!(tc["includeThoughts"], true);
assert!(tc.get("thinkingBudget").is_none());
}
#[test]
fn build_request_body_thinking_level_for_max_collapses_to_high_on_gemini_3() {
let adapter = test_adapter(); let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
reasoning: ReasoningLevel::Max,
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
assert_eq!(
body["generationConfig"]["thinkingConfig"]["thinkingLevel"],
"high"
);
}
#[test]
fn build_request_body_thinking_level_minimal_for_none_on_gemini_3() {
let adapter = test_adapter(); let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
reasoning: ReasoningLevel::None,
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let tc = &body["generationConfig"]["thinkingConfig"];
assert_eq!(tc["thinkingLevel"], "minimal");
assert_eq!(tc["includeThoughts"], false);
}
#[test]
fn build_request_body_thinking_budget_clamps_to_min_128_on_gemini_2_5_pro_for_none() {
let adapter = GeminiAdapter::new(
"test-key".to_string(),
"gemini-2.5-pro".to_string(),
"https://generativelanguage.googleapis.com/v1beta".to_string(),
)
.expect("adapter constructs");
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
reasoning: ReasoningLevel::None,
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let tc = &body["generationConfig"]["thinkingConfig"];
assert_eq!(tc["thinkingBudget"], 128);
assert_eq!(tc["includeThoughts"], true);
}
#[test]
fn build_request_body_thinking_budget_zero_for_none_on_gemini_2_5_flash() {
let adapter = GeminiAdapter::new(
"test-key".to_string(),
"gemini-2.5-flash".to_string(),
"https://generativelanguage.googleapis.com/v1beta".to_string(),
)
.expect("adapter constructs");
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
reasoning: ReasoningLevel::None,
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let tc = &body["generationConfig"]["thinkingConfig"];
assert_eq!(tc["thinkingBudget"], 0);
assert_eq!(tc["includeThoughts"], false);
}
#[test]
fn build_request_body_thinking_budget_adaptive_for_max_on_gemini_2_5_flash() {
let adapter = GeminiAdapter::new(
"test-key".to_string(),
"gemini-2.5-flash".to_string(),
"https://generativelanguage.googleapis.com/v1beta".to_string(),
)
.expect("adapter constructs");
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
reasoning: ReasoningLevel::Max,
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
assert_eq!(
body["generationConfig"]["thinkingConfig"]["thinkingBudget"],
-1
);
}
#[test]
fn build_request_body_omits_thinking_config_on_gemini_2_0() {
let adapter = GeminiAdapter::new(
"test-key".to_string(),
"gemini-2.0-flash".to_string(),
"https://generativelanguage.googleapis.com/v1beta".to_string(),
)
.expect("adapter constructs");
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
reasoning: ReasoningLevel::Medium,
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
assert!(
body["generationConfig"].get("thinkingConfig").is_none(),
"legacy Gemini models must NOT receive thinkingConfig"
);
}
#[test]
fn dispatch_is_level_for_gemini_3_models() {
assert_eq!(
gemini_thinking_dispatch("gemini-3-pro"),
GeminiThinkingDispatch::Level
);
assert_eq!(
gemini_thinking_dispatch("gemini-3-flash"),
GeminiThinkingDispatch::Level
);
assert_eq!(
gemini_thinking_dispatch("gemini-3-flash-lite"),
GeminiThinkingDispatch::Level
);
}
#[test]
fn dispatch_is_budget_with_min_128_no_disable_for_gemini_2_5_pro() {
assert_eq!(
gemini_thinking_dispatch("gemini-2.5-pro"),
GeminiThinkingDispatch::Budget {
min: 128,
can_disable: false
}
);
}
#[test]
fn dispatch_is_budget_with_min_512_can_disable_for_gemini_2_5_flash_lite() {
assert_eq!(
gemini_thinking_dispatch("gemini-2.5-flash-lite"),
GeminiThinkingDispatch::Budget {
min: 512,
can_disable: true
}
);
}
#[test]
fn dispatch_is_budget_with_min_0_can_disable_for_gemini_2_5_flash() {
assert_eq!(
gemini_thinking_dispatch("gemini-2.5-flash"),
GeminiThinkingDispatch::Budget {
min: 0,
can_disable: true
}
);
}
#[test]
fn dispatch_is_disabled_for_legacy_gemini_models() {
assert_eq!(
gemini_thinking_dispatch("gemini-2.0-flash"),
GeminiThinkingDispatch::Disabled
);
assert_eq!(
gemini_thinking_dispatch("gemini-1.5-pro"),
GeminiThinkingDispatch::Disabled
);
}
#[test]
fn thinking_level_per_reasoning_level() {
assert_eq!(thinking_level_for(ReasoningLevel::None), "minimal");
assert_eq!(thinking_level_for(ReasoningLevel::Minimal), "minimal");
assert_eq!(thinking_level_for(ReasoningLevel::Low), "low");
assert_eq!(thinking_level_for(ReasoningLevel::Medium), "medium");
assert_eq!(thinking_level_for(ReasoningLevel::High), "high");
assert_eq!(thinking_level_for(ReasoningLevel::Max), "high");
assert_eq!(thinking_level_for(ReasoningLevel::XHigh), "high");
}
#[test]
fn thinking_budget_for_xhigh_matches_max_adaptive_sentinel() {
assert_eq!(thinking_budget_for(ReasoningLevel::Max), -1);
assert_eq!(thinking_budget_for(ReasoningLevel::XHigh), -1);
}
#[test]
fn build_request_body_includes_tools_in_function_declarations_shape() {
let adapter = test_adapter();
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
tools: (0..5)
.map(|i| {
serde_json::json!({
"type": "function",
"function": {
"name": format!("tool_{}", i),
"description": "a test tool",
"parameters": {"type": "object"}
}
})
})
.collect(),
..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let tools = body["tools"].as_array().expect("tools array");
assert!(!tools.is_empty());
assert!(tools[0]["functionDeclarations"].is_array());
let decls = tools[0]["functionDeclarations"].as_array().unwrap();
assert_eq!(decls.len(), 5);
}
#[test]
fn build_request_body_clamps_temperature() {
let adapter = test_adapter();
let messages = vec![ChatMessage::user("hi")];
let config = ModelConfig {
temperature: 5.0, ..Default::default()
};
let body = adapter.build_request_body(&messages, &config);
let temp = body["generationConfig"]["temperature"].as_f64().unwrap();
assert!(temp <= 2.0);
}
use std::sync::Arc;
use std::sync::Mutex;
fn record_callback() -> (StreamCallback, Arc<Mutex<Vec<StreamEvent>>>) {
let events: Arc<Mutex<Vec<StreamEvent>>> = Arc::new(Mutex::new(Vec::new()));
let clone = Arc::clone(&events);
let cb: StreamCallback = Arc::new(move |evt| {
clone.lock().unwrap().push(evt);
});
(cb, events)
}
fn count_text(events: &[StreamEvent]) -> usize {
events
.iter()
.filter(|e| matches!(e, StreamEvent::Text(_)))
.count()
}
fn count_reasoning(events: &[StreamEvent]) -> usize {
events
.iter()
.filter(|e| matches!(e, StreamEvent::Reasoning(_)))
.count()
}
fn count_tool_calls(events: &[StreamEvent]) -> usize {
events
.iter()
.filter(|e| matches!(e, StreamEvent::ToolCall(_)))
.count()
}
#[test]
fn stream_text_only_multi_chunk() {
let (cb, events) = record_callback();
let mut state = StreamState::default();
let chunk1 = json!({
"candidates": [{
"content": {"parts": [{"text": "Hello, "}]}
}]
})
.to_string();
process_chunk_payload(&chunk1, &mut state, &cb, false).unwrap();
let chunk2 = json!({
"candidates": [{
"content": {"parts": [{"text": "world!"}]}
}],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 3,
"totalTokenCount": 8
}
})
.to_string();
process_chunk_payload(&chunk2, &mut state, &cb, false).unwrap();
assert_eq!(state.text_acc, "Hello, world!");
assert_eq!(state.prompt_tokens, 5);
assert_eq!(state.completion_tokens, 3);
assert_eq!(state.total_tokens, 8);
let evts = events.lock().unwrap();
assert_eq!(count_text(&evts), 2);
assert_eq!(count_reasoning(&evts), 0);
assert_eq!(count_tool_calls(&evts), 0);
}
#[test]
fn stream_thought_then_text() {
let (cb, events) = record_callback();
let mut state = StreamState::default();
let chunk1 = json!({
"candidates": [{
"content": {"parts": [{"text": "let me think...", "thought": true}]}
}]
})
.to_string();
process_chunk_payload(&chunk1, &mut state, &cb, false).unwrap();
let chunk2 = json!({
"candidates": [{
"content": {"parts": [{"text": "the answer is 42"}]}
}]
})
.to_string();
process_chunk_payload(&chunk2, &mut state, &cb, false).unwrap();
assert_eq!(state.thinking_acc, "let me think...");
assert_eq!(state.text_acc, "the answer is 42");
let evts = events.lock().unwrap();
assert_eq!(count_reasoning(&evts), 1);
assert_eq!(count_text(&evts), 1);
}
#[test]
fn stream_function_call_emits_tool_call_event() {
let (cb, events) = record_callback();
let mut state = StreamState::default();
let chunk = json!({
"candidates": [{
"content": {
"parts": [
{"functionCall": {"name": "read_file", "args": {"path": "Cargo.toml"}}}
]
}
}]
})
.to_string();
process_chunk_payload(&chunk, &mut state, &cb, false).unwrap();
assert_eq!(state.tool_calls_done.len(), 1);
let tc = &state.tool_calls_done[0];
assert_eq!(tc.function.name, "read_file");
assert_eq!(tc.function.arguments["path"], "Cargo.toml");
assert_eq!(tc.id.as_deref(), Some("call_0"));
let evts = events.lock().unwrap();
assert_eq!(count_tool_calls(&evts), 1);
}
#[test]
fn stream_thought_text_and_tool_call_in_one_chunk() {
let (cb, events) = record_callback();
let mut state = StreamState::default();
let chunk = json!({
"candidates": [{
"content": {
"parts": [
{"text": "thinking...", "thought": true},
{"text": "calling tool now"},
{"functionCall": {"name": "list_dir", "args": {"path": "."}}}
]
}
}]
})
.to_string();
process_chunk_payload(&chunk, &mut state, &cb, false).unwrap();
assert_eq!(state.thinking_acc, "thinking...");
assert_eq!(state.text_acc, "calling tool now");
assert_eq!(state.tool_calls_done.len(), 1);
let evts = events.lock().unwrap();
assert_eq!(count_reasoning(&evts), 1);
assert_eq!(count_text(&evts), 1);
assert_eq!(count_tool_calls(&evts), 1);
}
#[test]
fn stream_hide_reasoning_trace_suppresses_event_but_accumulates() {
let (cb, events) = record_callback();
let mut state = StreamState::default();
let chunk = json!({
"candidates": [{
"content": {"parts": [{"text": "hidden thoughts", "thought": true}]}
}]
})
.to_string();
process_chunk_payload(&chunk, &mut state, &cb, true).unwrap();
assert_eq!(state.thinking_acc, "hidden thoughts");
let evts = events.lock().unwrap();
assert_eq!(count_reasoning(&evts), 0);
}
#[test]
fn stream_mid_stream_error_returns_error() {
let (cb, _events) = record_callback();
let mut state = StreamState::default();
let chunk = json!({
"error": {
"code": 429,
"message": "Resource exhausted",
"status": "RESOURCE_EXHAUSTED"
}
})
.to_string();
let result = process_chunk_payload(&chunk, &mut state, &cb, false);
assert!(result.is_err());
match result {
Err(ModelError::Backend(BackendError::ProviderError { code, message, .. })) => {
assert_eq!(code.as_deref(), Some("RESOURCE_EXHAUSTED"));
assert!(message.contains("Resource exhausted"));
},
other => panic!("expected ProviderError, got {:?}", other),
}
}
#[test]
fn stream_tool_call_ids_are_synthesized_in_sequence() {
let (cb, _events) = record_callback();
let mut state = StreamState::default();
let chunk = json!({
"candidates": [{
"content": {
"parts": [
{"functionCall": {"name": "tool_a", "args": {}}},
{"functionCall": {"name": "tool_b", "args": {}}}
]
}
}]
})
.to_string();
process_chunk_payload(&chunk, &mut state, &cb, false).unwrap();
assert_eq!(state.tool_calls_done.len(), 2);
assert_eq!(state.tool_calls_done[0].id.as_deref(), Some("call_0"));
assert_eq!(state.tool_calls_done[1].id.as_deref(), Some("call_1"));
}
}