use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::RwLock;
use crate::kernel::{ExecutionId, MessageId, ParentType, TenantId, ThreadId, UserId};
use super::StorageBackend;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thread {
pub id: ThreadId,
pub tenant_id: TenantId,
pub user_id: UserId,
pub title: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub deleted_at: Option<DateTime<Utc>>,
}
impl Thread {
pub fn new(tenant_id: TenantId, user_id: UserId) -> Self {
let now = Utc::now();
Self {
id: ThreadId::new(),
tenant_id,
user_id,
title: None,
created_at: now,
updated_at: now,
deleted_at: None,
}
}
pub fn with_id(id: ThreadId, tenant_id: TenantId, user_id: UserId) -> Self {
let now = Utc::now();
Self {
id,
tenant_id,
user_id,
title: None,
created_at: now,
updated_at: now,
deleted_at: None,
}
}
pub fn is_deleted(&self) -> bool {
self.deleted_at.is_some()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
System,
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageRole::User => write!(f, "user"),
MessageRole::Assistant => write!(f, "assistant"),
MessageRole::System => write!(f, "system"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum MessagePart {
Text { text: String },
Reasoning { text: String },
ToolCall {
tool_call_id: String,
tool_name: String,
args: serde_json::Value,
},
ToolResult {
tool_call_id: String,
tool_name: String,
result: serde_json::Value,
is_error: bool,
},
Source {
source_id: String,
url: Option<String>,
title: Option<String>,
},
File {
file_id: String,
filename: String,
mime_type: String,
size_bytes: u64,
},
Image {
image_id: String,
url: Option<String>,
alt_text: Option<String>,
},
Code {
language: Option<String>,
code: String,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl TokenUsage {
pub fn new(prompt: u32, completion: u32) -> Self {
Self {
prompt_tokens: prompt,
completion_tokens: completion,
total_tokens: prompt + completion,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ExecutionStats {
pub llm_calls: u32,
pub tool_calls: u32,
pub sub_agents: u32,
pub steps: u32,
pub decisions: u32,
pub artifacts: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostInfo {
pub input_cost: f64,
pub output_cost: f64,
pub total_cost: f64,
pub currency: String,
}
impl Default for CostInfo {
fn default() -> Self {
Self {
input_cost: 0.0,
output_cost: 0.0,
total_cost: 0.0,
currency: "USD".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
Error,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MessageMetadata {
pub completed_at: Option<i64>,
pub duration_ms: Option<u64>,
pub model: Option<String>,
pub provider: Option<String>,
pub token_usage: Option<TokenUsage>,
pub stats: Option<ExecutionStats>,
pub finish_reason: Option<FinishReason>,
pub cost: Option<CostInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: MessageId,
pub thread_id: ThreadId,
pub execution_id: Option<ExecutionId>,
pub parent_id: Option<MessageId>,
pub parent_type: ParentType,
pub role: MessageRole,
pub content: String,
pub parts: Vec<MessagePart>,
pub created_at: DateTime<Utc>,
pub updated_at: Option<DateTime<Utc>>,
pub deleted_at: Option<DateTime<Utc>>,
pub metadata: MessageMetadata,
}
impl Message {
pub fn user(thread_id: ThreadId, content: impl Into<String>) -> Self {
let content = content.into();
Self {
id: MessageId::new(),
thread_id,
execution_id: None,
parent_id: None,
parent_type: ParentType::UserMessage,
role: MessageRole::User,
content: content.clone(),
parts: vec![MessagePart::Text { text: content }],
created_at: Utc::now(),
updated_at: None,
deleted_at: None,
metadata: MessageMetadata::default(),
}
}
pub fn assistant(
thread_id: ThreadId,
execution_id: ExecutionId,
content: impl Into<String>,
parent_id: Option<MessageId>,
) -> Self {
let content = content.into();
Self {
id: MessageId::new(),
thread_id,
execution_id: Some(execution_id),
parent_id,
parent_type: ParentType::UserMessage,
role: MessageRole::Assistant,
content: content.clone(),
parts: vec![MessagePart::Text { text: content }],
created_at: Utc::now(),
updated_at: None,
deleted_at: None,
metadata: MessageMetadata::default(),
}
}
pub fn system(thread_id: ThreadId, content: impl Into<String>) -> Self {
let content = content.into();
Self {
id: MessageId::new(),
thread_id,
execution_id: None,
parent_id: None,
parent_type: ParentType::System,
role: MessageRole::System,
content: content.clone(),
parts: vec![MessagePart::Text { text: content }],
created_at: Utc::now(),
updated_at: None,
deleted_at: None,
metadata: MessageMetadata::default(),
}
}
pub fn is_deleted(&self) -> bool {
self.deleted_at.is_some()
}
pub fn with_parent(mut self, parent_id: MessageId, parent_type: ParentType) -> Self {
self.parent_id = Some(parent_id);
self.parent_type = parent_type;
self
}
pub fn with_parts(mut self, parts: Vec<MessagePart>) -> Self {
self.parts = parts;
self
}
pub fn with_metadata(mut self, metadata: MessageMetadata) -> Self {
self.metadata = metadata;
self
}
}
#[async_trait]
pub trait MessageStore: StorageBackend {
async fn create_thread(&self, thread: Thread) -> anyhow::Result<ThreadId>;
async fn get_thread(&self, thread_id: &ThreadId) -> anyhow::Result<Option<Thread>>;
async fn update_thread(&self, thread: Thread) -> anyhow::Result<()>;
async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()>;
async fn list_threads(
&self,
tenant_id: &TenantId,
user_id: &UserId,
limit: usize,
offset: usize,
) -> anyhow::Result<Vec<Thread>>;
async fn create_message(&self, message: Message) -> anyhow::Result<MessageId>;
async fn get_message(&self, message_id: &MessageId) -> anyhow::Result<Option<Message>>;
async fn update_message(&self, message: Message) -> anyhow::Result<()>;
async fn delete_message(&self, message_id: &MessageId) -> anyhow::Result<()>;
async fn list_messages(
&self,
thread_id: &ThreadId,
include_deleted: bool,
) -> anyhow::Result<Vec<Message>>;
async fn get_messages_by_execution(
&self,
execution_id: &ExecutionId,
) -> anyhow::Result<Vec<Message>>;
async fn get_or_create_thread(
&self,
tenant_id: TenantId,
user_id: UserId,
) -> anyhow::Result<Thread> {
let thread = Thread::new(tenant_id, user_id);
self.create_thread(thread.clone()).await?;
Ok(thread)
}
async fn count_messages(&self, thread_id: &ThreadId) -> anyhow::Result<u64> {
let messages = self.list_messages(thread_id, false).await?;
Ok(messages.len() as u64)
}
}
#[derive(Default)]
pub struct InMemoryMessageStore {
threads: RwLock<HashMap<String, Thread>>,
messages: RwLock<HashMap<String, Message>>,
}
impl InMemoryMessageStore {
pub fn new() -> Self {
Self {
threads: RwLock::new(HashMap::new()),
messages: RwLock::new(HashMap::new()),
}
}
pub fn shared() -> std::sync::Arc<Self> {
std::sync::Arc::new(Self::new())
}
}
#[async_trait]
impl StorageBackend for InMemoryMessageStore {
fn name(&self) -> &str {
"in-memory-message-store"
}
fn requires_network(&self) -> bool {
false
}
async fn health_check(&self) -> anyhow::Result<()> {
Ok(())
}
}
#[async_trait]
impl MessageStore for InMemoryMessageStore {
async fn create_thread(&self, thread: Thread) -> anyhow::Result<ThreadId> {
let id = thread.id.clone();
let mut guard = self.threads.write().expect("lock poisoned");
guard.insert(id.to_string(), thread);
Ok(id)
}
async fn get_thread(&self, thread_id: &ThreadId) -> anyhow::Result<Option<Thread>> {
let guard = self.threads.read().expect("lock poisoned");
Ok(guard.get(&thread_id.to_string()).cloned())
}
async fn update_thread(&self, thread: Thread) -> anyhow::Result<()> {
let mut guard = self.threads.write().expect("lock poisoned");
guard.insert(thread.id.to_string(), thread);
Ok(())
}
async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
let mut guard = self.threads.write().expect("lock poisoned");
if let Some(thread) = guard.get_mut(&thread_id.to_string()) {
thread.deleted_at = Some(Utc::now());
}
Ok(())
}
async fn list_threads(
&self,
tenant_id: &TenantId,
user_id: &UserId,
limit: usize,
offset: usize,
) -> anyhow::Result<Vec<Thread>> {
let guard = self.threads.read().expect("lock poisoned");
let mut threads: Vec<_> = guard
.values()
.filter(|t| {
t.tenant_id == *tenant_id && t.user_id == *user_id && t.deleted_at.is_none()
})
.cloned()
.collect();
threads.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(threads.into_iter().skip(offset).take(limit).collect())
}
async fn create_message(&self, message: Message) -> anyhow::Result<MessageId> {
let id = message.id.clone();
let thread_id = message.thread_id.clone();
{
let mut thread_guard = self.threads.write().expect("lock poisoned");
if let Some(thread) = thread_guard.get_mut(&thread_id.to_string()) {
thread.updated_at = Utc::now();
}
}
let mut guard = self.messages.write().expect("lock poisoned");
guard.insert(id.to_string(), message);
Ok(id)
}
async fn get_message(&self, message_id: &MessageId) -> anyhow::Result<Option<Message>> {
let guard = self.messages.read().expect("lock poisoned");
Ok(guard.get(&message_id.to_string()).cloned())
}
async fn update_message(&self, mut message: Message) -> anyhow::Result<()> {
message.updated_at = Some(Utc::now());
let mut guard = self.messages.write().expect("lock poisoned");
guard.insert(message.id.to_string(), message);
Ok(())
}
async fn delete_message(&self, message_id: &MessageId) -> anyhow::Result<()> {
let mut guard = self.messages.write().expect("lock poisoned");
if let Some(message) = guard.get_mut(&message_id.to_string()) {
message.deleted_at = Some(Utc::now());
}
Ok(())
}
async fn list_messages(
&self,
thread_id: &ThreadId,
include_deleted: bool,
) -> anyhow::Result<Vec<Message>> {
let guard = self.messages.read().expect("lock poisoned");
let mut messages: Vec<_> = guard
.values()
.filter(|m| m.thread_id == *thread_id && (include_deleted || m.deleted_at.is_none()))
.cloned()
.collect();
messages.sort_by(|a, b| a.created_at.cmp(&b.created_at));
Ok(messages)
}
async fn get_messages_by_execution(
&self,
execution_id: &ExecutionId,
) -> anyhow::Result<Vec<Message>> {
let guard = self.messages.read().expect("lock poisoned");
let messages: Vec<_> = guard
.values()
.filter(|m| m.execution_id.as_ref() == Some(execution_id))
.cloned()
.collect();
Ok(messages)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tenant() -> TenantId {
TenantId::from_string("tenant_test")
}
fn test_user() -> UserId {
UserId::from_string("user_test")
}
#[tokio::test]
async fn test_create_thread() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let id = thread.id.clone();
store.create_thread(thread).await.unwrap();
let loaded = store.get_thread(&id).await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().id, id);
}
#[tokio::test]
async fn test_soft_delete_thread() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let id = thread.id.clone();
store.create_thread(thread).await.unwrap();
store.delete_thread(&id).await.unwrap();
let loaded = store.get_thread(&id).await.unwrap().unwrap();
assert!(loaded.is_deleted());
}
#[tokio::test]
async fn test_list_threads_excludes_deleted() {
let store = InMemoryMessageStore::new();
let tenant = test_tenant();
let user = test_user();
let thread1 = Thread::new(tenant.clone(), user.clone());
let thread2 = Thread::new(tenant.clone(), user.clone());
let id2 = thread2.id.clone();
store.create_thread(thread1).await.unwrap();
store.create_thread(thread2).await.unwrap();
store.delete_thread(&id2).await.unwrap();
let threads = store.list_threads(&tenant, &user, 100, 0).await.unwrap();
assert_eq!(threads.len(), 1);
}
#[tokio::test]
async fn test_create_message() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let thread_id = thread.id.clone();
store.create_thread(thread).await.unwrap();
let message = Message::user(thread_id.clone(), "Hello!");
let msg_id = message.id.clone();
store.create_message(message).await.unwrap();
let loaded = store.get_message(&msg_id).await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().content, "Hello!");
}
#[tokio::test]
async fn test_message_parent_chain() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let thread_id = thread.id.clone();
store.create_thread(thread).await.unwrap();
let user_msg = Message::user(thread_id.clone(), "What's the weather?");
let user_msg_id = user_msg.id.clone();
store.create_message(user_msg).await.unwrap();
let exec_id = ExecutionId::new();
let assistant_msg = Message::assistant(
thread_id.clone(),
exec_id,
"The weather is sunny.",
Some(user_msg_id.clone()),
);
store.create_message(assistant_msg).await.unwrap();
let messages = store.list_messages(&thread_id, false).await.unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, MessageRole::User);
assert_eq!(messages[1].role, MessageRole::Assistant);
assert_eq!(messages[1].parent_id, Some(user_msg_id));
}
#[tokio::test]
async fn test_soft_delete_message() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let thread_id = thread.id.clone();
store.create_thread(thread).await.unwrap();
let message = Message::user(thread_id.clone(), "Delete me");
let msg_id = message.id.clone();
store.create_message(message).await.unwrap();
store.delete_message(&msg_id).await.unwrap();
let messages = store.list_messages(&thread_id, false).await.unwrap();
assert_eq!(messages.len(), 0);
let all_messages = store.list_messages(&thread_id, true).await.unwrap();
assert_eq!(all_messages.len(), 1);
assert!(all_messages[0].is_deleted());
}
#[tokio::test]
async fn test_get_messages_by_execution() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let thread_id = thread.id.clone();
store.create_thread(thread).await.unwrap();
let exec_id = ExecutionId::new();
let msg = Message::assistant(thread_id, exec_id.clone(), "Response", None);
store.create_message(msg).await.unwrap();
let messages = store.get_messages_by_execution(&exec_id).await.unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].execution_id, Some(exec_id));
}
#[tokio::test]
async fn test_message_with_parts() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let thread_id = thread.id.clone();
store.create_thread(thread).await.unwrap();
let exec_id = ExecutionId::new();
let parts = vec![
MessagePart::Reasoning {
text: "Let me think...".to_string(),
},
MessagePart::ToolCall {
tool_call_id: "tc_123".to_string(),
tool_name: "get_weather".to_string(),
args: serde_json::json!({"city": "NYC"}),
},
MessagePart::Text {
text: "The weather is sunny.".to_string(),
},
];
let msg =
Message::assistant(thread_id, exec_id, "The weather is sunny.", None).with_parts(parts);
let msg_id = msg.id.clone();
store.create_message(msg).await.unwrap();
let loaded = store.get_message(&msg_id).await.unwrap().unwrap();
assert_eq!(loaded.parts.len(), 3);
}
#[tokio::test]
async fn test_message_metadata() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let thread_id = thread.id.clone();
store.create_thread(thread).await.unwrap();
let exec_id = ExecutionId::new();
let metadata = MessageMetadata {
model: Some("gpt-4o".to_string()),
provider: Some("azure".to_string()),
duration_ms: Some(1500),
token_usage: Some(TokenUsage::new(100, 200)),
stats: Some(ExecutionStats {
llm_calls: 2,
tool_calls: 1,
sub_agents: 0,
steps: 3,
decisions: 1,
artifacts: 0,
}),
finish_reason: Some(FinishReason::Stop),
..Default::default()
};
let msg = Message::assistant(thread_id, exec_id, "Response", None).with_metadata(metadata);
let msg_id = msg.id.clone();
store.create_message(msg).await.unwrap();
let loaded = store.get_message(&msg_id).await.unwrap().unwrap();
assert_eq!(loaded.metadata.model, Some("gpt-4o".to_string()));
assert_eq!(loaded.metadata.stats.as_ref().unwrap().llm_calls, 2);
}
#[tokio::test]
async fn test_thread_updated_on_message() {
let store = InMemoryMessageStore::new();
let thread = Thread::new(test_tenant(), test_user());
let thread_id = thread.id.clone();
let original_updated = thread.updated_at;
store.create_thread(thread).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let message = Message::user(thread_id.clone(), "Hello!");
store.create_message(message).await.unwrap();
let loaded_thread = store.get_thread(&thread_id).await.unwrap().unwrap();
assert!(loaded_thread.updated_at > original_updated);
}
}