Skip to main content

ares/db/
traits.rs

1use crate::types::{AppError, MemoryFact, Message, MessageRole, Preference, Result};
2use async_trait::async_trait;
3
4#[derive(Debug, Clone, Default)]
5pub enum DatabaseProvider {
6    #[default]
7    Memory,
8    SQLite {
9        path: String,
10    },
11    Postgres {
12        url: String,
13    },
14    #[cfg(feature = "turso")]
15    Turso {
16        url: String,
17        auth_token: String,
18    },
19}
20
21impl DatabaseProvider {
22    pub async fn create_client(&self) -> Result<Box<dyn DatabaseClient>> {
23        match self {
24            DatabaseProvider::Memory => {
25                let client = super::postgres::PostgresClient::new_memory().await?;
26                Ok(Box::new(client))
27            }
28            DatabaseProvider::SQLite { path } => {
29                let client = super::postgres::PostgresClient::new_local(path).await?;
30                Ok(Box::new(client))
31            }
32            DatabaseProvider::Postgres { url } => {
33                let client =
34                    super::postgres::PostgresClient::new_remote(url.clone(), "".to_string())
35                        .await?;
36                Ok(Box::new(client))
37            }
38            #[cfg(feature = "turso")]
39            DatabaseProvider::Turso { url, auth_token } => {
40                let client = super::turso::TursoClient::new(url.clone(), auth_token.clone()).await?;
41                Ok(Box::new(client))
42            }
43        }
44    }
45
46    pub fn from_env() -> Self {
47        // Turso takes priority if TURSO_URL is set
48        #[cfg(feature = "turso")]
49        if let Ok(url) = std::env::var("TURSO_URL") {
50            if !url.is_empty() {
51                let token = std::env::var("TURSO_AUTH_TOKEN").unwrap_or_default();
52                return DatabaseProvider::Turso { url, auth_token: token };
53            }
54        }
55        if let Ok(url) = std::env::var("DATABASE_URL") {
56            if !url.is_empty() {
57                return DatabaseProvider::Postgres { url };
58            }
59        }
60        if let Ok(path) = std::env::var("DATABASE_PATH") {
61            if !path.is_empty() && path != ":memory:" {
62                return DatabaseProvider::SQLite { path };
63            }
64        }
65        DatabaseProvider::Memory
66    }
67}
68
69pub use super::postgres::User;
70
71#[derive(Debug, Clone, sqlx::FromRow)]
72pub struct ConversationSummary {
73    pub id: String,
74    pub title: String,
75    pub created_at: String,
76    pub updated_at: String,
77    pub message_count: i32,
78}
79
80#[async_trait]
81pub trait DatabaseClient: Send + Sync {
82    async fn create_user(
83        &self,
84        id: &str,
85        email: &str,
86        password_hash: &str,
87        name: &str,
88    ) -> Result<()>;
89    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>>;
90    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>>;
91    async fn create_session(
92        &self,
93        id: &str,
94        user_id: &str,
95        token_hash: &str,
96        expires_at: i64,
97    ) -> Result<()>;
98    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>>;
99    async fn delete_session(&self, id: &str) -> Result<()>;
100    async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()>;
101    async fn create_conversation(&self, id: &str, user_id: &str, title: Option<&str>)
102        -> Result<()>;
103    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool>;
104    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>>;
105    async fn get_conversation(
106        &self,
107        conversation_id: &str,
108    ) -> Result<super::postgres::Conversation>;
109    async fn delete_conversation(&self, conversation_id: &str) -> Result<()>;
110    async fn update_conversation_title(
111        &self,
112        conversation_id: &str,
113        title: Option<&str>,
114    ) -> Result<()>;
115    async fn add_message(
116        &self,
117        id: &str,
118        conversation_id: &str,
119        role: MessageRole,
120        content: &str,
121    ) -> Result<()>;
122    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>>;
123    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()>;
124    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>>;
125    async fn get_memory_by_category(
126        &self,
127        user_id: &str,
128        category: &str,
129    ) -> Result<Vec<MemoryFact>>;
130    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()>;
131    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>>;
132    async fn get_preference(
133        &self,
134        user_id: &str,
135        category: &str,
136        key: &str,
137    ) -> Result<Option<Preference>>;
138    async fn get_user_agent_by_name(
139        &self,
140        user_id: &str,
141        name: &str,
142    ) -> Result<Option<super::postgres::UserAgent>>;
143    async fn get_public_agent_by_name(
144        &self,
145        name: &str,
146    ) -> Result<Option<super::postgres::UserAgent>>;
147    async fn list_user_agents(&self, user_id: &str) -> Result<Vec<super::postgres::UserAgent>>;
148    async fn list_public_agents(
149        &self,
150        limit: u32,
151        offset: u32,
152    ) -> Result<Vec<super::postgres::UserAgent>>;
153    async fn create_user_agent(&self, agent: &super::postgres::UserAgent) -> Result<()>;
154    async fn update_user_agent(&self, agent: &super::postgres::UserAgent) -> Result<()>;
155    async fn delete_user_agent(&self, id: &str, user_id: &str) -> Result<bool>;
156}
157
158#[async_trait]
159impl DatabaseClient for super::postgres::PostgresClient {
160    async fn create_user(
161        &self,
162        id: &str,
163        email: &str,
164        password_hash: &str,
165        name: &str,
166    ) -> Result<()> {
167        super::postgres::PostgresClient::create_user(self, id, email, password_hash, name).await
168    }
169    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
170        super::postgres::PostgresClient::get_user_by_email(self, email).await
171    }
172    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
173        super::postgres::PostgresClient::get_user_by_id(self, id).await
174    }
175    async fn create_session(
176        &self,
177        id: &str,
178        user_id: &str,
179        token_hash: &str,
180        expires_at: i64,
181    ) -> Result<()> {
182        super::postgres::PostgresClient::create_session(self, id, user_id, token_hash, expires_at)
183            .await
184    }
185    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>> {
186        super::postgres::PostgresClient::validate_session(self, token_hash).await
187    }
188    async fn delete_session(&self, id: &str) -> Result<()> {
189        super::postgres::PostgresClient::delete_session(self, id).await
190    }
191    async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()> {
192        super::postgres::PostgresClient::delete_session_by_token_hash(self, token_hash).await
193    }
194    async fn create_conversation(
195        &self,
196        id: &str,
197        user_id: &str,
198        title: Option<&str>,
199    ) -> Result<()> {
200        super::postgres::PostgresClient::create_conversation(self, id, user_id, title).await
201    }
202    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool> {
203        super::postgres::PostgresClient::conversation_exists(self, conversation_id).await
204    }
205    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>> {
206        super::postgres::PostgresClient::get_user_conversations(self, user_id).await
207    }
208    async fn get_conversation(
209        &self,
210        conversation_id: &str,
211    ) -> Result<super::postgres::Conversation> {
212        let row = sqlx::query_as::<_, super::postgres::Conversation>("SELECT id, user_id, title, created_at, updated_at, 0 as message_count FROM conversations WHERE id = $1").bind(conversation_id).fetch_optional(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
213        row.ok_or_else(|| AppError::NotFound("Conversation not found".into()))
214    }
215    async fn delete_conversation(&self, conversation_id: &str) -> Result<()> {
216        sqlx::query("DELETE FROM messages WHERE conversation_id = $1")
217            .bind(conversation_id)
218            .execute(&self.pool)
219            .await
220            .map_err(|e| AppError::Database(e.to_string()))?;
221        sqlx::query("DELETE FROM conversations WHERE id = $1")
222            .bind(conversation_id)
223            .execute(&self.pool)
224            .await
225            .map_err(|e| AppError::Database(e.to_string()))?;
226        Ok(())
227    }
228    async fn update_conversation_title(
229        &self,
230        conversation_id: &str,
231        title: Option<&str>,
232    ) -> Result<()> {
233        let now = chrono::Utc::now().timestamp();
234        sqlx::query("UPDATE conversations SET title = $1, updated_at = $2 WHERE id = $3")
235            .bind(title)
236            .bind(now)
237            .bind(conversation_id)
238            .execute(&self.pool)
239            .await
240            .map_err(|e| AppError::Database(e.to_string()))?;
241        Ok(())
242    }
243    async fn add_message(
244        &self,
245        id: &str,
246        conversation_id: &str,
247        role: MessageRole,
248        content: &str,
249    ) -> Result<()> {
250        super::postgres::PostgresClient::add_message(self, id, conversation_id, role, content).await
251    }
252    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>> {
253        super::postgres::PostgresClient::get_conversation_history(self, conversation_id).await
254    }
255    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()> {
256        super::postgres::PostgresClient::store_memory_fact(self, fact).await
257    }
258    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>> {
259        super::postgres::PostgresClient::get_user_memory(self, user_id).await
260    }
261    async fn get_memory_by_category(
262        &self,
263        user_id: &str,
264        category: &str,
265    ) -> Result<Vec<MemoryFact>> {
266        let mems = super::postgres::PostgresClient::get_user_memory(self, user_id).await?;
267        Ok(mems
268            .into_iter()
269            .filter(|m| m.category == category)
270            .collect())
271    }
272    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()> {
273        super::postgres::PostgresClient::store_preference(self, user_id, preference).await
274    }
275    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>> {
276        super::postgres::PostgresClient::get_user_preferences(self, user_id).await
277    }
278    async fn get_preference(
279        &self,
280        user_id: &str,
281        category: &str,
282        key: &str,
283    ) -> Result<Option<Preference>> {
284        let prefs = super::postgres::PostgresClient::get_user_preferences(self, user_id).await?;
285        Ok(prefs
286            .into_iter()
287            .find(|p| p.category == category && p.key == key))
288    }
289    async fn get_user_agent_by_name(
290        &self,
291        user_id: &str,
292        name: &str,
293    ) -> Result<Option<super::postgres::UserAgent>> {
294        super::postgres::PostgresClient::get_user_agent_by_name(self, user_id, name).await
295    }
296    async fn get_public_agent_by_name(
297        &self,
298        name: &str,
299    ) -> Result<Option<super::postgres::UserAgent>> {
300        super::postgres::PostgresClient::get_user_agent_by_name(self, "", name).await
301    }
302    async fn list_user_agents(&self, _user_id: &str) -> Result<Vec<super::postgres::UserAgent>> {
303        Ok(vec![])
304    }
305    async fn list_public_agents(
306        &self,
307        _limit: u32,
308        _offset: u32,
309    ) -> Result<Vec<super::postgres::UserAgent>> {
310        Ok(vec![])
311    }
312    async fn create_user_agent(&self, _agent: &super::postgres::UserAgent) -> Result<()> {
313        Ok(())
314    }
315    async fn update_user_agent(&self, _agent: &super::postgres::UserAgent) -> Result<()> {
316        Ok(())
317    }
318    async fn delete_user_agent(&self, _id: &str, _user_id: &str) -> Result<bool> {
319        Ok(true)
320    }
321}