Skip to main content

ares/db/
postgres.rs

1use crate::types::{AppError, MemoryFact, Message, MessageRole, Preference, Result};
2use chrono::{DateTime, Utc};
3use sqlx::{postgres::PgPoolOptions, PgPool};
4use uuid::Uuid;
5
6#[derive(Debug, Clone, sqlx::FromRow)]
7pub struct Conversation {
8    pub id: String,
9    pub user_id: String,
10    pub title: Option<String>,
11    #[sqlx(default)]
12    pub message_count: i32,
13    pub created_at: String,
14    pub updated_at: String,
15}
16
17pub struct PostgresClient {
18    pub pool: PgPool,
19}
20
21impl PostgresClient {
22    pub async fn new_remote(url: String, _auth_token: String) -> Result<Self> {
23        let pool = PgPoolOptions::new()
24            .max_connections(5)
25            .connect(&url)
26            .await
27            .map_err(|e| AppError::Database(format!("Failed to connect to Postgres: {}", e)))?;
28        let client = Self { pool };
29        Ok(client)
30    }
31
32    pub async fn new_local(_path: &str) -> Result<Self> {
33        let url = std::env::var("DATABASE_URL")
34            .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/ares".to_string());
35        Self::new_remote(url, "".to_string()).await
36    }
37
38    pub async fn new_memory() -> Result<Self> {
39        Self::new_local("").await
40    }
41
42    /// Create a test-only client with a lazy pool that doesn't actually connect.
43    /// Use this in unit tests that construct AppState but never execute queries.
44    #[doc(hidden)]
45    pub fn new_test() -> Self {
46        let url = "postgres://test:test@localhost:5432/test";
47        let pool = PgPoolOptions::new()
48            .max_connections(1)
49            .connect_lazy(url)
50            .expect("connect_lazy should never fail");
51        Self { pool }
52    }
53
54    pub async fn new(url: String, auth_token: String) -> Result<Self> {
55        Self::new_remote(url, auth_token).await
56    }
57
58    pub async fn operation_conn(&self) -> Result<&PgPool> {
59        Ok(&self.pool)
60    }
61
62    pub async fn create_user(
63        &self,
64        id: &str,
65        email: &str,
66        password_hash: &str,
67        name: &str,
68    ) -> Result<()> {
69        let now = Utc::now().timestamp();
70        sqlx::query("INSERT INTO users (id, email, password_hash, name, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)")
71            .bind(id).bind(email).bind(password_hash).bind(name).bind(now).bind(now).execute(&self.pool).await
72            .map_err(|e| AppError::Database(format!("Failed to create user: {}", e)))?;
73        Ok(())
74    }
75
76    pub async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
77        sqlx::query_as::<_, User>("SELECT id, email, password_hash, name, created_at, updated_at FROM users WHERE email = $1")
78            .bind(email).fetch_optional(&self.pool).await
79            .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))
80    }
81
82    pub async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
83        sqlx::query_as::<_, User>("SELECT id, email, password_hash, name, created_at, updated_at FROM users WHERE id = $1")
84            .bind(id).fetch_optional(&self.pool).await
85            .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))
86    }
87
88    pub async fn create_session(
89        &self,
90        id: &str,
91        user_id: &str,
92        token_hash: &str,
93        expires_at: i64,
94    ) -> Result<()> {
95        let now = Utc::now().timestamp();
96        sqlx::query("INSERT INTO sessions (id, user_id, token_hash, expires_at, created_at) VALUES ($1, $2, $3, $4, $5)")
97            .bind(id).bind(user_id).bind(token_hash).bind(expires_at).bind(now).execute(&self.pool).await
98            .map_err(|e| AppError::Database(format!("Failed to create session: {}", e)))?;
99        Ok(())
100    }
101
102    pub async fn validate_session(&self, token_hash: &str) -> Result<Option<String>> {
103        let now = Utc::now().timestamp();
104        let row: Option<(String,)> = sqlx::query_as(
105            "SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > $2",
106        )
107        .bind(token_hash)
108        .bind(now)
109        .fetch_optional(&self.pool)
110        .await
111        .map_err(|e| AppError::Database(format!("Failed to validate session: {}", e)))?;
112        Ok(row.map(|(id,)| id))
113    }
114
115    pub async fn delete_session(&self, id: &str) -> Result<()> {
116        sqlx::query("DELETE FROM sessions WHERE id = $1")
117            .bind(id)
118            .execute(&self.pool)
119            .await
120            .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
121        Ok(())
122    }
123
124    pub async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()> {
125        sqlx::query("DELETE FROM sessions WHERE token_hash = $1")
126            .bind(token_hash)
127            .execute(&self.pool)
128            .await
129            .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
130        Ok(())
131    }
132
133    pub async fn create_conversation(
134        &self,
135        id: &str,
136        user_id: &str,
137        title: Option<&str>,
138    ) -> Result<()> {
139        let now = Utc::now().timestamp();
140        sqlx::query("INSERT INTO conversations (id, user_id, title, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
141            .bind(id).bind(user_id).bind(title).bind(now).bind(now).execute(&self.pool).await
142            .map_err(|e| AppError::Database(format!("Failed to create conversation: {}", e)))?;
143        Ok(())
144    }
145
146    pub async fn conversation_exists(&self, conversation_id: &str) -> Result<bool> {
147        let row: Option<(i32,)> = sqlx::query_as("SELECT 1 FROM conversations WHERE id = $1")
148            .bind(conversation_id)
149            .fetch_optional(&self.pool)
150            .await
151            .map_err(|e| AppError::Database(format!("Failed to check conversation: {}", e)))?;
152        Ok(row.is_some())
153    }
154
155    pub async fn get_user_conversations(
156        &self,
157        user_id: &str,
158    ) -> Result<Vec<crate::db::traits::ConversationSummary>> {
159        let rows = sqlx::query_as::<_, crate::db::traits::ConversationSummary>(
160            "SELECT c.id, COALESCE(c.title, '') as title, c.created_at, c.updated_at, (SELECT COUNT(*) FROM messages WHERE conversation_id = c.id) as message_count FROM conversations c WHERE c.user_id = $1 ORDER BY c.updated_at DESC"
161        )
162        .bind(user_id).fetch_all(&self.pool).await
163        .map_err(|e| AppError::Database(format!("Failed to query conversations: {}", e)))?;
164        Ok(rows)
165    }
166
167    pub async fn add_message(
168        &self,
169        id: &str,
170        conversation_id: &str,
171        role: MessageRole,
172        content: &str,
173    ) -> Result<()> {
174        let now = Utc::now().timestamp();
175        let role_str = match role {
176            MessageRole::System => "system",
177            MessageRole::User => "user",
178            MessageRole::Assistant => "assistant",
179        };
180        sqlx::query("INSERT INTO messages (id, conversation_id, role, content, timestamp) VALUES ($1, $2, $3, $4, $5)")
181            .bind(id).bind(conversation_id).bind(role_str).bind(content).bind(now).execute(&self.pool).await
182            .map_err(|e| AppError::Database(format!("Failed to add message: {}", e)))?;
183        Ok(())
184    }
185
186    pub async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>> {
187        #[derive(sqlx::FromRow)]
188        struct MessageRow {
189            role: String,
190            content: String,
191            timestamp: i64,
192        }
193        let rows = sqlx::query_as::<_, MessageRow>("SELECT role, content, timestamp FROM messages WHERE conversation_id = $1 ORDER BY timestamp ASC")
194            .bind(conversation_id).fetch_all(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
195        Ok(rows
196            .into_iter()
197            .map(|row| Message {
198                role: match row.role.as_str() {
199                    "system" => MessageRole::System,
200                    "assistant" => MessageRole::Assistant,
201                    _ => MessageRole::User,
202                },
203                content: row.content,
204                timestamp: DateTime::from_timestamp(row.timestamp, 0).unwrap_or_default(),
205            })
206            .collect())
207    }
208
209    pub async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()> {
210        sqlx::query("INSERT INTO memory_facts (id, user_id, category, fact_key, fact_value, confidence, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT(id) DO UPDATE SET fact_value = $5")
211            .bind(&fact.id).bind(&fact.user_id).bind(&fact.category).bind(&fact.fact_key).bind(&fact.fact_value).bind(fact.confidence as f64).bind(fact.created_at.timestamp()).bind(fact.updated_at.timestamp()).execute(&self.pool).await
212            .map_err(|e| AppError::Database(e.to_string()))?;
213        Ok(())
214    }
215
216    pub async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>> {
217        #[derive(sqlx::FromRow)]
218        struct MemRow {
219            id: String,
220            user_id: String,
221            category: String,
222            fact_key: String,
223            fact_value: String,
224            confidence: f64,
225            created_at: i64,
226            updated_at: i64,
227        }
228        let rows = sqlx::query_as::<_, MemRow>("SELECT * FROM memory_facts WHERE user_id = $1")
229            .bind(user_id)
230            .fetch_all(&self.pool)
231            .await
232            .map_err(|e| AppError::Database(e.to_string()))?;
233        Ok(rows
234            .into_iter()
235            .map(|row| MemoryFact {
236                id: row.id,
237                user_id: row.user_id,
238                category: row.category,
239                fact_key: row.fact_key,
240                fact_value: row.fact_value,
241                confidence: row.confidence as f32,
242                created_at: DateTime::from_timestamp(row.created_at, 0).unwrap_or_default(),
243                updated_at: DateTime::from_timestamp(row.updated_at, 0).unwrap_or_default(),
244            })
245            .collect())
246    }
247
248    pub async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()> {
249        let now = Utc::now().timestamp();
250        let id = Uuid::new_v4().to_string();
251        sqlx::query("INSERT INTO preferences (id, user_id, category, key, value, confidence, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT(user_id, category, key) DO UPDATE SET value = $5")
252            .bind(id).bind(user_id).bind(&preference.category).bind(&preference.key).bind(&preference.value).bind(preference.confidence as f64).bind(now).execute(&self.pool).await
253            .map_err(|e| AppError::Database(e.to_string()))?;
254        Ok(())
255    }
256
257    pub async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>> {
258        #[derive(sqlx::FromRow)]
259        struct PrefRow {
260            category: String,
261            key: String,
262            value: String,
263            confidence: f64,
264        }
265        let rows = sqlx::query_as::<_, PrefRow>(
266            "SELECT category, key, value, confidence FROM preferences WHERE user_id = $1",
267        )
268        .bind(user_id)
269        .fetch_all(&self.pool)
270        .await
271        .map_err(|e| AppError::Database(e.to_string()))?;
272        Ok(rows
273            .into_iter()
274            .map(|r| Preference {
275                category: r.category,
276                key: r.key,
277                value: r.value,
278                confidence: r.confidence as f32,
279            })
280            .collect())
281    }
282
283    pub async fn get_user_agent_by_name(
284        &self,
285        user_id: &str,
286        name: &str,
287    ) -> Result<Option<UserAgent>> {
288        sqlx::query_as::<_, UserAgent>("SELECT * FROM user_agents WHERE user_id = $1 AND name = $2")
289            .bind(user_id)
290            .bind(name)
291            .fetch_optional(&self.pool)
292            .await
293            .map_err(|e| AppError::Database(e.to_string()))
294    }
295}
296
297#[derive(Debug, Clone, sqlx::FromRow)]
298pub struct User {
299    pub id: String,
300    pub email: String,
301    pub password_hash: String,
302    pub name: String,
303    pub created_at: i64,
304    pub updated_at: i64,
305}
306
307#[derive(Debug, Clone, sqlx::FromRow)]
308pub struct UserAgent {
309    pub id: String,
310    pub user_id: String,
311    pub name: String,
312    pub display_name: Option<String>,
313    pub description: Option<String>,
314    pub model: String,
315    pub system_prompt: Option<String>,
316    pub tools: String,
317    pub max_tool_iterations: i32,
318    pub parallel_tools: bool,
319    pub extra: String,
320    pub is_public: bool,
321    pub usage_count: i32,
322    pub rating_sum: i32,
323    pub rating_count: i32,
324    pub created_at: i64,
325    pub updated_at: i64,
326}
327
328impl UserAgent {
329    pub fn new(id: String, user_id: String, name: String, model: String) -> Self {
330        let now = chrono::Utc::now().timestamp();
331        Self {
332            id,
333            user_id,
334            name,
335            display_name: None,
336            description: None,
337            model,
338            system_prompt: None,
339            tools: "[]".to_string(),
340            max_tool_iterations: 10,
341            parallel_tools: false,
342            extra: "{}".to_string(),
343            is_public: false,
344            usage_count: 0,
345            rating_sum: 0,
346            rating_count: 0,
347            created_at: now,
348            updated_at: now,
349        }
350    }
351    pub fn tools_vec(&self) -> Vec<String> {
352        serde_json::from_str(&self.tools).unwrap_or_default()
353    }
354    pub fn set_tools(&mut self, tools: Vec<String>) {
355        self.tools = serde_json::to_string(&tools).unwrap_or_else(|_| "[]".to_string());
356    }
357    pub fn average_rating(&self) -> Option<f32> {
358        if self.rating_count > 0 {
359            Some(self.rating_sum as f32 / self.rating_count as f32)
360        } else {
361            None
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    /// Regression test: new_test() must not block or hang.
371    /// Previously, tests used new_memory() which called connect()
372    /// and hung for 30s+ trying to reach a non-existent database.
373    /// connect_lazy() fixes this by not establishing TCP connections.
374    #[tokio::test]
375    async fn test_new_test_does_not_block() {
376        let start = std::time::Instant::now();
377        let _client = PostgresClient::new_test();
378        let elapsed = start.elapsed();
379        assert!(
380            elapsed.as_millis() < 100,
381            "new_test() should complete instantly, took {}ms",
382            elapsed.as_millis()
383        );
384    }
385
386    /// Regression test: new_test() must work inside #[tokio::test].
387    /// Previously, mixing futures::executor::block_on with #[tokio::test]
388    /// caused a nested runtime deadlock.
389    #[tokio::test]
390    async fn test_new_test_in_tokio_context() {
391        let _client = PostgresClient::new_test();
392        // If we get here without hanging, the deadlock is fixed
393    }
394}