use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "role")]
pub enum Message {
#[serde(rename = "system")]
System { content: String },
#[serde(rename = "user")]
User { content: Vec<ContentBlock> },
#[serde(rename = "assistant")]
Assistant { content: Vec<ContentBlock> },
}
impl Message {
pub fn user(text: impl Into<String>) -> Self {
Self::User {
content: vec![ContentBlock::Text { text: text.into() }],
}
}
pub fn system(text: impl Into<String>) -> Self {
Self::System {
content: text.into(),
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self::Assistant {
content: vec![ContentBlock::Text { text: text.into() }],
}
}
}
pub trait AsUserMessage {
fn as_user_message(&self) -> Message;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(default = "default_true")]
succeeded: bool,
},
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseStatus {
#[default]
EndTurn,
StopSequence,
ToolUse,
OutputTruncated,
ContextWindowExceeded,
Refused,
PauseTurn,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
}
impl std::ops::AddAssign<&TokenUsage> for TokenUsage {
fn add_assign(&mut self, other: &TokenUsage) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelResponse {
pub content: Vec<ContentBlock>,
pub status: ResponseStatus,
pub usage: TokenUsage,
pub model: String,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
TextDelta { index: usize, text: String },
InputJsonDelta { index: usize, partial_json: String },
ContentBlockStop { index: usize },
MessageDelta {
status: ResponseStatus,
usage: TokenUsage,
},
MessageDone,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_serde_round_trip() {
let msg = Message::User {
content: vec![ContentBlock::Text {
text: "hello".into(),
}],
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
let deserialized: Message = serde_json::from_str(&json).unwrap();
match deserialized {
Message::User { content } => {
assert_eq!(content.len(), 1);
match &content[0] {
ContentBlock::Text { text } => assert_eq!(text, "hello"),
other => panic!("Expected Text, got {other:?}"),
}
}
other => panic!("Expected User, got {other:?}"),
}
}
#[test]
fn tool_use_block_serde() {
let block = ContentBlock::ToolUse {
id: "call_123".into(),
name: "read_file".into(),
input: serde_json::json!({"path": "/tmp/test.txt"}),
};
let json = serde_json::to_string(&block).unwrap();
assert!(json.contains("\"type\":\"tool_use\""));
let deserialized: ContentBlock = serde_json::from_str(&json).unwrap();
match deserialized {
ContentBlock::ToolUse { id, name, input } => {
assert_eq!(id, "call_123");
assert_eq!(name, "read_file");
assert_eq!(input["path"], "/tmp/test.txt");
}
other => panic!("Expected ToolUse, got {other:?}"),
}
}
#[test]
fn tool_result_succeeded_defaults_true() {
let json = r#"{"type":"tool_result","tool_use_id":"id1","content":"ok"}"#;
let block: ContentBlock = serde_json::from_str(json).unwrap();
match block {
ContentBlock::ToolResult { succeeded, .. } => assert!(succeeded),
other => panic!("Expected ToolResult, got {other:?}"),
}
}
#[test]
fn usage_add_accumulates() {
let mut usage = TokenUsage {
input_tokens: 100,
output_tokens: 50,
};
let other = TokenUsage {
input_tokens: 200,
output_tokens: 100,
};
usage += &other;
assert_eq!(usage.input_tokens, 300);
assert_eq!(usage.output_tokens, 150);
}
}