Skip to main content

auth_framework/storage/
postgres.rs

1/// PostgreSQL storage implementation for auth-framework.
2/// This module provides a production-ready PostgreSQL backend for storing
3/// authentication tokens, sessions, and audit logs.
4use crate::errors::{AuthError, Result};
5use crate::storage::{AuthStorage, SessionData};
6use crate::tokens::AuthToken;
7use async_trait::async_trait;
8use sqlx::PgPool;
9use sqlx::Row;
10// use std::time::Duration;
11
12/// PostgreSQL storage backend
13pub struct PostgresStorage {
14    pool: PgPool,
15}
16
17impl PostgresStorage {
18    /// Create a new PostgreSQL storage instance.
19    ///
20    /// # Example
21    /// ```rust,ignore
22    /// let pool = PgPool::connect("postgres://user:pass@localhost/db").await?;
23    /// let storage = PostgresStorage::new(pool);
24    /// ```
25    pub fn new(pool: PgPool) -> Self {
26        Self { pool }
27    }
28
29    /// Initialize database tables.
30    ///
31    /// Creates the `auth_tokens`, `sessions`, and `kv_store` tables together with
32    /// their secondary indexes if they do not already exist.  Safe to call on every
33    /// application startup (`IF NOT EXISTS` / `IF NOT EXISTS` guards are idempotent).
34    ///
35    /// # Errors
36    /// Returns an error if any DDL statement fails (e.g. insufficient privileges).
37    pub async fn migrate(&self) -> Result<()> {
38        // Each statement is executed separately because sqlx::query() accepts
39        // exactly one SQL statement per call.
40
41        // --- auth_tokens table ---
42        sqlx::query(
43            r#"
44            CREATE TABLE IF NOT EXISTS auth_tokens (
45                token_id    VARCHAR(255) PRIMARY KEY,
46                user_id     VARCHAR(255) NOT NULL,
47                access_token TEXT        NOT NULL UNIQUE,
48                refresh_token TEXT,
49                token_type  VARCHAR(50),
50                expires_at  TIMESTAMPTZ  NOT NULL,
51                scopes      TEXT[],
52                issued_at   TIMESTAMPTZ  NOT NULL,
53                auth_method VARCHAR(100) NOT NULL,
54                subject     VARCHAR(255),
55                issuer      VARCHAR(255),
56                client_id   VARCHAR(255),
57                metadata    JSONB,
58                created_at  TIMESTAMPTZ  DEFAULT NOW()
59            )
60            "#,
61        )
62        .execute(&self.pool)
63        .await
64        .map_err(|e| {
65            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
66                "Migration failed (auth_tokens): {e}"
67            )))
68        })?;
69
70        // Secondary indexes for auth_tokens (PostgreSQL requires separate CREATE INDEX statements)
71        for stmt in [
72            "CREATE INDEX IF NOT EXISTS idx_auth_tokens_user_id ON auth_tokens (user_id)",
73            "CREATE INDEX IF NOT EXISTS idx_auth_tokens_access_token ON auth_tokens (access_token)",
74            "CREATE INDEX IF NOT EXISTS idx_auth_tokens_expires_at ON auth_tokens (expires_at)",
75        ] {
76            sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
77                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
78                    "Migration failed (index): {e}"
79                )))
80            })?;
81        }
82
83        // --- sessions table ---
84        sqlx::query(
85            r#"
86            CREATE TABLE IF NOT EXISTS sessions (
87                session_id VARCHAR(255) PRIMARY KEY,
88                user_id    VARCHAR(255) NOT NULL,
89                data       JSONB        NOT NULL,
90                expires_at TIMESTAMPTZ,
91                created_at TIMESTAMPTZ  DEFAULT NOW()
92            )
93            "#,
94        )
95        .execute(&self.pool)
96        .await
97        .map_err(|e| {
98            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
99                "Migration failed (sessions): {e}"
100            )))
101        })?;
102
103        for stmt in [
104            "CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions (user_id)",
105            "CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions (expires_at)",
106        ] {
107            sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
108                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
109                    "Migration failed (index): {e}"
110                )))
111            })?;
112        }
113
114        // --- kv_store table ---
115        sqlx::query(
116            r#"
117            CREATE TABLE IF NOT EXISTS kv_store (
118                key        VARCHAR(512) PRIMARY KEY,
119                value      BYTEA        NOT NULL,
120                expires_at TIMESTAMPTZ,
121                created_at TIMESTAMPTZ  DEFAULT NOW()
122            )
123            "#,
124        )
125        .execute(&self.pool)
126        .await
127        .map_err(|e| {
128            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
129                "Migration failed (kv_store): {e}"
130            )))
131        })?;
132
133        sqlx::query("CREATE INDEX IF NOT EXISTS idx_kv_store_expires_at ON kv_store (expires_at)")
134            .execute(&self.pool)
135            .await
136            .map_err(|e| {
137                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
138                    "Migration failed (index idx_kv_store_expires_at): {e}"
139                )))
140            })?;
141
142        Ok(())
143    }
144}
145
146#[async_trait]
147impl AuthStorage for PostgresStorage {
148    async fn store_token(&self, token: &AuthToken) -> Result<()> {
149        sqlx::query(
150            r#"
151            INSERT INTO auth_tokens (
152                token_id, user_id, access_token, refresh_token, token_type,
153                expires_at, scopes, issued_at, auth_method, subject, issuer,
154                client_id, metadata
155            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
156            ON CONFLICT (token_id) DO UPDATE SET
157                access_token = EXCLUDED.access_token,
158                refresh_token = EXCLUDED.refresh_token,
159                expires_at = EXCLUDED.expires_at
160            "#,
161        )
162        .bind(&token.token_id)
163        .bind(&token.user_id)
164        .bind(&token.access_token)
165        .bind(&token.refresh_token)
166        .bind(&token.token_type)
167        .bind(token.expires_at)
168        .bind(&token.scopes)
169        .bind(token.issued_at)
170        .bind(&token.auth_method)
171        .bind(&token.subject)
172        .bind(&token.issuer)
173        .bind(&token.client_id)
174        .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
175        .execute(&self.pool)
176        .await
177        .map_err(|e| {
178            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
179                "Failed to store token: {}",
180                e
181            )))
182        })?;
183
184        Ok(())
185    }
186
187    // ... implement other AuthStorage methods
188    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
189        let row =
190            sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE token_id = $1"#)
191                .bind(token_id)
192                .fetch_optional(&self.pool)
193                .await
194                .map_err(|e| {
195                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
196                        "Failed to fetch token: {}",
197                        e
198                    )))
199                })?;
200        Ok(row)
201    }
202
203    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
204        let row =
205            sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE access_token = $1"#)
206                .bind(access_token)
207                .fetch_optional(&self.pool)
208                .await
209                .map_err(|e| {
210                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
211                        "Failed to fetch token by access_token: {}",
212                        e
213                    )))
214                })?;
215        Ok(row)
216    }
217
218    async fn update_token(&self, token: &AuthToken) -> Result<()> {
219        sqlx::query(
220            r#"
221            UPDATE auth_tokens SET
222                access_token = $1,
223                refresh_token = $2,
224                token_type = $3,
225                expires_at = $4,
226                scopes = $5,
227                issued_at = $6,
228                auth_method = $7,
229                subject = $8,
230                issuer = $9,
231                client_id = $10,
232                metadata = $11
233            WHERE token_id = $12
234            "#,
235        )
236        .bind(&token.access_token)
237        .bind(&token.refresh_token)
238        .bind(&token.token_type)
239        .bind(token.expires_at)
240        .bind(&token.scopes)
241        .bind(token.issued_at)
242        .bind(&token.auth_method)
243        .bind(&token.subject)
244        .bind(&token.issuer)
245        .bind(&token.client_id)
246        .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
247        .bind(&token.token_id)
248        .execute(&self.pool)
249        .await
250        .map_err(|e| {
251            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
252                "Failed to update token: {}",
253                e
254            )))
255        })?;
256        Ok(())
257    }
258
259    async fn delete_token(&self, token_id: &str) -> Result<()> {
260        sqlx::query(r#"DELETE FROM auth_tokens WHERE token_id = $1"#)
261            .bind(token_id)
262            .execute(&self.pool)
263            .await
264            .map_err(|e| {
265                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
266                    "Failed to delete token: {}",
267                    e
268                )))
269            })?;
270        Ok(())
271    }
272
273    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
274        let tokens =
275            sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE user_id = $1"#)
276                .bind(user_id)
277                .fetch_all(&self.pool)
278                .await
279                .map_err(|e| {
280                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
281                        "Failed to list user tokens: {}",
282                        e
283                    )))
284                })?;
285        Ok(tokens)
286    }
287
288    async fn store_session(
289        &self,
290        session_id: &str,
291        data: &crate::storage::core::SessionData,
292    ) -> Result<()> {
293        sqlx::query(
294            r#"
295            INSERT INTO sessions (session_id, user_id, data, expires_at)
296            VALUES ($1, $2, $3, $4)
297            ON CONFLICT (session_id) DO UPDATE SET
298                data = EXCLUDED.data,
299                expires_at = EXCLUDED.expires_at
300            "#,
301        )
302        .bind(session_id)
303        .bind(&data.user_id)
304        .bind(serde_json::to_value(data).unwrap_or_default())
305        .bind(data.expires_at)
306        .execute(&self.pool)
307        .await
308        .map_err(|e| {
309            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
310                "Failed to store session: {}",
311                e
312            )))
313        })?;
314        Ok(())
315    }
316
317    async fn get_session(
318        &self,
319        session_id: &str,
320    ) -> Result<Option<crate::storage::core::SessionData>> {
321        let row = sqlx::query(r#"SELECT data FROM sessions WHERE session_id = $1"#)
322            .bind(session_id)
323            .fetch_optional(&self.pool)
324            .await
325            .map_err(|e| {
326                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
327                    "Failed to fetch session: {}",
328                    e
329                )))
330            })?;
331        if let Some(row) = row {
332            let data: serde_json::Value = row.try_get("data").map_err(|e| {
333                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
334                    "Failed to deserialize session data: {}",
335                    e
336                )))
337            })?;
338            let session: crate::storage::core::SessionData =
339                serde_json::from_value(data).map_err(|e| {
340                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
341                        "Failed to parse session data: {}",
342                        e
343                    )))
344                })?;
345            Ok(Some(session))
346        } else {
347            Ok(None)
348        }
349    }
350
351    async fn delete_session(&self, session_id: &str) -> Result<()> {
352        sqlx::query(r#"DELETE FROM sessions WHERE session_id = $1"#)
353            .bind(session_id)
354            .execute(&self.pool)
355            .await
356            .map_err(|e| {
357                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
358                    "Failed to delete session: {}",
359                    e
360                )))
361            })?;
362        Ok(())
363    }
364
365    async fn store_kv(
366        &self,
367        key: &str,
368        value: &[u8],
369        ttl: Option<std::time::Duration>,
370    ) -> Result<()> {
371        let expires_at = ttl.map(|d| {
372            chrono::Utc::now()
373                + chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(0))
374        });
375        sqlx::query(
376            r#"
377            INSERT INTO kv_store (key, value, expires_at)
378            VALUES ($1, $2, $3)
379            ON CONFLICT (key) DO UPDATE SET
380                value = EXCLUDED.value,
381                expires_at = EXCLUDED.expires_at
382            "#,
383        )
384        .bind(key)
385        .bind(value)
386        .bind(expires_at)
387        .execute(&self.pool)
388        .await
389        .map_err(|e| {
390            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
391                "Failed to store kv: {}",
392                e
393            )))
394        })?;
395        Ok(())
396    }
397
398    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
399        let row = sqlx::query(
400            r#"SELECT value FROM kv_store WHERE key = $1 AND (expires_at IS NULL OR expires_at > NOW())"#
401        )
402        .bind(key)
403        .fetch_optional(&self.pool)
404        .await
405        .map_err(|e| {
406            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
407                "Failed to fetch kv: {}", e
408            )))
409        })?;
410        if let Some(row) = row {
411            let value: Vec<u8> = row.try_get("value").map_err(|e| {
412                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
413                    "Failed to deserialize kv value: {}",
414                    e
415                )))
416            })?;
417            Ok(Some(value))
418        } else {
419            Ok(None)
420        }
421    }
422
423    async fn delete_kv(&self, key: &str) -> Result<()> {
424        sqlx::query(r#"DELETE FROM kv_store WHERE key = $1"#)
425            .bind(key)
426            .execute(&self.pool)
427            .await
428            .map_err(|e| {
429                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
430                    "Failed to delete kv: {}",
431                    e
432                )))
433            })?;
434        Ok(())
435    }
436
437    async fn list_kv_keys(&self, prefix: &str) -> Result<Vec<String>> {
438        let rows = sqlx::query(
439            r#"
440            SELECT key
441            FROM kv_store
442            WHERE key LIKE $1 AND (expires_at IS NULL OR expires_at > NOW())
443            ORDER BY key
444            "#,
445        )
446        .bind(format!("{prefix}%"))
447        .fetch_all(&self.pool)
448        .await
449        .map_err(|e| {
450            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
451                "Failed to list kv keys: {}",
452                e
453            )))
454        })?;
455
456        rows.into_iter()
457            .map(|row| {
458                row.try_get("key").map_err(|e| {
459                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
460                        "Failed to decode kv key: {}",
461                        e
462                    )))
463                })
464            })
465            .collect()
466    }
467
468    async fn cleanup_expired(&self) -> Result<()> {
469        // Remove expired tokens
470        sqlx::query("DELETE FROM auth_tokens WHERE expires_at < NOW()")
471            .execute(&self.pool)
472            .await
473            .map_err(|e| {
474                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
475                    "Failed to cleanup expired tokens: {}",
476                    e
477                )))
478            })?;
479
480        // Remove expired sessions
481        sqlx::query("DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at < NOW()")
482            .execute(&self.pool)
483            .await
484            .map_err(|e| {
485                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
486                    "Failed to cleanup expired sessions: {}",
487                    e
488                )))
489            })?;
490
491        // Remove expired kv entries
492        sqlx::query("DELETE FROM kv_store WHERE expires_at IS NOT NULL AND expires_at < NOW()")
493            .execute(&self.pool)
494            .await
495            .map_err(|e| {
496                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
497                    "Failed to cleanup expired kv: {}",
498                    e
499                )))
500            })?;
501
502        Ok(())
503    }
504
505    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
506        let rows = sqlx::query(
507            r#"
508            SELECT session_id, user_id, data, expires_at, created_at, last_activity, ip_address, user_agent
509            FROM sessions
510            WHERE user_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
511            "#,
512        )
513        .bind(user_id)
514        .fetch_all(&self.pool)
515        .await
516        .map_err(|e| {
517            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
518                "Failed to list user sessions: {}",
519                e
520            )))
521        })?;
522
523        let sessions = rows
524            .into_iter()
525            .map(|row| SessionData {
526                session_id: row.try_get("session_id").unwrap_or_default(),
527                user_id: row.try_get("user_id").unwrap_or_default(),
528                created_at: row.try_get("created_at").unwrap_or_default(),
529                data: {
530                    let json_value: serde_json::Value = row.try_get("data").unwrap_or_default();
531                    if let serde_json::Value::Object(map) = json_value {
532                        map.into_iter().collect()
533                    } else {
534                        std::collections::HashMap::new()
535                    }
536                },
537                expires_at: row.try_get("expires_at").unwrap_or_default(),
538                last_activity: row.try_get("last_activity").unwrap_or_default(),
539                ip_address: row.try_get("ip_address").ok(),
540                user_agent: row.try_get("user_agent").ok(),
541            })
542            .collect();
543
544        Ok(sessions)
545    }
546
547    async fn count_active_sessions(&self) -> Result<u64> {
548        let row = sqlx::query(
549            "SELECT COUNT(*) as count FROM sessions WHERE expires_at IS NULL OR expires_at > NOW()",
550        )
551        .fetch_one(&self.pool)
552        .await
553        .map_err(|e| {
554            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
555                "Failed to count active sessions: {}",
556                e
557            )))
558        })?;
559
560        let count: i64 = row.try_get("count").map_err(|e| {
561            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
562                "Failed to parse session count: {}",
563                e
564            )))
565        })?;
566
567        Ok(count as u64)
568    }
569}