use crate::error::{ParseError, ToolResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum CacheControl {
Breakpoint,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TextBlock {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ThinkingBlock {
pub thinking: String,
pub redacted: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageSource {
pub data: String,
pub media_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text(TextBlock),
Thinking(ThinkingBlock),
Image { source: ImageSource },
ToolCall(ToolCall),
}
impl ContentBlock {
pub fn text(s: String) -> Self {
ContentBlock::Text(TextBlock {
text: s,
cache_control: None,
})
}
pub fn text_with_cache(s: String, cache: CacheControl) -> Self {
ContentBlock::Text(TextBlock {
text: s,
cache_control: Some(cache),
})
}
pub fn as_text(&self) -> Option<&str> {
match self {
ContentBlock::Text(block) => Some(&block.text),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Message {
System {
content: Vec<ContentBlock>,
},
User {
content: Vec<ContentBlock>,
},
Assistant {
content: Vec<ContentBlock>,
},
ToolResult {
tool_call_id: String,
is_error: bool,
content: Vec<ContentBlock>,
},
}
impl Message {
pub fn content(&self) -> &Vec<ContentBlock> {
match self {
Message::System { content }
| Message::User { content }
| Message::Assistant { content }
| Message::ToolResult { content, .. } => content,
}
}
pub fn tool_call_id(&self) -> String {
match self {
Message::ToolResult { tool_call_id, .. } => tool_call_id.clone(),
_ => String::new(),
}
}
pub fn is_tool_error(&self) -> bool {
matches!(self, Message::ToolResult { is_error: true, .. })
}
pub fn tool_result(call: &ToolCall, result: &ToolResult) -> Self {
let (content_str, is_error) = match result {
Ok(v) => (
serde_json::to_string(v).unwrap_or_else(|_| v.to_string()),
false,
),
Err(e) => (format!("tool error: {e}"), true),
};
Message::ToolResult {
tool_call_id: call.id.clone(),
is_error,
content: text_block(content_str),
}
}
pub fn validate(&self) -> Result<(), ParseError> {
match self {
Message::ToolResult {
tool_call_id,
is_error: _,
content,
} => {
if tool_call_id.is_empty() {
return Err(ParseError {
detail: "ToolResult.tool_call_id must not be empty".into(),
});
}
for block in content {
match block {
ContentBlock::ToolCall(_) => {
return Err(ParseError {
detail: "ToolResult must not contain ToolCall blocks".into(),
});
}
ContentBlock::Thinking(_) => {
return Err(ParseError {
detail: "ToolResult must not contain Thinking blocks".into(),
});
}
_ => {}
}
}
}
Message::Assistant { content } => {
for block in content {
if let ContentBlock::ToolCall(tc) = block
&& tc.id.is_empty()
{
return Err(ParseError {
detail: "Assistant ToolCall.id must not be empty".into(),
});
}
}
}
Message::User { content } => {
for block in content {
if let ContentBlock::Thinking(_) = block {
return Err(ParseError {
detail: "User must not contain Thinking blocks".into(),
});
}
}
}
Message::System { .. } => {}
}
Ok(())
}
pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
match self {
Message::Assistant { content } => content
.iter()
.filter_map(|b| {
if let ContentBlock::ToolCall(tc) = b {
Some(tc.clone())
} else {
None
}
})
.collect(),
_ => Vec::new(),
}
}
}
pub fn text_block(s: String) -> Vec<ContentBlock> {
vec![ContentBlock::text(s)]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_content_block_text() {
let block = ContentBlock::text("hello".to_string());
assert_eq!(block.as_text(), Some("hello"));
}
#[test]
fn test_content_block_tool_call_no_as_text() {
let block = ContentBlock::ToolCall(ToolCall {
id: "1".into(),
name: "test".into(),
arguments: serde_json::json!({}),
});
assert_eq!(block.as_text(), None);
}
#[test]
fn test_message_content() {
let msg = Message::User {
content: text_block("hello world".to_string()),
};
assert_eq!(msg.content().len(), 1);
assert_eq!(msg.content()[0].as_text(), Some("hello world"));
}
#[test]
fn test_message_extract_tool_calls() {
let tc = ToolCall {
id: "1".into(),
name: "test".into(),
arguments: serde_json::json!({}),
};
let msg = Message::Assistant {
content: vec![ContentBlock::ToolCall(tc.clone())],
};
let calls = msg.extract_tool_calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "test");
}
#[test]
fn test_validate_user_ok() {
let msg = Message::User {
content: text_block("hello".to_string()),
};
assert!(msg.validate().is_ok());
}
#[test]
fn test_validate_user_reject_thinking() {
let msg = Message::User {
content: vec![ContentBlock::Thinking(ThinkingBlock {
thinking: "hmm".into(),
redacted: None,
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_assistant_ok() {
let msg = Message::Assistant {
content: text_block("hi".to_string()),
};
assert!(msg.validate().is_ok());
}
#[test]
fn test_validate_assistant_tool_call_empty_id() {
let msg = Message::Assistant {
content: vec![ContentBlock::ToolCall(ToolCall {
id: String::new(),
name: "test".into(),
arguments: serde_json::json!({}),
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_tool_result_ok() {
let msg = Message::ToolResult {
tool_call_id: "call_1".to_string(),
is_error: false,
content: text_block("ok".to_string()),
};
assert!(msg.validate().is_ok());
}
#[test]
fn test_validate_tool_result_empty_id() {
let msg = Message::ToolResult {
tool_call_id: String::new(),
is_error: false,
content: text_block("ok".to_string()),
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_tool_result_reject_tool_call() {
let msg = Message::ToolResult {
tool_call_id: "call_1".to_string(),
is_error: false,
content: vec![ContentBlock::ToolCall(ToolCall {
id: "x".into(),
name: "y".into(),
arguments: serde_json::json!({}),
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_tool_result_reject_thinking() {
let msg = Message::ToolResult {
tool_call_id: "call_1".to_string(),
is_error: false,
content: vec![ContentBlock::Thinking(ThinkingBlock {
thinking: "hmm".into(),
redacted: None,
})],
};
assert!(matches!(msg.validate(), Err(ParseError { .. })));
}
#[test]
fn test_validate_system_ok() {
let msg = Message::System {
content: text_block("you are helpful".to_string()),
};
assert!(msg.validate().is_ok());
}
}