use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CacheControl {
#[serde(rename = "type")]
pub cache_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<String>,
}
impl CacheControl {
pub fn ephemeral() -> Self {
Self {
cache_type: "ephemeral".to_string(),
ttl: None,
}
}
pub fn ephemeral_long() -> Self {
Self {
cache_type: "ephemeral".to_string(),
ttl: Some("1h".to_string()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Response {
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u64,
pub output_tokens: u64,
#[serde(default, alias = "cache_read_input_tokens")]
pub cache_read_tokens: u64,
#[serde(default, alias = "cache_creation_input_tokens")]
pub cache_write_tokens: u64,
}
impl Usage {
pub fn total_tokens(&self) -> u64 {
self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_serialization() {
let msg = Message {
role: Role::User,
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(msg.content, deserialized.content);
}
#[test]
fn test_message_with_tool_calls() {
let msg = Message {
role: Role::Assistant,
content: Some("".to_string()),
tool_calls: Some(vec![ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "test_tool".to_string(),
arguments: serde_json::json!({"arg": "value"}).to_string(),
},
}]),
tool_call_id: None,
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert!(deserialized.tool_calls.is_some());
}
#[test]
fn test_tool_result_message() {
let msg = Message {
role: Role::Tool,
content: Some("result output".to_string()),
tool_calls: None,
tool_call_id: Some("call_123".to_string()),
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
println!("Tool result message JSON: {}", json);
assert!(json.contains("tool_call_id"));
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
}
#[test]
fn test_assistant_with_tool_calls_serialization() {
let msg = Message {
role: Role::Assistant,
content: None,
tool_calls: Some(vec![ToolCall {
id: "call_123".to_string(),
tool_type: "function".to_string(),
function: FunctionCall {
name: "test_tool".to_string(),
arguments: serde_json::json!({}).to_string(),
},
}]),
tool_call_id: None,
cache_control: None,
};
let json = serde_json::to_string(&msg).unwrap();
println!("Assistant with tool_calls JSON: {}", json);
assert!(!json.contains("\"content\":null"));
assert!(json.contains("tool_calls"));
}
#[test]
fn test_role_serialization() {
let role = Role::User;
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, "\"user\"");
}
#[test]
fn test_tool_serialization() {
let tool = Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
},
};
let json = serde_json::to_string(&tool).unwrap();
let deserialized: Tool = serde_json::from_str(&json).unwrap();
assert_eq!(tool.function.name, deserialized.function.name);
}
#[test]
fn test_response_serialization() {
let response = Response {
content: "Hello, world!".to_string(),
tool_calls: None,
usage: Usage {
input_tokens: 10,
output_tokens: 5,
cache_read_tokens: 0,
cache_write_tokens: 0,
},
};
let json = serde_json::to_string(&response).unwrap();
let deserialized: Response = serde_json::from_str(&json).unwrap();
assert_eq!(response.content, deserialized.content);
assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
}
#[test]
fn test_usage_serialization() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 0,
cache_write_tokens: 0,
};
let json = serde_json::to_string(&usage).unwrap();
let deserialized: Usage = serde_json::from_str(&json).unwrap();
assert_eq!(usage.input_tokens, deserialized.input_tokens);
assert_eq!(usage.output_tokens, deserialized.output_tokens);
}
#[test]
fn test_cache_control_serialization() {
let cache = CacheControl::ephemeral();
let json = serde_json::to_string(&cache).unwrap();
assert_eq!(json, r#"{"type":"ephemeral"}"#);
let cache_long = CacheControl::ephemeral_long();
let json_long = serde_json::to_string(&cache_long).unwrap();
assert!(json_long.contains(r#""ttl":"1h""#));
}
#[test]
fn test_message_with_cache_control() {
let msg = Message {
role: Role::User,
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
cache_control: Some(CacheControl::ephemeral()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("cache_control"));
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert!(deserialized.cache_control.is_some());
}
#[test]
fn test_usage_with_cache_fields() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 80,
cache_write_tokens: 20,
};
assert_eq!(usage.total_tokens(), 250);
let json = serde_json::to_string(&usage).unwrap();
assert!(json.contains("cache_read_tokens"));
}
#[test]
fn test_usage_anthropic_aliases() {
let json = r#"{
"input_tokens": 100,
"output_tokens": 50,
"cache_read_input_tokens": 80,
"cache_creation_input_tokens": 20
}"#;
let usage: Usage = serde_json::from_str(json).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_read_tokens, 80);
assert_eq!(usage.cache_write_tokens, 20);
assert_eq!(usage.total_tokens(), 250);
}
}