ares/db/
traits.rs

1//! Database abstraction traits
2//!
3//! This module provides the `DatabaseClient` trait that abstracts over different
4//! database backends (in-memory SQLite, file-based SQLite, remote Turso).
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use ares::db::{DatabaseClient, DatabaseProvider};
10//!
11//! // Use in-memory database (default for development/testing)
12//! let db = DatabaseProvider::Memory.create_client().await?;
13//!
14//! // Use file-based SQLite
15//! let db = DatabaseProvider::SQLite { path: "data.db".into() }.create_client().await?;
16//!
17//! // Use remote Turso (requires `turso` feature)
18//! let db = DatabaseProvider::Turso { url, token }.create_client().await?;
19//! ```
20
21use crate::types::{AppError, MemoryFact, Message, MessageRole, Preference, Result};
22use async_trait::async_trait;
23
24/// Database provider configuration
25#[derive(Debug, Clone, Default)]
26pub enum DatabaseProvider {
27    /// In-memory SQLite database (ephemeral, lost on restart)
28    #[default]
29    Memory,
30    /// File-based SQLite database
31    SQLite { path: String },
32    /// Remote Turso database (requires network access)
33    #[cfg(feature = "turso")]
34    Turso { url: String, auth_token: String },
35}
36
37impl DatabaseProvider {
38    /// Create a database client from this provider configuration
39    pub async fn create_client(&self) -> Result<Box<dyn DatabaseClient>> {
40        match self {
41            DatabaseProvider::Memory => {
42                let client = super::turso::TursoClient::new_memory().await?;
43                Ok(Box::new(client))
44            }
45            DatabaseProvider::SQLite { path } => {
46                let client = super::turso::TursoClient::new_local(path).await?;
47                Ok(Box::new(client))
48            }
49            #[cfg(feature = "turso")]
50            DatabaseProvider::Turso { url, auth_token } => {
51                let client =
52                    super::turso::TursoClient::new_remote(url.clone(), auth_token.clone()).await?;
53                Ok(Box::new(client))
54            }
55        }
56    }
57
58    /// Create from environment variables or use defaults
59    pub fn from_env() -> Self {
60        // Check for Turso configuration first
61        #[cfg(feature = "turso")]
62        {
63            if let (Ok(url), Ok(token)) = (
64                std::env::var("TURSO_DATABASE_URL"),
65                std::env::var("TURSO_AUTH_TOKEN"),
66            ) {
67                if !url.is_empty() && !token.is_empty() {
68                    return DatabaseProvider::Turso {
69                        url,
70                        auth_token: token,
71                    };
72                }
73            }
74        }
75
76        // Check for SQLite file path
77        if let Ok(path) = std::env::var("DATABASE_PATH") {
78            if !path.is_empty() && path != ":memory:" {
79                return DatabaseProvider::SQLite { path };
80            }
81        }
82
83        // Default to in-memory
84        DatabaseProvider::Memory
85    }
86}
87
88/// User record from the database
89pub use super::turso::User;
90
91/// Summary of a conversation (without full message history)
92#[derive(Debug, Clone)]
93pub struct ConversationSummary {
94    pub id: String,
95    pub title: String,
96    pub created_at: i64,
97    pub updated_at: i64,
98    pub message_count: i64,
99}
100
101/// Abstract trait for database operations
102///
103/// This trait defines all database operations needed by the application.
104/// Implementations can use different backends (SQLite, Turso, etc.)
105#[async_trait]
106pub trait DatabaseClient: Send + Sync {
107    // ============== User Operations ==============
108
109    /// Create a new user
110    async fn create_user(
111        &self,
112        id: &str,
113        email: &str,
114        password_hash: &str,
115        name: &str,
116    ) -> Result<()>;
117
118    /// Get a user by email
119    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>>;
120
121    /// Get a user by ID
122    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>>;
123
124    // ============== Session Operations ==============
125
126    /// Create a new session
127    async fn create_session(
128        &self,
129        id: &str,
130        user_id: &str,
131        token_hash: &str,
132        expires_at: i64,
133    ) -> Result<()>;
134
135    /// Validate and get session (returns user_id if valid)
136    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>>;
137
138    /// Delete a session
139    async fn delete_session(&self, id: &str) -> Result<()>;
140
141    // ============== Conversation Operations ==============
142
143    /// Create a new conversation
144    async fn create_conversation(&self, id: &str, user_id: &str, title: Option<&str>)
145        -> Result<()>;
146
147    /// Check if a conversation exists
148    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool>;
149
150    /// Get conversations for a user
151    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>>;
152
153    /// Add a message to a conversation
154    async fn add_message(
155        &self,
156        id: &str,
157        conversation_id: &str,
158        role: MessageRole,
159        content: &str,
160    ) -> Result<()>;
161
162    /// Get conversation history
163    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>>;
164
165    // ============== Memory Operations ==============
166
167    /// Store a memory fact
168    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()>;
169
170    /// Get all memory facts for a user
171    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>>;
172
173    /// Get memory facts by category
174    async fn get_memory_by_category(
175        &self,
176        user_id: &str,
177        category: &str,
178    ) -> Result<Vec<MemoryFact>>;
179
180    // ============== Preference Operations ==============
181
182    /// Store a user preference
183    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()>;
184
185    /// Get all preferences for a user
186    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>>;
187
188    /// Get preference by category and key
189    async fn get_preference(
190        &self,
191        user_id: &str,
192        category: &str,
193        key: &str,
194    ) -> Result<Option<Preference>>;
195}
196
197// ============== Implement DatabaseClient for TursoClient ==============
198
199#[async_trait]
200impl DatabaseClient for super::turso::TursoClient {
201    async fn create_user(
202        &self,
203        id: &str,
204        email: &str,
205        password_hash: &str,
206        name: &str,
207    ) -> Result<()> {
208        super::turso::TursoClient::create_user(self, id, email, password_hash, name).await
209    }
210
211    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
212        super::turso::TursoClient::get_user_by_email(self, email).await
213    }
214
215    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
216        let conn = self.operation_conn().await?;
217
218        let mut rows = conn
219            .query(
220                "SELECT id, email, password_hash, name, created_at, updated_at
221                 FROM users WHERE id = ?",
222                [id],
223            )
224            .await
225            .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))?;
226
227        if let Some(row) = rows
228            .next()
229            .await
230            .map_err(|e| AppError::Database(e.to_string()))?
231        {
232            Ok(Some(User {
233                id: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
234                email: row.get(1).map_err(|e| AppError::Database(e.to_string()))?,
235                password_hash: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
236                name: row.get(3).map_err(|e| AppError::Database(e.to_string()))?,
237                created_at: row.get(4).map_err(|e| AppError::Database(e.to_string()))?,
238                updated_at: row.get(5).map_err(|e| AppError::Database(e.to_string()))?,
239            }))
240        } else {
241            Ok(None)
242        }
243    }
244
245    async fn create_session(
246        &self,
247        id: &str,
248        user_id: &str,
249        token_hash: &str,
250        expires_at: i64,
251    ) -> Result<()> {
252        super::turso::TursoClient::create_session(self, id, user_id, token_hash, expires_at).await
253    }
254
255    async fn validate_session(&self, token_hash: &str) -> Result<Option<String>> {
256        let conn = self.operation_conn().await?;
257        let now = chrono::Utc::now().timestamp();
258
259        let mut rows = conn
260            .query(
261                "SELECT user_id FROM sessions WHERE token_hash = ? AND expires_at > ?",
262                [token_hash, &now.to_string()],
263            )
264            .await
265            .map_err(|e| AppError::Database(format!("Failed to validate session: {}", e)))?;
266
267        if let Some(row) = rows
268            .next()
269            .await
270            .map_err(|e| AppError::Database(e.to_string()))?
271        {
272            Ok(Some(
273                row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
274            ))
275        } else {
276            Ok(None)
277        }
278    }
279
280    async fn delete_session(&self, id: &str) -> Result<()> {
281        let conn = self.operation_conn().await?;
282
283        conn.execute("DELETE FROM sessions WHERE id = ?", [id])
284            .await
285            .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
286
287        Ok(())
288    }
289
290    async fn create_conversation(
291        &self,
292        id: &str,
293        user_id: &str,
294        title: Option<&str>,
295    ) -> Result<()> {
296        super::turso::TursoClient::create_conversation(self, id, user_id, title).await
297    }
298
299    async fn conversation_exists(&self, conversation_id: &str) -> Result<bool> {
300        let conn = self.operation_conn().await?;
301
302        let mut rows = conn
303            .query(
304                "SELECT 1 FROM conversations WHERE id = ?",
305                [conversation_id],
306            )
307            .await
308            .map_err(|e| AppError::Database(format!("Failed to check conversation: {}", e)))?;
309
310        Ok(rows
311            .next()
312            .await
313            .map_err(|e| AppError::Database(e.to_string()))?
314            .is_some())
315    }
316
317    async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<ConversationSummary>> {
318        let conn = self.operation_conn().await?;
319
320        let mut rows = conn
321            .query(
322                "SELECT c.id, c.title, c.created_at, c.updated_at,
323                        (SELECT COUNT(*) FROM messages WHERE conversation_id = c.id) as msg_count
324                 FROM conversations c
325                 WHERE c.user_id = ?
326                 ORDER BY c.updated_at DESC",
327                [user_id],
328            )
329            .await
330            .map_err(|e| AppError::Database(format!("Failed to query conversations: {}", e)))?;
331
332        let mut conversations = Vec::new();
333        while let Some(row) = rows
334            .next()
335            .await
336            .map_err(|e| AppError::Database(e.to_string()))?
337        {
338            conversations.push(ConversationSummary {
339                id: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
340                title: row.get::<String>(1).unwrap_or_default(),
341                created_at: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
342                updated_at: row.get(3).map_err(|e| AppError::Database(e.to_string()))?,
343                message_count: row.get(4).map_err(|e| AppError::Database(e.to_string()))?,
344            });
345        }
346
347        Ok(conversations)
348    }
349
350    async fn add_message(
351        &self,
352        id: &str,
353        conversation_id: &str,
354        role: MessageRole,
355        content: &str,
356    ) -> Result<()> {
357        super::turso::TursoClient::add_message(self, id, conversation_id, role, content).await
358    }
359
360    async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>> {
361        super::turso::TursoClient::get_conversation_history(self, conversation_id).await
362    }
363
364    async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()> {
365        super::turso::TursoClient::store_memory_fact(self, fact).await
366    }
367
368    async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>> {
369        super::turso::TursoClient::get_user_memory(self, user_id).await
370    }
371
372    async fn get_memory_by_category(
373        &self,
374        user_id: &str,
375        category: &str,
376    ) -> Result<Vec<MemoryFact>> {
377        let conn = self.operation_conn().await?;
378
379        let mut rows = conn
380            .query(
381                "SELECT id, user_id, category, fact_key, fact_value, confidence, created_at, updated_at
382                 FROM memory_facts WHERE user_id = ? AND category = ?",
383                [user_id, category],
384            )
385            .await
386            .map_err(|e| AppError::Database(format!("Failed to query memory facts: {}", e)))?;
387
388        let mut facts = Vec::new();
389        while let Some(row) = rows
390            .next()
391            .await
392            .map_err(|e| AppError::Database(e.to_string()))?
393        {
394            facts.push(MemoryFact {
395                id: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
396                user_id: row.get(1).map_err(|e| AppError::Database(e.to_string()))?,
397                category: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
398                fact_key: row.get(3).map_err(|e| AppError::Database(e.to_string()))?,
399                fact_value: row.get(4).map_err(|e| AppError::Database(e.to_string()))?,
400                confidence: row
401                    .get::<f64>(5)
402                    .map_err(|e| AppError::Database(e.to_string()))?
403                    as f32,
404                created_at: chrono::DateTime::from_timestamp(
405                    row.get::<i64>(6)
406                        .map_err(|e| AppError::Database(e.to_string()))?,
407                    0,
408                )
409                .unwrap(),
410                updated_at: chrono::DateTime::from_timestamp(
411                    row.get::<i64>(7)
412                        .map_err(|e| AppError::Database(e.to_string()))?,
413                    0,
414                )
415                .unwrap(),
416            });
417        }
418
419        Ok(facts)
420    }
421
422    async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()> {
423        super::turso::TursoClient::store_preference(self, user_id, preference).await
424    }
425
426    async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>> {
427        super::turso::TursoClient::get_user_preferences(self, user_id).await
428    }
429
430    async fn get_preference(
431        &self,
432        user_id: &str,
433        category: &str,
434        key: &str,
435    ) -> Result<Option<Preference>> {
436        let conn = self.operation_conn().await?;
437
438        let mut rows = conn
439            .query(
440                "SELECT category, key, value, confidence FROM preferences
441                 WHERE user_id = ? AND category = ? AND key = ?",
442                [user_id, category, key],
443            )
444            .await
445            .map_err(|e| AppError::Database(format!("Failed to query preference: {}", e)))?;
446
447        if let Some(row) = rows
448            .next()
449            .await
450            .map_err(|e| AppError::Database(e.to_string()))?
451        {
452            Ok(Some(Preference {
453                category: row.get(0).map_err(|e| AppError::Database(e.to_string()))?,
454                key: row.get(1).map_err(|e| AppError::Database(e.to_string()))?,
455                value: row.get(2).map_err(|e| AppError::Database(e.to_string()))?,
456                confidence: row
457                    .get::<f64>(3)
458                    .map_err(|e| AppError::Database(e.to_string()))?
459                    as f32,
460            }))
461        } else {
462            Ok(None)
463        }
464    }
465}