1use async_trait::async_trait;
18use chrono::{DateTime, Utc};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::RwLock;
22
23use crate::kernel::{ExecutionId, MessageId, ParentType, TenantId, ThreadId, UserId};
24
25use super::StorageBackend;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Thread {
34 pub id: ThreadId,
36 pub tenant_id: TenantId,
38 pub user_id: UserId,
40 pub title: Option<String>,
42 pub created_at: DateTime<Utc>,
44 pub updated_at: DateTime<Utc>,
46 pub deleted_at: Option<DateTime<Utc>>,
48}
49
50impl Thread {
51 pub fn new(tenant_id: TenantId, user_id: UserId) -> Self {
53 let now = Utc::now();
54 Self {
55 id: ThreadId::new(),
56 tenant_id,
57 user_id,
58 title: None,
59 created_at: now,
60 updated_at: now,
61 deleted_at: None,
62 }
63 }
64
65 pub fn with_id(id: ThreadId, tenant_id: TenantId, user_id: UserId) -> Self {
67 let now = Utc::now();
68 Self {
69 id,
70 tenant_id,
71 user_id,
72 title: None,
73 created_at: now,
74 updated_at: now,
75 deleted_at: None,
76 }
77 }
78
79 pub fn is_deleted(&self) -> bool {
81 self.deleted_at.is_some()
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91#[serde(rename_all = "lowercase")]
92pub enum MessageRole {
93 User,
95 Assistant,
97 System,
99}
100
101impl std::fmt::Display for MessageRole {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 MessageRole::User => write!(f, "user"),
105 MessageRole::Assistant => write!(f, "assistant"),
106 MessageRole::System => write!(f, "system"),
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
117#[serde(tag = "type", rename_all = "kebab-case")]
118pub enum MessagePart {
119 Text { text: String },
121
122 Reasoning { text: String },
124
125 ToolCall {
127 tool_call_id: String,
128 tool_name: String,
129 args: serde_json::Value,
130 },
131
132 ToolResult {
134 tool_call_id: String,
135 tool_name: String,
136 result: serde_json::Value,
137 is_error: bool,
138 },
139
140 Source {
142 source_id: String,
143 url: Option<String>,
144 title: Option<String>,
145 },
146
147 File {
149 file_id: String,
150 filename: String,
151 mime_type: String,
152 size_bytes: u64,
153 },
154
155 Image {
157 image_id: String,
158 url: Option<String>,
159 alt_text: Option<String>,
160 },
161
162 Code {
164 language: Option<String>,
165 code: String,
166 },
167}
168
169#[derive(Debug, Clone, Default, Serialize, Deserialize)]
175pub struct TokenUsage {
176 pub prompt_tokens: u32,
178 pub completion_tokens: u32,
180 pub total_tokens: u32,
182}
183
184impl TokenUsage {
185 pub fn new(prompt: u32, completion: u32) -> Self {
187 Self {
188 prompt_tokens: prompt,
189 completion_tokens: completion,
190 total_tokens: prompt + completion,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201pub struct ExecutionStats {
202 pub llm_calls: u32,
204 pub tool_calls: u32,
206 pub sub_agents: u32,
208 pub steps: u32,
210 pub decisions: u32,
212 pub artifacts: u32,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct CostInfo {
223 pub input_cost: f64,
225 pub output_cost: f64,
227 pub total_cost: f64,
229 pub currency: String,
231}
232
233impl Default for CostInfo {
234 fn default() -> Self {
235 Self {
236 input_cost: 0.0,
237 output_cost: 0.0,
238 total_cost: 0.0,
239 currency: "USD".to_string(),
240 }
241 }
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
250#[serde(rename_all = "snake_case")]
251pub enum FinishReason {
252 Stop,
254 Length,
256 ToolCalls,
258 ContentFilter,
260 Error,
262}
263
264#[derive(Debug, Clone, Default, Serialize, Deserialize)]
270pub struct MessageMetadata {
271 pub completed_at: Option<i64>,
273 pub duration_ms: Option<u64>,
275
276 pub model: Option<String>,
278 pub provider: Option<String>,
280
281 pub token_usage: Option<TokenUsage>,
283
284 pub stats: Option<ExecutionStats>,
286
287 pub finish_reason: Option<FinishReason>,
289
290 pub cost: Option<CostInfo>,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct Message {
301 pub id: MessageId,
303 pub thread_id: ThreadId,
305 pub execution_id: Option<ExecutionId>,
307
308 pub parent_id: Option<MessageId>,
311 pub parent_type: ParentType,
313
314 pub role: MessageRole,
317 pub content: String,
319 pub parts: Vec<MessagePart>,
321
322 pub created_at: DateTime<Utc>,
325 pub updated_at: Option<DateTime<Utc>>,
327 pub deleted_at: Option<DateTime<Utc>>,
329
330 pub metadata: MessageMetadata,
332}
333
334impl Message {
335 pub fn user(thread_id: ThreadId, content: impl Into<String>) -> Self {
337 let content = content.into();
338 Self {
339 id: MessageId::new(),
340 thread_id,
341 execution_id: None,
342 parent_id: None,
343 parent_type: ParentType::UserMessage,
344 role: MessageRole::User,
345 content: content.clone(),
346 parts: vec![MessagePart::Text { text: content }],
347 created_at: Utc::now(),
348 updated_at: None,
349 deleted_at: None,
350 metadata: MessageMetadata::default(),
351 }
352 }
353
354 pub fn assistant(
356 thread_id: ThreadId,
357 execution_id: ExecutionId,
358 content: impl Into<String>,
359 parent_id: Option<MessageId>,
360 ) -> Self {
361 let content = content.into();
362 Self {
363 id: MessageId::new(),
364 thread_id,
365 execution_id: Some(execution_id),
366 parent_id,
367 parent_type: ParentType::UserMessage,
368 role: MessageRole::Assistant,
369 content: content.clone(),
370 parts: vec![MessagePart::Text { text: content }],
371 created_at: Utc::now(),
372 updated_at: None,
373 deleted_at: None,
374 metadata: MessageMetadata::default(),
375 }
376 }
377
378 pub fn system(thread_id: ThreadId, content: impl Into<String>) -> Self {
380 let content = content.into();
381 Self {
382 id: MessageId::new(),
383 thread_id,
384 execution_id: None,
385 parent_id: None,
386 parent_type: ParentType::System,
387 role: MessageRole::System,
388 content: content.clone(),
389 parts: vec![MessagePart::Text { text: content }],
390 created_at: Utc::now(),
391 updated_at: None,
392 deleted_at: None,
393 metadata: MessageMetadata::default(),
394 }
395 }
396
397 pub fn is_deleted(&self) -> bool {
399 self.deleted_at.is_some()
400 }
401
402 pub fn with_parent(mut self, parent_id: MessageId, parent_type: ParentType) -> Self {
404 self.parent_id = Some(parent_id);
405 self.parent_type = parent_type;
406 self
407 }
408
409 pub fn with_parts(mut self, parts: Vec<MessagePart>) -> Self {
411 self.parts = parts;
412 self
413 }
414
415 pub fn with_metadata(mut self, metadata: MessageMetadata) -> Self {
417 self.metadata = metadata;
418 self
419 }
420}
421
422#[async_trait]
434pub trait MessageStore: StorageBackend {
435 async fn create_thread(&self, thread: Thread) -> anyhow::Result<ThreadId>;
441
442 async fn get_thread(&self, thread_id: &ThreadId) -> anyhow::Result<Option<Thread>>;
444
445 async fn update_thread(&self, thread: Thread) -> anyhow::Result<()>;
447
448 async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()>;
450
451 async fn list_threads(
453 &self,
454 tenant_id: &TenantId,
455 user_id: &UserId,
456 limit: usize,
457 offset: usize,
458 ) -> anyhow::Result<Vec<Thread>>;
459
460 async fn create_message(&self, message: Message) -> anyhow::Result<MessageId>;
466
467 async fn get_message(&self, message_id: &MessageId) -> anyhow::Result<Option<Message>>;
469
470 async fn update_message(&self, message: Message) -> anyhow::Result<()>;
472
473 async fn delete_message(&self, message_id: &MessageId) -> anyhow::Result<()>;
475
476 async fn list_messages(
478 &self,
479 thread_id: &ThreadId,
480 include_deleted: bool,
481 ) -> anyhow::Result<Vec<Message>>;
482
483 async fn get_messages_by_execution(
485 &self,
486 execution_id: &ExecutionId,
487 ) -> anyhow::Result<Vec<Message>>;
488
489 async fn get_or_create_thread(
495 &self,
496 tenant_id: TenantId,
497 user_id: UserId,
498 ) -> anyhow::Result<Thread> {
499 let thread = Thread::new(tenant_id, user_id);
501 self.create_thread(thread.clone()).await?;
502 Ok(thread)
503 }
504
505 async fn count_messages(&self, thread_id: &ThreadId) -> anyhow::Result<u64> {
507 let messages = self.list_messages(thread_id, false).await?;
508 Ok(messages.len() as u64)
509 }
510}
511
512#[derive(Default)]
521pub struct InMemoryMessageStore {
522 threads: RwLock<HashMap<String, Thread>>,
523 messages: RwLock<HashMap<String, Message>>,
524}
525
526impl InMemoryMessageStore {
527 pub fn new() -> Self {
529 Self {
530 threads: RwLock::new(HashMap::new()),
531 messages: RwLock::new(HashMap::new()),
532 }
533 }
534
535 pub fn shared() -> std::sync::Arc<Self> {
537 std::sync::Arc::new(Self::new())
538 }
539}
540
541#[async_trait]
542impl StorageBackend for InMemoryMessageStore {
543 fn name(&self) -> &str {
544 "in-memory-message-store"
545 }
546
547 fn requires_network(&self) -> bool {
548 false
549 }
550
551 async fn health_check(&self) -> anyhow::Result<()> {
552 Ok(())
553 }
554}
555
556#[async_trait]
557impl MessageStore for InMemoryMessageStore {
558 async fn create_thread(&self, thread: Thread) -> anyhow::Result<ThreadId> {
559 let id = thread.id.clone();
560 let mut guard = self.threads.write().expect("lock poisoned");
561 guard.insert(id.to_string(), thread);
562 Ok(id)
563 }
564
565 async fn get_thread(&self, thread_id: &ThreadId) -> anyhow::Result<Option<Thread>> {
566 let guard = self.threads.read().expect("lock poisoned");
567 Ok(guard.get(&thread_id.to_string()).cloned())
568 }
569
570 async fn update_thread(&self, thread: Thread) -> anyhow::Result<()> {
571 let mut guard = self.threads.write().expect("lock poisoned");
572 guard.insert(thread.id.to_string(), thread);
573 Ok(())
574 }
575
576 async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
577 let mut guard = self.threads.write().expect("lock poisoned");
578 if let Some(thread) = guard.get_mut(&thread_id.to_string()) {
579 thread.deleted_at = Some(Utc::now());
580 }
581 Ok(())
582 }
583
584 async fn list_threads(
585 &self,
586 tenant_id: &TenantId,
587 user_id: &UserId,
588 limit: usize,
589 offset: usize,
590 ) -> anyhow::Result<Vec<Thread>> {
591 let guard = self.threads.read().expect("lock poisoned");
592 let mut threads: Vec<_> = guard
593 .values()
594 .filter(|t| {
595 t.tenant_id == *tenant_id && t.user_id == *user_id && t.deleted_at.is_none()
596 })
597 .cloned()
598 .collect();
599
600 threads.sort_by(|a, b| b.created_at.cmp(&a.created_at));
602
603 Ok(threads.into_iter().skip(offset).take(limit).collect())
604 }
605
606 async fn create_message(&self, message: Message) -> anyhow::Result<MessageId> {
607 let id = message.id.clone();
608 let thread_id = message.thread_id.clone();
609
610 {
612 let mut thread_guard = self.threads.write().expect("lock poisoned");
613 if let Some(thread) = thread_guard.get_mut(&thread_id.to_string()) {
614 thread.updated_at = Utc::now();
615 }
616 }
617
618 let mut guard = self.messages.write().expect("lock poisoned");
619 guard.insert(id.to_string(), message);
620 Ok(id)
621 }
622
623 async fn get_message(&self, message_id: &MessageId) -> anyhow::Result<Option<Message>> {
624 let guard = self.messages.read().expect("lock poisoned");
625 Ok(guard.get(&message_id.to_string()).cloned())
626 }
627
628 async fn update_message(&self, mut message: Message) -> anyhow::Result<()> {
629 message.updated_at = Some(Utc::now());
630 let mut guard = self.messages.write().expect("lock poisoned");
631 guard.insert(message.id.to_string(), message);
632 Ok(())
633 }
634
635 async fn delete_message(&self, message_id: &MessageId) -> anyhow::Result<()> {
636 let mut guard = self.messages.write().expect("lock poisoned");
637 if let Some(message) = guard.get_mut(&message_id.to_string()) {
638 message.deleted_at = Some(Utc::now());
639 }
640 Ok(())
641 }
642
643 async fn list_messages(
644 &self,
645 thread_id: &ThreadId,
646 include_deleted: bool,
647 ) -> anyhow::Result<Vec<Message>> {
648 let guard = self.messages.read().expect("lock poisoned");
649 let mut messages: Vec<_> = guard
650 .values()
651 .filter(|m| m.thread_id == *thread_id && (include_deleted || m.deleted_at.is_none()))
652 .cloned()
653 .collect();
654
655 messages.sort_by(|a, b| a.created_at.cmp(&b.created_at));
657
658 Ok(messages)
659 }
660
661 async fn get_messages_by_execution(
662 &self,
663 execution_id: &ExecutionId,
664 ) -> anyhow::Result<Vec<Message>> {
665 let guard = self.messages.read().expect("lock poisoned");
666 let messages: Vec<_> = guard
667 .values()
668 .filter(|m| m.execution_id.as_ref() == Some(execution_id))
669 .cloned()
670 .collect();
671 Ok(messages)
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678
679 fn test_tenant() -> TenantId {
680 TenantId::from_string("tenant_test")
681 }
682
683 fn test_user() -> UserId {
684 UserId::from_string("user_test")
685 }
686
687 #[tokio::test]
688 async fn test_create_thread() {
689 let store = InMemoryMessageStore::new();
690 let thread = Thread::new(test_tenant(), test_user());
691 let id = thread.id.clone();
692
693 store.create_thread(thread).await.unwrap();
694
695 let loaded = store.get_thread(&id).await.unwrap();
696 assert!(loaded.is_some());
697 assert_eq!(loaded.unwrap().id, id);
698 }
699
700 #[tokio::test]
701 async fn test_soft_delete_thread() {
702 let store = InMemoryMessageStore::new();
703 let thread = Thread::new(test_tenant(), test_user());
704 let id = thread.id.clone();
705
706 store.create_thread(thread).await.unwrap();
707 store.delete_thread(&id).await.unwrap();
708
709 let loaded = store.get_thread(&id).await.unwrap().unwrap();
710 assert!(loaded.is_deleted());
711 }
712
713 #[tokio::test]
714 async fn test_list_threads_excludes_deleted() {
715 let store = InMemoryMessageStore::new();
716 let tenant = test_tenant();
717 let user = test_user();
718
719 let thread1 = Thread::new(tenant.clone(), user.clone());
721 let thread2 = Thread::new(tenant.clone(), user.clone());
722 let id2 = thread2.id.clone();
723
724 store.create_thread(thread1).await.unwrap();
725 store.create_thread(thread2).await.unwrap();
726
727 store.delete_thread(&id2).await.unwrap();
729
730 let threads = store.list_threads(&tenant, &user, 100, 0).await.unwrap();
732 assert_eq!(threads.len(), 1);
733 }
734
735 #[tokio::test]
736 async fn test_create_message() {
737 let store = InMemoryMessageStore::new();
738 let thread = Thread::new(test_tenant(), test_user());
739 let thread_id = thread.id.clone();
740 store.create_thread(thread).await.unwrap();
741
742 let message = Message::user(thread_id.clone(), "Hello!");
743 let msg_id = message.id.clone();
744 store.create_message(message).await.unwrap();
745
746 let loaded = store.get_message(&msg_id).await.unwrap();
747 assert!(loaded.is_some());
748 assert_eq!(loaded.unwrap().content, "Hello!");
749 }
750
751 #[tokio::test]
752 async fn test_message_parent_chain() {
753 let store = InMemoryMessageStore::new();
754 let thread = Thread::new(test_tenant(), test_user());
755 let thread_id = thread.id.clone();
756 store.create_thread(thread).await.unwrap();
757
758 let user_msg = Message::user(thread_id.clone(), "What's the weather?");
760 let user_msg_id = user_msg.id.clone();
761 store.create_message(user_msg).await.unwrap();
762
763 let exec_id = ExecutionId::new();
765 let assistant_msg = Message::assistant(
766 thread_id.clone(),
767 exec_id,
768 "The weather is sunny.",
769 Some(user_msg_id.clone()),
770 );
771 store.create_message(assistant_msg).await.unwrap();
772
773 let messages = store.list_messages(&thread_id, false).await.unwrap();
775 assert_eq!(messages.len(), 2);
776 assert_eq!(messages[0].role, MessageRole::User);
777 assert_eq!(messages[1].role, MessageRole::Assistant);
778 assert_eq!(messages[1].parent_id, Some(user_msg_id));
779 }
780
781 #[tokio::test]
782 async fn test_soft_delete_message() {
783 let store = InMemoryMessageStore::new();
784 let thread = Thread::new(test_tenant(), test_user());
785 let thread_id = thread.id.clone();
786 store.create_thread(thread).await.unwrap();
787
788 let message = Message::user(thread_id.clone(), "Delete me");
789 let msg_id = message.id.clone();
790 store.create_message(message).await.unwrap();
791
792 store.delete_message(&msg_id).await.unwrap();
793
794 let messages = store.list_messages(&thread_id, false).await.unwrap();
796 assert_eq!(messages.len(), 0);
797
798 let all_messages = store.list_messages(&thread_id, true).await.unwrap();
800 assert_eq!(all_messages.len(), 1);
801 assert!(all_messages[0].is_deleted());
802 }
803
804 #[tokio::test]
805 async fn test_get_messages_by_execution() {
806 let store = InMemoryMessageStore::new();
807 let thread = Thread::new(test_tenant(), test_user());
808 let thread_id = thread.id.clone();
809 store.create_thread(thread).await.unwrap();
810
811 let exec_id = ExecutionId::new();
812 let msg = Message::assistant(thread_id, exec_id.clone(), "Response", None);
813 store.create_message(msg).await.unwrap();
814
815 let messages = store.get_messages_by_execution(&exec_id).await.unwrap();
816 assert_eq!(messages.len(), 1);
817 assert_eq!(messages[0].execution_id, Some(exec_id));
818 }
819
820 #[tokio::test]
821 async fn test_message_with_parts() {
822 let store = InMemoryMessageStore::new();
823 let thread = Thread::new(test_tenant(), test_user());
824 let thread_id = thread.id.clone();
825 store.create_thread(thread).await.unwrap();
826
827 let exec_id = ExecutionId::new();
828 let parts = vec![
829 MessagePart::Reasoning {
830 text: "Let me think...".to_string(),
831 },
832 MessagePart::ToolCall {
833 tool_call_id: "tc_123".to_string(),
834 tool_name: "get_weather".to_string(),
835 args: serde_json::json!({"city": "NYC"}),
836 },
837 MessagePart::Text {
838 text: "The weather is sunny.".to_string(),
839 },
840 ];
841
842 let msg =
843 Message::assistant(thread_id, exec_id, "The weather is sunny.", None).with_parts(parts);
844 let msg_id = msg.id.clone();
845 store.create_message(msg).await.unwrap();
846
847 let loaded = store.get_message(&msg_id).await.unwrap().unwrap();
848 assert_eq!(loaded.parts.len(), 3);
849 }
850
851 #[tokio::test]
852 async fn test_message_metadata() {
853 let store = InMemoryMessageStore::new();
854 let thread = Thread::new(test_tenant(), test_user());
855 let thread_id = thread.id.clone();
856 store.create_thread(thread).await.unwrap();
857
858 let exec_id = ExecutionId::new();
859 let metadata = MessageMetadata {
860 model: Some("gpt-4o".to_string()),
861 provider: Some("azure".to_string()),
862 duration_ms: Some(1500),
863 token_usage: Some(TokenUsage::new(100, 200)),
864 stats: Some(ExecutionStats {
865 llm_calls: 2,
866 tool_calls: 1,
867 sub_agents: 0,
868 steps: 3,
869 decisions: 1,
870 artifacts: 0,
871 }),
872 finish_reason: Some(FinishReason::Stop),
873 ..Default::default()
874 };
875
876 let msg = Message::assistant(thread_id, exec_id, "Response", None).with_metadata(metadata);
877 let msg_id = msg.id.clone();
878 store.create_message(msg).await.unwrap();
879
880 let loaded = store.get_message(&msg_id).await.unwrap().unwrap();
881 assert_eq!(loaded.metadata.model, Some("gpt-4o".to_string()));
882 assert_eq!(loaded.metadata.stats.as_ref().unwrap().llm_calls, 2);
883 }
884
885 #[tokio::test]
886 async fn test_thread_updated_on_message() {
887 let store = InMemoryMessageStore::new();
888 let thread = Thread::new(test_tenant(), test_user());
889 let thread_id = thread.id.clone();
890 let original_updated = thread.updated_at;
891 store.create_thread(thread).await.unwrap();
892
893 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
895
896 let message = Message::user(thread_id.clone(), "Hello!");
897 store.create_message(message).await.unwrap();
898
899 let loaded_thread = store.get_thread(&thread_id).await.unwrap().unwrap();
900 assert!(loaded_thread.updated_at > original_updated);
901 }
902}