use crate::error::{ParseError, ToolResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum CacheControl {
Breakpoint,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TextBlock {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ThinkingBlock {
pub thinking: String,
pub redacted: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageSource {
pub data: String,
pub media_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text(TextBlock),
Thinking(ThinkingBlock),
Image { source: ImageSource },
ToolCall(ToolCall),
}
impl ContentBlock {
pub fn text(s: impl Into<String>) -> Self {
ContentBlock::Text(TextBlock {
text: s.into(),
cache_control: None,
})
}
pub fn text_with_cache(s: String, cache: CacheControl) -> Self {
ContentBlock::Text(TextBlock {
text: s,
cache_control: Some(cache),
})
}
pub fn as_text(&self) -> Option<&str> {
match self {
ContentBlock::Text(block) => Some(&block.text),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Message {
System {
content: Vec<ContentBlock>,
},
User {
content: Vec<ContentBlock>,
},
Assistant {
content: Vec<ContentBlock>,
},
ToolResult {
tool_call_id: String,
is_error: bool,
content: Vec<ContentBlock>,
},
}
impl Message {
pub fn system_text(s: &str) -> Self {
Message::System {
content: text_block(s.to_string()),
}
}
pub fn user_text(s: &str) -> Self {
Message::User {
content: text_block(s.to_string()),
}
}
pub fn assistant_text(s: &str) -> Self {
Message::Assistant {
content: text_block(s.to_string()),
}
}
pub fn user_text_image(text: &str, media_type: String, data: String) -> Self {
Message::User {
content: vec![
ContentBlock::text(text),
ContentBlock::Image {
source: ImageSource { data, media_type },
},
],
}
}
pub fn user_image(media_type: String, data: String) -> Self {
Message::User {
content: vec![ContentBlock::Image {
source: ImageSource { data, media_type },
}],
}
}
pub fn system(content: Vec<ContentBlock>) -> Self {
Message::System { content }
}
pub fn user(content: Vec<ContentBlock>) -> Self {
Message::User { content }
}
pub fn assistant(content: Vec<ContentBlock>) -> Self {
Message::Assistant { content }
}
pub fn tool_result_ok(call_id: impl Into<String>, content: String) -> Self {
Message::ToolResult {
tool_call_id: call_id.into(),
is_error: false,
content: text_block(content),
}
}
pub fn tool_error(call_id: impl Into<String>, error: String) -> Self {
Message::ToolResult {
tool_call_id: call_id.into(),
is_error: true,
content: text_block(error),
}
}
pub fn content(&self) -> &Vec<ContentBlock> {
match self {
Message::System { content }
| Message::User { content }
| Message::Assistant { content }
| Message::ToolResult { content, .. } => content,
}
}
pub fn tool_call_id(&self) -> String {
match self {
Message::ToolResult { tool_call_id, .. } => tool_call_id.clone(),
_ => String::new(),
}
}
pub fn is_tool_error(&self) -> bool {
matches!(self, Message::ToolResult { is_error: true, .. })
}
pub fn tool_result(call: &ToolCall, result: &ToolResult) -> Self {
let (content_str, is_error) = match result {
Ok(v) => (
serde_json::to_string(v).unwrap_or_else(|_| v.to_string()),
false,
),
Err(e) => (format!("tool error: {e}"), true),
};
Message::ToolResult {
tool_call_id: call.id.clone(),
is_error,
content: text_block(content_str),
}
}
pub fn validate(&self) -> Result<(), ParseError> {
match self {
Message::ToolResult {
tool_call_id,
is_error: _,
content,
} => {
if tool_call_id.is_empty() {
return Err(ParseError {
detail: "ToolResult.tool_call_id must not be empty".into(),
});
}
for block in content {
match block {
ContentBlock::ToolCall(_) => {
return Err(ParseError {
detail: "ToolResult must not contain ToolCall blocks".into(),
});
}
ContentBlock::Thinking(_) => {
return Err(ParseError {
detail: "ToolResult must not contain Thinking blocks".into(),
});
}
_ => {}
}
}
}
Message::Assistant { content } => {
for block in content {
if let ContentBlock::ToolCall(tc) = block
&& tc.id.is_empty()
{
return Err(ParseError {
detail: "Assistant ToolCall.id must not be empty".into(),
});
}
}
}
Message::User { content } => {
for block in content {
if let ContentBlock::Thinking(_) = block {
return Err(ParseError {
detail: "User must not contain Thinking blocks".into(),
});
}
}
}
Message::System { .. } => {}
}
Ok(())
}
pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
match self {
Message::Assistant { content } => content
.iter()
.filter_map(|b| {
if let ContentBlock::ToolCall(tc) = b {
Some(tc.clone())
} else {
None
}
})
.collect(),
_ => Vec::new(),
}
}
}
pub fn text_block(s: impl Into<String>) -> Vec<ContentBlock> {
vec![ContentBlock::text(s)]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_content_block_text() {
let block = ContentBlock::text("hello");
assert_eq!(block.as_text(), Some("hello"));
}
#[test]
fn test_content_block_tool_call_no_as_text() {
let block = ContentBlock::ToolCall(ToolCall {
id: "1".into(),
name: "test".into(),
arguments: serde_json::json!({}),
});
assert_eq!(block.as_text(), None);
}
#[test]
fn test_message_content() {
let msg = Message::user_text("hello world");
assert_eq!(msg.content().len(), 1);
assert_eq!(msg.content()[0].as_text(), Some("hello world"));
}
#[test]
fn test_message_extract_tool_calls() {
let tc = ToolCall {
id: "1".into(),
name: "test".into(),
arguments: serde_json::json!({}),
};
let msg = Message::Assistant {
content: vec![ContentBlock::ToolCall(tc.clone())],
};
let calls = msg.extract_tool_calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "test");
}
#[test]
fn test_validate_user_ok() {
let msg = Message::User {
content: text_block("hello".to_string()),
};
assert!(msg.validate().is_ok());
}
#[test]
fn test_validate_user_reject_thinking() {
let msg = Message::User {
content: vec![ContentBlock::Thinking(ThinkingBlock {
thinking: "hmm".into(),
redacted: None,
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_assistant_ok() {
let msg = Message::Assistant {
content: text_block("hi".to_string()),
};
assert!(msg.validate().is_ok());
}
#[test]
fn test_validate_assistant_tool_call_empty_id() {
let msg = Message::Assistant {
content: vec![ContentBlock::ToolCall(ToolCall {
id: String::new(),
name: "test".into(),
arguments: serde_json::json!({}),
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_tool_result_ok() {
let msg = Message::ToolResult {
tool_call_id: "call_1".to_string(),
is_error: false,
content: text_block("ok".to_string()),
};
assert!(msg.validate().is_ok());
}
#[test]
fn test_validate_tool_result_empty_id() {
let msg = Message::ToolResult {
tool_call_id: String::new(),
is_error: false,
content: text_block("ok".to_string()),
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_tool_result_reject_tool_call() {
let msg = Message::ToolResult {
tool_call_id: "call_1".to_string(),
is_error: false,
content: vec![ContentBlock::ToolCall(ToolCall {
id: "x".into(),
name: "y".into(),
arguments: serde_json::json!({}),
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_tool_result_reject_thinking() {
let msg = Message::ToolResult {
tool_call_id: "call_1".to_string(),
is_error: false,
content: vec![ContentBlock::Thinking(ThinkingBlock {
thinking: "hmm".into(),
redacted: None,
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_system_ok() {
let msg = Message::System {
content: text_block("you are helpful".to_string()),
};
assert!(msg.validate().is_ok());
}
#[test]
fn test_convenience_system_text() {
let msg = Message::system_text("you are helpful");
assert!(matches!(msg, Message::System { .. }));
assert_eq!(msg.content()[0].as_text(), Some("you are helpful"));
}
#[test]
fn test_convenience_user_text() {
let msg = Message::user_text("hello");
assert!(matches!(msg, Message::User { .. }));
assert_eq!(msg.content()[0].as_text(), Some("hello"));
}
#[test]
fn test_convenience_assistant_text() {
let msg = Message::assistant_text("the answer is 42");
assert!(matches!(msg, Message::Assistant { .. }));
assert_eq!(msg.content()[0].as_text(), Some("the answer is 42"));
}
#[test]
fn test_convenience_system_content() {
let msg = Message::system(vec![ContentBlock::text("prompt")]);
assert!(matches!(msg, Message::System { .. }));
assert_eq!(msg.content()[0].as_text(), Some("prompt"));
}
#[test]
fn test_convenience_user_content() {
let msg = Message::user(vec![ContentBlock::text("question")]);
assert!(matches!(msg, Message::User { .. }));
assert_eq!(msg.content()[0].as_text(), Some("question"));
}
#[test]
fn test_convenience_tool_result_ok() {
let msg = Message::tool_result_ok("call_1", "result data".to_string());
assert!(matches!(msg, Message::ToolResult { .. }));
assert!(!msg.is_tool_error());
assert_eq!(msg.tool_call_id(), "call_1");
}
#[test]
fn test_convenience_tool_error() {
let msg = Message::tool_error("call_2", "something failed".to_string());
assert!(matches!(msg, Message::ToolResult { .. }));
assert!(msg.is_tool_error());
assert_eq!(msg.tool_call_id(), "call_2");
}
#[test]
fn test_content_block_text_with_string() {
let s = String::from("dynamic");
let block = ContentBlock::text(s);
assert_eq!(block.as_text(), Some("dynamic"));
}
#[test]
fn test_text_block_with_str() {
let blocks = text_block("hello");
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].as_text(), Some("hello"));
}
#[test]
fn test_text_block_with_string() {
let blocks = text_block(String::from("hello"));
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].as_text(), Some("hello"));
}
#[test]
fn test_convenience_user_text_image() {
let msg = Message::user_text_image("what's this?", "image/png".into(), "base64data".into());
assert!(matches!(msg, Message::User { .. }));
assert_eq!(msg.content().len(), 2);
assert_eq!(msg.content()[0].as_text(), Some("what's this?"));
match &msg.content()[1] {
ContentBlock::Image { source } => {
assert_eq!(source.media_type, "image/png");
assert_eq!(source.data, "base64data");
}
_ => panic!("expected Image block"),
}
}
#[test]
fn test_convenience_user_image() {
let msg = Message::user_image("image/jpeg".into(), "jpgdata".into());
assert!(matches!(msg, Message::User { .. }));
assert_eq!(msg.content().len(), 1);
match &msg.content()[0] {
ContentBlock::Image { source } => {
assert_eq!(source.media_type, "image/jpeg");
assert_eq!(source.data, "jpgdata");
}
_ => panic!("expected Image block"),
}
}
}