use crate::content::ContentBlock;
use crate::types::{CacheUsage, StopReason, Usage};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub id: String,
#[serde(rename = "type")]
pub message_type: String,
pub role: MessageRole,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: StopReason,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequence: Option<String>,
pub created_at: String,
pub usage: Usage,
#[serde(default, skip_serializing_if = "is_zero_cache")]
pub cache_usage: CacheUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UserMessage {
pub id: Option<String>,
#[serde(rename = "type", default = "default_user_type")]
pub message_type: String,
pub role: MessageRole,
pub content: Vec<ContentBlock>,
#[serde(default = "default_timestamp")]
pub created_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AssistantMessage {
pub id: String,
#[serde(rename = "type", default = "default_assistant_type")]
pub message_type: String,
pub role: MessageRole,
pub content: Vec<ContentBlock>,
pub model: String,
pub stop_reason: StopReason,
#[serde(default = "default_timestamp")]
pub created_at: String,
pub usage: Usage,
#[serde(default, skip_serializing_if = "is_zero_cache")]
pub cache_usage: CacheUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SystemMessage {
pub subtype: String,
pub data: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ResultMessage {
pub subtype: String,
pub duration_ms: u64,
pub duration_api_ms: u64,
pub is_error: bool,
pub num_turns: u32,
pub session_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_cost_usd: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct StreamEvent {
pub uuid: String,
pub session_id: String,
pub event: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_tool_use_id: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageRequest {
pub model: String,
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
pub messages: Vec<MessageParameter>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "is_empty_metadata")]
pub metadata: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageParameter {
pub role: MessageRole,
pub content: Vec<ContentBlock>,
}
impl Message {
pub fn new(model: impl Into<String>, role: MessageRole, content: Vec<ContentBlock>) -> Self {
Self {
id: format!("msg_{}", Uuid::new_v4()),
message_type: "message".to_string(),
role,
content,
model: model.into(),
stop_reason: StopReason::EndTurn,
stop_sequence: None,
created_at: Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
usage: Usage::new(0, 0),
cache_usage: CacheUsage::default(),
}
}
pub fn get_text_content(&self) -> String {
self.content
.iter()
.filter_map(|block| block.as_text())
.collect::<Vec<_>>()
.join("\n")
}
pub fn get_tool_uses(&self) -> Vec<(&str, &str, &serde_json::Value)> {
self.content
.iter()
.filter_map(|block| block.as_tool_use())
.collect()
}
pub fn used_tools(&self) -> bool {
self.stop_reason == StopReason::ToolUse
}
pub fn is_complete(&self) -> bool {
self.stop_reason == StopReason::EndTurn
}
}
impl UserMessage {
pub fn new(content: Vec<ContentBlock>) -> Self {
Self {
id: Some(format!("msg_{}", Uuid::new_v4())),
message_type: "message".to_string(),
role: MessageRole::User,
content,
created_at: Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
}
}
pub fn text(text: impl Into<String>) -> Self {
Self::new(vec![ContentBlock::text(text)])
}
}
impl AssistantMessage {
pub fn new(model: impl Into<String>, content: Vec<ContentBlock>, usage: Usage) -> Self {
Self {
id: format!("msg_{}", Uuid::new_v4()),
message_type: "message".to_string(),
role: MessageRole::Assistant,
content,
model: model.into(),
stop_reason: StopReason::EndTurn,
created_at: Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
usage,
cache_usage: CacheUsage::default(),
}
}
}
impl SystemMessage {
pub fn new(subtype: impl Into<String>, data: serde_json::Value) -> Self {
Self {
subtype: subtype.into(),
data,
}
}
}
impl ResultMessage {
#[allow(clippy::too_many_arguments)]
pub fn new(
subtype: impl Into<String>,
duration_ms: u64,
duration_api_ms: u64,
is_error: bool,
num_turns: u32,
session_id: impl Into<String>,
) -> Self {
Self {
subtype: subtype.into(),
duration_ms,
duration_api_ms,
is_error,
num_turns,
session_id: session_id.into(),
total_cost_usd: None,
usage: None,
result: None,
}
}
pub fn with_cost(mut self, cost: f64) -> Self {
self.total_cost_usd = Some(cost);
self
}
pub fn with_usage(mut self, usage: serde_json::Value) -> Self {
self.usage = Some(usage);
self
}
pub fn with_result(mut self, result: impl Into<String>) -> Self {
self.result = Some(result.into());
self
}
}
impl StreamEvent {
pub fn new(
uuid: impl Into<String>,
session_id: impl Into<String>,
event: serde_json::Value,
) -> Self {
Self {
uuid: uuid.into(),
session_id: session_id.into(),
event,
parent_tool_use_id: None,
}
}
pub fn with_parent_tool_use_id(mut self, parent_tool_use_id: impl Into<String>) -> Self {
self.parent_tool_use_id = Some(parent_tool_use_id.into());
self
}
}
fn default_user_type() -> String {
"message".to_string()
}
fn default_assistant_type() -> String {
"message".to_string()
}
fn default_timestamp() -> String {
Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true)
}
fn is_zero_cache(cache: &CacheUsage) -> bool {
cache.cache_creation_input_tokens == 0 && cache.cache_read_input_tokens == 0
}
fn is_empty_metadata(meta: &serde_json::Map<String, serde_json::Value>) -> bool {
meta.is_empty()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_user_message_creation() {
let msg = UserMessage::text("Hello");
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.content.len(), 1);
}
#[test]
fn test_message_get_text() {
let msg = Message::new(
"claude-3-5-sonnet",
MessageRole::Assistant,
vec![ContentBlock::text("Hello world")],
);
assert_eq!(msg.get_text_content(), "Hello world");
}
#[test]
fn test_message_serialization() {
let msg = Message::new(
"claude-3-5-sonnet",
MessageRole::Assistant,
vec![ContentBlock::text("Test")],
);
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(msg, deserialized);
}
}