use crate::stream::StreamEvent;
use crate::tasks::generate::{ContentBlock, Message, ResponseFormat, ToolCall};
use crate::InferenceError;
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ApiRequest {
pub model: String,
pub messages: Vec<Value>,
pub system: Option<String>,
pub temperature: f64,
pub max_tokens: usize,
pub tools: Option<Vec<Value>>,
pub tool_choice: Option<String>,
pub parallel_tool_calls: Option<bool>,
pub stream: bool,
pub budget_tokens: usize,
pub cache_control: bool,
pub response_format: Option<ResponseFormat>,
}
#[derive(Debug, Clone)]
pub struct ApiResponse {
pub text: String,
pub tool_calls: Vec<ToolCall>,
pub usage: Option<crate::TokenUsage>,
}
pub trait ProtocolHandler: Send + Sync {
fn endpoint_path(&self) -> &str;
fn auth_headers(&self, api_key: &str) -> Vec<(String, String)>;
fn build_request_body(&self, req: &ApiRequest) -> Value;
fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError>;
fn parse_stream_event(&self, event_type: &str, data: &str) -> Vec<StreamEvent>;
fn build_messages(
&self,
messages: &[Message],
prompt: &str,
context: Option<&str>,
images: Option<&[ContentBlock]>,
) -> (Vec<Value>, Option<String>);
fn build_tools(&self, tools: &[Value]) -> Vec<Value>;
fn supports_streaming(&self) -> bool {
true
}
fn supports_thinking(&self) -> bool {
false
}
fn supports_video(&self) -> bool {
false
}
fn supports_audio(&self) -> bool {
false
}
fn protocol_name(&self) -> &'static str {
"remote"
}
}
pub struct OpenAiHandler;
impl ProtocolHandler for OpenAiHandler {
fn endpoint_path(&self) -> &str {
"/v1/chat/completions"
}
fn auth_headers(&self, api_key: &str) -> Vec<(String, String)> {
vec![
("Authorization".into(), format!("Bearer {}", api_key)),
("Content-Type".into(), "application/json".into()),
]
}
fn build_request_body(&self, req: &ApiRequest) -> Value {
let quirks = openai_quirks(&req.model);
let mut body = serde_json::json!({
"model": req.model,
"messages": req.messages,
});
if quirks.uses_max_completion_tokens {
body["max_completion_tokens"] = serde_json::json!(req.max_tokens);
} else {
body["max_tokens"] = serde_json::json!(req.max_tokens);
}
if req.temperature >= 0.0 && !quirks.rejects_temperature {
body["temperature"] = serde_json::json!(req.temperature);
}
if let Some(ref tools) = req.tools {
body["tools"] = serde_json::json!(tools);
body["tool_choice"] = serde_json::json!(req.tool_choice.as_deref().unwrap_or("auto"));
if let Some(parallel_tool_calls) = req.parallel_tool_calls {
body["parallel_tool_calls"] = serde_json::json!(parallel_tool_calls);
}
}
match &req.response_format {
Some(ResponseFormat::JsonSchema {
schema,
strict,
name,
}) => {
body["response_format"] = serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": name.as_deref().unwrap_or("response"),
"schema": schema,
"strict": strict,
},
});
}
Some(ResponseFormat::JsonObject) => {
body["response_format"] = serde_json::json!({ "type": "json_object" });
}
None => {}
}
if req.stream {
body["stream"] = serde_json::json!(true);
body["stream_options"] = serde_json::json!({ "include_usage": true });
}
body
}
fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError> {
let parsed: Value = serde_json::from_str(body)
.map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
let choice = parsed
.get("choices")
.and_then(|c| c.as_array())
.and_then(|a| a.first())
.ok_or_else(|| InferenceError::InferenceFailed("empty response".into()))?;
let message = choice
.get("message")
.ok_or_else(|| InferenceError::InferenceFailed("no message in choice".into()))?;
let text = message
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
let mut tool_calls = Vec::new();
if let Some(tcs) = message.get("tool_calls").and_then(|t| t.as_array()) {
for tc in tcs {
if let Some(func) = tc.get("function") {
let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
let name = func
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let args_str = func
.get("arguments")
.and_then(|a| a.as_str())
.unwrap_or("{}");
let arguments: HashMap<String, Value> =
serde_json::from_str(args_str).unwrap_or_default();
tool_calls.push(ToolCall {
id,
name,
arguments,
});
}
}
}
let usage = parsed.get("usage").and_then(|u| {
Some(crate::TokenUsage {
prompt_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
completion_tokens: u
.get("completion_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
total_tokens: u.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
context_window: 0, })
});
Ok(ApiResponse {
text,
tool_calls,
usage,
})
}
fn parse_stream_event(&self, _event_type: &str, data: &str) -> Vec<StreamEvent> {
crate::stream::parse_openai_sse_line(&format!("data: {}", data))
}
fn build_messages(
&self,
messages: &[Message],
prompt: &str,
context: Option<&str>,
images: Option<&[ContentBlock]>,
) -> (Vec<Value>, Option<String>) {
if !messages.is_empty() {
let mut result = Vec::new();
if let Some(ctx) = context {
result.push(serde_json::json!({"role": "system", "content": ctx}));
}
for msg in messages {
match msg {
Message::System { content } => {
result.push(serde_json::json!({"role": "system", "content": content}));
}
Message::User { content } => {
result.push(serde_json::json!({"role": "user", "content": content}));
}
Message::UserMultimodal { content } => {
let blocks: Vec<Value> = content
.iter()
.map(|block| match block {
ContentBlock::Text { text } => {
serde_json::json!({"type": "text", "text": text})
}
ContentBlock::ImageBase64 { data, media_type } => {
serde_json::json!({
"type": "image_url",
"image_url": {
"url": format!("data:{};base64,{}", media_type, data),
}
})
}
ContentBlock::ImageUrl { url, detail } => {
serde_json::json!({
"type": "image_url",
"image_url": {
"url": url,
"detail": detail,
}
})
}
ContentBlock::VideoPath { .. }
| ContentBlock::VideoUrl { .. }
| ContentBlock::VideoBase64 { .. }
| ContentBlock::AudioPath { .. }
| ContentBlock::AudioUrl { .. }
| ContentBlock::AudioBase64 { .. } => {
unreachable!(
"video/audio ContentBlock reached OpenAI \
build_messages — should have been rejected \
by RemoteBackend::execute_request"
)
}
})
.collect();
result.push(serde_json::json!({"role": "user", "content": blocks}));
}
Message::Assistant {
content,
tool_calls,
} => {
if tool_calls.is_empty() {
result
.push(serde_json::json!({"role": "assistant", "content": content}));
} else {
let tc: Vec<Value> = tool_calls.iter().enumerate().map(|(i, tc)| {
let id = tc.id.clone().unwrap_or_else(|| format!("call_{}", i));
serde_json::json!({
"id": id,
"type": "function",
"function": {
"name": tc.name,
"arguments": serde_json::to_string(&tc.arguments).unwrap_or_default(),
}
})
}).collect();
let mut msg =
serde_json::json!({"role": "assistant", "tool_calls": tc});
if !content.is_empty() {
msg["content"] = serde_json::json!(content);
}
result.push(msg);
}
}
Message::ToolResult {
tool_use_id,
content,
} => {
result.push(serde_json::json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": content,
}));
}
Message::ProviderOutputItems { .. } => continue,
}
}
(result, None) } else {
let mut msgs = Vec::new();
if let Some(ctx) = context {
msgs.push(serde_json::json!({"role": "system", "content": ctx}));
}
if let Some(images) = images.filter(|images| !images.is_empty()) {
let mut blocks = vec![serde_json::json!({"type": "text", "text": prompt})];
for image in images {
let block = match image {
ContentBlock::Text { text } => {
serde_json::json!({"type": "text", "text": text})
}
ContentBlock::ImageBase64 { data, media_type } => {
serde_json::json!({
"type": "image_url",
"image_url": {
"url": format!("data:{};base64,{}", media_type, data),
}
})
}
ContentBlock::ImageUrl { url, detail } => {
serde_json::json!({
"type": "image_url",
"image_url": {
"url": url,
"detail": detail,
}
})
}
ContentBlock::VideoPath { .. }
| ContentBlock::VideoUrl { .. }
| ContentBlock::VideoBase64 { .. }
| ContentBlock::AudioPath { .. }
| ContentBlock::AudioUrl { .. }
| ContentBlock::AudioBase64 { .. } => {
unreachable!(
"video/audio ContentBlock reached OpenAI build_messages \
— should have been rejected by RemoteBackend::execute_request"
)
}
};
blocks.push(block);
}
msgs.push(serde_json::json!({"role": "user", "content": blocks}));
} else {
msgs.push(serde_json::json!({"role": "user", "content": prompt}));
}
(msgs, None)
}
}
fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
tools
.iter()
.map(|t| {
if t.get("type").is_some() {
t.clone()
} else {
serde_json::json!({"type": "function", "function": t})
}
})
.collect()
}
fn protocol_name(&self) -> &'static str {
"openai"
}
}
struct OpenAiQuirks {
uses_max_completion_tokens: bool,
rejects_temperature: bool,
}
fn openai_quirks(model: &str) -> OpenAiQuirks {
let m = model.to_lowercase();
let is_o_series = m.starts_with("o1") || m.starts_with("o3") || m.starts_with("o4");
OpenAiQuirks {
uses_max_completion_tokens: is_o_series
|| m.starts_with("gpt-5")
|| m.starts_with("gpt-4.1"),
rejects_temperature: is_o_series,
}
}
pub struct AnthropicHandler;
impl ProtocolHandler for AnthropicHandler {
fn endpoint_path(&self) -> &str {
"/v1/messages"
}
fn auth_headers(&self, api_key: &str) -> Vec<(String, String)> {
vec![
("x-api-key".into(), api_key.to_string()),
("anthropic-version".into(), "2023-06-01".into()),
("anthropic-beta".into(), "prompt-caching-2024-07-31".into()),
("Content-Type".into(), "application/json".into()),
]
}
fn build_request_body(&self, req: &ApiRequest) -> Value {
let mut body = serde_json::json!({
"model": req.model,
"max_tokens": req.max_tokens,
"messages": req.messages,
});
if req.budget_tokens > 0 {
body["thinking"] = serde_json::json!({
"type": "enabled",
"budget_tokens": req.budget_tokens,
});
} else if req.temperature >= 0.0 {
body["temperature"] = serde_json::json!(req.temperature);
}
if let Some(ref system) = req.system {
if req.cache_control {
body["system"] = serde_json::json!([{
"type": "text",
"text": system,
"cache_control": {"type": "ephemeral"}
}]);
} else {
body["system"] = Value::String(system.clone());
}
}
if let Some(ref tools) = req.tools {
body["tools"] = Value::Array(tools.clone());
if req.cache_control && !tools.is_empty() {
if let Some(arr) = body["tools"].as_array_mut() {
if let Some(last) = arr.last_mut() {
if let Some(obj) = last.as_object_mut() {
obj.insert(
"cache_control".to_string(),
serde_json::json!({"type": "ephemeral"}),
);
}
}
}
}
body["tool_choice"] = serde_json::json!({"type": "auto"});
}
if req.response_format.is_some() {
tracing::warn!(
"response_format set on Anthropic request — Anthropic has no native \
JSON-schema field; the request will run unconstrained. Use tool_use \
with tool_choice=required to enforce a schema on Claude.",
);
}
if req.stream {
body["stream"] = serde_json::json!(true);
}
body
}
fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError> {
let parsed: Value = serde_json::from_str(body)
.map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
let mut text = String::new();
let mut tool_calls = Vec::new();
let mut thinking_text = String::new();
if let Some(content) = parsed.get("content").and_then(|c| c.as_array()) {
for block in content {
match block.get("type").and_then(|t| t.as_str()) {
Some("text") => {
if let Some(t) = block.get("text").and_then(|t| t.as_str()) {
text.push_str(t);
}
}
Some("thinking") => {
if let Some(t) = block.get("thinking").and_then(|t| t.as_str()) {
tracing::debug!(thinking_len = t.len(), "extended thinking block");
thinking_text.push_str(t);
}
}
Some("tool_use") => {
if let (Some(name), Some(input)) = (
block.get("name").and_then(|n| n.as_str()),
block.get("input"),
) {
let id = block
.get("id")
.and_then(|i| i.as_str())
.map(|s| s.to_string());
let arguments: HashMap<String, Value> =
serde_json::from_value(input.clone()).unwrap_or_default();
tool_calls.push(ToolCall {
id,
name: name.to_string(),
arguments,
});
}
}
_ => {}
}
}
}
if text.is_empty() && !thinking_text.is_empty() && tool_calls.is_empty() {
tracing::warn!(
thinking_len = thinking_text.len(),
"response had only thinking blocks, using thinking content as text fallback"
);
text = thinking_text;
}
let usage = parsed.get("usage").and_then(|u| {
Some(crate::TokenUsage {
prompt_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
completion_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
total_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0)
+ u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
context_window: 0, })
});
Ok(ApiResponse {
text,
tool_calls,
usage,
})
}
fn parse_stream_event(&self, event_type: &str, data: &str) -> Vec<StreamEvent> {
crate::stream::parse_anthropic_sse_line(event_type, data)
}
fn build_messages(
&self,
messages: &[Message],
prompt: &str,
context: Option<&str>,
images: Option<&[ContentBlock]>,
) -> (Vec<Value>, Option<String>) {
let mut system = context.map(|c| c.to_string());
if !messages.is_empty() {
for msg in messages {
if let Message::System { content } = msg {
system = Some(match system {
Some(existing) if !existing.is_empty() => {
format!("{existing}\n\n{content}")
}
_ => content.clone(),
});
}
}
}
if !messages.is_empty() {
let mut result = Vec::new();
for msg in messages {
match msg {
Message::System { .. } => continue,
Message::User { content } => {
result.push(serde_json::json!({"role": "user", "content": content}));
}
Message::UserMultimodal { content } => {
let blocks: Vec<Value> = content
.iter()
.map(|block| {
match block {
ContentBlock::Text { text } => {
serde_json::json!({"type": "text", "text": text})
}
ContentBlock::ImageBase64 { data, media_type } => {
serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
}
})
}
ContentBlock::ImageUrl { url, .. } => {
serde_json::json!({
"type": "image",
"source": {
"type": "url",
"url": url,
}
})
}
ContentBlock::VideoPath { .. }
| ContentBlock::VideoUrl { .. }
| ContentBlock::VideoBase64 { .. }
| ContentBlock::AudioPath { .. }
| ContentBlock::AudioUrl { .. }
| ContentBlock::AudioBase64 { .. } => {
unreachable!(
"video/audio ContentBlock reached Anthropic \
build_messages — should have been rejected \
by RemoteBackend::execute_request"
)
}
}
})
.collect();
result.push(serde_json::json!({"role": "user", "content": blocks}));
}
Message::Assistant {
content,
tool_calls,
} => {
let mut blocks: Vec<Value> = Vec::new();
if !content.is_empty() {
blocks.push(serde_json::json!({"type": "text", "text": content}));
}
for (i, tc) in tool_calls.iter().enumerate() {
let id = tc.id.clone().unwrap_or_else(|| format!("toolu_{}", i));
blocks.push(serde_json::json!({
"type": "tool_use",
"id": id,
"name": tc.name,
"input": tc.arguments,
}));
}
if blocks.is_empty() {
blocks.push(serde_json::json!({"type": "text", "text": ""}));
}
result.push(serde_json::json!({"role": "assistant", "content": blocks}));
}
Message::ToolResult {
tool_use_id,
content,
} => {
result.push(serde_json::json!({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
}]
}));
}
Message::ProviderOutputItems { .. } => continue,
}
}
(result, system)
} else {
let content = if let Some(images) = images.filter(|images| !images.is_empty()) {
let mut blocks = vec![serde_json::json!({"type": "text", "text": prompt})];
for image in images {
let block = match image {
ContentBlock::Text { text } => {
serde_json::json!({"type": "text", "text": text})
}
ContentBlock::ImageBase64 { data, media_type } => {
serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
}
})
}
ContentBlock::ImageUrl { url, .. } => {
serde_json::json!({
"type": "image",
"source": {
"type": "url",
"url": url,
}
})
}
ContentBlock::VideoPath { .. }
| ContentBlock::VideoUrl { .. }
| ContentBlock::VideoBase64 { .. }
| ContentBlock::AudioPath { .. }
| ContentBlock::AudioUrl { .. }
| ContentBlock::AudioBase64 { .. } => {
unreachable!(
"video/audio ContentBlock reached Anthropic build_messages \
— should have been rejected by RemoteBackend::execute_request"
)
}
};
blocks.push(block);
}
Value::Array(blocks)
} else {
Value::String(prompt.to_string())
};
let msgs = vec![serde_json::json!({"role": "user", "content": content})];
(msgs, system)
}
}
fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
tools.iter().filter_map(|t| {
let func = t.get("function").unwrap_or(t);
Some(serde_json::json!({
"name": func.get("name")?,
"description": func.get("description").and_then(|d| d.as_str()).unwrap_or(""),
"input_schema": func.get("parameters").cloned().unwrap_or(serde_json::json!({"type": "object"})),
}))
}).collect()
}
fn supports_thinking(&self) -> bool {
true
}
fn protocol_name(&self) -> &'static str {
"anthropic"
}
}
pub struct GoogleHandler;
impl ProtocolHandler for GoogleHandler {
fn endpoint_path(&self) -> &str {
""
}
fn auth_headers(&self, _api_key: &str) -> Vec<(String, String)> {
vec![("Content-Type".into(), "application/json".into())]
}
fn build_request_body(&self, req: &ApiRequest) -> Value {
let mut body = serde_json::json!({
"contents": req.messages,
});
if let Some(ref system) = req.system {
body["systemInstruction"] = serde_json::json!({
"parts": [{"text": system}],
});
}
let mut generation_config = serde_json::json!({
"maxOutputTokens": req.max_tokens,
});
if req.temperature >= 0.0 {
generation_config["temperature"] = serde_json::json!(req.temperature);
}
match &req.response_format {
Some(ResponseFormat::JsonSchema { schema, .. }) => {
generation_config["responseMimeType"] = serde_json::json!("application/json");
generation_config["responseSchema"] = schema.clone();
}
Some(ResponseFormat::JsonObject) => {
generation_config["responseMimeType"] = serde_json::json!("application/json");
}
None => {}
}
body["generationConfig"] = generation_config;
if let Some(ref tools) = req.tools {
body["tools"] = serde_json::json!([{
"functionDeclarations": tools,
}]);
body["toolConfig"] = serde_json::json!({
"functionCallingConfig": {
"mode": "AUTO",
}
});
}
body
}
fn parse_response(&self, body: &str) -> Result<ApiResponse, InferenceError> {
let parsed: Value = serde_json::from_str(body)
.map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
let parts = parsed
.get("candidates")
.and_then(|c| c.as_array())
.and_then(|a| a.first())
.and_then(|c| c.get("content"))
.and_then(|c| c.get("parts"))
.and_then(|p| p.as_array())
.cloned()
.unwrap_or_default();
let mut text_chunks = Vec::new();
let mut tool_calls = Vec::new();
for part in parts {
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
text_chunks.push(text.to_string());
}
if let Some(function_call) = part
.get("functionCall")
.or_else(|| part.get("function_call"))
{
let name = function_call
.get("name")
.and_then(|n| n.as_str())
.unwrap_or_default()
.to_string();
let arguments = function_call
.get("args")
.or_else(|| function_call.get("arguments"))
.and_then(|args| args.as_object())
.map(|map| {
map.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<HashMap<_, _>>()
})
.unwrap_or_default();
if !name.is_empty() {
tool_calls.push(ToolCall {
id: None,
name,
arguments,
});
}
}
}
let usage = parsed.get("usageMetadata").map(|usage| crate::TokenUsage {
prompt_tokens: usage
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
completion_tokens: usage
.get("candidatesTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
total_tokens: usage
.get("totalTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
context_window: 0,
});
Ok(ApiResponse {
text: text_chunks.join("\n"),
tool_calls,
usage,
})
}
fn parse_stream_event(&self, _event_type: &str, _data: &str) -> Vec<StreamEvent> {
Vec::new() }
fn build_messages(
&self,
messages: &[Message],
prompt: &str,
context: Option<&str>,
images: Option<&[ContentBlock]>,
) -> (Vec<Value>, Option<String>) {
let mut system_instruction: Option<String> = context.map(|c| c.to_string());
if !messages.is_empty() {
for msg in messages {
if let Message::System { content } = msg {
system_instruction = Some(match system_instruction {
Some(existing) if !existing.is_empty() => {
format!("{existing}\n\n{content}")
}
_ => content.clone(),
});
}
}
}
if !messages.is_empty() {
let contents = messages
.iter()
.filter(|msg| {
!matches!(
msg,
Message::System { .. } | Message::ProviderOutputItems { .. }
)
})
.map(|msg| match msg {
Message::System { .. } => unreachable!("System filtered above"),
Message::ProviderOutputItems { .. } => {
unreachable!("ProviderOutputItems filtered above")
}
Message::User { content } => serde_json::json!({
"role": "user",
"parts": [{"text": content}],
}),
Message::UserMultimodal { content } => serde_json::json!({
"role": "user",
"parts": content.iter().map(google_part_from_block).collect::<Vec<_>>(),
}),
Message::Assistant {
content,
tool_calls,
} => {
let mut parts = Vec::new();
if !content.is_empty() {
parts.push(serde_json::json!({"text": content}));
}
for tool_call in tool_calls {
parts.push(serde_json::json!({
"functionCall": {
"name": tool_call.name,
"args": tool_call.arguments,
}
}));
}
serde_json::json!({
"role": "model",
"parts": parts,
})
}
Message::ToolResult {
tool_use_id,
content,
} => serde_json::json!({
"role": "tool",
"parts": [{
"functionResponse": {
"name": tool_use_id,
"response": {"content": content},
}
}],
}),
})
.collect();
(contents, system_instruction)
} else {
let mut parts = vec![serde_json::json!({"text": prompt})];
if let Some(images) = images.filter(|images| !images.is_empty()) {
parts.extend(images.iter().map(google_part_from_block));
}
(
vec![serde_json::json!({
"role": "user",
"parts": parts,
})],
system_instruction,
)
}
}
fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
tools
.iter()
.filter_map(|t| {
let func = t.get("function").unwrap_or(t);
Some(serde_json::json!({
"name": func.get("name")?,
"description": func.get("description").and_then(|d| d.as_str()).unwrap_or(""),
"parameters": func.get("parameters").cloned().unwrap_or(serde_json::json!({"type": "object"})),
}))
})
.collect()
}
fn supports_streaming(&self) -> bool {
false
}
fn supports_video(&self) -> bool {
true
}
fn supports_audio(&self) -> bool {
true
}
fn protocol_name(&self) -> &'static str {
"google-gemini"
}
}
fn google_part_from_block(block: &ContentBlock) -> Value {
match block {
ContentBlock::Text { text } => serde_json::json!({"text": text}),
ContentBlock::ImageBase64 { data, media_type } => serde_json::json!({
"inlineData": {
"mimeType": media_type,
"data": data,
}
}),
ContentBlock::ImageUrl { url, .. } => serde_json::json!({
"fileData": {
"mimeType": infer_mime_type_from_url(url),
"fileUri": url,
}
}),
ContentBlock::VideoPath { path, .. } => serde_json::json!({
"fileData": {
"mimeType": "video/mp4",
"fileUri": format!("file://{path}"),
}
}),
ContentBlock::VideoUrl { url, .. } => serde_json::json!({
"fileData": {
"mimeType": "video/mp4",
"fileUri": url,
}
}),
ContentBlock::VideoBase64 {
data, media_type, ..
} => serde_json::json!({
"inlineData": {
"mimeType": media_type,
"data": data,
}
}),
ContentBlock::AudioPath { path, .. } => serde_json::json!({
"fileData": {
"mimeType": "audio/wav",
"fileUri": format!("file://{path}"),
}
}),
ContentBlock::AudioUrl { url, .. } => serde_json::json!({
"fileData": {
"mimeType": "audio/wav",
"fileUri": url,
}
}),
ContentBlock::AudioBase64 {
data, media_type, ..
} => serde_json::json!({
"inlineData": {
"mimeType": media_type,
"data": data,
}
}),
}
}
fn infer_mime_type_from_url(url: &str) -> &'static str {
let lower = url.to_ascii_lowercase();
if lower.ends_with(".png") {
"image/png"
} else if lower.ends_with(".webp") {
"image/webp"
} else if lower.ends_with(".heic") {
"image/heic"
} else if lower.ends_with(".heif") {
"image/heif"
} else {
"image/jpeg"
}
}
pub fn handler_for(protocol: crate::schema::ApiProtocol) -> Box<dyn ProtocolHandler> {
match protocol {
crate::schema::ApiProtocol::OpenAiCompat | crate::schema::ApiProtocol::OpenAiResponses => {
Box::new(OpenAiHandler)
}
crate::schema::ApiProtocol::Anthropic => Box::new(AnthropicHandler),
crate::schema::ApiProtocol::Google => Box::new(GoogleHandler),
crate::schema::ApiProtocol::AzureOpenAi => Box::new(AzureOpenAiHandler),
}
}
pub struct AzureOpenAiHandler;
impl ProtocolHandler for AzureOpenAiHandler {
fn endpoint_path(&self) -> &str {
"/openai/deployments"
}
fn auth_headers(&self, api_key: &str) -> Vec<(String, String)> {
vec![
("api-key".into(), api_key.to_string()),
("Content-Type".into(), "application/json".into()),
]
}
fn build_request_body(&self, req: &ApiRequest) -> Value {
OpenAiHandler.build_request_body(req)
}
fn parse_response(&self, body: &str) -> Result<ApiResponse, crate::InferenceError> {
OpenAiHandler.parse_response(body)
}
fn parse_stream_event(&self, event_type: &str, data: &str) -> Vec<crate::stream::StreamEvent> {
OpenAiHandler.parse_stream_event(event_type, data)
}
fn build_messages(
&self,
messages: &[Message],
prompt: &str,
context: Option<&str>,
images: Option<&[ContentBlock]>,
) -> (Vec<Value>, Option<String>) {
OpenAiHandler.build_messages(messages, prompt, context, images)
}
fn build_tools(&self, tools: &[Value]) -> Vec<Value> {
OpenAiHandler.build_tools(tools)
}
fn protocol_name(&self) -> &'static str {
"azure-openai"
}
}
pub fn google_url(endpoint: &str, model: &str, api_key: &str) -> String {
let base = endpoint.trim_end_matches('/');
format!(
"{}/v1beta/models/{}:generateContent?key={}",
base, model, api_key
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn openai_single_turn_messages() {
let handler = OpenAiHandler;
let (msgs, system) = handler.build_messages(&[], "Hello", Some("Be helpful"), None);
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0]["role"], "system");
assert_eq!(msgs[1]["content"], "Hello");
assert!(system.is_none()); }
#[test]
fn openai_multi_turn_messages() {
let handler = OpenAiHandler;
let messages = vec![
Message::User {
content: "Hi".into(),
},
Message::Assistant {
content: "Hello!".into(),
tool_calls: vec![],
},
Message::User {
content: "Search for X".into(),
},
];
let (msgs, _) = handler.build_messages(&messages, "", None, None);
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[2]["content"], "Search for X");
}
#[test]
fn openai_tool_call_messages() {
let handler = OpenAiHandler;
let tc = ToolCall {
id: None,
name: "search".into(),
arguments: [("q".into(), Value::String("rust".into()))].into(),
};
let messages = vec![
Message::User {
content: "Search".into(),
},
Message::Assistant {
content: String::new(),
tool_calls: vec![tc],
},
Message::ToolResult {
tool_use_id: "call_0".into(),
content: "found it".into(),
},
];
let (msgs, _) = handler.build_messages(&messages, "", None, None);
assert_eq!(msgs.len(), 3);
assert!(msgs[1].get("tool_calls").is_some());
assert_eq!(msgs[2]["role"], "tool");
}
#[test]
fn anthropic_system_separate() {
let handler = AnthropicHandler;
let (msgs, system) = handler.build_messages(&[], "Hello", Some("Be helpful"), None);
assert_eq!(msgs.len(), 1); assert_eq!(system, Some("Be helpful".into()));
}
#[test]
fn anthropic_tool_use_format() {
let handler = AnthropicHandler;
let tc = ToolCall {
id: None,
name: "search".into(),
arguments: [("q".into(), Value::String("test".into()))].into(),
};
let messages = vec![
Message::User {
content: "Search".into(),
},
Message::Assistant {
content: String::new(),
tool_calls: vec![tc],
},
Message::ToolResult {
tool_use_id: "toolu_0".into(),
content: "result".into(),
},
];
let (msgs, _) = handler.build_messages(&messages, "", None, None);
assert_eq!(msgs.len(), 3);
let assistant_content = msgs[1].get("content").unwrap().as_array().unwrap();
assert_eq!(assistant_content[0]["type"], "tool_use");
let user_content = msgs[2].get("content").unwrap().as_array().unwrap();
assert_eq!(user_content[0]["type"], "tool_result");
}
#[test]
fn anthropic_thinking_in_request() {
let handler = AnthropicHandler;
let req = ApiRequest {
model: "claude".into(),
messages: vec![serde_json::json!({"role": "user", "content": "plan"})],
system: None,
temperature: 0.7,
max_tokens: 4096,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
stream: false,
budget_tokens: 8000,
cache_control: false,
response_format: None,
};
let body = handler.build_request_body(&req);
assert!(body.get("thinking").is_some());
assert_eq!(body["thinking"]["budget_tokens"], 8000);
assert!(body.get("temperature").is_none()); }
fn empty_request(model: &str) -> ApiRequest {
ApiRequest {
model: model.into(),
messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
system: None,
temperature: 0.7,
max_tokens: 256,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
stream: false,
budget_tokens: 0,
cache_control: false,
response_format: None,
}
}
#[test]
fn openai_streaming_request_carries_stream_options_include_usage() {
let handler = OpenAiHandler;
let mut req = empty_request("gpt-5");
req.stream = true;
let body = handler.build_request_body(&req);
assert_eq!(body["stream"], serde_json::json!(true));
assert_eq!(
body["stream_options"]["include_usage"],
serde_json::json!(true),
"streaming bodies must include `stream_options.include_usage` so usage flows back"
);
}
#[test]
fn openai_non_streaming_request_omits_stream_options() {
let handler = OpenAiHandler;
let req = empty_request("gpt-5");
let body = handler.build_request_body(&req);
assert!(body.get("stream").is_none());
assert!(body.get("stream_options").is_none());
}
#[test]
fn anthropic_streaming_request_marks_stream_true() {
let handler = AnthropicHandler;
let mut req = empty_request("claude-opus-4-7");
req.stream = true;
let body = handler.build_request_body(&req);
assert_eq!(body["stream"], serde_json::json!(true));
assert!(
body.get("stream_options").is_none(),
"Anthropic has no stream_options field; usage flows via SSE frames"
);
}
#[test]
fn openai_emits_strict_json_schema_response_format() {
let mut req = empty_request("gpt-5");
req.response_format = Some(ResponseFormat::JsonSchema {
schema: serde_json::json!({
"type": "object",
"properties": {"answer": {"type": "string"}},
"required": ["answer"]
}),
strict: true,
name: Some("answer_schema".into()),
});
let body = OpenAiHandler.build_request_body(&req);
let rf = body.get("response_format").expect("response_format set");
assert_eq!(rf["type"], "json_schema");
assert_eq!(rf["json_schema"]["name"], "answer_schema");
assert_eq!(rf["json_schema"]["strict"], true);
assert_eq!(rf["json_schema"]["schema"]["required"][0], "answer");
}
#[test]
fn openai_emits_json_object_when_no_schema() {
let mut req = empty_request("gpt-4o");
req.response_format = Some(ResponseFormat::JsonObject);
let body = OpenAiHandler.build_request_body(&req);
assert_eq!(body["response_format"]["type"], "json_object");
assert!(body["response_format"].get("json_schema").is_none());
}
#[test]
fn openai_omits_response_format_when_none() {
let req = empty_request("gpt-4o");
let body = OpenAiHandler.build_request_body(&req);
assert!(body.get("response_format").is_none());
}
#[test]
fn google_emits_response_mime_and_schema() {
let mut req = empty_request("gemini-2.5-pro");
req.response_format = Some(ResponseFormat::JsonSchema {
schema: serde_json::json!({"type": "object"}),
strict: false,
name: None,
});
let body = GoogleHandler.build_request_body(&req);
let cfg = body.get("generationConfig").expect("generationConfig");
assert_eq!(cfg["responseMimeType"], "application/json");
assert_eq!(cfg["responseSchema"]["type"], "object");
}
#[test]
fn google_json_object_skips_schema() {
let mut req = empty_request("gemini-2.5-pro");
req.response_format = Some(ResponseFormat::JsonObject);
let body = GoogleHandler.build_request_body(&req);
let cfg = body.get("generationConfig").expect("generationConfig");
assert_eq!(cfg["responseMimeType"], "application/json");
assert!(cfg.get("responseSchema").is_none());
}
#[test]
fn anthropic_does_not_emit_response_format_field() {
let mut req = empty_request("claude-opus-4-7");
req.response_format = Some(ResponseFormat::JsonSchema {
schema: serde_json::json!({"type": "object"}),
strict: true,
name: None,
});
let body = AnthropicHandler.build_request_body(&req);
assert!(body.get("response_format").is_none());
assert!(body.get("responseSchema").is_none());
}
#[test]
fn openai_tools_wrapped() {
let handler = OpenAiHandler;
let tools = vec![serde_json::json!({"name": "search", "parameters": {}})];
let built = handler.build_tools(&tools);
assert_eq!(built[0]["type"], "function");
assert!(built[0].get("function").is_some());
}
#[test]
fn openai_request_preserves_required_tool_choice_and_parallel_tool_calls() {
let handler = OpenAiHandler;
let req = ApiRequest {
model: "gpt-5.4-mini".into(),
messages: vec![serde_json::json!({"role": "user", "content": "extract"})],
system: None,
temperature: 0.0,
max_tokens: 1024,
tools: Some(vec![serde_json::json!({
"type": "function",
"function": {
"name": "extract_action_items",
"parameters": {"type": "object", "additionalProperties": false}
}
})]),
tool_choice: Some("required".into()),
parallel_tool_calls: Some(false),
stream: false,
budget_tokens: 0,
cache_control: false,
response_format: None,
};
let body = handler.build_request_body(&req);
assert_eq!(body["tool_choice"], "required");
assert_eq!(body["parallel_tool_calls"], false);
assert_eq!(body["tools"][0]["function"]["name"], "extract_action_items");
}
#[test]
fn anthropic_tools_format() {
let handler = AnthropicHandler;
let tools = vec![
serde_json::json!({"function": {"name": "search", "description": "Search", "parameters": {"type": "object"}}}),
];
let built = handler.build_tools(&tools);
assert_eq!(built[0]["name"], "search");
assert!(built[0].get("input_schema").is_some());
}
#[test]
fn google_no_streaming() {
let handler = GoogleHandler;
assert!(!handler.supports_streaming());
}
#[test]
fn anthropic_supports_thinking() {
let handler = AnthropicHandler;
assert!(handler.supports_thinking());
}
#[test]
fn handler_factory() {
use crate::schema::ApiProtocol;
let h = handler_for(ApiProtocol::Anthropic);
assert!(h.supports_thinking());
let h = handler_for(ApiProtocol::OpenAiCompat);
assert!(!h.supports_thinking());
}
#[test]
fn openai_parse_text_response() {
let handler = OpenAiHandler;
let body = r#"{"choices":[{"message":{"content":"Hello world"}}]}"#;
let resp = handler.parse_response(body).unwrap();
assert_eq!(resp.text, "Hello world");
assert!(resp.tool_calls.is_empty());
}
#[test]
fn openai_parse_usage() {
let handler = OpenAiHandler;
let body = r#"{"choices":[{"message":{"content":"Hi"}}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}"#;
let resp = handler.parse_response(body).unwrap();
let usage = resp.usage.unwrap();
assert_eq!(usage.prompt_tokens, 10);
assert_eq!(usage.completion_tokens, 5);
assert_eq!(usage.total_tokens, 15);
}
#[test]
fn openai_parse_multiple_tool_calls() {
let handler = OpenAiHandler;
let body = r#"{"choices":[{"message":{"content":"","tool_calls":[{"function":{"name":"read_file","arguments":"{\"path\":\"a.rs\"}"}},{"function":{"name":"read_file","arguments":"{\"path\":\"b.rs\"}"}}]}}]}"#;
let resp = handler.parse_response(body).unwrap();
assert_eq!(resp.tool_calls.len(), 2);
assert_eq!(resp.tool_calls[0].name, "read_file");
assert_eq!(resp.tool_calls[1].name, "read_file");
}
#[test]
fn anthropic_parse_tool_response() {
let handler = AnthropicHandler;
let body = r#"{"content":[{"type":"text","text":"Let me search"},{"type":"tool_use","name":"search","id":"t1","input":{"q":"rust"}}]}"#;
let resp = handler.parse_response(body).unwrap();
assert_eq!(resp.text, "Let me search");
assert_eq!(resp.tool_calls.len(), 1);
assert_eq!(resp.tool_calls[0].name, "search");
}
#[test]
fn anthropic_parse_usage() {
let handler = AnthropicHandler;
let body = r#"{"content":[{"type":"text","text":"Hi"}],"usage":{"input_tokens":12,"output_tokens":3}}"#;
let resp = handler.parse_response(body).unwrap();
let usage = resp.usage.unwrap();
assert_eq!(usage.prompt_tokens, 12);
assert_eq!(usage.completion_tokens, 3);
assert_eq!(usage.total_tokens, 15);
}
#[test]
fn anthropic_parse_multiple_tool_calls() {
let handler = AnthropicHandler;
let body = r#"{"content":[{"type":"text","text":"I'll read both files"},{"type":"tool_use","name":"read","id":"t1","input":{"path":"a.rs"}},{"type":"tool_use","name":"read","id":"t2","input":{"path":"b.rs"}}]}"#;
let resp = handler.parse_response(body).unwrap();
assert_eq!(resp.text, "I'll read both files");
assert_eq!(resp.tool_calls.len(), 2);
assert_eq!(resp.tool_calls[0].name, "read");
assert_eq!(resp.tool_calls[1].name, "read");
}
#[test]
fn anthropic_cache_control_system_prompt() {
let handler = AnthropicHandler;
let req = ApiRequest {
model: "claude".into(),
messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
system: Some("You are helpful.".into()),
temperature: 0.7,
max_tokens: 1024,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
stream: false,
budget_tokens: 0,
cache_control: true,
response_format: None,
};
let body = handler.build_request_body(&req);
let system = body.get("system").unwrap();
assert!(system.is_array());
let blocks = system.as_array().unwrap();
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0]["type"], "text");
assert_eq!(blocks[0]["text"], "You are helpful.");
assert_eq!(blocks[0]["cache_control"]["type"], "ephemeral");
}
#[test]
fn anthropic_cache_control_disabled() {
let handler = AnthropicHandler;
let req = ApiRequest {
model: "claude".into(),
messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
system: Some("You are helpful.".into()),
temperature: 0.7,
max_tokens: 1024,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
stream: false,
budget_tokens: 0,
cache_control: false,
response_format: None,
};
let body = handler.build_request_body(&req);
assert!(body.get("system").unwrap().is_string());
}
#[test]
fn anthropic_cache_control_tools() {
let handler = AnthropicHandler;
let tools = vec![
serde_json::json!({"name": "search", "description": "Search", "input_schema": {"type": "object"}}),
serde_json::json!({"name": "read", "description": "Read file", "input_schema": {"type": "object"}}),
];
let req = ApiRequest {
model: "claude".into(),
messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
system: None,
temperature: 0.7,
max_tokens: 1024,
tools: Some(tools),
tool_choice: None,
parallel_tool_calls: None,
stream: false,
budget_tokens: 0,
cache_control: true,
response_format: None,
};
let body = handler.build_request_body(&req);
let tools_arr = body["tools"].as_array().unwrap();
assert!(tools_arr[0].get("cache_control").is_none());
assert_eq!(tools_arr[1]["cache_control"]["type"], "ephemeral");
}
#[test]
fn anthropic_beta_header_included() {
let handler = AnthropicHandler;
let headers = handler.auth_headers("test-key");
let beta = headers.iter().find(|(k, _)| k == "anthropic-beta");
assert!(beta.is_some());
assert_eq!(beta.unwrap().1, "prompt-caching-2024-07-31");
}
#[test]
fn google_parse_response() {
let handler = GoogleHandler;
let body = r#"{"candidates":[{"content":{"parts":[{"text":"Hello from Gemini"}]}}]}"#;
let resp = handler.parse_response(body).unwrap();
assert_eq!(resp.text, "Hello from Gemini");
}
#[test]
fn google_tools_format() {
let handler = GoogleHandler;
let tools = vec![serde_json::json!({
"function": {
"name": "search",
"description": "Search docs",
"parameters": {"type": "object"}
}
})];
let built = handler.build_tools(&tools);
assert_eq!(built[0]["name"], "search");
assert!(built[0].get("parameters").is_some());
}
#[test]
fn google_builds_multimodal_messages() {
let handler = GoogleHandler;
let messages = vec![Message::UserMultimodal {
content: vec![
ContentBlock::Text {
text: "Describe this image.".to_string(),
},
ContentBlock::ImageUrl {
url: "https://example.com/cat.jpg".to_string(),
detail: "auto".to_string(),
},
],
}];
let (msgs, system) = handler.build_messages(&messages, "", Some("Be concise"), None);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0]["role"], "user");
let parts = msgs[0]["parts"].as_array().unwrap();
assert_eq!(parts[0]["text"], "Describe this image.");
assert!(parts[1].get("fileData").is_some());
assert_eq!(system, Some("Be concise".to_string()));
}
#[test]
fn google_request_body_includes_tools_and_system() {
let handler = GoogleHandler;
let req = ApiRequest {
model: "gemini-2.5-flash".into(),
messages: vec![serde_json::json!({
"role": "user",
"parts": [{"text": "Find the file and summarize it."}],
})],
system: Some("Use tools when needed.".into()),
temperature: 0.2,
max_tokens: 512,
tools: Some(vec![serde_json::json!({
"name": "search",
"description": "Search files",
"parameters": {"type": "object"}
})]),
tool_choice: None,
parallel_tool_calls: None,
stream: false,
budget_tokens: 0,
cache_control: false,
response_format: None,
};
let body = handler.build_request_body(&req);
assert!(body.get("systemInstruction").is_some());
assert!(body.get("tools").is_some());
assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
assert_eq!(body["generationConfig"]["maxOutputTokens"], 512);
}
#[test]
fn google_parse_multiple_tool_calls_and_usage() {
let handler = GoogleHandler;
let body = r#"{
"candidates":[{"content":{"parts":[
{"text":"Let me do that."},
{"functionCall":{"name":"search","args":{"q":"rust"}}},
{"functionCall":{"name":"read_file","args":{"path":"src/lib.rs"}}}
]}}],
"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":4,"totalTokenCount":14}
}"#;
let resp = handler.parse_response(body).unwrap();
assert_eq!(resp.text, "Let me do that.");
assert_eq!(resp.tool_calls.len(), 2);
assert_eq!(resp.tool_calls[0].name, "search");
assert_eq!(resp.tool_calls[1].name, "read_file");
let usage = resp.usage.unwrap();
assert_eq!(usage.prompt_tokens, 10);
assert_eq!(usage.completion_tokens, 4);
assert_eq!(usage.total_tokens, 14);
}
#[test]
fn openai_vision_message() {
let handler = OpenAiHandler;
let messages = vec![Message::UserMultimodal {
content: vec![
ContentBlock::Text {
text: "What is in this image?".to_string(),
},
ContentBlock::ImageUrl {
url: "https://example.com/cat.jpg".to_string(),
detail: "auto".to_string(),
},
],
}];
let (msgs, _) = handler.build_messages(&messages, "", None, None);
assert_eq!(msgs.len(), 1);
let content = msgs[0]["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert_eq!(content[1]["type"], "image_url");
}
#[test]
fn anthropic_vision_message() {
let handler = AnthropicHandler;
let messages = vec![Message::UserMultimodal {
content: vec![
ContentBlock::Text {
text: "Describe this.".to_string(),
},
ContentBlock::ImageBase64 {
data: "iVBOR...".to_string(),
media_type: "image/png".to_string(),
},
],
}];
let (msgs, _) = handler.build_messages(&messages, "", None, None);
assert_eq!(msgs.len(), 1);
let content = msgs[0]["content"].as_array().unwrap();
assert_eq!(content[0]["type"], "text");
assert_eq!(content[1]["type"], "image");
assert_eq!(content[1]["source"]["type"], "base64");
}
#[test]
fn openai_single_turn_images() {
let handler = OpenAiHandler;
let images = vec![ContentBlock::ImageUrl {
url: "https://example.com/cat.jpg".to_string(),
detail: "high".to_string(),
}];
let (msgs, _) = handler.build_messages(&[], "Describe this image", None, Some(&images));
let content = msgs[0]["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert_eq!(content[1]["type"], "image_url");
}
#[test]
fn anthropic_single_turn_images() {
let handler = AnthropicHandler;
let images = vec![ContentBlock::ImageBase64 {
data: "iVBOR...".to_string(),
media_type: "image/png".to_string(),
}];
let (msgs, _) = handler.build_messages(&[], "Describe this image", None, Some(&images));
let content = msgs[0]["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert_eq!(content[1]["type"], "image");
assert_eq!(content[1]["source"]["type"], "base64");
}
}