use crate::llm::attachments::decode_attachment_bytes;
use crate::llm::{Content, ContentBlock, StopReason, StreamBox, StreamDelta, Usage};
use base64::Engine;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGenerateContentRequest<'a> {
pub contents: &'a [ApiContent],
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<&'a ApiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<&'a [ApiToolConfig]>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_config: Option<ApiGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_content: Option<&'a str>,
}
#[derive(Serialize, Deserialize)]
pub struct ApiContent {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default)]
pub parts: Vec<ApiPart>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum ApiPart {
Text {
text: String,
#[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
thought_signature: Option<String>,
},
InlineData {
#[serde(rename = "inlineData")]
inline_data: ApiBlob,
},
FunctionCall {
#[serde(rename = "functionCall")]
function_call: ApiFunctionCall,
#[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
thought_signature: Option<String>,
},
FunctionResponse {
#[serde(rename = "functionResponse")]
function_response: ApiFunctionResponse,
},
Unknown(serde_json::Value),
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ApiBlob {
pub mime_type: String,
pub data: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ApiFunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ApiFunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiToolConfig {
pub function_declarations: Vec<ApiFunctionDeclaration>,
}
#[derive(Serialize)]
pub struct ApiFunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ApiThinkingConfig>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiThinkingConfig {
pub thinking_level: &'static str,
}
pub const fn map_thinking_config(config: &crate::llm::ThinkingConfig) -> ApiThinkingConfig {
use crate::llm::{Effort, ThinkingMode};
if let Some(effort) = config.effort {
let level = match effort {
Effort::Low => "LOW",
Effort::Medium => "MEDIUM",
Effort::High | Effort::Max => "HIGH",
};
return ApiThinkingConfig {
thinking_level: level,
};
}
let level = match &config.mode {
ThinkingMode::Adaptive => "HIGH",
ThinkingMode::Enabled { budget_tokens } => {
if *budget_tokens <= 4_096 {
"LOW"
} else if *budget_tokens <= 16_384 {
"MEDIUM"
} else {
"HIGH"
}
}
};
ApiThinkingConfig {
thinking_level: level,
}
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiGenerateContentResponse {
pub candidates: Vec<ApiCandidate>,
pub usage_metadata: Option<ApiUsageMetadata>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiCandidate {
#[serde(default = "ApiCandidate::empty_content")]
pub content: ApiContent,
pub finish_reason: Option<ApiFinishReason>,
}
impl ApiCandidate {
const fn empty_content() -> ApiContent {
ApiContent {
role: None,
parts: Vec::new(),
}
}
}
#[derive(Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ApiFinishReason {
Stop,
MaxTokens,
Safety,
Recitation,
#[serde(other)]
Other,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiUsageMetadata {
#[serde(default, rename = "promptTokenCount")]
pub prompt: u32,
#[serde(default, rename = "candidatesTokenCount")]
pub candidates: u32,
#[serde(default, rename = "cachedContentTokenCount")]
pub cached_content: u32,
}
impl ApiUsageMetadata {
pub const fn into_usage(self) -> Usage {
Usage {
input_tokens: self.prompt,
output_tokens: self.candidates,
cached_input_tokens: self.cached_content,
}
}
}
pub fn build_api_contents(messages: &[crate::llm::Message]) -> Vec<ApiContent> {
let mut tool_names: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
for msg in messages {
if let Content::Blocks(blocks) = &msg.content {
for block in blocks {
if let ContentBlock::ToolUse { id, name, .. } = block {
tool_names.insert(id.clone(), name.clone());
}
}
}
}
let mut contents = Vec::new();
for msg in messages {
let role = match msg.role {
crate::llm::Role::User => "user",
crate::llm::Role::Assistant => "model",
};
let parts = match &msg.content {
Content::Text(text) => vec![ApiPart::Text {
text: text.clone(),
thought_signature: None,
}],
Content::Blocks(blocks) => {
let mut parts = Vec::new();
for block in blocks {
match block {
ContentBlock::Text { text } => {
parts.push(ApiPart::Text {
text: text.clone(),
thought_signature: None,
});
}
ContentBlock::Thinking { .. } | ContentBlock::RedactedThinking { .. } => {}
ContentBlock::Image { source } | ContentBlock::Document { source } => {
let bytes = decode_attachment_bytes(&source.data)
.unwrap_or_else(|_| Vec::new());
parts.push(ApiPart::InlineData {
inline_data: ApiBlob {
mime_type: source.media_type.clone(),
data: base64::engine::general_purpose::STANDARD.encode(bytes),
},
});
}
ContentBlock::ToolUse {
id: _,
name,
input,
thought_signature,
} => {
parts.push(ApiPart::FunctionCall {
function_call: ApiFunctionCall {
name: name.clone(),
args: input.clone(),
},
thought_signature: thought_signature.clone(),
});
}
ContentBlock::ToolResult {
tool_use_id,
content,
is_error,
} => {
let func_name = tool_names
.get(tool_use_id)
.cloned()
.unwrap_or_else(|| "unknown_function".to_owned());
let response = if is_error.unwrap_or(false) {
serde_json::json!({ "error": content })
} else {
serde_json::json!({ "result": content })
};
parts.push(ApiPart::FunctionResponse {
function_response: ApiFunctionResponse {
name: func_name,
response,
},
});
}
}
}
parts
}
};
contents.push(ApiContent {
role: Some(role.to_owned()),
parts,
});
}
contents
}
pub fn convert_tools_to_config(tools: Vec<crate::llm::Tool>) -> ApiToolConfig {
ApiToolConfig {
function_declarations: tools
.into_iter()
.map(|t| ApiFunctionDeclaration {
name: t.name,
description: t.description,
parameters: t.input_schema,
})
.collect(),
}
}
pub fn build_content_blocks(content: &ApiContent) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
for part in &content.parts {
match part {
ApiPart::Text { text, .. } => {
if !text.is_empty() {
blocks.push(ContentBlock::Text { text: text.clone() });
}
}
ApiPart::FunctionCall {
function_call,
thought_signature,
} => {
let id = format!("call_{}", uuid_simple());
blocks.push(ContentBlock::ToolUse {
id,
name: function_call.name.clone(),
input: function_call.args.clone(),
thought_signature: thought_signature.clone(),
});
}
ApiPart::InlineData { .. } | ApiPart::FunctionResponse { .. } => {
}
ApiPart::Unknown(value) => {
log::warn!("Unknown API part type in Gemini response, skipping part={value:?}");
}
}
}
blocks
}
pub fn uuid_simple() -> String {
uuid::Uuid::new_v4().to_string()
}
pub const fn map_finish_reason(reason: &ApiFinishReason, has_tool_calls: bool) -> StopReason {
if has_tool_calls {
StopReason::ToolUse
} else {
match reason {
ApiFinishReason::Stop | ApiFinishReason::Other => StopReason::EndTurn,
ApiFinishReason::MaxTokens => StopReason::MaxTokens,
ApiFinishReason::Safety | ApiFinishReason::Recitation => StopReason::StopSequence,
}
}
}
pub fn stream_gemini_response(response: reqwest::Response) -> StreamBox<'static> {
Box::pin(async_stream::stream! {
let mut block_index: usize = 0;
let mut in_text_block = false;
let mut saw_function_call = false;
let mut usage: Option<Usage> = None;
let mut stop_reason: Option<StopReason> = None;
let mut buffer = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let Ok(chunk) = chunk_result else {
yield Err(anyhow::anyhow!("stream error"));
return;
};
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].trim().to_string();
buffer = buffer[pos + 1..].to_string();
if line.is_empty() {
continue;
}
let Some(data) = line.strip_prefix("data: ") else {
continue;
};
let Ok(resp) =
serde_json::from_str::<ApiGenerateContentResponse>(data)
else {
continue;
};
if let Some(u) = resp.usage_metadata {
usage = Some(u.into_usage());
}
if let Some(candidate) = resp.candidates.into_iter().next() {
if let Some(reason) = &candidate.finish_reason {
stop_reason = Some(map_finish_reason(reason, false));
}
for part in &candidate.content.parts {
match part {
ApiPart::Text { text, .. } => {
if !text.is_empty() {
if !in_text_block {
in_text_block = true;
}
yield Ok(StreamDelta::TextDelta {
delta: text.clone(),
block_index,
});
}
}
ApiPart::FunctionCall { function_call, thought_signature } => {
if in_text_block {
in_text_block = false;
block_index += 1;
}
saw_function_call = true;
let id = format!("call_{}", uuid_simple());
yield Ok(StreamDelta::ToolUseStart {
id: id.clone(),
name: function_call.name.clone(),
block_index,
thought_signature: thought_signature.clone(),
});
yield Ok(StreamDelta::ToolInputDelta {
id,
delta: serde_json::to_string(&function_call.args)
.unwrap_or_default(),
block_index,
});
block_index += 1;
}
_ => {}
}
}
}
}
}
if saw_function_call {
stop_reason = Some(StopReason::ToolUse);
}
if let Some(u) = usage {
yield Ok(StreamDelta::Usage(u));
}
yield Ok(StreamDelta::Done { stop_reason });
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_content_serialization() {
let content = ApiContent {
role: Some("user".to_string()),
parts: vec![ApiPart::Text {
text: "Hello!".to_string(),
thought_signature: None,
}],
};
let json = serde_json::to_string(&content).unwrap_or_default();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"text\":\"Hello!\""));
}
#[test]
fn test_api_part_text_serialization() {
let part = ApiPart::Text {
text: "Hello, world!".to_string(),
thought_signature: None,
};
let json = serde_json::to_string(&part).unwrap_or_default();
assert!(json.contains("\"text\":\"Hello, world!\""));
}
#[test]
fn test_api_part_function_call_serialization() {
let part = ApiPart::FunctionCall {
function_call: ApiFunctionCall {
name: "read_file".to_string(),
args: serde_json::json!({"path": "/test.txt"}),
},
thought_signature: None,
};
let json = serde_json::to_string(&part).unwrap_or_default();
assert!(json.contains("\"functionCall\""));
assert!(json.contains("\"name\":\"read_file\""));
assert!(json.contains("\"args\""));
}
#[test]
fn test_api_part_function_response_serialization() {
let part = ApiPart::FunctionResponse {
function_response: ApiFunctionResponse {
name: "read_file".to_string(),
response: serde_json::json!({"result": "file contents"}),
},
};
let json = serde_json::to_string(&part).unwrap_or_default();
assert!(json.contains("\"functionResponse\""));
assert!(json.contains("\"name\":\"read_file\""));
assert!(json.contains("\"response\""));
}
#[test]
fn test_api_tool_config_serialization() {
let config = ApiToolConfig {
function_declarations: vec![ApiFunctionDeclaration {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
}],
};
let json = serde_json::to_string(&config).unwrap_or_default();
assert!(json.contains("\"functionDeclarations\""));
assert!(json.contains("\"name\":\"test_tool\""));
assert!(json.contains("\"description\":\"A test tool\""));
}
#[test]
fn test_api_generation_config_serialization() {
let config = ApiGenerationConfig {
max_output_tokens: Some(1024),
thinking_config: None,
};
let json = serde_json::to_string(&config).unwrap_or_default();
assert!(json.contains("\"maxOutputTokens\":1024"));
assert!(!json.contains("thinkingConfig"));
}
#[test]
fn test_api_generation_config_with_thinking() {
let config = ApiGenerationConfig {
max_output_tokens: Some(65536),
thinking_config: Some(ApiThinkingConfig {
thinking_level: "HIGH",
}),
};
let json = serde_json::to_string(&config).unwrap_or_default();
assert!(json.contains("\"thinkingConfig\""));
assert!(json.contains("\"thinkingLevel\":\"HIGH\""));
}
#[test]
fn test_generate_content_request_serializes_cached_content() {
let contents = vec![ApiContent {
role: Some("user".to_string()),
parts: vec![ApiPart::Text {
text: "Hello".to_string(),
thought_signature: None,
}],
}];
let request = ApiGenerateContentRequest {
contents: &contents,
system_instruction: None,
tools: None,
generation_config: None,
cached_content: Some("cachedContents/example"),
};
let json = serde_json::to_string(&request).unwrap_or_default();
assert!(json.contains("\"cachedContent\":\"cachedContents/example\""));
}
#[test]
fn test_api_response_deserialization() {
let json = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [{"text": "Hello!"}]
},
"finishReason": "STOP"
}
],
"usageMetadata": {
"promptTokenCount": 100,
"candidatesTokenCount": 50
}
}"#;
let response: ApiGenerateContentResponse =
serde_json::from_str(json).unwrap_or_else(|e| panic!("parse failed: {e}"));
assert_eq!(response.candidates.len(), 1);
assert!(response.usage_metadata.is_some());
let usage = response.usage_metadata.unwrap_or(ApiUsageMetadata {
prompt: 0,
candidates: 0,
cached_content: 0,
});
assert_eq!(usage.prompt, 100);
assert_eq!(usage.candidates, 50);
}
#[test]
fn test_api_response_with_function_call_deserialization() {
let json = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"functionCall": {
"name": "read_file",
"args": {"path": "test.txt"}
}
}
]
},
"finishReason": "STOP"
}
]
}"#;
let response: ApiGenerateContentResponse =
serde_json::from_str(json).unwrap_or_else(|e| panic!("parse failed: {e}"));
let content = &response.candidates[0].content;
assert_eq!(content.parts.len(), 1);
match &content.parts[0] {
ApiPart::FunctionCall { function_call, .. } => {
assert_eq!(function_call.name, "read_file");
}
_ => panic!("Expected FunctionCall part"),
}
}
#[test]
fn test_api_finish_reason_deserialization() {
let stop: ApiFinishReason =
serde_json::from_str("\"STOP\"").unwrap_or_else(|e| panic!("parse failed: {e}"));
let max_tokens: ApiFinishReason =
serde_json::from_str("\"MAX_TOKENS\"").unwrap_or_else(|e| panic!("parse failed: {e}"));
let safety: ApiFinishReason =
serde_json::from_str("\"SAFETY\"").unwrap_or_else(|e| panic!("parse failed: {e}"));
assert!(matches!(stop, ApiFinishReason::Stop));
assert!(matches!(max_tokens, ApiFinishReason::MaxTokens));
assert!(matches!(safety, ApiFinishReason::Safety));
}
#[test]
fn test_api_finish_reason_unknown_variants_map_to_other() {
let malformed: ApiFinishReason = serde_json::from_str("\"MALFORMED_FUNCTION_CALL\"")
.unwrap_or_else(|e| panic!("parse failed: {e}"));
let blocklist: ApiFinishReason =
serde_json::from_str("\"BLOCKLIST\"").unwrap_or_else(|e| panic!("parse failed: {e}"));
let prohibited: ApiFinishReason = serde_json::from_str("\"PROHIBITED_CONTENT\"")
.unwrap_or_else(|e| panic!("parse failed: {e}"));
let spii: ApiFinishReason =
serde_json::from_str("\"SPII\"").unwrap_or_else(|e| panic!("parse failed: {e}"));
assert!(matches!(malformed, ApiFinishReason::Other));
assert!(matches!(blocklist, ApiFinishReason::Other));
assert!(matches!(prohibited, ApiFinishReason::Other));
assert!(matches!(spii, ApiFinishReason::Other));
}
#[test]
fn test_api_candidate_missing_content_defaults_to_empty() {
let json = r#"{
"finishReason": "SAFETY"
}"#;
let candidate: ApiCandidate =
serde_json::from_str(json).unwrap_or_else(|e| panic!("parse failed: {e}"));
assert!(candidate.content.parts.is_empty());
assert!(matches!(
candidate.finish_reason,
Some(ApiFinishReason::Safety)
));
}
#[test]
fn test_api_response_with_unknown_finish_reason_parses() {
let json = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [{"text": "I could not call that function."}]
},
"finishReason": "MALFORMED_FUNCTION_CALL"
}
],
"usageMetadata": {
"promptTokenCount": 50,
"candidatesTokenCount": 20
}
}"#;
let response: ApiGenerateContentResponse =
serde_json::from_str(json).unwrap_or_else(|e| panic!("parse failed: {e}"));
assert_eq!(response.candidates.len(), 1);
assert!(matches!(
response.candidates[0].finish_reason,
Some(ApiFinishReason::Other)
));
}
#[test]
fn test_build_api_contents_simple() {
let messages = vec![crate::llm::Message::user("Hello")];
let contents = build_api_contents(&messages);
assert_eq!(contents.len(), 1);
assert_eq!(contents[0].role, Some("user".to_string()));
assert_eq!(contents[0].parts.len(), 1);
}
#[test]
fn test_build_api_contents_assistant() {
let messages = vec![crate::llm::Message::assistant("Hi there!")];
let contents = build_api_contents(&messages);
assert_eq!(contents.len(), 1);
assert_eq!(contents[0].role, Some("model".to_string()));
}
#[test]
fn test_convert_tools_to_config() {
let tools = vec![crate::llm::Tool {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
input_schema: serde_json::json!({"type": "object"}),
}];
let api_tools = convert_tools_to_config(tools);
assert_eq!(api_tools.function_declarations.len(), 1);
assert_eq!(api_tools.function_declarations[0].name, "test_tool");
}
#[test]
fn test_build_content_blocks_text_only() {
let content = ApiContent {
role: Some("model".to_string()),
parts: vec![ApiPart::Text {
text: "Hello!".to_string(),
thought_signature: None,
}],
};
let blocks = build_content_blocks(&content);
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
}
#[test]
fn test_build_content_blocks_with_function_call() {
let content = ApiContent {
role: Some("model".to_string()),
parts: vec![ApiPart::FunctionCall {
function_call: ApiFunctionCall {
name: "read_file".to_string(),
args: serde_json::json!({"path": "test.txt"}),
},
thought_signature: None,
}],
};
let blocks = build_content_blocks(&content);
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::ToolUse { name, .. } if name == "read_file"));
}
#[test]
fn test_uuid_simple_generates_unique_ids() {
let mut ids = std::collections::HashSet::new();
for _ in 0..1000 {
let id = uuid_simple();
assert!(!id.is_empty());
assert!(ids.insert(id), "Duplicate ID generated");
}
assert_eq!(ids.len(), 1000);
}
#[test]
fn test_streaming_response_text_deserialization() {
let json = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [{"text": "Hello"}]
}
}
]
}"#;
let response: ApiGenerateContentResponse =
serde_json::from_str(json).unwrap_or_else(|e| panic!("parse failed: {e}"));
assert_eq!(response.candidates.len(), 1);
match &response.candidates[0].content.parts[0] {
ApiPart::Text { text, .. } => assert_eq!(text, "Hello"),
_ => panic!("Expected Text part"),
}
}
#[test]
fn test_streaming_response_with_usage_deserialization() {
let json = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [{"text": "Hello"}]
},
"finishReason": "STOP"
}
],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 5
}
}"#;
let response: ApiGenerateContentResponse =
serde_json::from_str(json).unwrap_or_else(|e| panic!("parse failed: {e}"));
let usage = response.usage_metadata.unwrap_or(ApiUsageMetadata {
prompt: 0,
candidates: 0,
cached_content: 0,
});
assert_eq!(usage.prompt, 10);
assert_eq!(usage.candidates, 5);
assert!(matches!(
response.candidates[0].finish_reason,
Some(ApiFinishReason::Stop)
));
}
#[test]
fn test_streaming_response_function_call_deserialization() {
let json = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [{
"functionCall": {
"name": "get_weather",
"args": {"location": "NYC"}
}
}]
}
}
]
}"#;
let response: ApiGenerateContentResponse =
serde_json::from_str(json).unwrap_or_else(|e| panic!("parse failed: {e}"));
match &response.candidates[0].content.parts[0] {
ApiPart::FunctionCall { function_call, .. } => {
assert_eq!(function_call.name, "get_weather");
assert_eq!(function_call.args["location"], "NYC");
}
_ => panic!("Expected FunctionCall part"),
}
}
#[test]
fn test_map_finish_reason_stop() {
assert_eq!(
map_finish_reason(&ApiFinishReason::Stop, false),
StopReason::EndTurn
);
}
#[test]
fn test_map_finish_reason_overrides_to_tool_use() {
assert_eq!(
map_finish_reason(&ApiFinishReason::Stop, true),
StopReason::ToolUse
);
}
#[test]
fn test_map_finish_reason_max_tokens() {
assert_eq!(
map_finish_reason(&ApiFinishReason::MaxTokens, false),
StopReason::MaxTokens
);
}
}