use std::collections::HashMap;
use std::fmt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::usage::Usage;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
}
impl fmt::Display for ChatRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::System => f.write_str("system"),
Self::User => f.write_str("user"),
Self::Assistant => f.write_str("assistant"),
Self::Tool => f.write_str("tool"),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: Vec<ContentBlock>,
}
impl ChatMessage {
pub fn text(role: ChatRole, text: impl Into<String>) -> Self {
Self {
role,
content: vec![ContentBlock::Text(text.into())],
}
}
pub fn user(text: impl Into<String>) -> Self {
Self::text(ChatRole::User, text)
}
pub fn assistant(text: impl Into<String>) -> Self {
Self::text(ChatRole::Assistant, text)
}
pub fn system(text: impl Into<String>) -> Self {
Self::text(ChatRole::System, text)
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: ChatRole::Tool,
content: vec![ContentBlock::ToolResult(ToolResult {
tool_call_id: tool_call_id.into(),
content: content.into(),
is_error: false,
})],
}
}
pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: ChatRole::Tool,
content: vec![ContentBlock::ToolResult(ToolResult {
tool_call_id: tool_call_id.into(),
content: content.into(),
is_error: true,
})],
}
}
pub fn is_empty(&self) -> bool {
self.content.is_empty()
}
pub fn to_json(&self) -> Result<Value, serde_json::Error> {
serde_json::to_value(self)
}
pub fn from_json(value: &Value) -> Result<Self, serde_json::Error> {
serde_json::from_value(value.clone())
}
pub fn from_json_owned(value: Value) -> Result<Self, serde_json::Error> {
serde_json::from_value(value)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ContentBlock {
Text(String),
Image {
media_type: String,
data: ImageSource,
},
ToolCall(ToolCall),
ToolResult(ToolResult),
Reasoning {
content: String,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ImageSource {
Base64(String),
Url(url::Url),
}
impl ImageSource {
pub fn from_url(url: impl AsRef<str>) -> Result<Self, url::ParseError> {
let parsed = url::Url::parse(url.as_ref())?;
Ok(Self::Url(parsed))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
impl std::fmt::Display for ToolCall {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}({})", self.name, self.id)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_call_id: String,
pub content: String,
pub is_error: bool,
}
impl std::fmt::Display for ToolResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_error {
write!(f, "err:{} ({})", self.tool_call_id, self.content)
} else {
write!(f, "ok:{}", self.tool_call_id)
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatResponse {
pub content: Vec<ContentBlock>,
pub usage: Usage,
pub stop_reason: StopReason,
pub model: String,
pub metadata: HashMap<String, Value>,
}
impl ChatResponse {
pub fn empty() -> Self {
Self {
content: Vec::new(),
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: String::new(),
metadata: HashMap::new(),
}
}
pub fn text(&self) -> Option<&str> {
self.content.iter().find_map(|b| match b {
ContentBlock::Text(t) => Some(t.as_str()),
_ => None,
})
}
pub fn tool_calls(&self) -> Vec<&ToolCall> {
self.content
.iter()
.filter_map(|b| match b {
ContentBlock::ToolCall(tc) => Some(tc),
_ => None,
})
.collect()
}
pub fn tool_calls_iter(&self) -> impl Iterator<Item = &ToolCall> {
self.content.iter().filter_map(|b| match b {
ContentBlock::ToolCall(tc) => Some(tc),
_ => None,
})
}
pub fn into_tool_calls(self) -> Vec<ToolCall> {
self.content
.into_iter()
.filter_map(|b| match b {
ContentBlock::ToolCall(tc) => Some(tc),
_ => None,
})
.collect()
}
pub fn partition_content(self) -> (Vec<ToolCall>, Vec<ContentBlock>) {
let mut tool_calls = Vec::new();
let mut other = Vec::new();
for block in self.content {
match block {
ContentBlock::ToolCall(tc) => tool_calls.push(tc),
ContentBlock::ToolResult(_) => {}
other_block => other.push(other_block),
}
}
(tool_calls, other)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum StopReason {
EndTurn,
ToolUse,
MaxTokens,
StopSequence,
}
impl fmt::Display for StopReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::EndTurn => f.write_str("end_turn"),
Self::ToolUse => f.write_str("tool_use"),
Self::MaxTokens => f.write_str("max_tokens"),
Self::StopSequence => f.write_str("stop_sequence"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_role_copy_hash() {
use std::collections::HashMap;
let mut map = HashMap::new();
let role = ChatRole::User;
let role_copy = role; map.insert(role, "user");
map.insert(role_copy, "user_copy");
assert_eq!(map.len(), 1);
}
#[test]
fn test_chat_role_all_variants() {
let variants = [
ChatRole::System,
ChatRole::User,
ChatRole::Assistant,
ChatRole::Tool,
];
for v in &variants {
let debug = format!("{v:?}");
assert!(!debug.is_empty());
}
}
#[test]
fn test_chat_role_serde_roundtrip() {
let role = ChatRole::Assistant;
let json = serde_json::to_string(&role).unwrap();
let back: ChatRole = serde_json::from_str(&json).unwrap();
assert_eq!(role, back);
}
#[test]
fn test_user_constructor() {
let msg = ChatMessage::user("hello");
assert_eq!(msg.role, ChatRole::User);
assert_eq!(msg.content, vec![ContentBlock::Text("hello".into())]);
}
#[test]
fn test_assistant_constructor() {
let msg = ChatMessage::assistant("hi");
assert_eq!(msg.role, ChatRole::Assistant);
assert_eq!(msg.content, vec![ContentBlock::Text("hi".into())]);
}
#[test]
fn test_system_constructor() {
let msg = ChatMessage::system("be nice");
assert_eq!(msg.role, ChatRole::System);
}
#[test]
fn test_tool_result_constructor() {
let msg = ChatMessage::tool_result("tc_1", "42");
assert_eq!(msg.role, ChatRole::Tool);
assert!(matches!(
&msg.content[0],
ContentBlock::ToolResult(tr)
if tr.tool_call_id == "tc_1" && tr.content == "42" && !tr.is_error
));
}
#[test]
fn test_tool_error_constructor() {
let msg = ChatMessage::tool_error("tc_1", "something broke");
assert!(matches!(
&msg.content[0],
ContentBlock::ToolResult(tr) if tr.is_error
));
}
#[test]
fn test_message_text_clone_eq() {
let msg = ChatMessage::user("hello");
assert_eq!(msg, msg.clone());
}
#[test]
fn test_message_serde_roundtrip() {
let msg = ChatMessage::user("hello");
let json = serde_json::to_string(&msg).unwrap();
let back: ChatMessage = serde_json::from_str(&json).unwrap();
assert_eq!(msg, back);
}
#[test]
fn test_message_tool_use() {
let msg = ChatMessage {
role: ChatRole::Assistant,
content: vec![
ContentBlock::ToolCall(ToolCall {
id: "1".into(),
name: "calc".into(),
arguments: serde_json::json!({"a": 1}),
}),
ContentBlock::ToolCall(ToolCall {
id: "2".into(),
name: "search".into(),
arguments: serde_json::json!({"q": "rust"}),
}),
],
};
assert_eq!(msg.content.len(), 2);
assert_eq!(msg, msg.clone());
}
#[test]
fn test_message_tool_result() {
let msg = ChatMessage::tool_result("1", "42");
assert!(matches!(
&msg.content[0],
ContentBlock::ToolResult(tr) if tr.content == "42" && !tr.is_error
));
}
#[test]
fn test_message_mixed_content() {
let msg = ChatMessage {
role: ChatRole::User,
content: vec![
ContentBlock::Text("look at this".into()),
ContentBlock::Image {
media_type: "image/png".into(),
data: ImageSource::Base64("abc123".into()),
},
ContentBlock::ToolCall(ToolCall {
id: "1".into(),
name: "analyze".into(),
arguments: serde_json::json!({}),
}),
],
};
assert_eq!(msg.content.len(), 3);
}
#[test]
fn test_content_block_image_base64() {
let block = ContentBlock::Image {
media_type: "image/jpeg".into(),
data: ImageSource::Base64("data...".into()),
};
assert_eq!(block, block.clone());
}
#[test]
fn test_content_block_image_url() {
let block = ContentBlock::Image {
media_type: "image/png".into(),
data: ImageSource::from_url("https://example.com/img.png").unwrap(),
};
assert_eq!(block, block.clone());
}
#[test]
fn test_image_source_from_url_valid() {
let src = ImageSource::from_url("https://example.com/img.png");
assert!(src.is_ok());
let url = url::Url::parse("https://example.com/img.png").unwrap();
assert_eq!(src.unwrap(), ImageSource::Url(url));
}
#[test]
fn test_image_source_from_url_normalizes() {
let src = ImageSource::from_url("HTTP://EXAMPLE.COM").unwrap();
assert!(matches!(
&src,
ImageSource::Url(u) if u.as_str() == "http://example.com/"
));
}
#[test]
fn test_image_source_from_url_invalid() {
let err = ImageSource::from_url("not a url");
assert!(err.is_err());
let _parse_err: url::ParseError = err.unwrap_err();
assert!(ImageSource::from_url("").is_err());
}
#[test]
fn test_content_block_reasoning() {
let block = ContentBlock::Reasoning {
content: "thinking step by step".into(),
};
assert_eq!(block, block.clone());
}
#[test]
fn test_tool_call_json_arguments() {
let call = ToolCall {
id: "tc_1".into(),
name: "search".into(),
arguments: serde_json::json!({
"query": "rust async",
"filters": {"lang": "en", "limit": 10}
}),
};
assert_eq!(call, call.clone());
}
#[test]
fn test_tool_result_error_flag() {
let ok = ToolResult {
tool_call_id: "1".into(),
content: "result".into(),
is_error: false,
};
let err = ToolResult {
tool_call_id: "1".into(),
content: "result".into(),
is_error: true,
};
assert_ne!(ok, err);
}
#[test]
fn test_chat_response_metadata() {
let mut metadata = HashMap::new();
metadata.insert("cost".into(), serde_json::json!({"usd": 0.01}));
let resp = ChatResponse {
content: vec![ContentBlock::Text("hi".into())],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test-model".into(),
metadata,
};
assert!(resp.metadata.contains_key("cost"));
}
#[test]
fn test_chat_response_serde_roundtrip() {
let resp = ChatResponse {
content: vec![ContentBlock::Text("hi".into())],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test-model".into(),
metadata: HashMap::new(),
};
let json = serde_json::to_string(&resp).unwrap();
let back: ChatResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp, back);
}
#[test]
fn test_chat_response_empty_content() {
let resp = ChatResponse {
content: vec![],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test".into(),
metadata: HashMap::new(),
};
assert!(resp.content.is_empty());
}
#[test]
fn test_stop_reason_all_variants() {
let variants = [
StopReason::EndTurn,
StopReason::ToolUse,
StopReason::MaxTokens,
StopReason::StopSequence,
];
for v in &variants {
assert_eq!(*v, *v);
}
}
#[test]
fn test_stop_reason_serde_roundtrip() {
let sr = StopReason::MaxTokens;
let json = serde_json::to_string(&sr).unwrap();
let back: StopReason = serde_json::from_str(&json).unwrap();
assert_eq!(sr, back);
}
#[test]
fn test_stop_reason_eq_hash() {
use std::collections::HashMap;
let mut map = HashMap::new();
map.insert(StopReason::EndTurn, "end");
map.insert(StopReason::ToolUse, "tool");
assert_eq!(map[&StopReason::EndTurn], "end");
assert_eq!(map[&StopReason::ToolUse], "tool");
}
#[test]
fn test_chat_role_display() {
assert_eq!(ChatRole::System.to_string(), "system");
assert_eq!(ChatRole::User.to_string(), "user");
assert_eq!(ChatRole::Assistant.to_string(), "assistant");
assert_eq!(ChatRole::Tool.to_string(), "tool");
}
#[test]
fn test_stop_reason_display() {
assert_eq!(StopReason::EndTurn.to_string(), "end_turn");
assert_eq!(StopReason::ToolUse.to_string(), "tool_use");
assert_eq!(StopReason::MaxTokens.to_string(), "max_tokens");
assert_eq!(StopReason::StopSequence.to_string(), "stop_sequence");
}
#[test]
fn test_chat_response_text_returns_first() {
let resp = ChatResponse {
content: vec![
ContentBlock::Reasoning {
content: "thinking...".into(),
},
ContentBlock::Text("first".into()),
ContentBlock::Text("second".into()),
],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test".into(),
metadata: HashMap::new(),
};
assert_eq!(resp.text(), Some("first"));
}
#[test]
fn test_chat_response_text_none_when_no_text_blocks() {
let resp = ChatResponse {
content: vec![ContentBlock::Reasoning {
content: "thinking".into(),
}],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test".into(),
metadata: HashMap::new(),
};
assert_eq!(resp.text(), None);
}
#[test]
fn test_chat_response_text_none_when_empty() {
let resp = ChatResponse {
content: vec![],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test".into(),
metadata: HashMap::new(),
};
assert_eq!(resp.text(), None);
}
#[test]
fn test_chat_response_tool_calls() {
let resp = ChatResponse {
content: vec![
ContentBlock::Text("Let me search.".into()),
ContentBlock::ToolCall(ToolCall {
id: "1".into(),
name: "search".into(),
arguments: serde_json::json!({"q": "rust"}),
}),
ContentBlock::ToolCall(ToolCall {
id: "2".into(),
name: "calc".into(),
arguments: serde_json::json!({"expr": "2+2"}),
}),
],
usage: Usage::default(),
stop_reason: StopReason::ToolUse,
model: "test".into(),
metadata: HashMap::new(),
};
let calls = resp.tool_calls();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "search");
assert_eq!(calls[1].name, "calc");
}
#[test]
fn test_chat_response_tool_calls_empty_when_text_only() {
let resp = ChatResponse {
content: vec![ContentBlock::Text("hello".into())],
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
model: "test".into(),
metadata: HashMap::new(),
};
assert!(resp.tool_calls().is_empty());
}
#[test]
fn test_message_is_empty() {
let empty = ChatMessage {
role: ChatRole::User,
content: vec![],
};
assert!(empty.is_empty());
assert!(!ChatMessage::user("hi").is_empty());
}
#[test]
fn test_content_block_serde_text() {
let block = ContentBlock::Text("hello".into());
let val = serde_json::to_value(&block).unwrap();
assert_eq!(val, serde_json::json!({"text": "hello"}));
let back: ContentBlock = serde_json::from_value(val).unwrap();
assert_eq!(back, block);
}
#[test]
fn test_content_block_serde_image() {
let block = ContentBlock::Image {
media_type: "image/png".into(),
data: ImageSource::Base64("abc".into()),
};
let val = serde_json::to_value(&block).unwrap();
assert_eq!(
val,
serde_json::json!({"image": {"media_type": "image/png", "data": {"Base64": "abc"}}})
);
let back: ContentBlock = serde_json::from_value(val).unwrap();
assert_eq!(back, block);
}
#[test]
fn test_content_block_serde_tool_call() {
let block = ContentBlock::ToolCall(ToolCall {
id: "tc_1".into(),
name: "search".into(),
arguments: serde_json::json!({"q": "rust"}),
});
let val = serde_json::to_value(&block).unwrap();
assert_eq!(
val,
serde_json::json!({"tool_call": {"id": "tc_1", "name": "search", "arguments": {"q": "rust"}}})
);
let back: ContentBlock = serde_json::from_value(val).unwrap();
assert_eq!(back, block);
}
#[test]
fn test_content_block_serde_tool_result() {
let block = ContentBlock::ToolResult(ToolResult {
tool_call_id: "tc_1".into(),
content: "42".into(),
is_error: false,
});
let val = serde_json::to_value(&block).unwrap();
assert_eq!(
val,
serde_json::json!({"tool_result": {"tool_call_id": "tc_1", "content": "42", "is_error": false}})
);
let back: ContentBlock = serde_json::from_value(val).unwrap();
assert_eq!(back, block);
}
#[test]
fn test_content_block_serde_reasoning() {
let block = ContentBlock::Reasoning {
content: "thinking".into(),
};
let val = serde_json::to_value(&block).unwrap();
assert_eq!(
val,
serde_json::json!({"reasoning": {"content": "thinking"}})
);
let back: ContentBlock = serde_json::from_value(val).unwrap();
assert_eq!(back, block);
}
#[test]
fn test_user_constructor_produces_text_only() {
let msg = ChatMessage::user("hello");
assert_eq!(msg.role, ChatRole::User);
assert!(
msg.content
.iter()
.all(|b| matches!(b, ContentBlock::Text(_)))
);
}
#[test]
fn test_assistant_constructor_produces_text_only() {
let msg = ChatMessage::assistant("hi");
assert_eq!(msg.role, ChatRole::Assistant);
assert!(
msg.content
.iter()
.all(|b| matches!(b, ContentBlock::Text(_)))
);
}
#[test]
fn test_system_constructor_produces_text_only() {
let msg = ChatMessage::system("be nice");
assert_eq!(msg.role, ChatRole::System);
assert!(
msg.content
.iter()
.all(|b| matches!(b, ContentBlock::Text(_)))
);
}
#[test]
fn test_tool_result_constructor_produces_tool_result_only() {
let msg = ChatMessage::tool_result("tc_1", "42");
assert_eq!(msg.role, ChatRole::Tool);
assert!(
msg.content
.iter()
.all(|b| matches!(b, ContentBlock::ToolResult(_)))
);
}
#[test]
fn test_tool_error_constructor_produces_tool_result_only() {
let msg = ChatMessage::tool_error("tc_1", "boom");
assert_eq!(msg.role, ChatRole::Tool);
assert!(
msg.content
.iter()
.all(|b| matches!(b, ContentBlock::ToolResult(r) if r.is_error))
);
}
#[test]
fn test_assistant_tool_calls_is_valid_combination() {
let msg = ChatMessage {
role: ChatRole::Assistant,
content: vec![
ContentBlock::Text("Let me search for that.".into()),
ContentBlock::ToolCall(ToolCall {
id: "1".into(),
name: "search".into(),
arguments: serde_json::json!({"q": "rust"}),
}),
],
};
assert_eq!(msg.role, ChatRole::Assistant);
assert_eq!(msg.content.len(), 2);
}
#[test]
fn test_user_with_image_is_valid_combination() {
let msg = ChatMessage {
role: ChatRole::User,
content: vec![
ContentBlock::Text("What's this?".into()),
ContentBlock::Image {
media_type: "image/png".into(),
data: ImageSource::Base64("...".into()),
},
],
};
assert_eq!(msg.role, ChatRole::User);
assert_eq!(msg.content.len(), 2);
}
#[test]
fn test_assistant_with_reasoning_is_valid_combination() {
let msg = ChatMessage {
role: ChatRole::Assistant,
content: vec![
ContentBlock::Reasoning {
content: "step 1: think about it".into(),
},
ContentBlock::Text("The answer is 42.".into()),
],
};
assert_eq!(msg.role, ChatRole::Assistant);
assert_eq!(msg.content.len(), 2);
}
#[test]
fn test_chat_message_to_json() {
let msg = ChatMessage::user("Hello, world!");
let json = msg.to_json().unwrap();
assert_eq!(json["role"], "User");
assert_eq!(json["content"][0]["text"], "Hello, world!");
}
#[test]
fn test_chat_message_from_json() {
let json = serde_json::json!({
"role": "Assistant",
"content": [{"text": "Hello!"}]
});
let msg = ChatMessage::from_json(&json).unwrap();
assert_eq!(msg.role, ChatRole::Assistant);
assert!(matches!(&msg.content[0], ContentBlock::Text(t) if t == "Hello!"));
}
#[test]
fn test_chat_message_json_roundtrip() {
let original = ChatMessage {
role: ChatRole::User,
content: vec![
ContentBlock::Text("What's this?".into()),
ContentBlock::Image {
media_type: "image/png".into(),
data: ImageSource::Base64("abc123".into()),
},
],
};
let json = original.to_json().unwrap();
let restored = ChatMessage::from_json(&json).unwrap();
assert_eq!(original, restored);
}
#[test]
fn test_chat_message_json_roundtrip_with_tool_result() {
let original = ChatMessage::tool_result("tc_1", "success");
let json = original.to_json().unwrap();
let restored = ChatMessage::from_json(&json).unwrap();
assert_eq!(original, restored);
}
}