use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ContentPart {
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
impl ContentPart {
pub fn text(text: impl Into<String>) -> Self {
ContentPart::Text { text: text.into() }
}
pub fn image_url(url: impl Into<String>) -> Self {
ContentPart::ImageUrl {
image_url: ImageUrl {
url: url.into(),
detail: None,
},
}
}
pub fn image_base64(media_type: &str, base64_data: &str) -> Self {
ContentPart::ImageUrl {
image_url: ImageUrl {
url: format!("data:{};base64,{}", media_type, base64_data),
detail: None,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
impl MessageContent {
pub fn text(text: impl Into<String>) -> Self {
MessageContent::Text(text.into())
}
pub fn parts(parts: Vec<ContentPart>) -> Self {
MessageContent::Parts(parts)
}
pub fn as_text(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(text),
MessageContent::Parts(_) => None,
}
}
pub fn to_text(&self) -> String {
match self {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join(""),
}
}
}
impl std::fmt::Display for MessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageContent::Text(text) => write!(f, "{}", text),
MessageContent::Parts(parts) => {
let text = parts
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
write!(f, "{}", text)
}
}
}
}
impl From<String> for MessageContent {
fn from(text: String) -> Self {
MessageContent::Text(text)
}
}
impl From<&str> for MessageContent {
fn from(text: &str) -> Self {
MessageContent::Text(text.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CacheControl {
#[serde(rename = "type")]
pub cache_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<String>,
}
impl CacheControl {
pub fn ephemeral() -> Self {
Self {
cache_type: "ephemeral".to_string(),
ttl: None,
}
}
pub fn ephemeral_long() -> Self {
Self {
cache_type: "ephemeral".to_string(),
ttl: Some("1h".to_string()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Response {
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u64,
pub output_tokens: u64,
#[serde(default, alias = "cache_read_input_tokens")]
pub cache_read_tokens: u64,
#[serde(default, alias = "cache_creation_input_tokens")]
pub cache_write_tokens: u64,
}
impl Usage {
pub fn total_tokens(&self) -> u64 {
self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_serialization() {
let msg = Message {
role: Role::User,
content: Some(MessageContent::text("Hello")),
tool_calls: None,
tool_call_id: None,
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(msg.content, deserialized.content);
}
#[test]
fn test_message_with_tool_calls() {
let msg = Message {
role: Role::Assistant,
content: Some(MessageContent::text("")),
tool_calls: Some(vec![ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "test_tool".to_string(),
arguments: serde_json::json!({"arg": "value"}).to_string(),
},
}]),
tool_call_id: None,
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert!(deserialized.tool_calls.is_some());
}
#[test]
fn test_tool_result_message() {
let msg = Message {
role: Role::Tool,
content: Some(MessageContent::text("result output")),
tool_calls: None,
tool_call_id: Some("call_123".to_string()),
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
println!("Tool result message JSON: {}", json);
assert!(json.contains("tool_call_id"));
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
}
#[test]
fn test_assistant_with_tool_calls_serialization() {
let msg = Message {
role: Role::Assistant,
content: None,
tool_calls: Some(vec![ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "test_tool".to_string(),
arguments: serde_json::json!({}).to_string(),
},
}]),
tool_call_id: None,
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
println!("Assistant with tool_calls JSON: {}", json);
assert!(!json.contains("\"content\":null"));
assert!(json.contains("tool_calls"));
}
#[test]
fn test_role_serialization() {
let role = Role::User;
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, "\"user\"");
}
#[test]
fn test_tool_serialization() {
let tool = Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
},
};
let json = serde_json::to_string(&tool).unwrap();
let deserialized: Tool = serde_json::from_str(&json).unwrap();
assert_eq!(tool.function.name, deserialized.function.name);
}
#[test]
fn test_response_serialization() {
let response = Response {
content: "Hello, world!".to_string(),
tool_calls: None,
usage: Usage {
input_tokens: 10,
output_tokens: 5,
cache_read_tokens: 0,
cache_write_tokens: 0,
},
};
let json = serde_json::to_string(&response).unwrap();
let deserialized: Response = serde_json::from_str(&json).unwrap();
assert_eq!(response.content, deserialized.content);
assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
}
#[test]
fn test_usage_serialization() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
let json = serde_json::to_string(&usage).unwrap();
let deserialized: Usage = serde_json::from_str(&json).unwrap();
assert_eq!(usage.input_tokens, deserialized.input_tokens);
assert_eq!(usage.output_tokens, deserialized.output_tokens);
}
#[test]
fn test_cache_control_serialization() {
let cache = CacheControl::ephemeral();
let json = serde_json::to_string(&cache).unwrap();
assert_eq!(json, r#"{"type":"ephemeral"}"#);
let cache_long = CacheControl::ephemeral_long();
let json_long = serde_json::to_string(&cache_long).unwrap();
assert!(json_long.contains(r#""ttl":"1h""#));
}
#[test]
fn test_message_with_cache_control() {
let msg = Message {
role: Role::User,
content: Some(MessageContent::text("Hello")),
tool_calls: None,
tool_call_id: None,
cache_control: Some(CacheControl::ephemeral()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("cache_control"));
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert!(deserialized.cache_control.is_some());
}
#[test]
fn test_usage_with_cache_fields() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 80,
cache_write_tokens: 20,
};
assert_eq!(usage.total_tokens(), 250);
let json = serde_json::to_string(&usage).unwrap();
assert!(json.contains("cache_read_tokens"));
}
#[test]
fn test_usage_anthropic_aliases() {
let json = r#"{
"input_tokens": 100,
"output_tokens": 50,
"cache_read_input_tokens": 80,
"cache_creation_input_tokens": 20
}"#;
let usage: Usage = serde_json::from_str(json).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_read_tokens, 80);
assert_eq!(usage.cache_write_tokens, 20);
assert_eq!(usage.total_tokens(), 250);
}
}