use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageRequest {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<SystemPrompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<ThinkingParam>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<MetadataParam>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SystemPrompt {
String(String),
Blocks(Vec<SystemBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemBlock {
#[serde(rename = "type")]
pub block_type: String,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheControl {
#[serde(rename = "type")]
pub cache_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<String>,
}
impl CacheControl {
pub fn ephemeral_5m() -> Self {
Self {
cache_type: "ephemeral".to_string(),
ttl: Some("5m".to_string()),
}
}
pub fn ephemeral_1h() -> Self {
Self {
cache_type: "ephemeral".to_string(),
ttl: Some("1h".to_string()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: Vec<ContentBlock>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
Image {
source: ImageSource,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
Document {
source: DocumentSource,
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
ToolUse {
id: String,
name: String,
input: JsonValue,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
ToolResult {
tool_use_id: String,
content: ToolResultContent,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
Thinking {
thinking: String,
signature: String,
},
RedactedThinking {
data: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolResultContent {
String(String),
Blocks(Vec<ToolResultBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolResultBlock {
Text {
text: String,
},
Image {
source: ImageSource,
},
Document {
source: DocumentSource,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DocumentSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
Text {
media_type: String,
data: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub tool_type: Option<String>,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_schema: Option<JsonValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_uses: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_domains: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_domains: Option<Vec<String>>,
}
impl Tool {
pub fn client(name: String, description: String, input_schema: JsonValue) -> Self {
Self {
tool_type: None,
name,
description: Some(description),
input_schema: Some(input_schema),
cache_control: None,
max_uses: None,
allowed_domains: None,
blocked_domains: None,
}
}
pub fn web_search(max_uses: Option<u32>) -> Self {
Self {
tool_type: Some("web_search_20250305".to_string()),
name: "web_search".to_string(),
description: None,
input_schema: None,
cache_control: None,
max_uses,
allowed_domains: None,
blocked_domains: None,
}
}
pub fn web_fetch() -> Self {
Self {
tool_type: Some("web_fetch_20250910".to_string()),
name: "web_fetch".to_string(),
description: None,
input_schema: None,
cache_control: None,
max_uses: None,
allowed_domains: None,
blocked_domains: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolChoice {
Auto {
#[serde(skip_serializing_if = "Option::is_none")]
disable_parallel_tool_use: Option<bool>,
},
Any {
#[serde(skip_serializing_if = "Option::is_none")]
disable_parallel_tool_use: Option<bool>,
},
Tool {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
disable_parallel_tool_use: Option<bool>,
},
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ThinkingParam {
Enabled {
budget_tokens: u32,
},
Disabled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetadataParam {
#[serde(skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageResponse {
pub id: String,
#[serde(rename = "type")]
pub object_type: String,
pub role: String,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: String,
pub stop_sequence: Option<String>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Usage {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_creation_input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_read_input_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
#[serde(rename = "type")]
pub error_type: String,
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorDetail {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub retry_after: Option<u32>,
}
impl ErrorDetail {
pub fn is_retryable(&self) -> bool {
matches!(
self.error_type.as_str(),
"rate_limit_error" | "overloaded_error" | "api_error"
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_message_request_serialization() {
let req = MessageRequest {
model: "claude-sonnet-4-5".to_string(),
max_tokens: 1024,
messages: vec![Message {
role: MessageRole::User,
content: vec![ContentBlock::Text {
text: "Hello!".to_string(),
cache_control: None,
}],
}],
cache_control: Some(CacheControl::ephemeral_5m()),
system: Some(SystemPrompt::String("You are helpful.".to_string())),
tools: None,
tool_choice: None,
stream: Some(true),
temperature: None,
top_p: None,
top_k: None,
stop_sequences: None,
thinking: None,
metadata: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "claude-sonnet-4-5");
assert_eq!(json["max_tokens"], 1024);
assert!(json["stream"].as_bool().unwrap());
assert_eq!(json["cache_control"]["type"], "ephemeral");
assert_eq!(json["cache_control"]["ttl"], "5m");
}
#[test]
fn test_tool_creation() {
let tool = Tool::client(
"get_weather".to_string(),
"Get weather".to_string(),
json!({
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}),
);
assert_eq!(tool.name, "get_weather");
assert!(tool.input_schema.is_some());
}
#[test]
fn test_cache_control() {
let cc = CacheControl::ephemeral_5m();
assert_eq!(cc.cache_type, "ephemeral");
assert_eq!(cc.ttl.as_ref().unwrap(), "5m");
let cc_1h = CacheControl::ephemeral_1h();
assert_eq!(cc_1h.ttl.as_ref().unwrap(), "1h");
}
#[test]
fn test_tool_choice_serialization() {
let auto = ToolChoice::Auto {
disable_parallel_tool_use: Some(true),
};
let json = serde_json::to_value(&auto).unwrap();
assert_eq!(json["type"], "auto");
assert!(json["disable_parallel_tool_use"].as_bool().unwrap());
let tool = ToolChoice::Tool {
name: "get_weather".to_string(),
disable_parallel_tool_use: None,
};
let json = serde_json::to_value(&tool).unwrap();
assert_eq!(json["type"], "tool");
assert_eq!(json["name"], "get_weather");
}
#[test]
fn test_error_detail_is_retryable() {
let rate_limit_error = ErrorDetail {
error_type: "rate_limit_error".to_string(),
message: "Rate limit exceeded".to_string(),
retry_after: Some(60),
};
assert!(rate_limit_error.is_retryable());
let overloaded_error = ErrorDetail {
error_type: "overloaded_error".to_string(),
message: "Service overloaded".to_string(),
retry_after: None,
};
assert!(overloaded_error.is_retryable());
let api_error = ErrorDetail {
error_type: "api_error".to_string(),
message: "Internal server error".to_string(),
retry_after: None,
};
assert!(api_error.is_retryable());
let invalid_request = ErrorDetail {
error_type: "invalid_request_error".to_string(),
message: "Invalid request".to_string(),
retry_after: None,
};
assert!(!invalid_request.is_retryable());
let auth_error = ErrorDetail {
error_type: "authentication_error".to_string(),
message: "Invalid API key".to_string(),
retry_after: None,
};
assert!(!auth_error.is_retryable());
let not_found = ErrorDetail {
error_type: "not_found_error".to_string(),
message: "Not found".to_string(),
retry_after: None,
};
assert!(!not_found.is_retryable());
}
#[test]
fn test_error_response_serialization() {
let error_json = json!({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": "Rate limit exceeded",
"retry_after": 30
}
});
let error_response: ErrorResponse = serde_json::from_value(error_json).unwrap();
assert_eq!(error_response.error.error_type, "rate_limit_error");
assert_eq!(error_response.error.message, "Rate limit exceeded");
assert_eq!(error_response.error.retry_after, Some(30));
assert!(error_response.error.is_retryable());
}
#[test]
fn test_usage_with_cache_tokens() {
let usage_json = json!({
"input_tokens": 1000,
"output_tokens": 500,
"cache_creation_input_tokens": 200,
"cache_read_input_tokens": 800
});
let usage: Usage = serde_json::from_value(usage_json).unwrap();
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_creation_input_tokens, Some(200));
assert_eq!(usage.cache_read_input_tokens, Some(800));
}
#[test]
fn test_usage_without_cache_tokens() {
let usage_json = json!({
"input_tokens": 1000,
"output_tokens": 500
});
let usage: Usage = serde_json::from_value(usage_json).unwrap();
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_creation_input_tokens, None);
assert_eq!(usage.cache_read_input_tokens, None);
}
#[test]
fn test_unified_usage_conversion_from_anthropic() {
let usage = Usage {
input_tokens: 1000,
output_tokens: 500,
cache_creation_input_tokens: Some(200),
cache_read_input_tokens: Some(800),
};
let unified = crate::llm::unified::UnifiedUsage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
cache_creation_input_tokens: usage.cache_creation_input_tokens,
cache_read_input_tokens: usage.cache_read_input_tokens,
reasoning_tokens: None, };
assert_eq!(unified.input_tokens, 1000);
assert_eq!(unified.output_tokens, 500);
assert_eq!(unified.cache_creation_input_tokens, Some(200));
assert_eq!(unified.cache_read_input_tokens, Some(800));
assert_eq!(unified.reasoning_tokens, None);
}
}