1use crate::errors::{AuthError, Result};
5use crate::storage::{AuthStorage, SessionData};
6use crate::tokens::AuthToken;
7use async_trait::async_trait;
8use sqlx::PgPool;
9use sqlx::Row;
10pub struct PostgresStorage {
14 pool: PgPool,
15}
16
17impl PostgresStorage {
18 pub fn new(pool: PgPool) -> Self {
26 Self { pool }
27 }
28
29 pub async fn migrate(&self) -> Result<()> {
38 sqlx::query(
43 r#"
44 CREATE TABLE IF NOT EXISTS auth_tokens (
45 token_id VARCHAR(255) PRIMARY KEY,
46 user_id VARCHAR(255) NOT NULL,
47 access_token TEXT NOT NULL UNIQUE,
48 refresh_token TEXT,
49 token_type VARCHAR(50),
50 expires_at TIMESTAMPTZ NOT NULL,
51 scopes TEXT[],
52 issued_at TIMESTAMPTZ NOT NULL,
53 auth_method VARCHAR(100) NOT NULL,
54 subject VARCHAR(255),
55 issuer VARCHAR(255),
56 client_id VARCHAR(255),
57 metadata JSONB,
58 created_at TIMESTAMPTZ DEFAULT NOW()
59 )
60 "#,
61 )
62 .execute(&self.pool)
63 .await
64 .map_err(|e| {
65 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
66 "Migration failed (auth_tokens): {e}"
67 )))
68 })?;
69
70 for stmt in [
72 "CREATE INDEX IF NOT EXISTS idx_auth_tokens_user_id ON auth_tokens (user_id)",
73 "CREATE INDEX IF NOT EXISTS idx_auth_tokens_access_token ON auth_tokens (access_token)",
74 "CREATE INDEX IF NOT EXISTS idx_auth_tokens_expires_at ON auth_tokens (expires_at)",
75 ] {
76 sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
77 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
78 "Migration failed (index): {e}"
79 )))
80 })?;
81 }
82
83 sqlx::query(
85 r#"
86 CREATE TABLE IF NOT EXISTS sessions (
87 session_id VARCHAR(255) PRIMARY KEY,
88 user_id VARCHAR(255) NOT NULL,
89 data JSONB NOT NULL,
90 expires_at TIMESTAMPTZ,
91 created_at TIMESTAMPTZ DEFAULT NOW()
92 )
93 "#,
94 )
95 .execute(&self.pool)
96 .await
97 .map_err(|e| {
98 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
99 "Migration failed (sessions): {e}"
100 )))
101 })?;
102
103 for stmt in [
104 "CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions (user_id)",
105 "CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions (expires_at)",
106 ] {
107 sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
108 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
109 "Migration failed (index): {e}"
110 )))
111 })?;
112 }
113
114 sqlx::query(
116 r#"
117 CREATE TABLE IF NOT EXISTS kv_store (
118 key VARCHAR(512) PRIMARY KEY,
119 value BYTEA NOT NULL,
120 expires_at TIMESTAMPTZ,
121 created_at TIMESTAMPTZ DEFAULT NOW()
122 )
123 "#,
124 )
125 .execute(&self.pool)
126 .await
127 .map_err(|e| {
128 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
129 "Migration failed (kv_store): {e}"
130 )))
131 })?;
132
133 sqlx::query("CREATE INDEX IF NOT EXISTS idx_kv_store_expires_at ON kv_store (expires_at)")
134 .execute(&self.pool)
135 .await
136 .map_err(|e| {
137 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
138 "Migration failed (index idx_kv_store_expires_at): {e}"
139 )))
140 })?;
141
142 Ok(())
143 }
144}
145
146#[async_trait]
147impl AuthStorage for PostgresStorage {
148 async fn store_token(&self, token: &AuthToken) -> Result<()> {
149 sqlx::query(
150 r#"
151 INSERT INTO auth_tokens (
152 token_id, user_id, access_token, refresh_token, token_type,
153 expires_at, scopes, issued_at, auth_method, subject, issuer,
154 client_id, metadata
155 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
156 ON CONFLICT (token_id) DO UPDATE SET
157 access_token = EXCLUDED.access_token,
158 refresh_token = EXCLUDED.refresh_token,
159 expires_at = EXCLUDED.expires_at
160 "#,
161 )
162 .bind(&token.token_id)
163 .bind(&token.user_id)
164 .bind(&token.access_token)
165 .bind(&token.refresh_token)
166 .bind(&token.token_type)
167 .bind(token.expires_at)
168 .bind(&token.scopes)
169 .bind(token.issued_at)
170 .bind(&token.auth_method)
171 .bind(&token.subject)
172 .bind(&token.issuer)
173 .bind(&token.client_id)
174 .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
175 .execute(&self.pool)
176 .await
177 .map_err(|e| {
178 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
179 "Failed to store token: {}",
180 e
181 )))
182 })?;
183
184 Ok(())
185 }
186
187 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
189 let row =
190 sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE token_id = $1"#)
191 .bind(token_id)
192 .fetch_optional(&self.pool)
193 .await
194 .map_err(|e| {
195 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
196 "Failed to fetch token: {}",
197 e
198 )))
199 })?;
200 Ok(row)
201 }
202
203 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
204 let row =
205 sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE access_token = $1"#)
206 .bind(access_token)
207 .fetch_optional(&self.pool)
208 .await
209 .map_err(|e| {
210 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
211 "Failed to fetch token by access_token: {}",
212 e
213 )))
214 })?;
215 Ok(row)
216 }
217
218 async fn update_token(&self, token: &AuthToken) -> Result<()> {
219 sqlx::query(
220 r#"
221 UPDATE auth_tokens SET
222 access_token = $1,
223 refresh_token = $2,
224 token_type = $3,
225 expires_at = $4,
226 scopes = $5,
227 issued_at = $6,
228 auth_method = $7,
229 subject = $8,
230 issuer = $9,
231 client_id = $10,
232 metadata = $11
233 WHERE token_id = $12
234 "#,
235 )
236 .bind(&token.access_token)
237 .bind(&token.refresh_token)
238 .bind(&token.token_type)
239 .bind(token.expires_at)
240 .bind(&token.scopes)
241 .bind(token.issued_at)
242 .bind(&token.auth_method)
243 .bind(&token.subject)
244 .bind(&token.issuer)
245 .bind(&token.client_id)
246 .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
247 .bind(&token.token_id)
248 .execute(&self.pool)
249 .await
250 .map_err(|e| {
251 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
252 "Failed to update token: {}",
253 e
254 )))
255 })?;
256 Ok(())
257 }
258
259 async fn delete_token(&self, token_id: &str) -> Result<()> {
260 sqlx::query(r#"DELETE FROM auth_tokens WHERE token_id = $1"#)
261 .bind(token_id)
262 .execute(&self.pool)
263 .await
264 .map_err(|e| {
265 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
266 "Failed to delete token: {}",
267 e
268 )))
269 })?;
270 Ok(())
271 }
272
273 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
274 let tokens =
275 sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE user_id = $1"#)
276 .bind(user_id)
277 .fetch_all(&self.pool)
278 .await
279 .map_err(|e| {
280 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
281 "Failed to list user tokens: {}",
282 e
283 )))
284 })?;
285 Ok(tokens)
286 }
287
288 async fn store_session(
289 &self,
290 session_id: &str,
291 data: &crate::storage::core::SessionData,
292 ) -> Result<()> {
293 sqlx::query(
294 r#"
295 INSERT INTO sessions (session_id, user_id, data, expires_at)
296 VALUES ($1, $2, $3, $4)
297 ON CONFLICT (session_id) DO UPDATE SET
298 data = EXCLUDED.data,
299 expires_at = EXCLUDED.expires_at
300 "#,
301 )
302 .bind(session_id)
303 .bind(&data.user_id)
304 .bind(serde_json::to_value(data).unwrap_or_default())
305 .bind(data.expires_at)
306 .execute(&self.pool)
307 .await
308 .map_err(|e| {
309 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
310 "Failed to store session: {}",
311 e
312 )))
313 })?;
314 Ok(())
315 }
316
317 async fn get_session(
318 &self,
319 session_id: &str,
320 ) -> Result<Option<crate::storage::core::SessionData>> {
321 let row = sqlx::query(r#"SELECT data FROM sessions WHERE session_id = $1"#)
322 .bind(session_id)
323 .fetch_optional(&self.pool)
324 .await
325 .map_err(|e| {
326 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
327 "Failed to fetch session: {}",
328 e
329 )))
330 })?;
331 if let Some(row) = row {
332 let data: serde_json::Value = row.try_get("data").map_err(|e| {
333 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
334 "Failed to deserialize session data: {}",
335 e
336 )))
337 })?;
338 let session: crate::storage::core::SessionData =
339 serde_json::from_value(data).map_err(|e| {
340 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
341 "Failed to parse session data: {}",
342 e
343 )))
344 })?;
345 Ok(Some(session))
346 } else {
347 Ok(None)
348 }
349 }
350
351 async fn delete_session(&self, session_id: &str) -> Result<()> {
352 sqlx::query(r#"DELETE FROM sessions WHERE session_id = $1"#)
353 .bind(session_id)
354 .execute(&self.pool)
355 .await
356 .map_err(|e| {
357 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
358 "Failed to delete session: {}",
359 e
360 )))
361 })?;
362 Ok(())
363 }
364
365 async fn store_kv(
366 &self,
367 key: &str,
368 value: &[u8],
369 ttl: Option<std::time::Duration>,
370 ) -> Result<()> {
371 let expires_at = ttl.map(|d| {
372 chrono::Utc::now()
373 + chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(0))
374 });
375 sqlx::query(
376 r#"
377 INSERT INTO kv_store (key, value, expires_at)
378 VALUES ($1, $2, $3)
379 ON CONFLICT (key) DO UPDATE SET
380 value = EXCLUDED.value,
381 expires_at = EXCLUDED.expires_at
382 "#,
383 )
384 .bind(key)
385 .bind(value)
386 .bind(expires_at)
387 .execute(&self.pool)
388 .await
389 .map_err(|e| {
390 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
391 "Failed to store kv: {}",
392 e
393 )))
394 })?;
395 Ok(())
396 }
397
398 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
399 let row = sqlx::query(
400 r#"SELECT value FROM kv_store WHERE key = $1 AND (expires_at IS NULL OR expires_at > NOW())"#
401 )
402 .bind(key)
403 .fetch_optional(&self.pool)
404 .await
405 .map_err(|e| {
406 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
407 "Failed to fetch kv: {}", e
408 )))
409 })?;
410 if let Some(row) = row {
411 let value: Vec<u8> = row.try_get("value").map_err(|e| {
412 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
413 "Failed to deserialize kv value: {}",
414 e
415 )))
416 })?;
417 Ok(Some(value))
418 } else {
419 Ok(None)
420 }
421 }
422
423 async fn delete_kv(&self, key: &str) -> Result<()> {
424 sqlx::query(r#"DELETE FROM kv_store WHERE key = $1"#)
425 .bind(key)
426 .execute(&self.pool)
427 .await
428 .map_err(|e| {
429 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
430 "Failed to delete kv: {}",
431 e
432 )))
433 })?;
434 Ok(())
435 }
436
437 async fn list_kv_keys(&self, prefix: &str) -> Result<Vec<String>> {
438 let rows = sqlx::query(
439 r#"
440 SELECT key
441 FROM kv_store
442 WHERE key LIKE $1 AND (expires_at IS NULL OR expires_at > NOW())
443 ORDER BY key
444 "#,
445 )
446 .bind(format!("{prefix}%"))
447 .fetch_all(&self.pool)
448 .await
449 .map_err(|e| {
450 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
451 "Failed to list kv keys: {}",
452 e
453 )))
454 })?;
455
456 rows.into_iter()
457 .map(|row| {
458 row.try_get("key").map_err(|e| {
459 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
460 "Failed to decode kv key: {}",
461 e
462 )))
463 })
464 })
465 .collect()
466 }
467
468 async fn cleanup_expired(&self) -> Result<()> {
469 sqlx::query("DELETE FROM auth_tokens WHERE expires_at < NOW()")
471 .execute(&self.pool)
472 .await
473 .map_err(|e| {
474 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
475 "Failed to cleanup expired tokens: {}",
476 e
477 )))
478 })?;
479
480 sqlx::query("DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at < NOW()")
482 .execute(&self.pool)
483 .await
484 .map_err(|e| {
485 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
486 "Failed to cleanup expired sessions: {}",
487 e
488 )))
489 })?;
490
491 sqlx::query("DELETE FROM kv_store WHERE expires_at IS NOT NULL AND expires_at < NOW()")
493 .execute(&self.pool)
494 .await
495 .map_err(|e| {
496 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
497 "Failed to cleanup expired kv: {}",
498 e
499 )))
500 })?;
501
502 Ok(())
503 }
504
505 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
506 let rows = sqlx::query(
507 r#"
508 SELECT session_id, user_id, data, expires_at, created_at, last_activity, ip_address, user_agent
509 FROM sessions
510 WHERE user_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
511 "#,
512 )
513 .bind(user_id)
514 .fetch_all(&self.pool)
515 .await
516 .map_err(|e| {
517 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
518 "Failed to list user sessions: {}",
519 e
520 )))
521 })?;
522
523 let sessions = rows
524 .into_iter()
525 .map(|row| SessionData {
526 session_id: row.try_get("session_id").unwrap_or_default(),
527 user_id: row.try_get("user_id").unwrap_or_default(),
528 created_at: row.try_get("created_at").unwrap_or_default(),
529 data: {
530 let json_value: serde_json::Value = row.try_get("data").unwrap_or_default();
531 if let serde_json::Value::Object(map) = json_value {
532 map.into_iter().collect()
533 } else {
534 std::collections::HashMap::new()
535 }
536 },
537 expires_at: row.try_get("expires_at").unwrap_or_default(),
538 last_activity: row.try_get("last_activity").unwrap_or_default(),
539 ip_address: row.try_get("ip_address").ok(),
540 user_agent: row.try_get("user_agent").ok(),
541 })
542 .collect();
543
544 Ok(sessions)
545 }
546
547 async fn count_active_sessions(&self) -> Result<u64> {
548 let row = sqlx::query(
549 "SELECT COUNT(*) as count FROM sessions WHERE expires_at IS NULL OR expires_at > NOW()",
550 )
551 .fetch_one(&self.pool)
552 .await
553 .map_err(|e| {
554 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
555 "Failed to count active sessions: {}",
556 e
557 )))
558 })?;
559
560 let count: i64 = row.try_get("count").map_err(|e| {
561 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
562 "Failed to parse session count: {}",
563 e
564 )))
565 })?;
566
567 Ok(count as u64)
568 }
569}