1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct Message {
15 pub role: Role,
17
18 pub parts: Vec<MessagePart>,
20
21 #[serde(rename = "messageId", skip_serializing_if = "Option::is_none")]
23 pub message_id: Option<String>,
24
25 #[serde(rename = "taskId", skip_serializing_if = "Option::is_none")]
27 pub task_id: Option<String>,
28
29 #[serde(rename = "contextId", skip_serializing_if = "Option::is_none")]
31 pub context_id: Option<String>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub metadata: Option<HashMap<String, Value>>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub extensions: Option<HashMap<String, Value>>,
40}
41
42impl Message {
43 pub fn new(role: Role, text: impl Into<String>) -> Self {
45 Self {
46 role,
47 parts: vec![MessagePart::Text { text: text.into() }],
48 message_id: None,
49 task_id: None,
50 context_id: None,
51 metadata: None,
52 extensions: None,
53 }
54 }
55
56 pub fn user(text: impl Into<String>) -> Self {
58 Self::new(Role::User, text)
59 }
60
61 pub fn agent(text: impl Into<String>) -> Self {
63 Self::new(Role::Agent, text)
64 }
65
66 pub fn builder() -> MessageBuilder {
68 MessageBuilder::new()
69 }
70
71 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
73 self.metadata
74 .get_or_insert_with(HashMap::new)
75 .insert(key.into(), value);
76 self
77 }
78
79 pub fn with_extension(mut self, key: impl Into<String>, value: Value) -> Self {
81 self.extensions
82 .get_or_insert_with(HashMap::new)
83 .insert(key.into(), value);
84 self
85 }
86
87 pub fn with_part(mut self, part: MessagePart) -> Self {
89 self.parts.push(part);
90 self
91 }
92}
93
94#[derive(Debug, Default)]
96pub struct MessageBuilder {
97 role: Option<Role>,
98 parts: Vec<MessagePart>,
99 message_id: Option<String>,
100 task_id: Option<String>,
101 context_id: Option<String>,
102 metadata: Option<HashMap<String, Value>>,
103 extensions: Option<HashMap<String, Value>>,
104}
105
106impl MessageBuilder {
107 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn role(mut self, role: Role) -> Self {
114 self.role = Some(role);
115 self
116 }
117
118 pub fn parts(mut self, parts: Vec<MessagePart>) -> Self {
120 self.parts = parts;
121 self
122 }
123
124 pub fn part(mut self, part: MessagePart) -> Self {
126 self.parts.push(part);
127 self
128 }
129
130 pub fn message_id(mut self, id: impl Into<String>) -> Self {
132 self.message_id = Some(id.into());
133 self
134 }
135
136 pub fn task_id(mut self, id: impl Into<String>) -> Self {
138 self.task_id = Some(id.into());
139 self
140 }
141
142 pub fn context_id(mut self, id: impl Into<String>) -> Self {
144 self.context_id = Some(id.into());
145 self
146 }
147
148 pub fn metadata(mut self, key: impl Into<String>, value: Value) -> Self {
150 self.metadata
151 .get_or_insert_with(HashMap::new)
152 .insert(key.into(), value);
153 self
154 }
155
156 pub fn extension(mut self, key: impl Into<String>, value: Value) -> Self {
158 self.extensions
159 .get_or_insert_with(HashMap::new)
160 .insert(key.into(), value);
161 self
162 }
163
164 pub fn build(self) -> Message {
170 let role = self.role.expect("Message role is required");
171 assert!(
172 !self.parts.is_empty(),
173 "Message must have at least one part"
174 );
175
176 Message {
177 role,
178 parts: self.parts,
179 message_id: self.message_id,
180 task_id: self.task_id,
181 context_id: self.context_id,
182 metadata: self.metadata,
183 extensions: self.extensions,
184 }
185 }
186}
187
188#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
190#[serde(rename_all = "lowercase")]
191pub enum Role {
192 User,
194
195 Agent,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
201#[serde(rename_all = "camelCase")]
202pub struct FileContent {
203 #[serde(skip_serializing_if = "Option::is_none")]
205 pub media_type: Option<String>,
206
207 pub name: String,
209
210 #[serde(skip_serializing_if = "Option::is_none")]
212 pub file_with_uri: Option<String>,
213
214 #[serde(skip_serializing_if = "Option::is_none")]
216 pub file_with_bytes: Option<String>,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
223#[serde(untagged)]
224pub enum MessagePart {
225 Text {
227 text: String,
229 },
230
231 File {
233 file: FileContent,
235 },
236
237 Data {
239 data: Value,
241 },
242}
243
244impl MessagePart {
245 pub fn text(text: impl Into<String>) -> Self {
247 Self::Text { text: text.into() }
248 }
249
250 pub fn file(name: impl Into<String>, file_uri: impl Into<String>) -> Self {
252 Self::File {
253 file: FileContent {
254 media_type: None,
255 name: name.into(),
256 file_with_uri: Some(file_uri.into()),
257 file_with_bytes: None,
258 },
259 }
260 }
261
262 pub fn file_with_type(
264 name: impl Into<String>,
265 file_uri: impl Into<String>,
266 media_type: impl Into<String>,
267 ) -> Self {
268 Self::File {
269 file: FileContent {
270 media_type: Some(media_type.into()),
271 name: name.into(),
272 file_with_uri: Some(file_uri.into()),
273 file_with_bytes: None,
274 },
275 }
276 }
277
278 pub fn file_with_bytes(
280 name: impl Into<String>,
281 file_bytes: impl Into<String>,
282 media_type: Option<String>,
283 ) -> Self {
284 Self::File {
285 file: FileContent {
286 media_type,
287 name: name.into(),
288 file_with_uri: None,
289 file_with_bytes: Some(file_bytes.into()),
290 },
291 }
292 }
293
294 pub fn data(data: Value) -> Self {
296 Self::Data { data }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use serde_json::json;
303
304 use super::*;
305
306 #[test]
307 fn test_message_creation() {
308 let msg = Message::user("Hello, agent!");
309 assert_eq!(msg.role, Role::User);
310 assert_eq!(msg.parts.len(), 1);
311
312 match &msg.parts[0] {
313 MessagePart::Text { text } => assert_eq!(text, "Hello, agent!"),
314 _ => panic!("Expected text part"),
315 }
316 }
317
318 #[test]
319 fn test_message_with_metadata() {
320 let msg = Message::user("Test")
321 .with_metadata("key", json!("value"))
322 .with_extension("ext", json!({"enabled": true}));
323
324 assert!(msg.metadata.is_some());
325 assert!(msg.extensions.is_some());
326 }
327
328 #[test]
329 fn test_message_serialization() {
330 let msg = Message::user("Test message");
331 let json = serde_json::to_string(&msg).unwrap();
332 assert!(json.contains("\"role\":\"user\""));
333 assert!(json.contains("\"text\":\"Test message\""));
334
335 let deserialized: Message = serde_json::from_str(&json).unwrap();
336 assert_eq!(msg, deserialized);
337 }
338
339 #[test]
340 fn test_message_part_types() {
341 let text = MessagePart::text("Hello");
342 let file = MessagePart::file("myfile.txt", "file://path/to/file");
343 let data = MessagePart::data(json!({"key": "value"}));
344
345 assert!(matches!(text, MessagePart::Text { .. }));
346 assert!(matches!(file, MessagePart::File { .. }));
347 assert!(matches!(data, MessagePart::Data { .. }));
348 }
349
350 #[test]
351 fn test_message_builder() {
352 let msg = Message::builder()
353 .role(Role::Agent)
354 .parts(vec![MessagePart::text("Hello")])
355 .message_id("msg-123")
356 .task_id("task-456")
357 .context_id("ctx-789")
358 .build();
359
360 assert_eq!(msg.role, Role::Agent);
361 assert_eq!(msg.parts.len(), 1);
362 assert_eq!(msg.message_id, Some("msg-123".to_string()));
363 assert_eq!(msg.task_id, Some("task-456".to_string()));
364 assert_eq!(msg.context_id, Some("ctx-789".to_string()));
365 }
366
367 #[test]
368 fn test_message_builder_with_part() {
369 let msg = Message::builder()
370 .role(Role::Agent)
371 .part(MessagePart::text("First"))
372 .part(MessagePart::text("Second"))
373 .build();
374
375 assert_eq!(msg.parts.len(), 2);
376 }
377
378 #[test]
379 #[should_panic(expected = "Message role is required")]
380 fn test_message_builder_missing_role() {
381 Message::builder()
382 .parts(vec![MessagePart::text("Hello")])
383 .build();
384 }
385
386 #[test]
387 #[should_panic(expected = "Message must have at least one part")]
388 fn test_message_builder_no_parts() {
389 Message::builder().role(Role::User).build();
390 }
391
392 #[test]
393 fn test_message_serialization_with_ids() {
394 let msg = Message::builder()
395 .role(Role::User)
396 .parts(vec![MessagePart::text("Test")])
397 .message_id("msg-123")
398 .task_id("task-456")
399 .build();
400
401 let json = serde_json::to_string(&msg).unwrap();
402 assert!(json.contains("\"messageId\":\"msg-123\""));
403 assert!(json.contains("\"taskId\":\"task-456\""));
404
405 let deserialized: Message = serde_json::from_str(&json).unwrap();
406 assert_eq!(msg, deserialized);
407 }
408}