use crate::error::LlmError;
use crate::types::{LlmClient, LlmDoneOutcome, LlmEvent, LlmRequest};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use meerkat_core::schema::{CompiledSchema, SchemaCompat, SchemaError, SchemaWarning};
use meerkat_core::{Message, OutputSchema, Provider, StopReason, Usage};
use serde::Deserialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::pin::Pin;
pub struct GeminiClient {
api_key: String,
base_url: String,
http: reqwest::Client,
}
impl GeminiClient {
pub fn new(api_key: String) -> Self {
Self::new_with_base_url(
api_key,
"https://generativelanguage.googleapis.com".to_string(),
)
}
pub fn new_with_base_url(api_key: String, base_url: String) -> Self {
let http =
crate::http::build_http_client_for_base_url(reqwest::Client::builder(), &base_url)
.unwrap_or_else(|_| reqwest::Client::new());
Self {
api_key,
base_url,
http,
}
}
pub fn with_base_url(mut self, url: String) -> Self {
if let Ok(http) =
crate::http::build_http_client_for_base_url(reqwest::Client::builder(), &url)
{
self.http = http;
}
self.base_url = url;
self
}
pub fn from_env() -> Result<Self, LlmError> {
let api_key = std::env::var("RKAT_GEMINI_API_KEY")
.or_else(|_| {
std::env::var("GEMINI_API_KEY").or_else(|_| std::env::var("GOOGLE_API_KEY"))
})
.map_err(|_| LlmError::InvalidApiKey)?;
Ok(Self::new(api_key))
}
fn build_request_body(&self, request: &LlmRequest) -> Result<Value, LlmError> {
let mut contents = Vec::new();
let mut system_instruction = None;
let mut tool_name_by_id: HashMap<String, String> = HashMap::new();
for msg in &request.messages {
match msg {
Message::System(s) => {
system_instruction = Some(serde_json::json!({
"parts": [{"text": s.content}]
}));
}
Message::User(u) => {
contents.push(serde_json::json!({
"role": "user",
"parts": [{"text": u.content}]
}));
}
Message::Assistant(_) => {
return Err(LlmError::InvalidRequest {
message: "Legacy Message::Assistant is not supported by Gemini adapter; use BlockAssistant".to_string(),
});
}
Message::BlockAssistant(a) => {
let mut parts = Vec::new();
for block in &a.blocks {
match block {
meerkat_core::AssistantBlock::Text { text, meta } => {
if !text.is_empty() {
let mut part = serde_json::json!({"text": text});
if let Some(meerkat_core::ProviderMeta::Gemini {
thought_signature,
}) = meta.as_deref()
{
part["thoughtSignature"] =
serde_json::json!(thought_signature);
}
parts.push(part);
}
}
meerkat_core::AssistantBlock::Reasoning { text, .. } => {
if !text.is_empty() {
parts.push(serde_json::json!({"text": format!("[Reasoning: {}]", text)}));
}
}
meerkat_core::AssistantBlock::ToolUse {
id,
name,
args,
meta,
} => {
tool_name_by_id.insert(id.clone(), name.clone());
let args_value: Value = serde_json::from_str(args.get())
.unwrap_or_else(|_| serde_json::json!({}));
let mut part = serde_json::json!({"functionCall": {"name": name, "args": args_value}});
if let Some(meerkat_core::ProviderMeta::Gemini {
thought_signature,
}) = meta.as_deref()
{
part["thoughtSignature"] = serde_json::json!(thought_signature);
}
parts.push(part);
}
_ => {} }
}
contents.push(serde_json::json!({
"role": "model",
"parts": parts
}));
}
Message::ToolResults { results } => {
let parts: Vec<Value> = results
.iter()
.map(|r| {
let function_name = tool_name_by_id
.get(&r.tool_use_id)
.cloned()
.unwrap_or_else(|| r.tool_use_id.clone());
serde_json::json!({
"functionResponse": {
"name": function_name,
"response": {
"content": r.content,
"error": r.is_error
}
}
})
})
.collect();
contents.push(serde_json::json!({
"role": "user",
"parts": parts
}));
}
}
}
let mut body = serde_json::json!({
"contents": contents,
"generationConfig": {
"maxOutputTokens": request.max_tokens,
}
});
if let Some(system) = system_instruction {
body["systemInstruction"] = system;
}
if let Some(temp) = request.temperature {
if let Some(num) = serde_json::Number::from_f64(temp as f64) {
body["generationConfig"]["temperature"] = Value::Number(num);
}
}
if let Some(ref params) = request.provider_params {
let thinking_budget = params.get("thinking_budget").or_else(|| {
params
.get("thinking")
.and_then(|t| t.get("thinking_budget"))
});
if let Some(budget) = thinking_budget {
body["generationConfig"]["thinkingConfig"] = serde_json::json!({
"thinkingBudget": budget
});
}
if let Some(top_k) = params.get("top_k") {
body["generationConfig"]["topK"] = top_k.clone();
}
if let Some(top_p) = params.get("top_p") {
body["generationConfig"]["topP"] = top_p.clone();
}
if let Some(structured) = params.get("structured_output") {
let output_schema: OutputSchema = serde_json::from_value(structured.clone())
.map_err(|e| LlmError::InvalidRequest {
message: format!("Invalid structured_output schema: {e}"),
})?;
let compiled = Self::compile_schema_for_gemini(&output_schema).map_err(|e| {
LlmError::InvalidRequest {
message: e.to_string(),
}
})?;
body["generationConfig"]["responseMimeType"] =
Value::String("application/json".to_string());
body["generationConfig"]["responseSchema"] = compiled.schema;
}
}
if !request.tools.is_empty() {
let function_declarations: Vec<Value> = request
.tools
.iter()
.map(|t| {
serde_json::json!({
"name": t.name,
"description": t.description,
"parameters": Self::sanitize_schema_for_gemini(&t.input_schema)
})
})
.collect();
body["tools"] = serde_json::json!([{
"functionDeclarations": function_declarations
}]);
}
Ok(body)
}
fn sanitize_schema_for_gemini(schema: &Value) -> Value {
match schema {
Value::Object(map) => {
let mut sanitized = serde_json::Map::new();
for (key, value) in map {
if key == "$defs"
|| key == "$ref"
|| key == "$schema"
|| key == "additionalProperties"
|| key == "oneOf"
|| key == "anyOf"
|| key == "allOf"
{
continue;
}
if key == "type" {
if let Value::Array(types) = value {
let primary_type = types
.iter()
.find(|t| t.as_str() != Some("null"))
.cloned()
.unwrap_or_else(|| Value::String("string".to_string()));
sanitized.insert(key.clone(), primary_type);
continue;
}
}
sanitized.insert(key.clone(), Self::sanitize_schema_for_gemini(value));
}
Value::Object(sanitized)
}
Value::Array(arr) => {
Value::Array(arr.iter().map(Self::sanitize_schema_for_gemini).collect())
}
other => other.clone(),
}
}
fn parse_stream_line(line: &str) -> Option<GenerateContentResponse> {
serde_json::from_str(line).ok()
}
fn compile_schema_for_gemini(
output_schema: &OutputSchema,
) -> Result<CompiledSchema, SchemaError> {
let (schema, warnings) =
sanitize_for_gemini(output_schema.schema.as_value(), Provider::Gemini);
if output_schema.compat == SchemaCompat::Strict && !warnings.is_empty() {
return Err(SchemaError::UnsupportedFeatures {
provider: Provider::Gemini,
warnings,
});
}
Ok(CompiledSchema { schema, warnings })
}
}
fn sanitize_for_gemini(schema: &Value, provider: Provider) -> (Value, Vec<SchemaWarning>) {
let mut warnings = Vec::new();
let sanitized = sanitize_gemini_value(schema, provider, "", &mut warnings);
(sanitized, warnings)
}
fn sanitize_gemini_value(
value: &Value,
provider: Provider,
path: &str,
warnings: &mut Vec<SchemaWarning>,
) -> Value {
match value {
Value::Object(obj) => {
let mut sanitized = serde_json::Map::new();
for (key, value) in obj {
if is_gemini_unsupported_key(key) {
warnings.push(SchemaWarning {
provider,
path: join_path(path, key),
message: format!("Removed unsupported keyword '{key}'"),
});
continue;
}
if key == "type" {
if let Value::Array(types) = value {
let primary = types
.iter()
.find(|t| t.as_str() != Some("null"))
.cloned()
.unwrap_or_else(|| Value::String("string".to_string()));
warnings.push(SchemaWarning {
provider,
path: join_path(path, key),
message: "Collapsed array type to a single type; nullable/union semantics may be lost for Gemini".to_string(),
});
sanitized.insert(key.clone(), primary);
continue;
}
}
let next = join_path(path, key);
sanitized.insert(
key.clone(),
sanitize_gemini_value(value, provider, &next, warnings),
);
}
Value::Object(sanitized)
}
Value::Array(items) => Value::Array(
items
.iter()
.enumerate()
.map(|(idx, item)| {
let next = join_index(path, idx);
sanitize_gemini_value(item, provider, &next, warnings)
})
.collect(),
),
other => other.clone(),
}
}
fn is_gemini_unsupported_key(key: &str) -> bool {
matches!(
key,
"$defs" | "$ref" | "$schema" | "additionalProperties" | "oneOf" | "anyOf" | "allOf"
)
}
fn join_path(prefix: &str, key: &str) -> String {
if prefix.is_empty() {
format!("/{key}")
} else {
format!("{prefix}/{key}")
}
}
fn join_index(prefix: &str, index: usize) -> String {
if prefix.is_empty() {
format!("/{index}")
} else {
format!("{prefix}/{index}")
}
}
#[async_trait]
impl LlmClient for GeminiClient {
fn stream<'a>(
&'a self,
request: &'a LlmRequest,
) -> Pin<Box<dyn Stream<Item = Result<LlmEvent, LlmError>> + Send + 'a>> {
let inner: Pin<Box<dyn Stream<Item = Result<LlmEvent, LlmError>> + Send + 'a>> = Box::pin(
async_stream::try_stream! {
let body = self.build_request_body(request)?;
let url = format!(
"{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
self.base_url, request.model, self.api_key
);
let response = self.http
.post(url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|_| LlmError::NetworkTimeout {
duration_ms: 30000,
})?;
let status_code = response.status().as_u16();
let stream_result = if (200..=299).contains(&status_code) {
Ok(response.bytes_stream())
} else {
let text = response.text().await.unwrap_or_default();
Err(LlmError::from_http_status(status_code, text))
};
let mut stream = stream_result?;
let mut buffer = String::with_capacity(512);
let mut tool_call_index: u32 = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|_| LlmError::ConnectionReset)?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].trim();
let data = line.strip_prefix("data: ");
let parsed_response = if let Some(d) = data {
Self::parse_stream_line(d)
} else {
None
};
buffer.drain(..=newline_pos);
if let Some(resp) = parsed_response {
if let Some(usage) = resp.usage_metadata {
yield LlmEvent::UsageUpdate {
usage: Usage {
input_tokens: usage.prompt_token_count.unwrap_or(0),
output_tokens: usage.candidates_token_count.unwrap_or(0),
cache_creation_tokens: None,
cache_read_tokens: None,
}
};
}
if let Some(candidates) = resp.candidates {
for cand in candidates {
if let Some(content) = cand.content {
if let Some(parts) = content.parts {
for part in parts {
let meta = part.thought_signature.as_ref().map(|sig| {
Box::new(meerkat_core::ProviderMeta::Gemini {
thought_signature: sig.clone(),
})
});
if let Some(text) = part.text {
yield LlmEvent::TextDelta { delta: text, meta: meta.clone() };
}
if let Some(fc) = part.function_call {
let id = format!("fc_{}", tool_call_index);
tool_call_index += 1;
yield LlmEvent::ToolCallComplete {
id,
name: fc.name,
args: fc.args.unwrap_or(json!({})),
meta,
};
}
}
}
}
if let Some(reason) = cand.finish_reason {
let stop = match reason.as_str() {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
"SAFETY" | "RECITATION" => StopReason::ContentFilter,
"TOOL_CALL" | "FUNCTION_CALL" => StopReason::ToolUse,
_ => StopReason::EndTurn,
};
yield LlmEvent::Done {
outcome: LlmDoneOutcome::Success { stop_reason: stop },
};
}
}
}
}
}
}
},
);
crate::streaming::ensure_terminal_done(inner)
}
fn provider(&self) -> &'static str {
"gemini"
}
async fn health_check(&self) -> Result<(), LlmError> {
Ok(())
}
fn compile_schema(&self, output_schema: &OutputSchema) -> Result<CompiledSchema, SchemaError> {
GeminiClient::compile_schema_for_gemini(output_schema)
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GenerateContentResponse {
candidates: Option<Vec<Candidate>>,
usage_metadata: Option<GeminiUsage>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Candidate {
content: Option<CandidateContent>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CandidateContent {
parts: Option<Vec<Part>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Part {
text: Option<String>,
function_call: Option<FunctionCall>,
thought_signature: Option<String>,
}
#[derive(Debug, Deserialize)]
struct FunctionCall {
name: String,
args: Option<Value>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiUsage {
prompt_token_count: Option<u64>,
candidates_token_count: Option<u64>,
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::explicit_counter_loop
)]
mod tests {
use super::*;
use meerkat_core::{AssistantBlock, BlockAssistantMessage, ProviderMeta, UserMessage};
#[test]
fn test_build_request_body_with_thinking_budget() -> Result<(), Box<dyn std::error::Error>> {
let client = GeminiClient::new("test-key".to_string());
let request = LlmRequest::new(
"gemini-1.5-pro",
vec![Message::User(UserMessage {
content: "test".to_string(),
})],
)
.with_provider_param("thinking_budget", 10000);
let body = client.build_request_body(&request)?;
let generation_config = body.get("generationConfig").ok_or("missing config")?;
let thinking_config = generation_config
.get("thinkingConfig")
.ok_or("missing thinking")?;
let thinking_budget = thinking_config
.get("thinkingBudget")
.ok_or("missing budget")?;
assert_eq!(thinking_budget.as_i64(), Some(10000));
Ok(())
}
#[test]
fn test_build_request_body_with_top_k() -> Result<(), Box<dyn std::error::Error>> {
let client = GeminiClient::new("test-key".to_string());
let request = LlmRequest::new(
"gemini-1.5-pro",
vec![Message::User(UserMessage {
content: "test".to_string(),
})],
)
.with_provider_param("top_k", 40);
let body = client.build_request_body(&request)?;
let generation_config = body.get("generationConfig").ok_or("missing config")?;
let top_k = generation_config.get("topK").ok_or("missing top_k")?;
assert_eq!(top_k.as_i64(), Some(40));
Ok(())
}
#[test]
fn test_build_request_body_with_multiple_provider_params()
-> Result<(), Box<dyn std::error::Error>> {
let client = GeminiClient::new("test-key".to_string());
let request = LlmRequest::new(
"gemini-1.5-pro",
vec![Message::User(UserMessage {
content: "test".to_string(),
})],
)
.with_provider_param("top_k", 50)
.with_provider_param("thinking_budget", 5000);
let body = client.build_request_body(&request)?;
let generation_config = body.get("generationConfig").ok_or("missing config")?;
let top_k = generation_config.get("topK").ok_or("missing top_k")?;
assert_eq!(top_k.as_i64(), Some(50));
let thinking_config = generation_config
.get("thinkingConfig")
.ok_or("missing thinking")?;
let thinking_budget = thinking_config
.get("thinkingBudget")
.ok_or("missing budget")?;
assert_eq!(thinking_budget.as_i64(), Some(5000));
Ok(())
}
#[test]
fn test_build_request_body_no_provider_params() -> Result<(), Box<dyn std::error::Error>> {
let client = GeminiClient::new("test-key".to_string());
let request = LlmRequest::new(
"gemini-1.5-pro",
vec![Message::User(UserMessage {
content: "test".to_string(),
})],
);
let body = client.build_request_body(&request)?;
let generation_config = body.get("generationConfig").ok_or("missing config")?;
assert!(generation_config.get("thinkingConfig").is_none());
assert!(generation_config.get("topK").is_none());
Ok(())
}
#[test]
fn test_tool_response_uses_function_name_no_signature() -> Result<(), Box<dyn std::error::Error>>
{
use serde_json::value::RawValue;
let client = GeminiClient::new("test-key".to_string());
let args_raw = RawValue::from_string(json!({"city": "Tokyo"}).to_string()).unwrap();
let request = LlmRequest::new(
"gemini-1.5-pro",
vec![
Message::User(UserMessage {
content: "test".to_string(),
}),
Message::BlockAssistant(BlockAssistantMessage {
blocks: vec![AssistantBlock::ToolUse {
id: "call_1".to_string(),
name: "get_weather".to_string(),
args: args_raw,
meta: Some(Box::new(ProviderMeta::Gemini {
thought_signature: "sig_123".to_string(),
})),
}],
stop_reason: StopReason::ToolUse,
}),
Message::ToolResults {
results: vec![meerkat_core::ToolResult::new(
"call_1".to_string(),
"Sunny".to_string(),
false,
)],
},
],
);
let body = client.build_request_body(&request)?;
let contents = body
.get("contents")
.and_then(|c| c.as_array())
.ok_or("missing contents")?;
let model_content = contents
.iter()
.find(|c| c.get("role").and_then(|r| r.as_str()) == Some("model"))
.ok_or("missing model content")?;
let model_parts = model_content
.get("parts")
.and_then(|p| p.as_array())
.ok_or("missing model parts")?;
let fc_part = model_parts
.iter()
.find(|p| p.get("functionCall").is_some())
.ok_or("missing functionCall part")?;
assert_eq!(
fc_part["thoughtSignature"], "sig_123",
"functionCall SHOULD have signature"
);
let tool_result_parts = contents
.last()
.and_then(|c| c.get("parts"))
.and_then(|p| p.as_array())
.ok_or("missing parts")?;
let function_response = &tool_result_parts[0]["functionResponse"];
assert_eq!(function_response["name"], "get_weather");
assert!(
tool_result_parts[0].get("thoughtSignature").is_none(),
"functionResponse MUST NOT have thoughtSignature"
);
Ok(())
}
#[test]
fn test_parse_stream_line_valid_response() -> Result<(), Box<dyn std::error::Error>> {
let line =
r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}]},"finishReason":"STOP"}]}"#;
let response = GeminiClient::parse_stream_line(line);
assert!(response.is_some());
let response = response.ok_or("missing response")?;
assert!(response.candidates.is_some());
let candidates = response.candidates.ok_or("missing candidates")?;
assert_eq!(candidates.len(), 1);
Ok(())
}
#[test]
fn test_parse_stream_line_with_usage() -> Result<(), Box<dyn std::error::Error>> {
let line = r#"{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}"#;
let response = GeminiClient::parse_stream_line(line);
assert!(response.is_some());
let response = response.ok_or("missing response")?;
assert!(response.usage_metadata.is_some());
let usage = response.usage_metadata.ok_or("missing usage")?;
assert_eq!(usage.prompt_token_count, Some(10));
Ok(())
}
#[test]
fn test_parse_stream_line_function_call() -> Result<(), Box<dyn std::error::Error>> {
let line = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}}}]}}]}"#;
let response = GeminiClient::parse_stream_line(line);
assert!(response.is_some());
let response = response.ok_or("missing response")?;
let candidates = response.candidates.as_ref().ok_or("missing candidates")?;
let parts = candidates[0]
.content
.as_ref()
.ok_or("missing content")?
.parts
.as_ref()
.ok_or("missing parts")?;
let fc = parts[0].function_call.as_ref().ok_or("missing fc")?;
assert_eq!(fc.name, "get_weather");
assert_eq!(fc.args.as_ref().ok_or("missing args")?["city"], "Tokyo");
Ok(())
}
#[test]
fn test_parse_stream_line_empty() {
let line = "";
let response = GeminiClient::parse_stream_line(line);
assert!(response.is_none());
}
#[test]
fn test_parse_stream_line_invalid_json() {
let line = "{invalid}";
let response = GeminiClient::parse_stream_line(line);
assert!(response.is_none());
}
#[test]
fn test_regression_gemini_finish_reason_tool_call_maps_to_tool_use() {
let finish_reasons = ["TOOL_CALL", "FUNCTION_CALL"];
for reason in finish_reasons {
let stop = match reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
"SAFETY" | "RECITATION" => StopReason::ContentFilter,
"TOOL_CALL" | "FUNCTION_CALL" => StopReason::ToolUse,
_ => StopReason::EndTurn,
};
assert_eq!(
stop,
StopReason::ToolUse,
"finish_reason '{}' should map to ToolUse",
reason
);
}
}
#[test]
fn test_regression_gemini_finish_reason_recitation_maps_to_content_filter() {
let reason = "RECITATION";
let stop = match reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
"SAFETY" | "RECITATION" => StopReason::ContentFilter,
"TOOL_CALL" | "FUNCTION_CALL" => StopReason::ToolUse,
_ => StopReason::EndTurn,
};
assert_eq!(stop, StopReason::ContentFilter);
}
#[test]
fn test_regression_gemini_tool_call_ids_must_be_unique() {
let mut tool_call_index: u32 = 0;
let tool_names = ["search", "search", "search"];
let mut generated_ids = Vec::new();
for _name in tool_names {
let id = format!("fc_{}", tool_call_index);
tool_call_index += 1;
generated_ids.push(id);
}
assert_eq!(generated_ids[0], "fc_0");
assert_eq!(generated_ids[1], "fc_1");
assert_eq!(generated_ids[2], "fc_2");
let mut seen = std::collections::HashSet::new();
for id in &generated_ids {
assert!(
seen.insert(id.clone()),
"Duplicate tool call ID found: {}",
id
);
}
}
#[test]
fn test_regression_gemini_tool_call_ids_unique_across_different_tools() {
let mut tool_call_index: u32 = 0;
let tool_names = ["search", "write_file", "search", "read_file"];
let mut id_to_name = Vec::new();
for name in tool_names {
let id = format!("fc_{}", tool_call_index);
tool_call_index += 1;
id_to_name.push((id, name));
}
assert_eq!(id_to_name[0], ("fc_0".to_string(), "search"));
assert_eq!(id_to_name[1], ("fc_1".to_string(), "write_file"));
assert_eq!(id_to_name[2], ("fc_2".to_string(), "search")); assert_eq!(id_to_name[3], ("fc_3".to_string(), "read_file"));
}
#[test]
fn test_build_request_body_with_structured_output() -> Result<(), Box<dyn std::error::Error>> {
let client = GeminiClient::new("test-key".to_string());
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
});
let request = LlmRequest::new(
"gemini-3-pro-preview",
vec![Message::User(UserMessage {
content: "test".to_string(),
})],
)
.with_provider_param(
"structured_output",
serde_json::json!({
"schema": schema,
"name": "person",
"strict": true
}),
);
let body = client.build_request_body(&request)?;
let gen_config = body
.get("generationConfig")
.ok_or("missing generationConfig")?;
assert_eq!(gen_config["responseMimeType"], "application/json");
assert!(gen_config.get("responseSchema").is_some());
let response_schema = &gen_config["responseSchema"];
assert!(response_schema.get("$defs").is_none());
assert!(response_schema.get("$ref").is_none());
Ok(())
}
#[test]
fn test_build_request_body_with_structured_output_sanitizes_schema()
-> Result<(), Box<dyn std::error::Error>> {
let client = GeminiClient::new("test-key".to_string());
let schema = serde_json::json!({
"type": "object",
"$defs": {
"Address": {"type": "object"}
},
"$ref": "#/$defs/Address",
"properties": {
"name": {"type": "string"}
},
"additionalProperties": false
});
let request = LlmRequest::new(
"gemini-3-pro-preview",
vec![Message::User(UserMessage {
content: "test".to_string(),
})],
)
.with_provider_param("structured_output", serde_json::json!({"schema": schema}));
let body = client.build_request_body(&request)?;
let gen_config = body
.get("generationConfig")
.ok_or("missing generationConfig")?;
let response_schema = &gen_config["responseSchema"];
assert!(
response_schema.get("$defs").is_none(),
"$defs should be removed"
);
assert!(
response_schema.get("$ref").is_none(),
"$ref should be removed"
);
assert!(
response_schema.get("additionalProperties").is_none(),
"additionalProperties should be removed"
);
assert_eq!(response_schema["type"], "object");
assert!(response_schema.get("properties").is_some());
Ok(())
}
#[test]
fn test_build_request_body_without_structured_output() -> Result<(), Box<dyn std::error::Error>>
{
let client = GeminiClient::new("test-key".to_string());
let request = LlmRequest::new(
"gemini-3-pro-preview",
vec![Message::User(UserMessage {
content: "test".to_string(),
})],
);
let body = client.build_request_body(&request)?;
let gen_config = body
.get("generationConfig")
.ok_or("missing generationConfig")?;
assert!(
gen_config.get("responseMimeType").is_none(),
"responseMimeType should not be present"
);
assert!(
gen_config.get("responseSchema").is_none(),
"responseSchema should not be present"
);
Ok(())
}
#[test]
fn test_sanitize_schema_converts_array_type_to_string() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": ["integer", "null"]},
"email": {"type": ["string", "null"]}
}
});
let sanitized = GeminiClient::sanitize_schema_for_gemini(&schema);
assert_eq!(
sanitized["properties"]["age"]["type"], "integer",
"['integer', 'null'] should become 'integer'"
);
assert_eq!(
sanitized["properties"]["email"]["type"], "string",
"['string', 'null'] should become 'string'"
);
assert_eq!(
sanitized["properties"]["name"]["type"], "string",
"'string' should remain 'string'"
);
}
#[test]
fn test_sanitize_schema_removes_oneof_anyof_allof() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"status": {
"oneOf": [
{"const": "active"},
{"const": "inactive"}
]
},
"value": {
"anyOf": [
{"type": "string"},
{"type": "number"}
]
}
},
"allOf": [
{"required": ["status"]}
]
});
let sanitized = GeminiClient::sanitize_schema_for_gemini(&schema);
assert!(
sanitized["properties"]["status"].get("oneOf").is_none(),
"oneOf should be removed"
);
assert!(
sanitized["properties"]["value"].get("anyOf").is_none(),
"anyOf should be removed"
);
assert!(sanitized.get("allOf").is_none(), "allOf should be removed");
}
#[test]
fn test_parse_function_call_with_thought_signature() -> Result<(), Box<dyn std::error::Error>> {
let line = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}},"thoughtSignature":"sig_abc123"}]}}]}"#;
let response = GeminiClient::parse_stream_line(line).ok_or("missing response")?;
let candidates = response.candidates.as_ref().ok_or("missing candidates")?;
let parts = candidates[0]
.content
.as_ref()
.ok_or("missing content")?
.parts
.as_ref()
.ok_or("missing parts")?;
assert!(
parts[0].function_call.is_some(),
"should have function_call"
);
assert_eq!(
parts[0].thought_signature.as_deref(),
Some("sig_abc123"),
"should have thoughtSignature"
);
Ok(())
}
#[test]
fn test_parse_text_with_thought_signature() -> Result<(), Box<dyn std::error::Error>> {
let line = r#"{"candidates":[{"content":{"parts":[{"text":"Hello world","thoughtSignature":"sig_text_456"}]}}]}"#;
let response = GeminiClient::parse_stream_line(line).ok_or("missing response")?;
let candidates = response.candidates.as_ref().ok_or("missing candidates")?;
let parts = candidates[0]
.content
.as_ref()
.ok_or("missing content")?
.parts
.as_ref()
.ok_or("missing parts")?;
assert_eq!(parts[0].text.as_deref(), Some("Hello world"));
assert_eq!(
parts[0].thought_signature.as_deref(),
Some("sig_text_456"),
"text parts can have thoughtSignature for continuity"
);
Ok(())
}
#[test]
fn test_parallel_calls_only_first_has_signature() -> Result<(), Box<dyn std::error::Error>> {
let line = r#"{"candidates":[{"content":{"parts":[
{"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}},"thoughtSignature":"sig_first"},
{"functionCall":{"name":"get_time","args":{"tz":"JST"}}},
{"functionCall":{"name":"get_population","args":{"city":"Tokyo"}}}
]}}]}"#;
let response = GeminiClient::parse_stream_line(line).ok_or("missing response")?;
let candidates = response.candidates.ok_or("missing candidates")?;
let parts = candidates[0]
.content
.as_ref()
.ok_or("missing content")?
.parts
.as_ref()
.ok_or("missing parts")?;
assert_eq!(parts.len(), 3);
assert_eq!(
parts[0].thought_signature.as_deref(),
Some("sig_first"),
"first parallel call MUST have signature"
);
assert!(
parts[1].thought_signature.is_none(),
"second parallel call must NOT have signature"
);
assert!(
parts[2].thought_signature.is_none(),
"third parallel call must NOT have signature"
);
Ok(())
}
#[test]
fn test_request_building_no_signature_on_function_response()
-> Result<(), Box<dyn std::error::Error>> {
use serde_json::value::RawValue;
let client = GeminiClient::new("test-key".to_string());
let args_raw = RawValue::from_string(json!({"city": "Tokyo"}).to_string()).unwrap();
let request = LlmRequest::new(
"gemini-3-pro-preview",
vec![
Message::User(UserMessage {
content: "What's the weather?".to_string(),
}),
Message::BlockAssistant(BlockAssistantMessage {
blocks: vec![AssistantBlock::ToolUse {
id: "call_1".to_string(),
name: "get_weather".to_string(),
args: args_raw,
meta: Some(Box::new(ProviderMeta::Gemini {
thought_signature: "sig_123".to_string(),
})),
}],
stop_reason: StopReason::ToolUse,
}),
Message::ToolResults {
results: vec![meerkat_core::ToolResult::new(
"call_1".to_string(),
"Sunny, 25C".to_string(),
false,
)],
},
],
);
let body = client.build_request_body(&request)?;
let contents = body
.get("contents")
.and_then(|c| c.as_array())
.ok_or("missing contents")?;
let assistant_content = contents
.iter()
.find(|c| c.get("role").and_then(|r| r.as_str()) == Some("model"))
.ok_or("missing model content")?;
let assistant_parts = assistant_content
.get("parts")
.and_then(|p| p.as_array())
.ok_or("missing parts")?;
let fc_part = assistant_parts
.iter()
.find(|p| p.get("functionCall").is_some())
.ok_or("missing functionCall part")?;
assert!(
fc_part.get("thoughtSignature").is_some(),
"functionCall part SHOULD have thoughtSignature"
);
let tool_results_content = contents.last().ok_or("missing last content")?;
let tool_result_parts = tool_results_content
.get("parts")
.and_then(|p| p.as_array())
.ok_or("missing tool result parts")?;
let fr_part = tool_result_parts
.iter()
.find(|p| p.get("functionResponse").is_some())
.ok_or("missing functionResponse part")?;
assert!(
fr_part.get("thoughtSignature").is_none(),
"functionResponse MUST NOT have thoughtSignature"
);
Ok(())
}
#[test]
fn test_tool_call_complete_uses_provider_meta() {
use meerkat_core::ProviderMeta;
let meta = Some(Box::new(ProviderMeta::Gemini {
thought_signature: "sig_test".to_string(),
}));
let event = LlmEvent::ToolCallComplete {
id: "fc_0".to_string(),
name: "test_tool".to_string(),
args: json!({}),
meta, };
if let LlmEvent::ToolCallComplete { meta: m, .. } = event {
assert!(m.is_some(), "meta should be Some");
match *m.unwrap() {
ProviderMeta::Gemini { thought_signature } => {
assert_eq!(thought_signature, "sig_test");
}
_ => panic!("expected Gemini variant"),
}
}
}
#[test]
fn test_text_delta_uses_provider_meta() {
use meerkat_core::ProviderMeta;
let meta = Some(Box::new(ProviderMeta::Gemini {
thought_signature: "sig_text".to_string(),
}));
let event = LlmEvent::TextDelta {
delta: "Hello".to_string(),
meta,
};
if let LlmEvent::TextDelta { meta: m, .. } = event {
assert!(m.is_some());
match *m.unwrap() {
ProviderMeta::Gemini { thought_signature } => {
assert_eq!(thought_signature, "sig_text");
}
_ => panic!("expected Gemini variant"),
}
}
}
}