1use crate::types::{AppError, MemoryFact, Message, MessageRole, Preference, Result};
2use chrono::{DateTime, Utc};
3use sqlx::{postgres::PgPoolOptions, PgPool};
4use uuid::Uuid;
5
6#[derive(Debug, Clone, sqlx::FromRow)]
7pub struct Conversation {
8 pub id: String,
9 pub user_id: String,
10 pub title: Option<String>,
11 #[sqlx(default)]
12 pub message_count: i32,
13 pub created_at: String,
14 pub updated_at: String,
15}
16
17pub struct PostgresClient {
18 pub pool: PgPool,
19}
20
21impl PostgresClient {
22 pub async fn new_remote(url: String, _auth_token: String) -> Result<Self> {
23 let pool = PgPoolOptions::new()
24 .max_connections(5)
25 .connect(&url)
26 .await
27 .map_err(|e| AppError::Database(format!("Failed to connect to Postgres: {}", e)))?;
28 let client = Self { pool };
29 Ok(client)
30 }
31
32 pub async fn new_local(_path: &str) -> Result<Self> {
33 let url = std::env::var("DATABASE_URL")
34 .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/ares".to_string());
35 Self::new_remote(url, "".to_string()).await
36 }
37
38 pub async fn new_memory() -> Result<Self> {
39 Self::new_local("").await
40 }
41
42 #[doc(hidden)]
45 pub fn new_test() -> Self {
46 let url = "postgres://test:test@localhost:5432/test";
47 let pool = PgPoolOptions::new()
48 .max_connections(1)
49 .connect_lazy(url)
50 .expect("connect_lazy should never fail");
51 Self { pool }
52 }
53
54 pub async fn new(url: String, auth_token: String) -> Result<Self> {
55 Self::new_remote(url, auth_token).await
56 }
57
58 pub async fn operation_conn(&self) -> Result<&PgPool> {
59 Ok(&self.pool)
60 }
61
62 pub async fn create_user(
63 &self,
64 id: &str,
65 email: &str,
66 password_hash: &str,
67 name: &str,
68 ) -> Result<()> {
69 let now = Utc::now().timestamp();
70 sqlx::query("INSERT INTO users (id, email, password_hash, name, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)")
71 .bind(id).bind(email).bind(password_hash).bind(name).bind(now).bind(now).execute(&self.pool).await
72 .map_err(|e| AppError::Database(format!("Failed to create user: {}", e)))?;
73 Ok(())
74 }
75
76 pub async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
77 sqlx::query_as::<_, User>("SELECT id, email, password_hash, name, created_at, updated_at FROM users WHERE email = $1")
78 .bind(email).fetch_optional(&self.pool).await
79 .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))
80 }
81
82 pub async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
83 sqlx::query_as::<_, User>("SELECT id, email, password_hash, name, created_at, updated_at FROM users WHERE id = $1")
84 .bind(id).fetch_optional(&self.pool).await
85 .map_err(|e| AppError::Database(format!("Failed to query user: {}", e)))
86 }
87
88 pub async fn create_session(
89 &self,
90 id: &str,
91 user_id: &str,
92 token_hash: &str,
93 expires_at: i64,
94 ) -> Result<()> {
95 let now = Utc::now().timestamp();
96 sqlx::query("INSERT INTO sessions (id, user_id, token_hash, expires_at, created_at) VALUES ($1, $2, $3, $4, $5)")
97 .bind(id).bind(user_id).bind(token_hash).bind(expires_at).bind(now).execute(&self.pool).await
98 .map_err(|e| AppError::Database(format!("Failed to create session: {}", e)))?;
99 Ok(())
100 }
101
102 pub async fn validate_session(&self, token_hash: &str) -> Result<Option<String>> {
103 let now = Utc::now().timestamp();
104 let row: Option<(String,)> = sqlx::query_as(
105 "SELECT user_id FROM sessions WHERE token_hash = $1 AND expires_at > $2",
106 )
107 .bind(token_hash)
108 .bind(now)
109 .fetch_optional(&self.pool)
110 .await
111 .map_err(|e| AppError::Database(format!("Failed to validate session: {}", e)))?;
112 Ok(row.map(|(id,)| id))
113 }
114
115 pub async fn delete_session(&self, id: &str) -> Result<()> {
116 sqlx::query("DELETE FROM sessions WHERE id = $1")
117 .bind(id)
118 .execute(&self.pool)
119 .await
120 .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
121 Ok(())
122 }
123
124 pub async fn delete_session_by_token_hash(&self, token_hash: &str) -> Result<()> {
125 sqlx::query("DELETE FROM sessions WHERE token_hash = $1")
126 .bind(token_hash)
127 .execute(&self.pool)
128 .await
129 .map_err(|e| AppError::Database(format!("Failed to delete session: {}", e)))?;
130 Ok(())
131 }
132
133 pub async fn create_conversation(
134 &self,
135 id: &str,
136 user_id: &str,
137 title: Option<&str>,
138 ) -> Result<()> {
139 let now = Utc::now().timestamp();
140 sqlx::query("INSERT INTO conversations (id, user_id, title, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
141 .bind(id).bind(user_id).bind(title).bind(now).bind(now).execute(&self.pool).await
142 .map_err(|e| AppError::Database(format!("Failed to create conversation: {}", e)))?;
143 Ok(())
144 }
145
146 pub async fn conversation_exists(&self, conversation_id: &str) -> Result<bool> {
147 let row: Option<(i32,)> = sqlx::query_as("SELECT 1 FROM conversations WHERE id = $1")
148 .bind(conversation_id)
149 .fetch_optional(&self.pool)
150 .await
151 .map_err(|e| AppError::Database(format!("Failed to check conversation: {}", e)))?;
152 Ok(row.is_some())
153 }
154
155 pub async fn get_user_conversations(
156 &self,
157 user_id: &str,
158 ) -> Result<Vec<crate::db::traits::ConversationSummary>> {
159 let rows = sqlx::query_as::<_, crate::db::traits::ConversationSummary>(
160 "SELECT c.id, COALESCE(c.title, '') as title, c.created_at, c.updated_at, (SELECT COUNT(*) FROM messages WHERE conversation_id = c.id) as message_count FROM conversations c WHERE c.user_id = $1 ORDER BY c.updated_at DESC"
161 )
162 .bind(user_id).fetch_all(&self.pool).await
163 .map_err(|e| AppError::Database(format!("Failed to query conversations: {}", e)))?;
164 Ok(rows)
165 }
166
167 pub async fn add_message(
168 &self,
169 id: &str,
170 conversation_id: &str,
171 role: MessageRole,
172 content: &str,
173 ) -> Result<()> {
174 let now = Utc::now().timestamp();
175 let role_str = match role {
176 MessageRole::System => "system",
177 MessageRole::User => "user",
178 MessageRole::Assistant => "assistant",
179 };
180 sqlx::query("INSERT INTO messages (id, conversation_id, role, content, timestamp) VALUES ($1, $2, $3, $4, $5)")
181 .bind(id).bind(conversation_id).bind(role_str).bind(content).bind(now).execute(&self.pool).await
182 .map_err(|e| AppError::Database(format!("Failed to add message: {}", e)))?;
183 Ok(())
184 }
185
186 pub async fn get_conversation_history(&self, conversation_id: &str) -> Result<Vec<Message>> {
187 #[derive(sqlx::FromRow)]
188 struct MessageRow {
189 role: String,
190 content: String,
191 timestamp: i64,
192 }
193 let rows = sqlx::query_as::<_, MessageRow>("SELECT role, content, timestamp FROM messages WHERE conversation_id = $1 ORDER BY timestamp ASC")
194 .bind(conversation_id).fetch_all(&self.pool).await.map_err(|e| AppError::Database(e.to_string()))?;
195 Ok(rows
196 .into_iter()
197 .map(|row| Message {
198 role: match row.role.as_str() {
199 "system" => MessageRole::System,
200 "assistant" => MessageRole::Assistant,
201 _ => MessageRole::User,
202 },
203 content: row.content,
204 timestamp: DateTime::from_timestamp(row.timestamp, 0).unwrap_or_default(),
205 })
206 .collect())
207 }
208
209 pub async fn store_memory_fact(&self, fact: &MemoryFact) -> Result<()> {
210 sqlx::query("INSERT INTO memory_facts (id, user_id, category, fact_key, fact_value, confidence, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT(id) DO UPDATE SET fact_value = $5")
211 .bind(&fact.id).bind(&fact.user_id).bind(&fact.category).bind(&fact.fact_key).bind(&fact.fact_value).bind(fact.confidence as f64).bind(fact.created_at.timestamp()).bind(fact.updated_at.timestamp()).execute(&self.pool).await
212 .map_err(|e| AppError::Database(e.to_string()))?;
213 Ok(())
214 }
215
216 pub async fn get_user_memory(&self, user_id: &str) -> Result<Vec<MemoryFact>> {
217 #[derive(sqlx::FromRow)]
218 struct MemRow {
219 id: String,
220 user_id: String,
221 category: String,
222 fact_key: String,
223 fact_value: String,
224 confidence: f64,
225 created_at: i64,
226 updated_at: i64,
227 }
228 let rows = sqlx::query_as::<_, MemRow>("SELECT * FROM memory_facts WHERE user_id = $1")
229 .bind(user_id)
230 .fetch_all(&self.pool)
231 .await
232 .map_err(|e| AppError::Database(e.to_string()))?;
233 Ok(rows
234 .into_iter()
235 .map(|row| MemoryFact {
236 id: row.id,
237 user_id: row.user_id,
238 category: row.category,
239 fact_key: row.fact_key,
240 fact_value: row.fact_value,
241 confidence: row.confidence as f32,
242 created_at: DateTime::from_timestamp(row.created_at, 0).unwrap_or_default(),
243 updated_at: DateTime::from_timestamp(row.updated_at, 0).unwrap_or_default(),
244 })
245 .collect())
246 }
247
248 pub async fn store_preference(&self, user_id: &str, preference: &Preference) -> Result<()> {
249 let now = Utc::now().timestamp();
250 let id = Uuid::new_v4().to_string();
251 sqlx::query("INSERT INTO preferences (id, user_id, category, key, value, confidence, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT(user_id, category, key) DO UPDATE SET value = $5")
252 .bind(id).bind(user_id).bind(&preference.category).bind(&preference.key).bind(&preference.value).bind(preference.confidence as f64).bind(now).execute(&self.pool).await
253 .map_err(|e| AppError::Database(e.to_string()))?;
254 Ok(())
255 }
256
257 pub async fn get_user_preferences(&self, user_id: &str) -> Result<Vec<Preference>> {
258 #[derive(sqlx::FromRow)]
259 struct PrefRow {
260 category: String,
261 key: String,
262 value: String,
263 confidence: f64,
264 }
265 let rows = sqlx::query_as::<_, PrefRow>(
266 "SELECT category, key, value, confidence FROM preferences WHERE user_id = $1",
267 )
268 .bind(user_id)
269 .fetch_all(&self.pool)
270 .await
271 .map_err(|e| AppError::Database(e.to_string()))?;
272 Ok(rows
273 .into_iter()
274 .map(|r| Preference {
275 category: r.category,
276 key: r.key,
277 value: r.value,
278 confidence: r.confidence as f32,
279 })
280 .collect())
281 }
282
283 pub async fn get_user_agent_by_name(
284 &self,
285 user_id: &str,
286 name: &str,
287 ) -> Result<Option<UserAgent>> {
288 sqlx::query_as::<_, UserAgent>("SELECT * FROM user_agents WHERE user_id = $1 AND name = $2")
289 .bind(user_id)
290 .bind(name)
291 .fetch_optional(&self.pool)
292 .await
293 .map_err(|e| AppError::Database(e.to_string()))
294 }
295}
296
297#[derive(Debug, Clone, sqlx::FromRow)]
298pub struct User {
299 pub id: String,
300 pub email: String,
301 pub password_hash: String,
302 pub name: String,
303 pub created_at: i64,
304 pub updated_at: i64,
305}
306
307#[derive(Debug, Clone, sqlx::FromRow)]
308pub struct UserAgent {
309 pub id: String,
310 pub user_id: String,
311 pub name: String,
312 pub display_name: Option<String>,
313 pub description: Option<String>,
314 pub model: String,
315 pub system_prompt: Option<String>,
316 pub tools: String,
317 pub max_tool_iterations: i32,
318 pub parallel_tools: bool,
319 pub extra: String,
320 pub is_public: bool,
321 pub usage_count: i32,
322 pub rating_sum: i32,
323 pub rating_count: i32,
324 pub created_at: i64,
325 pub updated_at: i64,
326}
327
328impl UserAgent {
329 pub fn new(id: String, user_id: String, name: String, model: String) -> Self {
330 let now = chrono::Utc::now().timestamp();
331 Self {
332 id,
333 user_id,
334 name,
335 display_name: None,
336 description: None,
337 model,
338 system_prompt: None,
339 tools: "[]".to_string(),
340 max_tool_iterations: 10,
341 parallel_tools: false,
342 extra: "{}".to_string(),
343 is_public: false,
344 usage_count: 0,
345 rating_sum: 0,
346 rating_count: 0,
347 created_at: now,
348 updated_at: now,
349 }
350 }
351 pub fn tools_vec(&self) -> Vec<String> {
352 serde_json::from_str(&self.tools).unwrap_or_default()
353 }
354 pub fn set_tools(&mut self, tools: Vec<String>) {
355 self.tools = serde_json::to_string(&tools).unwrap_or_else(|_| "[]".to_string());
356 }
357 pub fn average_rating(&self) -> Option<f32> {
358 if self.rating_count > 0 {
359 Some(self.rating_sum as f32 / self.rating_count as f32)
360 } else {
361 None
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[tokio::test]
375 async fn test_new_test_does_not_block() {
376 let start = std::time::Instant::now();
377 let _client = PostgresClient::new_test();
378 let elapsed = start.elapsed();
379 assert!(
380 elapsed.as_millis() < 100,
381 "new_test() should complete instantly, took {}ms",
382 elapsed.as_millis()
383 );
384 }
385
386 #[tokio::test]
390 async fn test_new_test_in_tokio_context() {
391 let _client = PostgresClient::new_test();
392 }
394}