Skip to main content

ares/db/
traits.rs

1//! Database abstraction traits
2//!
3//! This module provides the `DatabaseClient` trait that abstracts over different
4//! database backends (in-memory SQLite, file-based SQLite, remote Turso).
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use ares::db::{DatabaseClient, DatabaseProvider};
10//!
11//! // Use in-memory database (default for development/testing)
12//! let db = DatabaseProvider::Memory.create_client().await?;
13//!
14//! // Use file-based SQLite
15//! let db = DatabaseProvider::SQLite { path: "data.db".into() }.create_client().await?;
16//!
17//! // Use remote Turso (requires `turso` feature)
18//! let db = DatabaseProvider::Turso { url, token }.create_client().await?;
19//! ```
20
21use crate::types::{AppError, MemoryFact, Message, MessageRole, Preference, Result};
22use async_trait::async_trait;
23
24/// Database provider configuration
25#[derive(Debug, Clone, Default)]
26pub enum DatabaseProvider {
27    /// In-memory SQLite database (ephemeral, lost on restart)
28    #[default]
29    Memory,
30    /// File-based SQLite database
31    SQLite {
32        /// Path to the SQLite database file
33        path: String,
34    },
35    /// Remote Turso database (requires network access)
36    #[cfg(feature = "turso")]
37    Turso {
38        /// The Turso database URL (e.g., `libsql://your-db.turso.io`)
39        url: String,
40        /// Authentication token for the Turso database
41        auth_token: String,
42    },
43}
44
45impl DatabaseProvider {
46    /// Create a database client from this provider configuration
47    pub async fn create_client(&self) -> Result<Box<dyn DatabaseClient>> {
48        match self {
49            DatabaseProvider::Memory => {
50                let client = super::turso::TursoClient::new_memory().await?;
51                Ok(Box::new(client))
52            }
53            DatabaseProvider::SQLite { path } => {
54                let client = super::turso::TursoClient::new_local(path).await?;
55                Ok(Box::new(client))
56            }
57            #[cfg(feature = "turso")]
58            DatabaseProvider::Turso { url, auth_token } => {
59                let client =
60                    super::turso::TursoClient::new_remote(url.clone(), auth_token.clone()).await?;
61                Ok(Box::new(client))
62            }
63        }
64    }
65
66    /// Create from environment variables or use defaults
67    pub fn from_env() -> Self {
68        // Check for Turso configuration first
69        #[cfg(feature = "turso")]
70        {
71            if let (Ok(url), Ok(token)) = (
72                std::env::var("TURSO_DATABASE_URL"),
73                std::env::var("TURSO_AUTH_TOKEN"),
74            ) {
75                if !url.is_empty() && !token.is_empty() {
76                    return DatabaseProvider::Turso {
77                        url,
78                        auth_token: token,
79                    };
80                }
81            }
82        }
83
84        // Check for SQLite file path
85        if let Ok(path) = std::env::var("DATABASE_PATH") {
86            if !path.is_empty() && path != ":memory:" {
87                return DatabaseProvider::SQLite { path };
88            }
89        }
90
91        // Default to in-memory
92        DatabaseProvider::Memory
93    }
94}
95
96/// User record from the database
97pub use super::turso::User;
98
99/// Summary of a conversation (without full message history)
100#[derive(Debug, Clone)]
101pub struct ConversationSummary {
102    /// Unique conversation identifier
103    pub id: String,
104    /// Conversation title
105    pub title: String,
106    /// Unix timestamp of creation
107    pub created_at: i64,
108    /// Unix timestamp of last update
109    pub updated_at: i64,
110    /// Number of messages in conversation
111    pub message_count: i64,
112}
113
114/// Abstract trait for database operations
115///
116/// This trait defines all database operations needed by the application.
117/// Implementations can use different backends (SQLite, Turso, etc.)
118#[async_trait]
119pub trait DatabaseClient: Send + Sync {
120    // ============== User Operations ==============
121
122    /// Create a new user
123    async fn create_user(
124        &self,
125        id: &str,
126        email: &str,
127        password_hash: &str,
128        name: &str,
129    ) -> Result<()>;
130
131    /// Get a user by email
132    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>>;
133
134    /// Get a user by ID
135    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>>;
136
137    // ============== Session Operations ==============
138
139    /// Create a new session
140    async fn create_session(
141        &self,
142        id: &str,
143        user_id: &str,
144        token_hash: &str,
145        expires_at: i64,
146    ) -> Result<()>;
147
148    /// Validate and get session (returns user_id if valid)
149    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>>;
150
151    /// Delete a session by ID
152    async fn delete_session(&self, id: &str) -> Result<()>;
153
154    /// Delete a session by token hash (for refresh token invalidation)
155    async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()>;
156
157    // ============== Conversation Operations ==============
158
159    /// Create a new conversation
160    async fn create_conversation(&self, id: &str, user_id: &str, title: Option<&str>)
161        -> Result<()>;
162
163    /// Check if a conversation exists
164    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool>;
165
166    /// Get conversations for a user
167    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>>;
168
169    /// Add a message to a conversation
170    async fn add_message(
171        &self,
172        id: &str,
173        conversation_id: &str,
174        role: MessageRole,
175        content: &str,
176    ) -> Result<()>;
177
178    /// Get conversation history
179    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>>;
180
181    // ============== Memory Operations ==============
182
183    /// Store a memory fact
184    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()>;
185
186    /// Get all memory facts for a user
187    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>>;
188
189    /// Get memory facts by category
190    async fn get_memory_by_category(
191        &self,
192        user_id: &str,
193        category: &str,
194    ) -> Result<Vec<MemoryFact>>;
195
196    // ============== Preference Operations ==============
197
198    /// Store a user preference
199    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()>;
200
201    /// Get all preferences for a user
202    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>>;
203
204    /// Get preference by category and key
205    async fn get_preference(
206        &self,
207        user_id: &str,
208        category: &str,
209        key: &str,
210    ) -> Result<Option<Preference>>;
211}
212
213// ============== Implement DatabaseClient for TursoClient ==============
214
215#[async_trait]
216impl DatabaseClient for super::turso::TursoClient {
217    async fn create_user(
218        &self,
219        id: &str,
220        email: &str,
221        password_hash: &str,
222        name: &str,
223    ) -> Result<()> {
224        super::turso::TursoClient::create_user(self, id, email, password_hash, name).await
225    }
226
227    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
228        super::turso::TursoClient::get_user_by_email(self, email).await
229    }
230
231    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
232        let conn = self.operation_conn().await?;
233
234        let mut rows = conn
235            .query(
236                "SELECT id, email, password_hash, name, created_at, updated_at
237                 FROM users WHERE id = ?",
238                [id],
239            )
240            .await
241            .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))?;
242
243        if let Some(row) = rows
244            .next()
245            .await
246            .map_err(|e| AppError::Database(e.to_string()))?
247        {
248            Ok(Some(User {
249                id: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
250                email: row.get(1).map_err(|e| AppError::Database(e.to_string()))?,
251                password_hash: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
252                name: row.get(3).map_err(|e| AppError::Database(e.to_string()))?,
253                created_at: row.get(4).map_err(|e| AppError::Database(e.to_string()))?,
254                updated_at: row.get(5).map_err(|e| AppError::Database(e.to_string()))?,
255            }))
256        } else {
257            Ok(None)
258        }
259    }
260
261    async fn create_session(
262        &self,
263        id: &str,
264        user_id: &str,
265        token_hash: &str,
266        expires_at: i64,
267    ) -> Result<()> {
268        super::turso::TursoClient::create_session(self, id, user_id, token_hash, expires_at).await
269    }
270
271    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>> {
272        let conn = self.operation_conn().await?;
273        let now = chrono::Utc::now().timestamp();
274
275        let mut rows = conn
276            .query(
277                "SELECT user_id FROM sessions WHERE token_hash = ? AND expires_at > ?",
278                [token_hash, &now.to_string()],
279            )
280            .await
281            .map_err(|e| AppError::Database(format!("Failed to validate session: {}", e)))?;
282
283        if let Some(row) = rows
284            .next()
285            .await
286            .map_err(|e| AppError::Database(e.to_string()))?
287        {
288            Ok(Some(
289                row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
290            ))
291        } else {
292            Ok(None)
293        }
294    }
295
296    async fn delete_session(&self, id: &str) -> Result<()> {
297        let conn = self.operation_conn().await?;
298
299        conn.execute("DELETE FROM sessions WHERE id = ?", [id])
300            .await
301            .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
302
303        Ok(())
304    }
305
306    async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()> {
307        let conn = self.operation_conn().await?;
308
309        conn.execute("DELETE FROM sessions WHERE token_hash = ?", [token_hash])
310            .await
311            .map_err(|e| {
312                AppError::Database(format!("Failed to delete session by token hash: {}", e))
313            })?;
314
315        Ok(())
316    }
317
318    async fn create_conversation(
319        &self,
320        id: &str,
321        user_id: &str,
322        title: Option<&str>,
323    ) -> Result<()> {
324        super::turso::TursoClient::create_conversation(self, id, user_id, title).await
325    }
326
327    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool> {
328        let conn = self.operation_conn().await?;
329
330        let mut rows = conn
331            .query(
332                "SELECT 1 FROM conversations WHERE id = ?",
333                [conversation_id],
334            )
335            .await
336            .map_err(|e| AppError::Database(format!("Failed to check conversation: {}", e)))?;
337
338        Ok(rows
339            .next()
340            .await
341            .map_err(|e| AppError::Database(e.to_string()))?
342            .is_some())
343    }
344
345    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>> {
346        let conn = self.operation_conn().await?;
347
348        let mut rows = conn
349            .query(
350                "SELECT c.id, c.title, c.created_at, c.updated_at,
351                        (SELECT COUNT(*) FROM messages WHERE conversation_id = c.id) as msg_count
352                 FROM conversations c
353                 WHERE c.user_id = ?
354                 ORDER BY c.updated_at DESC",
355                [user_id],
356            )
357            .await
358            .map_err(|e| AppError::Database(format!("Failed to query conversations: {}", e)))?;
359
360        let mut conversations = Vec::new();
361        while let Some(row) = rows
362            .next()
363            .await
364            .map_err(|e| AppError::Database(e.to_string()))?
365        {
366            conversations.push(ConversationSummary {
367                id: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
368                title: row.get::<String>(1).unwrap_or_default(),
369                created_at: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
370                updated_at: row.get(3).map_err(|e| AppError::Database(e.to_string()))?,
371                message_count: row.get(4).map_err(|e| AppError::Database(e.to_string()))?,
372            });
373        }
374
375        Ok(conversations)
376    }
377
378    async fn add_message(
379        &self,
380        id: &str,
381        conversation_id: &str,
382        role: MessageRole,
383        content: &str,
384    ) -> Result<()> {
385        super::turso::TursoClient::add_message(self, id, conversation_id, role, content).await
386    }
387
388    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>> {
389        super::turso::TursoClient::get_conversation_history(self, conversation_id).await
390    }
391
392    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()> {
393        super::turso::TursoClient::store_memory_fact(self, fact).await
394    }
395
396    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>> {
397        super::turso::TursoClient::get_user_memory(self, user_id).await
398    }
399
400    async fn get_memory_by_category(
401        &self,
402        user_id: &str,
403        category: &str,
404    ) -> Result<Vec<MemoryFact>> {
405        let conn = self.operation_conn().await?;
406
407        let mut rows = conn
408            .query(
409                "SELECT id, user_id, category, fact_key, fact_value, confidence, created_at, updated_at
410                 FROM memory_facts WHERE user_id = ? AND category = ?",
411                [user_id, category],
412            )
413            .await
414            .map_err(|e| AppError::Database(format!("Failed to query memory facts: {}", e)))?;
415
416        let mut facts = Vec::new();
417        while let Some(row) = rows
418            .next()
419            .await
420            .map_err(|e| AppError::Database(e.to_string()))?
421        {
422            facts.push(MemoryFact {
423                id: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
424                user_id: row.get(1).map_err(|e| AppError::Database(e.to_string()))?,
425                category: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
426                fact_key: row.get(3).map_err(|e| AppError::Database(e.to_string()))?,
427                fact_value: row.get(4).map_err(|e| AppError::Database(e.to_string()))?,
428                confidence: row
429                    .get::<f64>(5)
430                    .map_err(|e| AppError::Database(e.to_string()))?
431                    as f32,
432                created_at: chrono::DateTime::from_timestamp(
433                    row.get::<i64>(6)
434                        .map_err(|e| AppError::Database(e.to_string()))?,
435                    0,
436                )
437                .unwrap(),
438                updated_at: chrono::DateTime::from_timestamp(
439                    row.get::<i64>(7)
440                        .map_err(|e| AppError::Database(e.to_string()))?,
441                    0,
442                )
443                .unwrap(),
444            });
445        }
446
447        Ok(facts)
448    }
449
450    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()> {
451        super::turso::TursoClient::store_preference(self, user_id, preference).await
452    }
453
454    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>> {
455        super::turso::TursoClient::get_user_preferences(self, user_id).await
456    }
457
458    async fn get_preference(
459        &self,
460        user_id: &str,
461        category: &str,
462        key: &str,
463    ) -> Result<Option<Preference>> {
464        let conn = self.operation_conn().await?;
465
466        let mut rows = conn
467            .query(
468                "SELECT category, key, value, confidence FROM preferences
469                 WHERE user_id = ? AND category = ? AND key = ?",
470                [user_id, category, key],
471            )
472            .await
473            .map_err(|e| AppError::Database(format!("Failed to query preference: {}", e)))?;
474
475        if let Some(row) = rows
476            .next()
477            .await
478            .map_err(|e| AppError::Database(e.to_string()))?
479        {
480            Ok(Some(Preference {
481                category: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
482                key: row.get(1).map_err(|e| AppError::Database(e.to_string()))?,
483                value: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
484                confidence: row
485                    .get::<f64>(3)
486                    .map_err(|e| AppError::Database(e.to_string()))?
487                    as f32,
488            }))
489        } else {
490            Ok(None)
491        }
492    }
493}