hehe_core/message/
message.rs1use super::content::{ContentBlock, ImageContent, ToolUse};
2use super::role::Role;
3use crate::types::{MessageId, Metadata, Timestamp};
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, Serialize, Deserialize)]
7pub struct Message {
8 pub id: MessageId,
9 pub role: Role,
10 pub content: Vec<ContentBlock>,
11 pub created_at: Timestamp,
12 #[serde(default, skip_serializing_if = "Metadata::is_empty")]
13 pub metadata: Metadata,
14}
15
16impl Message {
17 pub fn new(role: Role, content: Vec<ContentBlock>) -> Self {
18 Self {
19 id: MessageId::new(),
20 role,
21 content,
22 created_at: Timestamp::now(),
23 metadata: Metadata::new(),
24 }
25 }
26
27 pub fn with_id(mut self, id: MessageId) -> Self {
28 self.id = id;
29 self
30 }
31
32 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
33 self.metadata = metadata;
34 self
35 }
36
37 pub fn system(text: impl Into<String>) -> Self {
38 Self::new(Role::System, vec![ContentBlock::text(text)])
39 }
40
41 pub fn user(text: impl Into<String>) -> Self {
42 Self::new(Role::User, vec![ContentBlock::text(text)])
43 }
44
45 pub fn assistant(text: impl Into<String>) -> Self {
46 Self::new(Role::Assistant, vec![ContentBlock::text(text)])
47 }
48
49 pub fn tool(content: Vec<ContentBlock>) -> Self {
50 Self::new(Role::Tool, content)
51 }
52
53 pub fn text_content(&self) -> String {
54 self.content
55 .iter()
56 .filter_map(|b| b.as_text())
57 .collect::<Vec<_>>()
58 .join("")
59 }
60
61 pub fn has_tool_use(&self) -> bool {
62 self.content.iter().any(|b| b.is_tool_use())
63 }
64
65 pub fn has_tool_result(&self) -> bool {
66 self.content.iter().any(|b| b.is_tool_result())
67 }
68
69 pub fn tool_uses(&self) -> Vec<&ToolUse> {
70 self.content
71 .iter()
72 .filter_map(|b| b.as_tool_use())
73 .collect()
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.content.is_empty()
78 }
79
80 pub fn push(&mut self, block: ContentBlock) {
81 self.content.push(block);
82 }
83}
84
85#[derive(Default)]
86pub struct MessageBuilder {
87 id: Option<MessageId>,
88 role: Option<Role>,
89 content: Vec<ContentBlock>,
90 metadata: Metadata,
91}
92
93impl MessageBuilder {
94 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn id(mut self, id: MessageId) -> Self {
99 self.id = Some(id);
100 self
101 }
102
103 pub fn role(mut self, role: Role) -> Self {
104 self.role = Some(role);
105 self
106 }
107
108 pub fn system(self) -> Self {
109 self.role(Role::System)
110 }
111
112 pub fn user(self) -> Self {
113 self.role(Role::User)
114 }
115
116 pub fn assistant(self) -> Self {
117 self.role(Role::Assistant)
118 }
119
120 pub fn text(mut self, text: impl Into<String>) -> Self {
121 self.content.push(ContentBlock::text(text));
122 self
123 }
124
125 pub fn image(mut self, image: ImageContent) -> Self {
126 self.content.push(ContentBlock::Image(image));
127 self
128 }
129
130 pub fn content(mut self, block: ContentBlock) -> Self {
131 self.content.push(block);
132 self
133 }
134
135 pub fn contents(mut self, blocks: Vec<ContentBlock>) -> Self {
136 self.content.extend(blocks);
137 self
138 }
139
140 pub fn metadata<K: Into<String>, V: Serialize>(mut self, key: K, value: V) -> Self {
141 self.metadata.insert(key, value);
142 self
143 }
144
145 pub fn build(self) -> Result<Message, &'static str> {
146 let role = self.role.ok_or("role is required")?;
147 if self.content.is_empty() {
148 return Err("content is required");
149 }
150 let mut msg = Message::new(role, self.content);
151 if let Some(id) = self.id {
152 msg.id = id;
153 }
154 msg.metadata = self.metadata;
155 Ok(msg)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_message_creation() {
165 let msg = Message::user("Hello");
166 assert_eq!(msg.role, Role::User);
167 assert_eq!(msg.text_content(), "Hello");
168 }
169
170 #[test]
171 fn test_message_builder() {
172 let msg = MessageBuilder::new()
173 .user()
174 .text("Hello")
175 .text(" World")
176 .build()
177 .unwrap();
178
179 assert_eq!(msg.role, Role::User);
180 assert_eq!(msg.text_content(), "Hello World");
181 assert_eq!(msg.content.len(), 2);
182 }
183}