Skip to main content

ares/db/
postgres.rs

1use crate::types::{AppError, MemoryFact, Message, MessageRole, Preference, Result};
2use chrono::{DateTime, Utc};
3use sqlx::{postgres::PgPoolOptions, PgPool};
4use uuid::Uuid;
5
6#[derive(Debug, Clone, sqlx::FromRow)]
7pub struct Conversation {
8    pub id: String,
9    pub user_id: String,
10    pub title: Option<String>,
11    #[sqlx(default)]
12    pub message_count: i32,
13    pub created_at: String,
14    pub updated_at: String,
15}
16
17pub struct PostgresClient {
18    pub pool: PgPool,
19}
20
21impl PostgresClient {
22    pub async fn new_remote(url: String, _auth_token: String) -> Result<Self> {
23        let pool = PgPoolOptions::new().max_connections(5).connect(&url).await.map_err(|e| AppError::Database(format!("Failed to connect to Postgres: {}", e)))?;
24        let client = Self { pool };
25        Ok(client)
26    }
27
28    pub async fn new_local(_path: &str) -> Result<Self> {
29        let url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/ares".to_string());
30        Self::new_remote(url, "".to_string()).await
31    }
32
33    pub async fn new_memory() -> Result<Self> {
34        Self::new_local("").await
35    }
36
37    pub async fn new(url: String, auth_token: String) -> Result<Self> {
38        Self::new_remote(url, auth_token).await
39    }
40
41    pub async fn operation_conn(&self) -> Result<&PgPool> {
42        Ok(&self.pool)
43    }
44    
45    pub async fn create_user(&self, id: &str, email: &str, password_hash: &str, name: &str) -> Result<()> {
46        let now = Utc::now().timestamp();
47        sqlx::query("INSERT INTO users (id, email, password_hash, name, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)")
48            .bind(id).bind(email).bind(password_hash).bind(name).bind(now).bind(now).execute(&self.pool).await
49            .map_err(|e| AppError::Database(format!("Failed to create user: {}", e)))?;
50        Ok(())
51    }
52
53    pub async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
54        sqlx::query_as::<_, User>("SELECT id, email, password_hash, name, created_at, updated_at FROM users WHERE email = $1")
55            .bind(email).fetch_optional(&self.pool).await
56            .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))
57    }
58    
59    pub async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
60        sqlx::query_as::<_, User>("SELECT id, email, password_hash, name, created_at, updated_at FROM users WHERE id = $1")
61            .bind(id).fetch_optional(&self.pool).await
62            .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))
63    }
64
65    pub async fn create_session(&self, id: &str, user_id: &str, token_hash: &str, expires_at: i64) -> Result<()> {
66        let now = Utc::now().timestamp();
67        sqlx::query("INSERT INTO sessions (id, user_id, token_hash, expires_at, created_at) VALUES ($1, $2, $3, $4, $5)")
68            .bind(id).bind(user_id).bind(token_hash).bind(expires_at).bind(now).execute(&self.pool).await
69            .map_err(|e| AppError::Database(format!("Failed to create session: {}", e)))?;
70        Ok(())
71    }
72
73    pub async fn validate_session(&self, token_hash: &str) -> Result<Option<String>> {
74        let now = Utc::now().timestamp();
75        let row: Option<(String,)> = sqlx::query_as("SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > $2")
76            .bind(token_hash).bind(now).fetch_optional(&self.pool).await
77            .map_err(|e| AppError::Database(format!("Failed to validate session: {}", e)))?;
78        Ok(row.map(|(id,)| id))
79    }
80
81    pub async fn delete_session(&self, id: &str) -> Result<()> {
82        sqlx::query("DELETE FROM sessions WHERE id = $1").bind(id).execute(&self.pool).await
83            .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
84        Ok(())
85    }
86
87    pub async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()> {
88        sqlx::query("DELETE FROM sessions WHERE token_hash = $1").bind(token_hash).execute(&self.pool).await
89            .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
90        Ok(())
91    }
92
93    pub async fn create_conversation(&self, id: &str, user_id: &str, title: Option<&str>) -> Result<()> {
94        let now = Utc::now().timestamp();
95        sqlx::query("INSERT INTO conversations (id, user_id, title, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
96            .bind(id).bind(user_id).bind(title).bind(now).bind(now).execute(&self.pool).await
97            .map_err(|e| AppError::Database(format!("Failed to create conversation: {}", e)))?;
98        Ok(())
99    }
100
101    pub async fn conversation_exists(&self, conversation_id: &str) -> Result<bool> {
102        let row: Option<(i32,)> = sqlx::query_as("SELECT 1 FROM conversations WHERE id = $1")
103            .bind(conversation_id).fetch_optional(&self.pool).await
104            .map_err(|e| AppError::Database(format!("Failed to check conversation: {}", e)))?;
105        Ok(row.is_some())
106    }
107
108    pub async fn get_user_conversations(&self, user_id: &str) -> Result<Vec<crate::db::traits::ConversationSummary>> {
109        let rows = sqlx::query_as::<_, crate::db::traits::ConversationSummary>(
110            "SELECT c.id, COALESCE(c.title, '') as title, c.created_at, c.updated_at, (SELECT COUNT(*) FROM messages WHERE conversation_id = c.id) as message_count FROM conversations c WHERE c.user_id = $1 ORDER BY c.updated_at DESC"
111        )
112        .bind(user_id).fetch_all(&self.pool).await
113        .map_err(|e| AppError::Database(format!("Failed to query conversations: {}", e)))?;
114        Ok(rows)
115    }
116
117    pub async fn add_message(&self, id: &str, conversation_id: &str, role: MessageRole, content: &str) -> Result<()> {
118        let now = Utc::now().timestamp();
119        let role_str = match role { MessageRole::System => "system", MessageRole::User => "user", MessageRole::Assistant => "assistant" };
120        sqlx::query("INSERT INTO messages (id, conversation_id, role, content, timestamp) VALUES ($1, $2, $3, $4, $5)")
121            .bind(id).bind(conversation_id).bind(role_str).bind(content).bind(now).execute(&self.pool).await
122            .map_err(|e| AppError::Database(format!("Failed to add message: {}", e)))?;
123        Ok(())
124    }
125
126    pub async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>> {
127        #[derive(sqlx::FromRow)] struct MessageRow { role: String, content: String, timestamp: i64 }
128        let rows = sqlx::query_as::<_, MessageRow>("SELECT role, content, timestamp FROM messages WHERE conversation_id = $1 ORDER BY timestamp ASC")
129            .bind(conversation_id).fetch_all(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
130        Ok(rows.into_iter().map(|row| Message {
131            role: match row.role.as_str() { "system" => MessageRole::System, "assistant" => MessageRole::Assistant, _ => MessageRole::User },
132            content: row.content,
133            timestamp: DateTime::from_timestamp(row.timestamp, 0).unwrap_or_default(),
134        }).collect())
135    }
136
137    pub async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()> {
138        sqlx::query("INSERT INTO memory_facts (id, user_id, category, fact_key, fact_value, confidence, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT(id) DO UPDATE SET fact_value = $5")
139            .bind(&fact.id).bind(&fact.user_id).bind(&fact.category).bind(&fact.fact_key).bind(&fact.fact_value).bind(fact.confidence as f64).bind(fact.created_at.timestamp()).bind(fact.updated_at.timestamp()).execute(&self.pool).await
140            .map_err(|e| AppError::Database(e.to_string()))?;
141        Ok(())
142    }
143
144    pub async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>> {
145        #[derive(sqlx::FromRow)] struct MemRow { id: String, user_id: String, category: String, fact_key: String, fact_value: String, confidence: f64, created_at: i64, updated_at: i64 }
146        let rows = sqlx::query_as::<_, MemRow>("SELECT * FROM memory_facts WHERE user_id = $1").bind(user_id).fetch_all(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
147        Ok(rows.into_iter().map(|row| MemoryFact {
148            id: row.id, user_id: row.user_id, category: row.category, fact_key: row.fact_key, fact_value: row.fact_value, confidence: row.confidence as f32, created_at: DateTime::from_timestamp(row.created_at, 0).unwrap_or_default(), updated_at: DateTime::from_timestamp(row.updated_at, 0).unwrap_or_default(),
149        }).collect())
150    }
151
152    pub async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()> {
153        let now = Utc::now().timestamp();
154        let id = Uuid::new_v4().to_string();
155        sqlx::query("INSERT INTO preferences (id, user_id, category, key, value, confidence, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT(user_id, category, key) DO UPDATE SET value = $5")
156            .bind(id).bind(user_id).bind(&preference.category).bind(&preference.key).bind(&preference.value).bind(preference.confidence as f64).bind(now).execute(&self.pool).await
157            .map_err(|e| AppError::Database(e.to_string()))?;
158        Ok(())
159    }
160
161    pub async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>> {
162        #[derive(sqlx::FromRow)] struct PrefRow { category: String, key: String, value: String, confidence: f64 }
163        let rows = sqlx::query_as::<_, PrefRow>("SELECT category, key, value, confidence FROM preferences WHERE user_id = $1").bind(user_id).fetch_all(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
164        Ok(rows.into_iter().map(|r| Preference { category: r.category, key: r.key, value: r.value, confidence: r.confidence as f32 }).collect())
165    }
166
167    pub async fn get_user_agent_by_name(&self, user_id: &str, name: &str) -> Result<Option<UserAgent>> {
168        sqlx::query_as::<_, UserAgent>("SELECT * FROM user_agents WHERE user_id = $1 AND name = $2").bind(user_id).bind(name).fetch_optional(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))
169    }
170}
171
172#[derive(Debug, Clone, sqlx::FromRow)]
173pub struct User {
174    pub id: String,
175    pub email: String,
176    pub password_hash: String,
177    pub name: String,
178    pub created_at: i64,
179    pub updated_at: i64,
180}
181
182#[derive(Debug, Clone, sqlx::FromRow)]
183pub struct UserAgent {
184    pub id: String,
185    pub user_id: String,
186    pub name: String,
187    pub display_name: Option<String>,
188    pub description: Option<String>,
189    pub model: String,
190    pub system_prompt: Option<String>,
191    pub tools: String,
192    pub max_tool_iterations: i32,
193    pub parallel_tools: bool,
194    pub extra: String,
195    pub is_public: bool,
196    pub usage_count: i32,
197    pub rating_sum: i32,
198    pub rating_count: i32,
199    pub created_at: i64,
200    pub updated_at: i64,
201}
202
203impl UserAgent {
204    pub fn tools_vec(&self) -> Vec<String> {
205        serde_json::from_str(&self.tools).unwrap_or_default()
206    }
207    pub fn average_rating(&self) -> Option<f32> {
208        if self.rating_count > 0 { Some(self.rating_sum as f32 / self.rating_count as f32) } else { None }
209    }
210}