auth_framework/storage/
postgres.rs

1/// PostgreSQL storage implementation for auth-framework.
2/// This module provides a production-ready PostgreSQL backend for storing
3/// authentication tokens, sessions, and audit logs.
4use crate::errors::{AuthError, Result};
5use crate::storage::{AuthStorage, SessionData};
6use crate::tokens::AuthToken;
7use async_trait::async_trait;
8use sqlx::PgPool;
9use sqlx::Row;
10// use std::time::Duration;
11
12/// PostgreSQL storage backend
13pub struct PostgresStorage {
14    pool: PgPool,
15}
16
17impl PostgresStorage {
18    /// Create a new PostgreSQL storage instance
19    pub fn new(pool: PgPool) -> Self {
20        Self { pool }
21    }
22
23    /// Initialize database tables
24    pub async fn migrate(&self) -> Result<()> {
25        sqlx::query(
26            r#"
27            CREATE TABLE IF NOT EXISTS auth_tokens (
28                token_id VARCHAR(255) PRIMARY KEY,
29                user_id VARCHAR(255) NOT NULL,
30                access_token TEXT NOT NULL UNIQUE,
31                refresh_token TEXT,
32                token_type VARCHAR(50),
33                expires_at TIMESTAMPTZ NOT NULL,
34                scopes TEXT[],
35                issued_at TIMESTAMPTZ NOT NULL,
36                auth_method VARCHAR(100) NOT NULL,
37                subject VARCHAR(255),
38                issuer VARCHAR(255),
39                client_id VARCHAR(255),
40                metadata JSONB,
41                created_at TIMESTAMPTZ DEFAULT NOW(),
42                INDEX idx_auth_tokens_user_id (user_id),
43                INDEX idx_auth_tokens_access_token (access_token),
44                INDEX idx_auth_tokens_expires_at (expires_at)
45            );
46
47            CREATE TABLE IF NOT EXISTS sessions (
48                session_id VARCHAR(255) PRIMARY KEY,
49                user_id VARCHAR(255) NOT NULL,
50                data JSONB NOT NULL,
51                expires_at TIMESTAMPTZ,
52                created_at TIMESTAMPTZ DEFAULT NOW(),
53                INDEX idx_sessions_user_id (user_id),
54                INDEX idx_sessions_expires_at (expires_at)
55            );
56
57            CREATE TABLE IF NOT EXISTS kv_store (
58                key VARCHAR(255) PRIMARY KEY,
59                value BYTEA NOT NULL,
60                expires_at TIMESTAMPTZ,
61                created_at TIMESTAMPTZ DEFAULT NOW(),
62                INDEX idx_kv_store_expires_at (expires_at)
63            );
64            "#,
65        )
66        .execute(&self.pool)
67        .await
68        .map_err(|e| {
69            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
70                "Migration failed: {}",
71                e
72            )))
73        })?;
74
75        Ok(())
76    }
77}
78
79#[async_trait]
80impl AuthStorage for PostgresStorage {
81    async fn store_token(&self, token: &AuthToken) -> Result<()> {
82        sqlx::query(
83            r#"
84            INSERT INTO auth_tokens (
85                token_id, user_id, access_token, refresh_token, token_type,
86                expires_at, scopes, issued_at, auth_method, subject, issuer,
87                client_id, metadata
88            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
89            ON CONFLICT (token_id) DO UPDATE SET
90                access_token = EXCLUDED.access_token,
91                refresh_token = EXCLUDED.refresh_token,
92                expires_at = EXCLUDED.expires_at
93            "#,
94        )
95        .bind(&token.token_id)
96        .bind(&token.user_id)
97        .bind(&token.access_token)
98        .bind(&token.refresh_token)
99        .bind(&token.token_type)
100        .bind(token.expires_at)
101        .bind(&token.scopes)
102        .bind(token.issued_at)
103        .bind(&token.auth_method)
104        .bind(&token.subject)
105        .bind(&token.issuer)
106        .bind(&token.client_id)
107        .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
108        .execute(&self.pool)
109        .await
110        .map_err(|e| {
111            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
112                "Failed to store token: {}",
113                e
114            )))
115        })?;
116
117        Ok(())
118    }
119
120    // ... implement other AuthStorage methods
121    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
122        let row =
123            sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE token_id = $1"#)
124                .bind(token_id)
125                .fetch_optional(&self.pool)
126                .await
127                .map_err(|e| {
128                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
129                        "Failed to fetch token: {}",
130                        e
131                    )))
132                })?;
133        Ok(row)
134    }
135
136    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
137        let row =
138            sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE access_token = $1"#)
139                .bind(access_token)
140                .fetch_optional(&self.pool)
141                .await
142                .map_err(|e| {
143                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
144                        "Failed to fetch token by access_token: {}",
145                        e
146                    )))
147                })?;
148        Ok(row)
149    }
150
151    async fn update_token(&self, token: &AuthToken) -> Result<()> {
152        sqlx::query(
153            r#"
154            UPDATE auth_tokens SET
155                access_token = $1,
156                refresh_token = $2,
157                token_type = $3,
158                expires_at = $4,
159                scopes = $5,
160                issued_at = $6,
161                auth_method = $7,
162                subject = $8,
163                issuer = $9,
164                client_id = $10,
165                metadata = $11
166            WHERE token_id = $12
167            "#,
168        )
169        .bind(&token.access_token)
170        .bind(&token.refresh_token)
171        .bind(&token.token_type)
172        .bind(token.expires_at)
173        .bind(&token.scopes)
174        .bind(token.issued_at)
175        .bind(&token.auth_method)
176        .bind(&token.subject)
177        .bind(&token.issuer)
178        .bind(&token.client_id)
179        .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
180        .bind(&token.token_id)
181        .execute(&self.pool)
182        .await
183        .map_err(|e| {
184            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
185                "Failed to update token: {}",
186                e
187            )))
188        })?;
189        Ok(())
190    }
191
192    async fn delete_token(&self, token_id: &str) -> Result<()> {
193        sqlx::query(r#"DELETE FROM auth_tokens WHERE token_id = $1"#)
194            .bind(token_id)
195            .execute(&self.pool)
196            .await
197            .map_err(|e| {
198                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
199                    "Failed to delete token: {}",
200                    e
201                )))
202            })?;
203        Ok(())
204    }
205
206    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
207        let tokens =
208            sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE user_id = $1"#)
209                .bind(user_id)
210                .fetch_all(&self.pool)
211                .await
212                .map_err(|e| {
213                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
214                        "Failed to list user tokens: {}",
215                        e
216                    )))
217                })?;
218        Ok(tokens)
219    }
220
221    async fn store_session(
222        &self,
223        session_id: &str,
224        data: &crate::storage::core::SessionData,
225    ) -> Result<()> {
226        sqlx::query(
227            r#"
228            INSERT INTO sessions (session_id, user_id, data, expires_at)
229            VALUES ($1, $2, $3, $4)
230            ON CONFLICT (session_id) DO UPDATE SET
231                data = EXCLUDED.data,
232                expires_at = EXCLUDED.expires_at
233            "#,
234        )
235        .bind(session_id)
236        .bind(&data.user_id)
237        .bind(serde_json::to_value(data).unwrap_or_default())
238        .bind(data.expires_at)
239        .execute(&self.pool)
240        .await
241        .map_err(|e| {
242            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
243                "Failed to store session: {}",
244                e
245            )))
246        })?;
247        Ok(())
248    }
249
250    async fn get_session(
251        &self,
252        session_id: &str,
253    ) -> Result<Option<crate::storage::core::SessionData>> {
254        let row = sqlx::query(r#"SELECT data FROM sessions WHERE session_id = $1"#)
255            .bind(session_id)
256            .fetch_optional(&self.pool)
257            .await
258            .map_err(|e| {
259                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
260                    "Failed to fetch session: {}",
261                    e
262                )))
263            })?;
264        if let Some(row) = row {
265            let data: serde_json::Value = row.try_get("data").map_err(|e| {
266                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
267                    "Failed to deserialize session data: {}",
268                    e
269                )))
270            })?;
271            let session: crate::storage::core::SessionData =
272                serde_json::from_value(data).map_err(|e| {
273                    AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
274                        "Failed to parse session data: {}",
275                        e
276                    )))
277                })?;
278            Ok(Some(session))
279        } else {
280            Ok(None)
281        }
282    }
283
284    async fn delete_session(&self, session_id: &str) -> Result<()> {
285        sqlx::query(r#"DELETE FROM sessions WHERE session_id = $1"#)
286            .bind(session_id)
287            .execute(&self.pool)
288            .await
289            .map_err(|e| {
290                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
291                    "Failed to delete session: {}",
292                    e
293                )))
294            })?;
295        Ok(())
296    }
297
298    async fn store_kv(
299        &self,
300        key: &str,
301        value: &[u8],
302        ttl: Option<std::time::Duration>,
303    ) -> Result<()> {
304        let expires_at = ttl.map(|d| {
305            chrono::Utc::now()
306                + chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(0))
307        });
308        sqlx::query(
309            r#"
310            INSERT INTO kv_store (key, value, expires_at)
311            VALUES ($1, $2, $3)
312            ON CONFLICT (key) DO UPDATE SET
313                value = EXCLUDED.value,
314                expires_at = EXCLUDED.expires_at
315            "#,
316        )
317        .bind(key)
318        .bind(value)
319        .bind(expires_at)
320        .execute(&self.pool)
321        .await
322        .map_err(|e| {
323            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
324                "Failed to store kv: {}",
325                e
326            )))
327        })?;
328        Ok(())
329    }
330
331    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
332        let row = sqlx::query(
333            r#"SELECT value FROM kv_store WHERE key = $1 AND (expires_at IS NULL OR expires_at > NOW())"#
334        )
335        .bind(key)
336        .fetch_optional(&self.pool)
337        .await
338        .map_err(|e| {
339            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
340                "Failed to fetch kv: {}", e
341            )))
342        })?;
343        if let Some(row) = row {
344            let value: Vec<u8> = row.try_get("value").map_err(|e| {
345                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
346                    "Failed to deserialize kv value: {}",
347                    e
348                )))
349            })?;
350            Ok(Some(value))
351        } else {
352            Ok(None)
353        }
354    }
355
356    async fn delete_kv(&self, key: &str) -> Result<()> {
357        sqlx::query(r#"DELETE FROM kv_store WHERE key = $1"#)
358            .bind(key)
359            .execute(&self.pool)
360            .await
361            .map_err(|e| {
362                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
363                    "Failed to delete kv: {}",
364                    e
365                )))
366            })?;
367        Ok(())
368    }
369
370    async fn cleanup_expired(&self) -> Result<()> {
371        // Remove expired tokens
372        sqlx::query("DELETE FROM auth_tokens WHERE expires_at < NOW()")
373            .execute(&self.pool)
374            .await
375            .map_err(|e| {
376                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
377                    "Failed to cleanup expired tokens: {}",
378                    e
379                )))
380            })?;
381
382        // Remove expired sessions
383        sqlx::query("DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at < NOW()")
384            .execute(&self.pool)
385            .await
386            .map_err(|e| {
387                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
388                    "Failed to cleanup expired sessions: {}",
389                    e
390                )))
391            })?;
392
393        // Remove expired kv entries
394        sqlx::query("DELETE FROM kv_store WHERE expires_at IS NOT NULL AND expires_at < NOW()")
395            .execute(&self.pool)
396            .await
397            .map_err(|e| {
398                AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
399                    "Failed to cleanup expired kv: {}",
400                    e
401                )))
402            })?;
403
404        Ok(())
405    }
406
407    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
408        let rows = sqlx::query(
409            r#"
410            SELECT session_id, user_id, data, expires_at, created_at, last_activity, ip_address, user_agent
411            FROM sessions
412            WHERE user_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
413            "#,
414        )
415        .bind(user_id)
416        .fetch_all(&self.pool)
417        .await
418        .map_err(|e| {
419            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
420                "Failed to list user sessions: {}",
421                e
422            )))
423        })?;
424
425        let sessions = rows
426            .into_iter()
427            .map(|row| SessionData {
428                session_id: row.try_get("session_id").unwrap_or_default(),
429                user_id: row.try_get("user_id").unwrap_or_default(),
430                created_at: row.try_get("created_at").unwrap_or_default(),
431                data: {
432                    let json_value: serde_json::Value = row.try_get("data").unwrap_or_default();
433                    if let serde_json::Value::Object(map) = json_value {
434                        map.into_iter().collect()
435                    } else {
436                        std::collections::HashMap::new()
437                    }
438                },
439                expires_at: row.try_get("expires_at").unwrap_or_default(),
440                last_activity: row.try_get("last_activity").unwrap_or_default(),
441                ip_address: row.try_get("ip_address").ok(),
442                user_agent: row.try_get("user_agent").ok(),
443            })
444            .collect();
445
446        Ok(sessions)
447    }
448
449    async fn count_active_sessions(&self) -> Result<u64> {
450        let row = sqlx::query(
451            "SELECT COUNT(*) as count FROM sessions WHERE expires_at IS NULL OR expires_at > NOW()",
452        )
453        .fetch_one(&self.pool)
454        .await
455        .map_err(|e| {
456            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
457                "Failed to count active sessions: {}",
458                e
459            )))
460        })?;
461
462        let count: i64 = row.try_get("count").map_err(|e| {
463            AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
464                "Failed to parse session count: {}",
465                e
466            )))
467        })?;
468
469        Ok(count as u64)
470    }
471}
472
473