use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::warn;
use crate::error::{Result, ZeptoError};
use crate::session::{ContentPart, ImageSource, Message, Role, ToolCall};
use super::{
parse_provider_error, ChatOptions, LLMProvider, LLMResponse, LLMToolCall, ToolDefinition, Usage,
};
const CLAUDE_API_URL: &str = "https://api.anthropic.com/v1/messages";
const DEFAULT_MODEL: &str = match option_env!("ZEPTOCLAW_CLAUDE_DEFAULT_MODEL") {
Some(v) => v,
None => "claude-sonnet-4-6",
};
const ANTHROPIC_VERSION: &str = "2023-06-01";
pub struct ClaudeProvider {
credential: crate::auth::ResolvedCredential,
client: Client,
}
impl ClaudeProvider {
pub fn new(api_key: &str) -> Self {
Self {
credential: crate::auth::ResolvedCredential::ApiKey(api_key.to_string()),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
pub fn with_credential(credential: crate::auth::ResolvedCredential) -> Self {
Self {
credential,
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
pub fn with_client(api_key: &str, client: Client) -> Self {
Self {
credential: crate::auth::ResolvedCredential::ApiKey(api_key.to_string()),
client,
}
}
fn auth_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
match &self.credential {
crate::auth::ResolvedCredential::ApiKey(key) => {
match reqwest::header::HeaderValue::from_str(key) {
Ok(v) => {
headers.insert("x-api-key", v);
}
Err(e) => {
warn!(error = %e, "Invalid API key header value; omitting header");
}
}
}
crate::auth::ResolvedCredential::BearerToken { access_token, .. } => {
let value = format!("Bearer {}", access_token);
match reqwest::header::HeaderValue::from_str(&value) {
Ok(v) => {
headers.insert(reqwest::header::AUTHORIZATION, v);
}
Err(e) => {
warn!(error = %e, "Invalid Authorization header value; omitting header");
}
}
headers.insert(
"anthropic-beta",
reqwest::header::HeaderValue::from_static(
"claude-code-20250219,oauth-2025-04-20",
),
);
}
}
headers
}
}
#[async_trait]
impl LLMProvider for ClaudeProvider {
async fn chat(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> Result<LLMResponse> {
let model = model.unwrap_or(DEFAULT_MODEL);
let (mut system, claude_messages) = convert_messages(messages)?;
if let Some(suffix) = options.output_format.to_claude_system_suffix() {
let base = system.unwrap_or_default();
system = Some(format!("{}{}", base, suffix));
}
let request = ClaudeRequest {
model: model.to_string(),
max_tokens: options.max_tokens.unwrap_or(8192),
messages: claude_messages,
system,
tools: if tools.is_empty() {
None
} else {
Some(convert_tools(tools))
},
temperature: options.temperature,
top_p: options.top_p,
stop_sequences: options.stop,
stream: None,
};
let response = self
.client
.post(CLAUDE_API_URL)
.headers(self.auth_headers())
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let error_text = response.text().await.unwrap_or_default();
let body = if let Ok(error_response) =
serde_json::from_str::<ClaudeErrorResponse>(&error_text)
{
format!(
"Claude API error: {} - {}",
error_response.error.r#type, error_response.error.message
)
} else {
format!("Claude API error: {}", error_text)
};
return Err(ZeptoError::from(parse_provider_error(status, &body)));
}
let claude_response: ClaudeResponse = response.json().await?;
Ok(convert_response(claude_response))
}
async fn chat_stream(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> crate::error::Result<tokio::sync::mpsc::Receiver<super::StreamEvent>> {
use super::StreamEvent;
use futures::StreamExt;
let model = model.unwrap_or(DEFAULT_MODEL);
let (mut system, claude_messages) = convert_messages(messages)?;
if let Some(suffix) = options.output_format.to_claude_system_suffix() {
let base = system.unwrap_or_default();
system = Some(format!("{}{}", base, suffix));
}
let request = ClaudeRequest {
model: model.to_string(),
max_tokens: options.max_tokens.unwrap_or(8192),
messages: claude_messages,
system,
tools: if tools.is_empty() {
None
} else {
Some(convert_tools(tools))
},
temperature: options.temperature,
top_p: options.top_p,
stop_sequences: options.stop,
stream: Some(true),
};
let response = self
.client
.post(CLAUDE_API_URL)
.headers(self.auth_headers())
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let error_text = response.text().await.unwrap_or_default();
let body = if let Ok(error_response) =
serde_json::from_str::<ClaudeErrorResponse>(&error_text)
{
format!(
"Claude API error: {} - {}",
error_response.error.r#type, error_response.error.message
)
} else {
format!("Claude API error: {}", error_text)
};
return Err(ZeptoError::from(parse_provider_error(status, &body)));
}
let (tx, rx) = tokio::sync::mpsc::channel::<StreamEvent>(32);
let byte_stream = response.bytes_stream();
tokio::spawn(async move {
let mut assembled_content = String::new();
let mut tool_calls: Vec<super::LLMToolCall> = Vec::new();
let mut current_tool_id: Option<String> = None;
let mut current_tool_name: Option<String> = None;
let mut current_tool_json = String::new();
let mut input_tokens: u32 = 0;
let mut output_tokens: u32 = 0;
let mut line_buffer = String::new();
tokio::pin!(byte_stream);
while let Some(chunk_result) = byte_stream.next().await {
let chunk = match chunk_result {
Ok(bytes) => bytes,
Err(e) => {
let _ = tx
.send(StreamEvent::Error(ZeptoError::Provider(format!(
"Stream read error: {}",
e
))))
.await;
return;
}
};
let chunk_str = String::from_utf8_lossy(&chunk);
line_buffer.push_str(&chunk_str);
while let Some(newline_pos) = line_buffer.find('\n') {
let line = line_buffer[..newline_pos].trim().to_string();
line_buffer = line_buffer[newline_pos + 1..].to_string();
if line.is_empty() || line.starts_with("event:") {
continue;
}
let data = if let Some(stripped) = line.strip_prefix("data: ") {
stripped
} else if let Some(stripped) = line.strip_prefix("data:") {
stripped
} else {
continue;
};
if data == "[DONE]" {
break;
}
let sse: SseEvent = match serde_json::from_str(data) {
Ok(v) => v,
Err(_) => continue,
};
match sse.event_type.as_str() {
"message_start" => {
if let Some(msg) = &sse.message {
if let Some(usage) = &msg.usage {
input_tokens = usage.input_tokens.unwrap_or(0);
}
}
}
"content_block_start" => {
if let Some(block) = &sse.content_block {
if block.block_type == "tool_use" {
current_tool_id = block.id.clone();
current_tool_name = block.name.clone();
current_tool_json.clear();
}
}
}
"content_block_delta" => {
if let Some(delta) = &sse.delta {
match delta.delta_type.as_deref() {
Some("text_delta") => {
if let Some(text) = &delta.text {
assembled_content.push_str(text);
if tx
.send(StreamEvent::Delta(text.clone()))
.await
.is_err()
{
return;
}
}
}
Some("input_json_delta") => {
if let Some(json_chunk) = &delta.partial_json {
current_tool_json.push_str(json_chunk);
}
}
_ => {}
}
}
}
"content_block_stop" => {
if let (Some(id), Some(name)) =
(current_tool_id.take(), current_tool_name.take())
{
let args = if current_tool_json.is_empty() {
"{}".to_string()
} else {
std::mem::take(&mut current_tool_json)
};
tool_calls.push(super::LLMToolCall::new(&id, &name, &args));
}
}
"message_delta" => {
if let Some(usage) = &sse.usage {
output_tokens = usage.output_tokens.unwrap_or(0);
}
}
"message_stop" => {
if !tool_calls.is_empty() {
let _ = tx
.send(StreamEvent::ToolCalls(std::mem::take(&mut tool_calls)))
.await;
}
let usage = super::Usage::new(input_tokens, output_tokens);
let _ = tx
.send(StreamEvent::Done {
content: assembled_content.clone(),
usage: Some(usage),
})
.await;
return;
}
_ => {}
}
}
}
if !tool_calls.is_empty() {
let _ = tx
.send(StreamEvent::ToolCalls(std::mem::take(&mut tool_calls)))
.await;
}
let usage = super::Usage::new(input_tokens, output_tokens);
let _ = tx
.send(StreamEvent::Done {
content: assembled_content,
usage: Some(usage),
})
.await;
});
Ok(rx)
}
fn default_model(&self) -> &str {
DEFAULT_MODEL
}
fn name(&self) -> &str {
"claude"
}
}
#[derive(Debug, Serialize)]
struct ClaudeRequest {
model: String,
max_tokens: u32,
messages: Vec<ClaudeMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ClaudeTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ClaudeMessage {
role: String,
content: ClaudeContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum ClaudeContent {
Text(String),
Blocks(Vec<ClaudeContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
enum ClaudeContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
#[serde(rename = "image")]
Image { source: ClaudeImageSource },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ClaudeImageSource {
#[serde(rename = "type")]
source_type: String,
media_type: String,
data: String,
}
#[derive(Debug, Serialize)]
struct ClaudeTool {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct ClaudeResponse {
content: Vec<ClaudeContentBlock>,
usage: ClaudeUsage,
#[allow(dead_code)]
stop_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ClaudeErrorResponse {
error: ClaudeError,
}
#[derive(Debug, Deserialize)]
struct ClaudeError {
r#type: String,
message: String,
}
#[derive(Debug, Deserialize)]
struct ClaudeUsage {
input_tokens: u32,
output_tokens: u32,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct SseEvent {
#[serde(rename = "type")]
event_type: String,
#[serde(default)]
delta: Option<SseDelta>,
#[serde(default)]
content_block: Option<SseContentBlock>,
#[serde(default)]
usage: Option<SseUsage>,
#[serde(default)]
index: Option<u32>,
#[serde(default)]
message: Option<SseMessage>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct SseDelta {
#[serde(rename = "type")]
#[serde(default)]
delta_type: Option<String>,
#[serde(default)]
text: Option<String>,
#[serde(default)]
partial_json: Option<String>,
#[serde(default)]
stop_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct SseContentBlock {
#[serde(rename = "type")]
block_type: String,
#[serde(default)]
id: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct SseUsage {
#[serde(default)]
input_tokens: Option<u32>,
#[serde(default)]
output_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct SseMessage {
#[serde(default)]
usage: Option<SseUsage>,
}
fn convert_messages(messages: Vec<Message>) -> Result<(Option<String>, Vec<ClaudeMessage>)> {
let mut system: Option<String> = None;
let mut claude_messages: Vec<ClaudeMessage> = Vec::new();
let mut pending_tool_results: Vec<ClaudeContentBlock> = Vec::new();
for msg in messages {
match msg.role {
Role::System => {
system = Some(msg.content);
}
Role::User => {
if !pending_tool_results.is_empty() {
claude_messages.push(ClaudeMessage {
role: "user".to_string(),
content: ClaudeContent::Blocks(std::mem::take(&mut pending_tool_results)),
});
}
if msg.has_images() {
let blocks: Vec<ClaudeContentBlock> = msg
.content_parts
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => {
Some(ClaudeContentBlock::Text { text: text.clone() })
}
ContentPart::Image { source, media_type } => {
if let ImageSource::Base64 { data } = source {
Some(ClaudeContentBlock::Image {
source: ClaudeImageSource {
source_type: "base64".to_string(),
media_type: media_type.clone(),
data: data.clone(),
},
})
} else {
None
}
}
})
.collect();
claude_messages.push(ClaudeMessage {
role: "user".to_string(),
content: ClaudeContent::Blocks(blocks),
});
} else {
claude_messages.push(ClaudeMessage {
role: "user".to_string(),
content: ClaudeContent::Text(msg.content),
});
}
}
Role::Assistant => {
if !pending_tool_results.is_empty() {
claude_messages.push(ClaudeMessage {
role: "user".to_string(),
content: ClaudeContent::Blocks(std::mem::take(&mut pending_tool_results)),
});
}
if let Some(tool_calls) = msg.tool_calls {
let mut blocks: Vec<ClaudeContentBlock> = Vec::new();
if !msg.content.is_empty() {
blocks.push(ClaudeContentBlock::Text { text: msg.content });
}
for tc in tool_calls {
let input: serde_json::Value =
serde_json::from_str(&tc.arguments).unwrap_or(serde_json::json!({}));
blocks.push(ClaudeContentBlock::ToolUse {
id: tc.id,
name: tc.name,
input,
});
}
claude_messages.push(ClaudeMessage {
role: "assistant".to_string(),
content: ClaudeContent::Blocks(blocks),
});
} else {
claude_messages.push(ClaudeMessage {
role: "assistant".to_string(),
content: ClaudeContent::Text(msg.content),
});
}
}
Role::Tool => {
if let Some(tool_call_id) = msg.tool_call_id {
pending_tool_results.push(ClaudeContentBlock::ToolResult {
tool_use_id: tool_call_id,
content: msg.content,
is_error: None,
});
}
}
}
}
if !pending_tool_results.is_empty() {
claude_messages.push(ClaudeMessage {
role: "user".to_string(),
content: ClaudeContent::Blocks(pending_tool_results),
});
}
Ok((system, claude_messages))
}
fn convert_tools(tools: Vec<ToolDefinition>) -> Vec<ClaudeTool> {
tools
.into_iter()
.map(|t| ClaudeTool {
name: t.name,
description: t.description,
input_schema: t.parameters,
})
.collect()
}
fn convert_response(response: ClaudeResponse) -> LLMResponse {
let mut content = String::new();
let mut tool_calls: Vec<LLMToolCall> = Vec::new();
for block in response.content {
match block {
ClaudeContentBlock::Text { text } => {
if !content.is_empty() {
content.push('\n');
}
content.push_str(&text);
}
ClaudeContentBlock::ToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input).unwrap_or_else(|_| "{}".to_string());
tool_calls.push(LLMToolCall::new(&id, &name, &arguments));
}
ClaudeContentBlock::ToolResult { .. } => {
}
ClaudeContentBlock::Image { .. } => {
}
}
}
let usage = Usage::new(response.usage.input_tokens, response.usage.output_tokens);
LLMResponse {
content,
tool_calls,
usage: Some(usage),
}
}
#[allow(dead_code)]
fn tool_call_to_llm_tool_call(tc: &ToolCall) -> LLMToolCall {
LLMToolCall::new(&tc.id, &tc.name, &tc.arguments)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{ContentPart, ImageSource, Message};
#[test]
fn test_claude_provider_creation() {
let provider = ClaudeProvider::new("test-key");
assert_eq!(provider.name(), "claude");
assert_eq!(provider.default_model(), "claude-sonnet-4-6");
}
#[test]
fn test_claude_provider_with_client() {
let client = Client::new();
let provider = ClaudeProvider::with_client("test-key", client);
assert_eq!(provider.name(), "claude");
}
#[test]
fn test_message_conversion_simple() {
let messages = vec![Message::user("Hello"), Message::assistant("Hi there!")];
let (system, claude_messages) = convert_messages(messages).unwrap();
assert!(system.is_none());
assert_eq!(claude_messages.len(), 2);
assert_eq!(claude_messages[0].role, "user");
assert_eq!(claude_messages[1].role, "assistant");
}
#[test]
fn test_message_conversion_with_system() {
let messages = vec![
Message::system("You are a helpful assistant"),
Message::user("Hello"),
Message::assistant("Hi there!"),
];
let (system, claude_messages) = convert_messages(messages).unwrap();
assert_eq!(system, Some("You are a helpful assistant".to_string()));
assert_eq!(claude_messages.len(), 2);
assert_eq!(claude_messages[0].role, "user");
assert_eq!(claude_messages[1].role, "assistant");
}
#[test]
fn test_message_conversion_with_tool_calls() {
let tool_call = ToolCall::new("call_1", "web_search", r#"{"query": "rust"}"#);
let messages = vec![
Message::user("Search for Rust"),
Message::assistant_with_tools("Let me search for that.", vec![tool_call]),
Message::tool_result("call_1", "Found 100 results"),
Message::assistant("I found 100 results about Rust."),
];
let (system, claude_messages) = convert_messages(messages).unwrap();
assert!(system.is_none());
assert_eq!(claude_messages.len(), 4);
assert_eq!(claude_messages[0].role, "user");
assert_eq!(claude_messages[1].role, "assistant");
if let ClaudeContent::Blocks(blocks) = &claude_messages[1].content {
assert_eq!(blocks.len(), 2); assert!(matches!(blocks[0], ClaudeContentBlock::Text { .. }));
assert!(matches!(blocks[1], ClaudeContentBlock::ToolUse { .. }));
} else {
panic!("Expected blocks content for tool call message");
}
assert_eq!(claude_messages[2].role, "user");
if let ClaudeContent::Blocks(blocks) = &claude_messages[2].content {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], ClaudeContentBlock::ToolResult { .. }));
} else {
panic!("Expected blocks content for tool result");
}
assert_eq!(claude_messages[3].role, "assistant");
}
#[test]
fn test_message_conversion_multiple_tool_results() {
let tc1 = ToolCall::new("call_1", "tool_a", "{}");
let tc2 = ToolCall::new("call_2", "tool_b", "{}");
let messages = vec![
Message::user("Do both"),
Message::assistant_with_tools("Running both tools.", vec![tc1, tc2]),
Message::tool_result("call_1", "Result A"),
Message::tool_result("call_2", "Result B"),
Message::assistant("Both completed."),
];
let (_, claude_messages) = convert_messages(messages).unwrap();
assert_eq!(claude_messages.len(), 4);
assert_eq!(claude_messages[2].role, "user");
if let ClaudeContent::Blocks(blocks) = &claude_messages[2].content {
assert_eq!(blocks.len(), 2); } else {
panic!("Expected grouped tool results");
}
}
#[test]
fn test_convert_tools() {
let tools = vec![
ToolDefinition::new(
"web_search",
"Search the web",
serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"]
}),
),
ToolDefinition::new(
"calculator",
"Perform calculations",
serde_json::json!({
"type": "object",
"properties": {
"expression": { "type": "string" }
}
}),
),
];
let claude_tools = convert_tools(tools);
assert_eq!(claude_tools.len(), 2);
assert_eq!(claude_tools[0].name, "web_search");
assert_eq!(claude_tools[0].description, "Search the web");
assert_eq!(claude_tools[1].name, "calculator");
}
#[test]
fn test_convert_response_text_only() {
let response = ClaudeResponse {
content: vec![ClaudeContentBlock::Text {
text: "Hello, world!".to_string(),
}],
usage: ClaudeUsage {
input_tokens: 10,
output_tokens: 5,
},
stop_reason: Some("end_turn".to_string()),
};
let llm_response = convert_response(response);
assert_eq!(llm_response.content, "Hello, world!");
assert!(!llm_response.has_tool_calls());
assert!(llm_response.usage.is_some());
let usage = llm_response.usage.unwrap();
assert_eq!(usage.prompt_tokens, 10);
assert_eq!(usage.completion_tokens, 5);
assert_eq!(usage.total_tokens, 15);
}
#[test]
fn test_convert_response_with_tool_calls() {
let response = ClaudeResponse {
content: vec![
ClaudeContentBlock::Text {
text: "Let me search for that.".to_string(),
},
ClaudeContentBlock::ToolUse {
id: "toolu_01".to_string(),
name: "web_search".to_string(),
input: serde_json::json!({"query": "rust programming"}),
},
],
usage: ClaudeUsage {
input_tokens: 20,
output_tokens: 30,
},
stop_reason: Some("tool_use".to_string()),
};
let llm_response = convert_response(response);
assert_eq!(llm_response.content, "Let me search for that.");
assert!(llm_response.has_tool_calls());
assert_eq!(llm_response.tool_calls.len(), 1);
let tc = &llm_response.tool_calls[0];
assert_eq!(tc.id, "toolu_01");
assert_eq!(tc.name, "web_search");
assert!(tc.arguments.contains("rust programming"));
}
#[test]
fn test_convert_response_multiple_text_blocks() {
let response = ClaudeResponse {
content: vec![
ClaudeContentBlock::Text {
text: "First part.".to_string(),
},
ClaudeContentBlock::Text {
text: "Second part.".to_string(),
},
],
usage: ClaudeUsage {
input_tokens: 10,
output_tokens: 10,
},
stop_reason: Some("end_turn".to_string()),
};
let llm_response = convert_response(response);
assert_eq!(llm_response.content, "First part.\nSecond part.");
}
#[test]
fn test_claude_request_serialization() {
let request = ClaudeRequest {
model: "claude-sonnet-4-5-20250929".to_string(),
max_tokens: 1000,
messages: vec![ClaudeMessage {
role: "user".to_string(),
content: ClaudeContent::Text("Hello".to_string()),
}],
system: Some("You are helpful.".to_string()),
tools: None,
temperature: Some(0.7),
top_p: None,
stop_sequences: None,
stream: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("claude-sonnet-4-5-20250929"));
assert!(json.contains("max_tokens"));
assert!(json.contains("Hello"));
assert!(json.contains("You are helpful"));
assert!(json.contains("temperature"));
assert!(!json.contains("top_p"));
}
#[test]
fn test_claude_request_without_optional_fields() {
let request = ClaudeRequest {
model: "claude-sonnet-4-5-20250929".to_string(),
max_tokens: 1000,
messages: vec![],
system: None,
tools: None,
temperature: None,
top_p: None,
stop_sequences: None,
stream: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(!json.contains("system"));
assert!(!json.contains("tools"));
assert!(!json.contains("temperature"));
assert!(!json.contains("top_p"));
assert!(!json.contains("stop_sequences"));
}
#[test]
fn test_content_block_serialization() {
let text_block = ClaudeContentBlock::Text {
text: "Hello".to_string(),
};
let json = serde_json::to_string(&text_block).unwrap();
assert!(json.contains(r#""type":"text""#));
assert!(json.contains(r#""text":"Hello""#));
let tool_use = ClaudeContentBlock::ToolUse {
id: "call_1".to_string(),
name: "search".to_string(),
input: serde_json::json!({"q": "test"}),
};
let json = serde_json::to_string(&tool_use).unwrap();
assert!(json.contains(r#""type":"tool_use""#));
assert!(json.contains(r#""id":"call_1""#));
assert!(json.contains(r#""name":"search""#));
let tool_result = ClaudeContentBlock::ToolResult {
tool_use_id: "call_1".to_string(),
content: "Result".to_string(),
is_error: None,
};
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains(r#""type":"tool_result""#));
assert!(json.contains(r#""tool_use_id":"call_1""#));
}
#[test]
fn test_empty_messages() {
let messages: Vec<Message> = vec![];
let (system, claude_messages) = convert_messages(messages).unwrap();
assert!(system.is_none());
assert!(claude_messages.is_empty());
}
#[test]
fn test_only_system_message() {
let messages = vec![Message::system("You are helpful.")];
let (system, claude_messages) = convert_messages(messages).unwrap();
assert_eq!(system, Some("You are helpful.".to_string()));
assert!(claude_messages.is_empty());
}
#[test]
fn test_tool_call_to_llm_tool_call() {
let tc = ToolCall::new("call_123", "web_search", r#"{"query": "test"}"#);
let llm_tc = tool_call_to_llm_tool_call(&tc);
assert_eq!(llm_tc.id, "call_123");
assert_eq!(llm_tc.name, "web_search");
assert_eq!(llm_tc.arguments, r#"{"query": "test"}"#);
}
#[test]
fn test_claude_request_with_stream_flag() {
let request = ClaudeRequest {
model: "claude-sonnet-4-5-20250929".to_string(),
max_tokens: 1000,
messages: vec![],
system: None,
tools: None,
temperature: None,
top_p: None,
stop_sequences: None,
stream: Some(true),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains(r#""stream":true"#));
}
#[test]
fn test_claude_request_without_stream_flag() {
let request = ClaudeRequest {
model: "claude-sonnet-4-5-20250929".to_string(),
max_tokens: 1000,
messages: vec![],
system: None,
tools: None,
temperature: None,
top_p: None,
stop_sequences: None,
stream: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(!json.contains("stream"));
}
#[test]
fn test_parse_sse_content_block_delta() {
let line = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
let parsed: serde_json::Value = serde_json::from_str(line).unwrap();
assert_eq!(parsed["type"].as_str().unwrap(), "content_block_delta");
assert_eq!(parsed["delta"]["text"].as_str().unwrap(), "Hello");
}
#[test]
fn test_parse_sse_message_stop() {
let line = r#"{"type":"message_stop"}"#;
let parsed: serde_json::Value = serde_json::from_str(line).unwrap();
assert_eq!(parsed["type"].as_str().unwrap(), "message_stop");
}
#[test]
fn test_parse_sse_message_delta_with_usage() {
let line = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":42}}"#;
let parsed: serde_json::Value = serde_json::from_str(line).unwrap();
assert_eq!(parsed["type"].as_str().unwrap(), "message_delta");
assert_eq!(parsed["usage"]["output_tokens"].as_u64().unwrap(), 42);
}
#[test]
fn test_parse_sse_content_block_start_tool_use() {
let line = r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01","name":"web_search","input":{}}}"#;
let parsed: serde_json::Value = serde_json::from_str(line).unwrap();
assert_eq!(
parsed["content_block"]["type"].as_str().unwrap(),
"tool_use"
);
assert_eq!(
parsed["content_block"]["name"].as_str().unwrap(),
"web_search"
);
}
#[test]
fn test_convert_user_message_with_image() {
let images = vec![ContentPart::Image {
source: ImageSource::Base64 {
data: "abc123".to_string(),
},
media_type: "image/jpeg".to_string(),
}];
let msg = Message::user_with_images("What is this?", images);
let (_, claude_msgs) = convert_messages(vec![msg]).unwrap();
assert_eq!(claude_msgs.len(), 1);
assert_eq!(claude_msgs[0].role, "user");
if let ClaudeContent::Blocks(blocks) = &claude_msgs[0].content {
assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ClaudeContentBlock::Text { .. }));
assert!(matches!(&blocks[1], ClaudeContentBlock::Image { .. }));
} else {
panic!("Expected Blocks content for image message");
}
}
#[test]
fn test_convert_text_only_message_unchanged() {
let msg = Message::user("Hello");
let (_, claude_msgs) = convert_messages(vec![msg]).unwrap();
assert!(matches!(&claude_msgs[0].content, ClaudeContent::Text(_)));
}
#[test]
fn test_claude_image_json_matches_api_spec() {
let images = vec![ContentPart::Image {
source: ImageSource::Base64 {
data: "iVBOR".to_string(),
},
media_type: "image/png".to_string(),
}];
let msg = Message::user_with_images("Describe this", images);
let (_, claude_msgs) = convert_messages(vec![msg]).unwrap();
let json = serde_json::to_value(&claude_msgs[0]).unwrap();
let blocks = json["content"].as_array().unwrap();
assert_eq!(blocks[0]["type"], "text");
assert_eq!(blocks[0]["text"], "Describe this");
assert_eq!(blocks[1]["type"], "image");
assert_eq!(blocks[1]["source"]["type"], "base64");
assert_eq!(blocks[1]["source"]["media_type"], "image/png");
assert_eq!(blocks[1]["source"]["data"], "iVBOR");
}
#[test]
fn test_convert_message_with_multiple_images() {
let images = vec![
ContentPart::Image {
source: ImageSource::Base64 {
data: "img1".to_string(),
},
media_type: "image/jpeg".to_string(),
},
ContentPart::Image {
source: ImageSource::Base64 {
data: "img2".to_string(),
},
media_type: "image/png".to_string(),
},
];
let msg = Message::user_with_images("Compare these two images", images);
let (_, claude_msgs) = convert_messages(vec![msg]).unwrap();
if let ClaudeContent::Blocks(blocks) = &claude_msgs[0].content {
assert_eq!(blocks.len(), 3); assert!(matches!(&blocks[0], ClaudeContentBlock::Text { .. }));
assert!(matches!(&blocks[1], ClaudeContentBlock::Image { .. }));
assert!(matches!(&blocks[2], ClaudeContentBlock::Image { .. }));
} else {
panic!("Expected Blocks content");
}
}
#[test]
fn test_convert_message_skips_filepath_images() {
let parts = vec![ContentPart::Image {
source: ImageSource::FilePath {
path: "media/abc.jpg".to_string(),
},
media_type: "image/jpeg".to_string(),
}];
let msg = Message::user_with_images("What is this?", parts);
let (_, claude_msgs) = convert_messages(vec![msg]).unwrap();
if let ClaudeContent::Blocks(blocks) = &claude_msgs[0].content {
assert_eq!(blocks.len(), 1); assert!(matches!(&blocks[0], ClaudeContentBlock::Text { .. }));
} else {
panic!("Expected Blocks content");
}
}
}