1use crate::errors::{AuthError, Result};
5use crate::storage::{AuthStorage, SessionData};
6use crate::tokens::AuthToken;
7use async_trait::async_trait;
8use sqlx::PgPool;
9use sqlx::Row;
10pub struct PostgresStorage {
14 pool: PgPool,
15}
16
17impl PostgresStorage {
18 pub fn new(pool: PgPool) -> Self {
20 Self { pool }
21 }
22
23 pub async fn migrate(&self) -> Result<()> {
25 sqlx::query(
26 r#"
27 CREATE TABLE IF NOT EXISTS auth_tokens (
28 token_id VARCHAR(255) PRIMARY KEY,
29 user_id VARCHAR(255) NOT NULL,
30 access_token TEXT NOT NULL UNIQUE,
31 refresh_token TEXT,
32 token_type VARCHAR(50),
33 expires_at TIMESTAMPTZ NOT NULL,
34 scopes TEXT[],
35 issued_at TIMESTAMPTZ NOT NULL,
36 auth_method VARCHAR(100) NOT NULL,
37 subject VARCHAR(255),
38 issuer VARCHAR(255),
39 client_id VARCHAR(255),
40 metadata JSONB,
41 created_at TIMESTAMPTZ DEFAULT NOW(),
42 INDEX idx_auth_tokens_user_id (user_id),
43 INDEX idx_auth_tokens_access_token (access_token),
44 INDEX idx_auth_tokens_expires_at (expires_at)
45 );
46
47 CREATE TABLE IF NOT EXISTS sessions (
48 session_id VARCHAR(255) PRIMARY KEY,
49 user_id VARCHAR(255) NOT NULL,
50 data JSONB NOT NULL,
51 expires_at TIMESTAMPTZ,
52 created_at TIMESTAMPTZ DEFAULT NOW(),
53 INDEX idx_sessions_user_id (user_id),
54 INDEX idx_sessions_expires_at (expires_at)
55 );
56
57 CREATE TABLE IF NOT EXISTS kv_store (
58 key VARCHAR(255) PRIMARY KEY,
59 value BYTEA NOT NULL,
60 expires_at TIMESTAMPTZ,
61 created_at TIMESTAMPTZ DEFAULT NOW(),
62 INDEX idx_kv_store_expires_at (expires_at)
63 );
64 "#,
65 )
66 .execute(&self.pool)
67 .await
68 .map_err(|e| {
69 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
70 "Migration failed: {}",
71 e
72 )))
73 })?;
74
75 Ok(())
76 }
77}
78
79#[async_trait]
80impl AuthStorage for PostgresStorage {
81 async fn store_token(&self, token: &AuthToken) -> Result<()> {
82 sqlx::query(
83 r#"
84 INSERT INTO auth_tokens (
85 token_id, user_id, access_token, refresh_token, token_type,
86 expires_at, scopes, issued_at, auth_method, subject, issuer,
87 client_id, metadata
88 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
89 ON CONFLICT (token_id) DO UPDATE SET
90 access_token = EXCLUDED.access_token,
91 refresh_token = EXCLUDED.refresh_token,
92 expires_at = EXCLUDED.expires_at
93 "#,
94 )
95 .bind(&token.token_id)
96 .bind(&token.user_id)
97 .bind(&token.access_token)
98 .bind(&token.refresh_token)
99 .bind(&token.token_type)
100 .bind(token.expires_at)
101 .bind(&token.scopes)
102 .bind(token.issued_at)
103 .bind(&token.auth_method)
104 .bind(&token.subject)
105 .bind(&token.issuer)
106 .bind(&token.client_id)
107 .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
108 .execute(&self.pool)
109 .await
110 .map_err(|e| {
111 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
112 "Failed to store token: {}",
113 e
114 )))
115 })?;
116
117 Ok(())
118 }
119
120 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
122 let row =
123 sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE token_id = $1"#)
124 .bind(token_id)
125 .fetch_optional(&self.pool)
126 .await
127 .map_err(|e| {
128 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
129 "Failed to fetch token: {}",
130 e
131 )))
132 })?;
133 Ok(row)
134 }
135
136 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
137 let row =
138 sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE access_token = $1"#)
139 .bind(access_token)
140 .fetch_optional(&self.pool)
141 .await
142 .map_err(|e| {
143 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
144 "Failed to fetch token by access_token: {}",
145 e
146 )))
147 })?;
148 Ok(row)
149 }
150
151 async fn update_token(&self, token: &AuthToken) -> Result<()> {
152 sqlx::query(
153 r#"
154 UPDATE auth_tokens SET
155 access_token = $1,
156 refresh_token = $2,
157 token_type = $3,
158 expires_at = $4,
159 scopes = $5,
160 issued_at = $6,
161 auth_method = $7,
162 subject = $8,
163 issuer = $9,
164 client_id = $10,
165 metadata = $11
166 WHERE token_id = $12
167 "#,
168 )
169 .bind(&token.access_token)
170 .bind(&token.refresh_token)
171 .bind(&token.token_type)
172 .bind(token.expires_at)
173 .bind(&token.scopes)
174 .bind(token.issued_at)
175 .bind(&token.auth_method)
176 .bind(&token.subject)
177 .bind(&token.issuer)
178 .bind(&token.client_id)
179 .bind(serde_json::to_value(&token.metadata).unwrap_or_default())
180 .bind(&token.token_id)
181 .execute(&self.pool)
182 .await
183 .map_err(|e| {
184 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
185 "Failed to update token: {}",
186 e
187 )))
188 })?;
189 Ok(())
190 }
191
192 async fn delete_token(&self, token_id: &str) -> Result<()> {
193 sqlx::query(r#"DELETE FROM auth_tokens WHERE token_id = $1"#)
194 .bind(token_id)
195 .execute(&self.pool)
196 .await
197 .map_err(|e| {
198 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
199 "Failed to delete token: {}",
200 e
201 )))
202 })?;
203 Ok(())
204 }
205
206 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
207 let tokens =
208 sqlx::query_as::<_, AuthToken>(r#"SELECT * FROM auth_tokens WHERE user_id = $1"#)
209 .bind(user_id)
210 .fetch_all(&self.pool)
211 .await
212 .map_err(|e| {
213 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
214 "Failed to list user tokens: {}",
215 e
216 )))
217 })?;
218 Ok(tokens)
219 }
220
221 async fn store_session(
222 &self,
223 session_id: &str,
224 data: &crate::storage::core::SessionData,
225 ) -> Result<()> {
226 sqlx::query(
227 r#"
228 INSERT INTO sessions (session_id, user_id, data, expires_at)
229 VALUES ($1, $2, $3, $4)
230 ON CONFLICT (session_id) DO UPDATE SET
231 data = EXCLUDED.data,
232 expires_at = EXCLUDED.expires_at
233 "#,
234 )
235 .bind(session_id)
236 .bind(&data.user_id)
237 .bind(serde_json::to_value(data).unwrap_or_default())
238 .bind(data.expires_at)
239 .execute(&self.pool)
240 .await
241 .map_err(|e| {
242 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
243 "Failed to store session: {}",
244 e
245 )))
246 })?;
247 Ok(())
248 }
249
250 async fn get_session(
251 &self,
252 session_id: &str,
253 ) -> Result<Option<crate::storage::core::SessionData>> {
254 let row = sqlx::query(r#"SELECT data FROM sessions WHERE session_id = $1"#)
255 .bind(session_id)
256 .fetch_optional(&self.pool)
257 .await
258 .map_err(|e| {
259 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
260 "Failed to fetch session: {}",
261 e
262 )))
263 })?;
264 if let Some(row) = row {
265 let data: serde_json::Value = row.try_get("data").map_err(|e| {
266 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
267 "Failed to deserialize session data: {}",
268 e
269 )))
270 })?;
271 let session: crate::storage::core::SessionData =
272 serde_json::from_value(data).map_err(|e| {
273 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
274 "Failed to parse session data: {}",
275 e
276 )))
277 })?;
278 Ok(Some(session))
279 } else {
280 Ok(None)
281 }
282 }
283
284 async fn delete_session(&self, session_id: &str) -> Result<()> {
285 sqlx::query(r#"DELETE FROM sessions WHERE session_id = $1"#)
286 .bind(session_id)
287 .execute(&self.pool)
288 .await
289 .map_err(|e| {
290 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
291 "Failed to delete session: {}",
292 e
293 )))
294 })?;
295 Ok(())
296 }
297
298 async fn store_kv(
299 &self,
300 key: &str,
301 value: &[u8],
302 ttl: Option<std::time::Duration>,
303 ) -> Result<()> {
304 let expires_at = ttl.map(|d| {
305 chrono::Utc::now()
306 + chrono::Duration::from_std(d).unwrap_or(chrono::Duration::seconds(0))
307 });
308 sqlx::query(
309 r#"
310 INSERT INTO kv_store (key, value, expires_at)
311 VALUES ($1, $2, $3)
312 ON CONFLICT (key) DO UPDATE SET
313 value = EXCLUDED.value,
314 expires_at = EXCLUDED.expires_at
315 "#,
316 )
317 .bind(key)
318 .bind(value)
319 .bind(expires_at)
320 .execute(&self.pool)
321 .await
322 .map_err(|e| {
323 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
324 "Failed to store kv: {}",
325 e
326 )))
327 })?;
328 Ok(())
329 }
330
331 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
332 let row = sqlx::query(
333 r#"SELECT value FROM kv_store WHERE key = $1 AND (expires_at IS NULL OR expires_at > NOW())"#
334 )
335 .bind(key)
336 .fetch_optional(&self.pool)
337 .await
338 .map_err(|e| {
339 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
340 "Failed to fetch kv: {}", e
341 )))
342 })?;
343 if let Some(row) = row {
344 let value: Vec<u8> = row.try_get("value").map_err(|e| {
345 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
346 "Failed to deserialize kv value: {}",
347 e
348 )))
349 })?;
350 Ok(Some(value))
351 } else {
352 Ok(None)
353 }
354 }
355
356 async fn delete_kv(&self, key: &str) -> Result<()> {
357 sqlx::query(r#"DELETE FROM kv_store WHERE key = $1"#)
358 .bind(key)
359 .execute(&self.pool)
360 .await
361 .map_err(|e| {
362 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
363 "Failed to delete kv: {}",
364 e
365 )))
366 })?;
367 Ok(())
368 }
369
370 async fn cleanup_expired(&self) -> Result<()> {
371 sqlx::query("DELETE FROM auth_tokens WHERE expires_at < NOW()")
373 .execute(&self.pool)
374 .await
375 .map_err(|e| {
376 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
377 "Failed to cleanup expired tokens: {}",
378 e
379 )))
380 })?;
381
382 sqlx::query("DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at < NOW()")
384 .execute(&self.pool)
385 .await
386 .map_err(|e| {
387 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
388 "Failed to cleanup expired sessions: {}",
389 e
390 )))
391 })?;
392
393 sqlx::query("DELETE FROM kv_store WHERE expires_at IS NOT NULL AND expires_at < NOW()")
395 .execute(&self.pool)
396 .await
397 .map_err(|e| {
398 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
399 "Failed to cleanup expired kv: {}",
400 e
401 )))
402 })?;
403
404 Ok(())
405 }
406
407 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
408 let rows = sqlx::query(
409 r#"
410 SELECT session_id, user_id, data, expires_at, created_at, last_activity, ip_address, user_agent
411 FROM sessions
412 WHERE user_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
413 "#,
414 )
415 .bind(user_id)
416 .fetch_all(&self.pool)
417 .await
418 .map_err(|e| {
419 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
420 "Failed to list user sessions: {}",
421 e
422 )))
423 })?;
424
425 let sessions = rows
426 .into_iter()
427 .map(|row| SessionData {
428 session_id: row.try_get("session_id").unwrap_or_default(),
429 user_id: row.try_get("user_id").unwrap_or_default(),
430 created_at: row.try_get("created_at").unwrap_or_default(),
431 data: {
432 let json_value: serde_json::Value = row.try_get("data").unwrap_or_default();
433 if let serde_json::Value::Object(map) = json_value {
434 map.into_iter().collect()
435 } else {
436 std::collections::HashMap::new()
437 }
438 },
439 expires_at: row.try_get("expires_at").unwrap_or_default(),
440 last_activity: row.try_get("last_activity").unwrap_or_default(),
441 ip_address: row.try_get("ip_address").ok(),
442 user_agent: row.try_get("user_agent").ok(),
443 })
444 .collect();
445
446 Ok(sessions)
447 }
448
449 async fn count_active_sessions(&self) -> Result<u64> {
450 let row = sqlx::query(
451 "SELECT COUNT(*) as count FROM sessions WHERE expires_at IS NULL OR expires_at > NOW()",
452 )
453 .fetch_one(&self.pool)
454 .await
455 .map_err(|e| {
456 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
457 "Failed to count active sessions: {}",
458 e
459 )))
460 })?;
461
462 let count: i64 = row.try_get("count").map_err(|e| {
463 AuthError::Storage(crate::errors::StorageError::operation_failed(format!(
464 "Failed to parse session count: {}",
465 e
466 )))
467 })?;
468
469 Ok(count as u64)
470 }
471}
472
473