use anyhow::{Context, Result};
use chrono::Utc;
use std::sync::Arc;
use crate::databases::{
FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
};
const TABLE_NAME: &str = "conversations";
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ConversationMetadata {
pub conversation_id: String,
pub title: Option<String>,
pub model_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
pub message_count: i32,
}
fn table_schema() -> Vec<FieldDef> {
vec![
FieldDef::required("conversation_id", FieldType::Utf8),
FieldDef::optional("title", FieldType::Utf8),
FieldDef::optional("model_id", FieldType::Utf8),
FieldDef::required("created_at", FieldType::Int64),
FieldDef::required("updated_at", FieldType::Int64),
FieldDef::required("message_count", FieldType::Int32),
]
}
fn to_record(m: &ConversationMetadata) -> Record {
vec![
(
"conversation_id".into(),
FieldValue::Utf8(Some(m.conversation_id.clone())),
),
("title".into(), FieldValue::Utf8(m.title.clone())),
("model_id".into(), FieldValue::Utf8(m.model_id.clone())),
("created_at".into(), FieldValue::Int64(Some(m.created_at))),
("updated_at".into(), FieldValue::Int64(Some(m.updated_at))),
(
"message_count".into(),
FieldValue::Int32(Some(m.message_count)),
),
]
}
fn from_record(r: &Record) -> Result<ConversationMetadata> {
Ok(ConversationMetadata {
conversation_id: record_get(r, "conversation_id")
.and_then(|v| v.as_str())
.context("missing conversation_id")?
.to_string(),
title: record_get(r, "title")
.and_then(|v| v.as_str())
.map(String::from),
model_id: record_get(r, "model_id")
.and_then(|v| v.as_str())
.map(String::from),
created_at: record_get(r, "created_at")
.and_then(|v| v.as_i64())
.context("missing created_at")?,
updated_at: record_get(r, "updated_at")
.and_then(|v| v.as_i64())
.context("missing updated_at")?,
message_count: record_get(r, "message_count")
.and_then(|v| v.as_i32())
.context("missing message_count")?,
})
}
pub struct ConversationStore<B: StorageBackend = crate::databases::lance::LanceDatabase> {
backend: Arc<B>,
}
impl<B: StorageBackend> ConversationStore<B> {
pub fn new(backend: Arc<B>) -> Self {
Self { backend }
}
pub async fn ensure_table(&self) -> Result<()> {
self.backend.ensure_table(TABLE_NAME, &table_schema()).await
}
pub async fn create(
&self,
conversation_id: String,
title: Option<String>,
model_id: Option<String>,
message_count: Option<i32>,
) -> Result<ConversationMetadata> {
if let Ok(Some(existing)) = self.get(&conversation_id).await {
self.update(
&conversation_id,
title.or(existing.title.clone()),
message_count,
)
.await?;
return self
.get(&conversation_id)
.await?
.context("Conversation should exist after update");
}
let now = Utc::now().timestamp();
let metadata = ConversationMetadata {
conversation_id,
title,
model_id,
created_at: now,
updated_at: now,
message_count: message_count.unwrap_or(0),
};
self.backend
.insert(TABLE_NAME, vec![to_record(&metadata)])
.await
.context("Failed to create conversation")?;
Ok(metadata)
}
pub async fn get(&self, conversation_id: &str) -> Result<Option<ConversationMetadata>> {
let filter = Filter::Eq(
"conversation_id".into(),
FieldValue::Utf8(Some(conversation_id.to_string())),
);
let records = self
.backend
.query(TABLE_NAME, Some(&filter), Some(1))
.await?;
match records.first() {
Some(r) => Ok(Some(from_record(r)?)),
None => Ok(None),
}
}
pub async fn list(&self, limit: Option<usize>) -> Result<Vec<ConversationMetadata>> {
let records = self.backend.query(TABLE_NAME, None, None).await?;
let mut conversations: Vec<ConversationMetadata> =
records.iter().filter_map(|r| from_record(r).ok()).collect();
conversations.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
if let Some(limit) = limit {
conversations.truncate(limit);
}
Ok(conversations)
}
pub async fn update(
&self,
conversation_id: &str,
title: Option<String>,
message_count: Option<i32>,
) -> Result<()> {
let current = self
.get(conversation_id)
.await?
.context("Conversation not found")?;
let filter = Filter::Eq(
"conversation_id".into(),
FieldValue::Utf8(Some(conversation_id.to_string())),
);
self.backend.delete(TABLE_NAME, &filter).await?;
let updated = ConversationMetadata {
conversation_id: conversation_id.to_string(),
title: title.or(current.title),
model_id: current.model_id,
created_at: current.created_at,
updated_at: Utc::now().timestamp(),
message_count: message_count.unwrap_or(current.message_count),
};
self.backend
.insert(TABLE_NAME, vec![to_record(&updated)])
.await?;
Ok(())
}
pub async fn delete(&self, conversation_id: &str) -> Result<()> {
let filter = Filter::Eq(
"conversation_id".into(),
FieldValue::Utf8(Some(conversation_id.to_string())),
);
self.backend.delete(TABLE_NAME, &filter).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn setup() -> (
TempDir,
ConversationStore<crate::databases::lance::LanceDatabase>,
) {
let temp = TempDir::new().unwrap();
let db_path = temp.path().join("test.lance");
let backend = Arc::new(
crate::databases::lance::LanceDatabase::new(db_path.to_str().unwrap())
.await
.unwrap(),
);
let store = ConversationStore::new(Arc::clone(&backend));
store.ensure_table().await.unwrap();
(temp, store)
}
#[tokio::test]
async fn test_create_conversation() {
let (_temp, store) = setup().await;
let conv = store
.create(
"test-conv-1".to_string(),
Some("Test Conversation".to_string()),
Some("gpt-4".to_string()),
None,
)
.await
.unwrap();
assert_eq!(conv.conversation_id, "test-conv-1");
assert_eq!(conv.title, Some("Test Conversation".to_string()));
assert_eq!(conv.model_id, Some("gpt-4".to_string()));
assert_eq!(conv.message_count, 0);
}
#[tokio::test]
async fn test_create_conversation_with_message_count() {
let (_temp, store) = setup().await;
let conv = store
.create(
"test-conv-1".to_string(),
Some("Test Conversation".to_string()),
Some("gpt-4".to_string()),
Some(5),
)
.await
.unwrap();
assert_eq!(conv.conversation_id, "test-conv-1");
assert_eq!(conv.message_count, 5);
}
#[tokio::test]
async fn test_get_conversation() {
let (_temp, store) = setup().await;
store
.create(
"test-conv-2".to_string(),
Some("Test".to_string()),
None,
None,
)
.await
.unwrap();
let conv = store.get("test-conv-2").await.unwrap();
assert!(conv.is_some());
assert_eq!(conv.unwrap().conversation_id, "test-conv-2");
}
#[tokio::test]
async fn test_get_nonexistent() {
let (_temp, store) = setup().await;
let conv = store.get("nonexistent").await.unwrap();
assert!(conv.is_none());
}
#[tokio::test]
async fn test_list_conversations() {
let (_temp, store) = setup().await;
store
.create("conv-1".to_string(), Some("Conv 1".to_string()), None, None)
.await
.unwrap();
store
.create("conv-2".to_string(), Some("Conv 2".to_string()), None, None)
.await
.unwrap();
store
.create("conv-3".to_string(), Some("Conv 3".to_string()), None, None)
.await
.unwrap();
let convs = store.list(None).await.unwrap();
assert_eq!(convs.len(), 3);
}
#[tokio::test]
async fn test_update_conversation() {
let (_temp, store) = setup().await;
store
.create(
"conv-update".to_string(),
Some("Original".to_string()),
None,
None,
)
.await
.unwrap();
store
.update("conv-update", Some("Updated".to_string()), Some(5))
.await
.unwrap();
let conv = store.get("conv-update").await.unwrap().unwrap();
assert_eq!(conv.title, Some("Updated".to_string()));
assert_eq!(conv.message_count, 5);
}
#[tokio::test]
async fn test_delete_conversation() {
let (_temp, store) = setup().await;
store
.create("conv-delete".to_string(), None, None, None)
.await
.unwrap();
let conv = store.get("conv-delete").await.unwrap();
assert!(conv.is_some());
store.delete("conv-delete").await.unwrap();
let conv = store.get("conv-delete").await.unwrap();
assert!(conv.is_none());
}
}