Skip to main content

ares/db/
traits.rs

1use crate::types::{AppError, MemoryFact, Message, MessageRole, Preference, Result};
2use async_trait::async_trait;
3
4#[derive(Debug, Clone, Default)]
5pub enum DatabaseProvider {
6    #[default]
7    Memory,
8    SQLite {
9        path: String,
10    },
11    Postgres {
12        url: String,
13    },
14}
15
16impl DatabaseProvider {
17    pub async fn create_client(&self) -> Result<Box<dyn DatabaseClient>> {
18        match self {
19            DatabaseProvider::Memory => {
20                let client = super::postgres::PostgresClient::new_memory().await?;
21                Ok(Box::new(client))
22            }
23            DatabaseProvider::SQLite { path } => {
24                let client = super::postgres::PostgresClient::new_local(path).await?;
25                Ok(Box::new(client))
26            }
27            DatabaseProvider::Postgres { url } => {
28                let client = super::postgres::PostgresClient::new_remote(url.clone(), "".to_string()).await?;
29                Ok(Box::new(client))
30            }
31        }
32    }
33
34    pub fn from_env() -> Self {
35        if let Ok(url) = std::env::var("DATABASE_URL") {
36            if !url.is_empty() { return DatabaseProvider::Postgres { url }; }
37        }
38        if let Ok(path) = std::env::var("DATABASE_PATH") {
39            if !path.is_empty() && path != ":memory:" { return DatabaseProvider::SQLite { path }; }
40        }
41        DatabaseProvider::Memory
42    }
43}
44
45pub use super::postgres::User;
46
47#[derive(Debug, Clone, sqlx::FromRow)]
48pub struct ConversationSummary {
49    pub id: String,
50    pub title: String,
51    pub created_at: String,
52    pub updated_at: String,
53    pub message_count: i32,
54}
55
56#[async_trait]
57pub trait DatabaseClient: Send + Sync {
58    async fn create_user(&self, id: &str, email: &str, password_hash: &str, name: &str) -> Result<()>;
59    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>>;
60    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>>;
61    async fn create_session(&self, id: &str, user_id: &str, token_hash: &str, expires_at: i64) -> Result<()>;
62    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>>;
63    async fn delete_session(&self, id: &str) -> Result<()>;
64    async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()>;
65    async fn create_conversation(&self, id: &str, user_id: &str, title: Option<&str>) -> Result<()>;
66    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool>;
67    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>>;
68    async fn get_conversation(&self, conversation_id: &str) -> Result<super::postgres::Conversation>;
69    async fn delete_conversation(&self, conversation_id: &str) -> Result<()>;
70    async fn update_conversation_title(&self, conversation_id: &str, title: Option<&str>) -> Result<()>;
71    async fn add_message(&self, id: &str, conversation_id: &str, role: MessageRole, content: &str) -> Result<()>;
72    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>>;
73    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()>;
74    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>>;
75    async fn get_memory_by_category(&self, user_id: &str, category: &str) -> Result<Vec<MemoryFact>>;
76    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()>;
77    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>>;
78    async fn get_preference(&self, user_id: &str, category: &str, key: &str) -> Result<Option<Preference>>;
79    async fn get_user_agent_by_name(&self, user_id: &str, name: &str) -> Result<Option<super::postgres::UserAgent>>;
80    async fn get_public_agent_by_name(&self, name: &str) -> Result<Option<super::postgres::UserAgent>>;
81    async fn list_user_agents(&self, user_id: &str) -> Result<Vec<super::postgres::UserAgent>>;
82    async fn list_public_agents(&self, limit: u32, offset: u32) -> Result<Vec<super::postgres::UserAgent>>;
83    async fn create_user_agent(&self, agent: &super::postgres::UserAgent) -> Result<()>;
84    async fn update_user_agent(&self, agent: &super::postgres::UserAgent) -> Result<()>;
85    async fn delete_user_agent(&self, id: &str, user_id: &str) -> Result<bool>;
86}
87
88#[async_trait]
89impl DatabaseClient for super::postgres::PostgresClient {
90    async fn create_user(&self, id: &str, email: &str, password_hash: &str, name: &str) -> Result<()> { super::postgres::PostgresClient::create_user(self, id, email, password_hash, name).await }
91    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> { super::postgres::PostgresClient::get_user_by_email(self, email).await }
92    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> { super::postgres::PostgresClient::get_user_by_id(self, id).await }
93    async fn create_session(&self, id: &str, user_id: &str, token_hash: &str, expires_at: i64) -> Result<()> { super::postgres::PostgresClient::create_session(self, id, user_id, token_hash, expires_at).await }
94    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>> { super::postgres::PostgresClient::validate_session(self, token_hash).await }
95    async fn delete_session(&self, id: &str) -> Result<()> { super::postgres::PostgresClient::delete_session(self, id).await }
96    async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()> { super::postgres::PostgresClient::delete_session_by_token_hash(self, token_hash).await }
97    async fn create_conversation(&self, id: &str, user_id: &str, title: Option<&str>) -> Result<()> { super::postgres::PostgresClient::create_conversation(self, id, user_id, title).await }
98    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool> { super::postgres::PostgresClient::conversation_exists(self, conversation_id).await }
99    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>> { super::postgres::PostgresClient::get_user_conversations(self, user_id).await }
100    async fn get_conversation(&self, conversation_id: &str) -> Result<super::postgres::Conversation> { 
101        let row = sqlx::query_as::<_, super::postgres::Conversation>("SELECT id, user_id, title, created_at, updated_at, 0 as message_count FROM conversations WHERE id = $1").bind(conversation_id).fetch_optional(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
102        row.ok_or_else(|| AppError::NotFound("Conversation not found".into()))
103    }
104    async fn delete_conversation(&self, conversation_id: &str) -> Result<()> { 
105        sqlx::query("DELETE FROM messages WHERE conversation_id = $1").bind(conversation_id).execute(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
106        sqlx::query("DELETE FROM conversations WHERE id = $1").bind(conversation_id).execute(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
107        Ok(())
108    }
109    async fn update_conversation_title(&self, conversation_id: &str, title: Option<&str>) -> Result<()> { 
110        let now = chrono::Utc::now().timestamp();
111        sqlx::query("UPDATE conversations SET title = $1, updated_at = $2 WHERE id = $3").bind(title).bind(now).bind(conversation_id).execute(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
112        Ok(())
113    }
114    async fn add_message(&self, id: &str, conversation_id: &str, role: MessageRole, content: &str) -> Result<()> { super::postgres::PostgresClient::add_message(self, id, conversation_id, role, content).await }
115    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>> { super::postgres::PostgresClient::get_conversation_history(self, conversation_id).await }
116    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()> { super::postgres::PostgresClient::store_memory_fact(self, fact).await }
117    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>> { super::postgres::PostgresClient::get_user_memory(self, user_id).await }
118    async fn get_memory_by_category(&self, user_id: &str, category: &str) -> Result<Vec<MemoryFact>> {
119        let mems = super::postgres::PostgresClient::get_user_memory(self, user_id).await?;
120        Ok(mems.into_iter().filter(|m| m.category == category).collect())
121    }
122    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()> { super::postgres::PostgresClient::store_preference(self, user_id, preference).await }
123    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>> { super::postgres::PostgresClient::get_user_preferences(self, user_id).await }
124    async fn get_preference(&self, user_id: &str, category: &str, key: &str) -> Result<Option<Preference>> {
125        let prefs = super::postgres::PostgresClient::get_user_preferences(self, user_id).await?;
126        Ok(prefs.into_iter().find(|p| p.category == category && p.key == key))
127    }
128    async fn get_user_agent_by_name(&self, user_id: &str, name: &str) -> Result<Option<super::postgres::UserAgent>> { super::postgres::PostgresClient::get_user_agent_by_name(self, user_id, name).await }
129    async fn get_public_agent_by_name(&self, name: &str) -> Result<Option<super::postgres::UserAgent>> { 
130        super::postgres::PostgresClient::get_user_agent_by_name(self, "", name).await 
131    }
132    async fn list_user_agents(&self, _user_id: &str) -> Result<Vec<super::postgres::UserAgent>> { Ok(vec![]) } 
133    async fn list_public_agents(&self, _limit: u32, _offset: u32) -> Result<Vec<super::postgres::UserAgent>> { Ok(vec![]) } 
134    async fn create_user_agent(&self, _agent: &super::postgres::UserAgent) -> Result<()> { Ok(()) } 
135    async fn update_user_agent(&self, _agent: &super::postgres::UserAgent) -> Result<()> { Ok(()) } 
136    async fn delete_user_agent(&self, _id: &str, _user_id: &str) -> Result<bool> { Ok(true) } 
137}