1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Message {
7 pub role: Role,
8 pub content: Vec<ContentBlock>,
9 #[serde(default)]
10 pub metadata: HashMap<String, Value>,
11}
12
13impl Message {
14 pub fn new(role: Role, content: Vec<ContentBlock>) -> Self {
15 Self {
16 role,
17 content,
18 metadata: HashMap::new(),
19 }
20 }
21
22 pub fn system(text: impl Into<String>) -> Self {
23 Self::new(Role::System, vec![ContentBlock::Text(text.into())])
24 }
25
26 pub fn user(text: impl Into<String>) -> Self {
27 Self::new(Role::User, vec![ContentBlock::Text(text.into())])
28 }
29
30 pub fn assistant(text: impl Into<String>) -> Self {
31 Self::new(Role::Assistant, vec![ContentBlock::Text(text.into())])
32 }
33
34 pub fn tool_result(tool_use_id: impl Into<String>, content: Value, is_error: bool) -> Self {
35 Self::new(
36 Role::Tool,
37 vec![ContentBlock::ToolResult {
38 tool_use_id: tool_use_id.into(),
39 content,
40 is_error,
41 }],
42 )
43 }
44
45 pub fn text_content(&self) -> Option<String> {
47 let texts: Vec<&str> = self
48 .content
49 .iter()
50 .filter_map(|c| match c {
51 ContentBlock::Text(t) => Some(t.as_str()),
52 _ => None,
53 })
54 .collect();
55 if texts.is_empty() {
56 None
57 } else {
58 Some(texts.join(""))
59 }
60 }
61
62 pub fn tool_uses(&self) -> Vec<ToolCall> {
64 self
65 .content
66 .iter()
67 .filter_map(|c| match c {
68 ContentBlock::ToolUse { id, name, input } => Some(ToolCall {
69 id: id.clone(),
70 name: name.clone(),
71 input: input.clone(),
72 }),
73 _ => None,
74 })
75 .collect()
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum Role {
81 System,
82 User,
83 Assistant,
84 Tool,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(tag = "type", rename_all = "snake_case")]
89pub enum ContentBlock {
90 Text(String),
91 Image { url: String, media_type: String },
92 ToolUse { id: String, name: String, input: Value },
93 ToolResult { tool_use_id: String, content: Value, is_error: bool },
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ToolCall {
98 pub id: String,
99 pub name: String,
100 pub input: Value,
101}