1use std::collections::HashMap;
46use std::sync::Arc;
47
48use chrono::{DateTime, Utc};
49use serde::{Deserialize, Serialize};
50use smallvec::SmallVec;
51use typed_builder::TypedBuilder;
52use uuid::Uuid;
53
54use crate::tools::ToolCall;
55
56#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
60#[non_exhaustive]
61pub enum MessageRole {
62 #[serde(rename = "system")]
64 System,
65
66 #[serde(rename = "user")]
68 User,
69
70 #[serde(rename = "assistant")]
72 Assistant,
73
74 #[serde(rename = "tool")]
76 Tool,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone, TypedBuilder)]
84pub struct Message {
85 #[builder(default = Uuid::new_v4())]
87 pub id: Uuid,
88
89 pub conversation_id: Uuid,
91
92 pub role: MessageRole,
94
95 pub content: String,
97
98 #[builder(default)]
100 pub metadata: HashMap<String, serde_json::Value>,
101
102 #[builder(default = Utc::now())]
104 pub timestamp: DateTime<Utc>,
105
106 #[builder(default)]
108 pub tool_calls: SmallVec<[ToolCall; 2]>,
109
110 #[builder(default)]
112 pub tool_call_id: Option<String>,
113
114 #[builder(default)]
116 pub name: Option<String>,
117}
118
119impl Message {
120 pub fn new(conversation_id: Uuid, role: MessageRole, content: impl Into<String>) -> Self {
122 Self {
123 id: Uuid::new_v4(),
124 conversation_id,
125 role,
126 content: content.into(),
127 metadata: HashMap::new(),
128 timestamp: Utc::now(),
129 tool_calls: SmallVec::new(),
130 tool_call_id: None,
131 name: None,
132 }
133 }
134
135 pub fn system(conversation_id: Uuid, content: impl Into<String>) -> Self {
137 Self::new(conversation_id, MessageRole::System, content)
138 }
139
140 pub fn user(conversation_id: Uuid, content: impl Into<String>) -> Self {
142 Self::new(conversation_id, MessageRole::User, content)
143 }
144
145 pub fn assistant(conversation_id: Uuid, content: impl Into<String>) -> Self {
147 Self::new(conversation_id, MessageRole::Assistant, content)
148 }
149
150 pub fn tool(
163 conversation_id: Uuid,
164 content: impl Into<String>,
165 tool_call_id: String,
166 function_name: String,
167 ) -> anyhow::Result<Self> {
168 if tool_call_id.is_empty() {
169 anyhow::bail!("Tool call ID cannot be empty");
170 }
171 if function_name.is_empty() {
172 anyhow::bail!("Function name cannot be empty for tool messages");
173 }
174 let mut msg = Self::new(conversation_id, MessageRole::Tool, content);
175 msg.tool_call_id = Some(tool_call_id);
176 msg.name = Some(function_name);
177 Ok(msg)
178 }
179
180 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
182 self.metadata.insert(key.into(), value);
183 self
184 }
185
186 pub fn with_metadata_typed<T: serde::Serialize>(
190 mut self,
191 key: impl Into<String>,
192 value: T,
193 ) -> anyhow::Result<Self> {
194 let json_value = serde_json::to_value(value)?;
195 self.metadata.insert(key.into(), json_value);
196 Ok(self)
197 }
198
199 pub fn with_tool_calls(
205 mut self,
206 tool_calls: impl Into<SmallVec<[ToolCall; 2]>>,
207 ) -> anyhow::Result<Self> {
208 if self.role != MessageRole::Assistant {
209 anyhow::bail!(
210 "Tool calls can only be added to assistant messages, found {:?}",
211 self.role
212 );
213 }
214 self.tool_calls = tool_calls.into();
215 Ok(self)
216 }
217
218 pub fn add_tool_call(&mut self, tool_call: ToolCall) -> anyhow::Result<()> {
224 if self.role != MessageRole::Assistant {
225 anyhow::bail!(
226 "Tool calls can only be added to assistant messages, found {:?}",
227 self.role
228 );
229 }
230 self.tool_calls.push(tool_call);
231 Ok(())
232 }
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
239#[non_exhaustive]
240pub enum ConversationStatus {
241 #[serde(rename = "active")]
243 Active,
244
245 #[serde(rename = "paused")]
247 Paused,
248
249 #[serde(rename = "archived")]
251 Archived,
252
253 #[serde(rename = "deleted")]
255 Deleted,
256}
257
258#[derive(Debug, Serialize, Deserialize, Clone)]
263pub struct Conversation {
264 pub id: Uuid,
266
267 pub title: Option<String>,
269
270 pub description: Option<String>,
272
273 pub created_at: DateTime<Utc>,
275
276 pub updated_at: DateTime<Utc>,
278
279 pub metadata: HashMap<String, serde_json::Value>,
281
282 pub status: ConversationStatus,
284
285 pub messages: Arc<Vec<Message>>,
287}
288
289impl Conversation {
290 pub fn new() -> Self {
292 let now = Utc::now();
293 Self {
294 id: Uuid::new_v4(),
295 title: None,
296 description: None,
297 created_at: now,
298 updated_at: now,
299 metadata: HashMap::new(),
300 status: ConversationStatus::Active,
301 messages: Arc::new(Vec::new()),
302 }
303 }
304
305 pub fn with_title(mut self, title: impl Into<String>) -> Self {
307 self.title = Some(title.into());
308 self
309 }
310
311 pub fn with_description(mut self, description: impl Into<String>) -> Self {
313 self.description = Some(description.into());
314 self
315 }
316
317 pub fn set_status(&mut self, status: ConversationStatus) {
319 self.status = status;
320 self.updated_at = Utc::now();
321 }
322
323 pub fn touch(&mut self) {
325 self.updated_at = Utc::now();
326 }
327
328 pub fn add_message(&mut self, message: Message) -> anyhow::Result<()> {
330 if message.conversation_id != self.id {
331 anyhow::bail!(
332 "Message conversation_id {} does not match conversation id {}",
333 message.conversation_id,
334 self.id
335 );
336 }
337 Arc::make_mut(&mut self.messages).push(message);
338 self.touch();
339 Ok(())
340 }
341
342 pub fn get_messages(&self) -> &[Message] {
344 &self.messages
345 }
346
347 pub fn user_message(&self, content: impl Into<String>) -> Message {
349 Message::user(self.id, content)
350 }
351
352 pub fn assistant_message(&self, content: impl Into<String>) -> Message {
354 Message::assistant(self.id, content)
355 }
356
357 pub fn system_message(&self, content: impl Into<String>) -> Message {
359 Message::system(self.id, content)
360 }
361
362 pub fn tool_message(
364 &self,
365 content: impl Into<String>,
366 tool_call_id: String,
367 function_name: String,
368 ) -> anyhow::Result<Message> {
369 Message::tool(self.id, content, tool_call_id, function_name)
370 }
371}
372
373impl Default for Conversation {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_message_creation() {
385 let conv_id = Uuid::new_v4();
386 let msg = Message::user(conv_id, "Hello, world!");
387
388 assert_eq!(msg.conversation_id, conv_id);
389 assert_eq!(msg.role, MessageRole::User);
390 assert_eq!(msg.content, "Hello, world!");
391 assert!(msg.tool_calls.is_empty());
392 }
393
394 #[test]
395 fn test_conversation_creation() {
396 let conv = Conversation::new()
397 .with_title("Test Conversation")
398 .with_description("A test conversation");
399
400 assert_eq!(conv.title, Some("Test Conversation".to_string()));
401 assert_eq!(conv.description, Some("A test conversation".to_string()));
402 assert_eq!(conv.status, ConversationStatus::Active);
403 }
404
405 #[test]
406 fn test_tool_call_creation() {
407 let tool_call = ToolCall::new("test_function", [r#"{"param": "value"}"#]);
408
409 assert_eq!(tool_call.function.name, "test_function");
410 assert_eq!(tool_call.function.arguments, vec![r#"{"param": "value"}"#]);
411 assert_eq!(tool_call.call_type, "function");
412 assert!(!tool_call.id.is_empty());
413 }
414
415 #[test]
416 fn test_message_with_tool_calls() {
417 let conv_id = Uuid::new_v4();
418 let tool_call = ToolCall::new("get_weather", [r#"{"location": "New York"}"#]);
419 let msg = Message::assistant(conv_id, "I'll check the weather for you.")
420 .with_tool_calls(vec![tool_call])
421 .expect("Failed to add tool calls");
422
423 assert_eq!(msg.tool_calls.len(), 1);
424 assert_eq!(msg.tool_calls[0].function.name, "get_weather");
425 assert_eq!(
426 msg.tool_calls[0].function.arguments,
427 vec![r#"{"location": "New York"}"#]
428 );
429 }
430
431 #[test]
432 fn test_message_tool_call_validation() {
433 let conv_id = Uuid::new_v4();
434 let tool_call = ToolCall::new("get_weather", [r#"{"location": "New York"}"#]);
435
436 let user_msg = Message::user(conv_id, "What's the weather?");
438 let result = user_msg.with_tool_calls(vec![tool_call.clone()]);
439 assert!(result.is_err());
440
441 let assistant_msg = Message::assistant(conv_id, "Let me check.");
443 let result = assistant_msg.with_tool_calls(vec![tool_call]);
444 assert!(result.is_ok());
445 }
446
447 #[test]
448 fn test_tool_message_validation() {
449 let conv_id = Uuid::new_v4();
450
451 let result = Message::tool(conv_id, "Result", String::new(), "test_func".to_string());
453 assert!(result.is_err());
454
455 let result = Message::tool(conv_id, "Result", "call_123".to_string(), String::new());
457 assert!(result.is_err());
458
459 let result = Message::tool(
461 conv_id,
462 "Result",
463 "call_123".to_string(),
464 "test_func".to_string(),
465 );
466 assert!(result.is_ok());
467 let msg = result.unwrap();
468 assert_eq!(msg.name, Some("test_func".to_string()));
469 }
470
471 #[test]
472 fn test_conversation_add_message() {
473 let mut conv = Conversation::new();
474 let msg = Message::user(conv.id, "Hello");
475
476 conv.add_message(msg).expect("Failed to add message");
477 assert_eq!(conv.messages.len(), 1);
478 assert_eq!(conv.messages[0].content, "Hello");
479 }
480
481 #[test]
482 fn test_conversation_add_message_wrong_id() {
483 let mut conv = Conversation::new();
484 let other_id = Uuid::new_v4();
485 let msg = Message::user(other_id, "Hello");
486
487 let result = conv.add_message(msg);
488 assert!(result.is_err());
489 }
490
491 #[test]
492 fn test_message_with_metadata_typed() {
493 let conv_id = Uuid::new_v4();
494 let msg = Message::user(conv_id, "Hello")
495 .with_metadata_typed("count", 42)
496 .expect("Failed to add metadata");
497
498 assert_eq!(msg.metadata.get("count"), Some(&serde_json::json!(42)));
499 }
500
501 #[test]
502 fn test_tool_call_with_multiple_args() {
503 let tool_call = ToolCall::new(
504 "complex_function",
505 vec![
506 "arg1".to_string(),
507 "arg2".to_string(),
508 r#"{"key": "value"}"#.to_string(),
509 ],
510 );
511
512 assert_eq!(tool_call.function.name, "complex_function");
513 assert_eq!(tool_call.function.arguments.len(), 3);
514 assert_eq!(tool_call.function.arguments[0], "arg1");
515 assert_eq!(tool_call.function.arguments[1], "arg2");
516 assert_eq!(tool_call.function.arguments[2], r#"{"key": "value"}"#);
517 }
518}
519
520#[cfg(test)]
521mod proptests {
522 use super::*;
523 use proptest::prelude::*;
524
525 proptest! {
526 #[test]
527 fn message_accepts_string_types(content in ".*") {
528 let conv_id = Uuid::new_v4();
529
530 let msg1 = Message::new(conv_id, MessageRole::User, content.as_str());
532 assert_eq!(msg1.content, content);
533
534 let msg2 = Message::new(conv_id, MessageRole::User, content.clone());
536 assert_eq!(msg2.content, content);
537
538 let msg3 = Message::user(conv_id, content.as_str());
540 assert_eq!(msg3.role, MessageRole::User);
541 assert_eq!(msg3.content, content);
542 }
543
544 #[test]
545 fn message_serialization_roundtrip(
546 content in ".*",
547 role_idx in 0usize..4,
548 ) {
549 let conv_id = Uuid::new_v4();
550 let role = match role_idx {
551 0 => MessageRole::User,
552 1 => MessageRole::Assistant,
553 2 => MessageRole::System,
554 _ => MessageRole::Tool,
555 };
556
557 let msg = Message::new(conv_id, role, content);
558 let serialized = serde_json::to_string(&msg).expect("Failed to serialize");
559 let deserialized: Message = serde_json::from_str(&serialized)
560 .expect("Failed to deserialize");
561
562 assert_eq!(msg.id, deserialized.id);
563 assert_eq!(msg.conversation_id, deserialized.conversation_id);
564 assert_eq!(msg.role, deserialized.role);
565 assert_eq!(msg.content, deserialized.content);
566 }
567
568 #[test]
569 fn conversation_builder_with_strings(
570 title in ".*",
571 description in ".*",
572 ) {
573 let conv1 = Conversation::new()
575 .with_title(title.as_str())
576 .with_description(description.as_str());
577 assert_eq!(conv1.title, Some(title.clone()));
578 assert_eq!(conv1.description, Some(description.clone()));
579
580 let conv2 = Conversation::new()
582 .with_title(title.clone())
583 .with_description(description.clone());
584 assert_eq!(conv2.title, Some(title));
585 assert_eq!(conv2.description, Some(description));
586 }
587
588 #[test]
589 fn tool_call_accepts_various_argument_types(
590 func_name in ".*",
591 args in prop::collection::vec(".*", 0..10),
592 ) {
593 let tc1 = ToolCall::new(func_name.as_str(), args.clone());
595 assert_eq!(tc1.function.name, func_name);
596 assert_eq!(tc1.function.arguments, args);
597
598 let str_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
600 let tc2 = ToolCall::new(func_name.as_str(), str_refs);
601 assert_eq!(tc2.function.name, func_name);
602 assert_eq!(tc2.function.arguments, args);
603 }
604
605 #[test]
606 fn message_metadata_operations(
607 key in ".*",
608 value_num in 0i64..1000000,
609 ) {
610 let conv_id = Uuid::new_v4();
611 let msg = Message::user(conv_id, "test")
612 .with_metadata(key.as_str(), serde_json::json!(value_num));
613
614 assert!(msg.metadata.contains_key(&key));
615 assert_eq!(msg.metadata[&key], serde_json::json!(value_num));
616 }
617
618 #[test]
619 fn conversation_status_transitions(
620 status_idx in 0usize..4,
621 ) {
622 let status = match status_idx {
623 0 => ConversationStatus::Active,
624 1 => ConversationStatus::Archived,
625 2 => ConversationStatus::Deleted,
626 _ => ConversationStatus::Active,
627 };
628
629 let mut conv = Conversation::new();
630 conv.set_status(status.clone());
631
632 assert_eq!(conv.status, status);
633 }
634
635 #[test]
636 fn message_clone_preserves_data(content in ".*") {
637 let conv_id = Uuid::new_v4();
638 let original = Message::user(conv_id, content.as_str());
639 let cloned = original.clone();
640
641 assert_eq!(original.id, cloned.id);
642 assert_eq!(original.conversation_id, cloned.conversation_id);
643 assert_eq!(original.role, cloned.role);
644 assert_eq!(original.content, cloned.content);
645 assert_eq!(original.timestamp, cloned.timestamp);
646 }
647
648 #[test]
649 fn fuzz_message_deserialization(data in prop::collection::vec(any::<u8>(), 0..1000)) {
650 let _ = serde_json::from_slice::<Message>(&data);
652 }
653
654 #[test]
655 fn fuzz_message_json_with_invalid_roles(
656 content in "[\\p{L}\\p{N}\\p{P}\\p{S} ]{0,100}",
657 role_str in "[a-z]{1,20}",
658 ) {
659 let conv_id = Uuid::new_v4();
660 let msg_id = Uuid::new_v4();
661 let escaped_content = content.replace('\\', "\\\\").replace('"', "\\\"");
663 let json = format!(
664 r#"{{"id":"{}","conversation_id":"{}","role":"{}","content":"{}","metadata":{{}},"timestamp":"2024-01-01T00:00:00Z","tool_calls":[],"tool_call_id":null,"name":null}}"#,
665 msg_id, conv_id, role_str, escaped_content
666 );
667 let _ = serde_json::from_str::<Message>(&json);
669 }
670
671 #[test]
672 fn fuzz_message_with_extreme_lengths(
673 content_len in 10000usize..20000,
674 ) {
675 let conv_id = Uuid::new_v4();
676 let content: String = "a".repeat(content_len);
678 let msg = Message::user(conv_id, content.clone());
679
680 let json = serde_json::to_string(&msg).unwrap();
682 let deserialized: Message = serde_json::from_str(&json).unwrap();
683 assert_eq!(msg.content, deserialized.content);
684 }
685
686 #[test]
687 fn fuzz_tool_message_with_invalid_ids(
688 content in ".*",
689 tool_call_id in ".*",
690 func_name in ".*",
691 ) {
692 let conv_id = Uuid::new_v4();
693 let result = Message::tool(conv_id, content.clone(), tool_call_id.clone(), func_name.clone());
694
695 if tool_call_id.is_empty() || func_name.is_empty() {
697 assert!(result.is_err());
698 } else {
699 assert!(result.is_ok());
700 let msg = result.unwrap();
701 assert_eq!(msg.tool_call_id, Some(tool_call_id));
702 assert_eq!(msg.name, Some(func_name));
703 assert_eq!(msg.content, content);
704 }
705 }
706
707 #[test]
708 fn fuzz_message_with_special_characters(
709 content in r#"[\x00-\x1F\x7F\n\r\t"'`{}\[\]]*"#,
710 ) {
711 let conv_id = Uuid::new_v4();
712 let msg = Message::user(conv_id, content.clone());
713
714 let json_result = serde_json::to_string(&msg);
716 assert!(json_result.is_ok());
717
718 if let Ok(json) = json_result {
719 let parsed: Result<Message, _> = serde_json::from_str(&json);
720 if let Ok(parsed_msg) = parsed {
721 assert_eq!(parsed_msg.content, content);
722 }
723 }
724 }
725
726 #[test]
727 fn fuzz_conversation_serialization(
728 title in prop::option::of(".*"),
729 description in prop::option::of(".*"),
730 num_messages in 0usize..20,
731 ) {
732 let mut conv = Conversation::new();
733 if let Some(t) = title.clone() {
734 conv = conv.with_title(t);
735 }
736 if let Some(d) = description.clone() {
737 conv = conv.with_description(d);
738 }
739
740 for i in 0..num_messages {
742 let msg = conv.user_message(format!("Message {}", i));
743 let _ = conv.add_message(msg);
744 }
745
746 let json = serde_json::to_string(&conv).unwrap();
748 let parsed: Conversation = serde_json::from_str(&json).unwrap();
749
750 assert_eq!(conv.id, parsed.id);
751 assert_eq!(conv.title, parsed.title);
752 assert_eq!(conv.description, parsed.description);
753 assert_eq!(conv.messages.len(), parsed.messages.len());
754 }
755 }
756}