Skip to main content

authx_storage/sqlx/
postgres.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use sqlx::{PgPool, Row, postgres::PgPoolOptions};
4use uuid::Uuid;
5
6use authx_core::{
7    error::{AuthError, Result, StorageError},
8    models::{
9        ApiKey, AuditLog, AuthorizationCode, CreateApiKey, CreateAuditLog, CreateAuthorizationCode,
10        CreateCredential, CreateDeviceCode, CreateInvite, CreateOidcClient,
11        CreateOidcFederationProvider, CreateOidcToken, CreateOrg, CreateSession, CreateUser,
12        Credential, CredentialKind, DeviceCode, Invite, Membership, OAuthAccount, OidcClient,
13        OidcFederationProvider, OidcToken, OidcTokenType, Organization, Role, Session, UpdateUser,
14        UpsertOAuthAccount, User,
15    },
16};
17
18use crate::ports::{
19    ApiKeyRepository, AuditLogRepository, AuthorizationCodeRepository, CredentialRepository,
20    DeviceCodeRepository, InviteRepository, OAuthAccountRepository, OidcClientRepository,
21    OidcFederationProviderRepository, OidcTokenRepository, OrgRepository, SessionRepository,
22    UserRepository,
23};
24
25// ── Store ─────────────────────────────────────────────────────────────────────
26
27/// Postgres-backed storage adapter.
28///
29/// Wrap a [`PgPool`] and pass this to [`AuthxState::new`].
30#[derive(Clone)]
31pub struct PostgresStore {
32    pub pool: PgPool,
33}
34
35impl PostgresStore {
36    pub async fn connect(database_url: &str) -> std::result::Result<Self, sqlx::Error> {
37        let pool = PgPoolOptions::new()
38            .max_connections(10)
39            .connect(database_url)
40            .await?;
41        tracing::info!("postgres pool connected");
42        Ok(Self { pool })
43    }
44
45    pub fn from_pool(pool: PgPool) -> Self {
46        Self { pool }
47    }
48
49    pub async fn migrate(pool: &PgPool) -> std::result::Result<(), sqlx::migrate::MigrateError> {
50        sqlx::migrate!("src/sqlx/migrations").run(pool).await?;
51        tracing::info!("database migrations applied");
52        Ok(())
53    }
54}
55
56// ── Helpers ───────────────────────────────────────────────────────────────────
57
58fn db_err(e: sqlx::Error) -> AuthError {
59    match e {
60        sqlx::Error::RowNotFound => AuthError::Storage(StorageError::NotFound),
61        sqlx::Error::Database(ref dbe) if dbe.constraint().is_some() => {
62            AuthError::Storage(StorageError::Conflict(dbe.message().to_owned()))
63        }
64        other => AuthError::Storage(StorageError::Database(other.to_string())),
65    }
66}
67
68fn credential_kind_str(k: &CredentialKind) -> &'static str {
69    match k {
70        CredentialKind::Password => "password",
71        CredentialKind::Passkey => "passkey",
72        CredentialKind::Webauthn => "webauthn",
73        CredentialKind::OauthToken => "oauth_token",
74    }
75}
76
77fn credential_kind_from_str(s: &str) -> CredentialKind {
78    match s {
79        "passkey" => CredentialKind::Passkey,
80        "webauthn" => CredentialKind::Webauthn,
81        "oauth_token" => CredentialKind::OauthToken,
82        _ => CredentialKind::Password,
83    }
84}
85
86fn map_user(r: &sqlx::postgres::PgRow) -> User {
87    User {
88        id: r.get("id"),
89        email: r.get("email"),
90        email_verified: r.get("email_verified"),
91        username: r.get("username"),
92        created_at: r.get("created_at"),
93        updated_at: r.get("updated_at"),
94        metadata: r.get::<serde_json::Value, _>("metadata"),
95    }
96}
97
98fn map_session(r: &sqlx::postgres::PgRow) -> Session {
99    Session {
100        id: r.get("id"),
101        user_id: r.get("user_id"),
102        token_hash: r.get("token_hash"),
103        device_info: r.get::<serde_json::Value, _>("device_info"),
104        ip_address: r.get("ip_address"),
105        org_id: r.get("org_id"),
106        expires_at: r.get("expires_at"),
107        created_at: r.get("created_at"),
108    }
109}
110
111fn map_audit_log(r: &sqlx::postgres::PgRow) -> AuditLog {
112    let ip_address: String = r.get("ip_address");
113    AuditLog {
114        id: r.get("id"),
115        user_id: r.get("user_id"),
116        org_id: r.get("org_id"),
117        action: r.get("action"),
118        resource_type: r.get("resource_type"),
119        resource_id: r.get("resource_id"),
120        ip_address: if ip_address.is_empty() {
121            None
122        } else {
123            Some(ip_address)
124        },
125        metadata: r.get::<serde_json::Value, _>("metadata"),
126        created_at: r.get("created_at"),
127    }
128}
129
130// ── UserRepository ────────────────────────────────────────────────────────────
131
132#[async_trait]
133impl UserRepository for PostgresStore {
134    async fn find_by_id(&self, id: Uuid) -> Result<Option<User>> {
135        let row = sqlx::query(
136            "SELECT id, email, email_verified, username, created_at, updated_at, metadata \
137             FROM authx_users WHERE id = $1",
138        )
139        .bind(id)
140        .fetch_optional(&self.pool)
141        .await
142        .map_err(db_err)?;
143        Ok(row.as_ref().map(map_user))
144    }
145
146    async fn find_by_email(&self, email: &str) -> Result<Option<User>> {
147        let row = sqlx::query(
148            "SELECT id, email, email_verified, username, created_at, updated_at, metadata \
149             FROM authx_users WHERE email = $1",
150        )
151        .bind(email)
152        .fetch_optional(&self.pool)
153        .await
154        .map_err(db_err)?;
155        Ok(row.as_ref().map(map_user))
156    }
157
158    async fn find_by_username(&self, username: &str) -> Result<Option<User>> {
159        let row = sqlx::query(
160            "SELECT id, email, email_verified, username, created_at, updated_at, metadata \
161             FROM authx_users WHERE username = $1",
162        )
163        .bind(username)
164        .fetch_optional(&self.pool)
165        .await
166        .map_err(db_err)?;
167        Ok(row.as_ref().map(map_user))
168    }
169
170    async fn list(&self, offset: u32, limit: u32) -> Result<Vec<User>> {
171        let rows = sqlx::query(
172            "SELECT id, email, email_verified, username, created_at, updated_at, metadata \
173             FROM authx_users ORDER BY created_at ASC LIMIT $1 OFFSET $2",
174        )
175        .bind(limit as i64)
176        .bind(offset as i64)
177        .fetch_all(&self.pool)
178        .await
179        .map_err(db_err)?;
180        Ok(rows.iter().map(map_user).collect())
181    }
182
183    async fn create(&self, data: CreateUser) -> Result<User> {
184        let meta = data.metadata.unwrap_or(serde_json::json!({}));
185        let row = sqlx::query(
186            "INSERT INTO authx_users (id, email, email_verified, username, metadata) \
187             VALUES ($1, $2, false, $3, $4) \
188             RETURNING id, email, email_verified, username, created_at, updated_at, metadata",
189        )
190        .bind(Uuid::new_v4())
191        .bind(&data.email)
192        .bind(data.username.as_deref())
193        .bind(&meta)
194        .fetch_one(&self.pool)
195        .await
196        .map_err(|e| {
197            if let sqlx::Error::Database(ref dbe) = e {
198                if dbe.constraint() == Some("authx_users_email_key") {
199                    return AuthError::EmailTaken;
200                }
201                if dbe.constraint() == Some("authx_users_username_key") {
202                    return AuthError::Storage(StorageError::Conflict(
203                        "username already taken".into(),
204                    ));
205                }
206            }
207            db_err(e)
208        })?;
209
210        tracing::debug!(email = %data.email, "user row inserted");
211        Ok(map_user(&row))
212    }
213
214    async fn update(&self, id: Uuid, data: UpdateUser) -> Result<User> {
215        let row = sqlx::query(
216            "UPDATE authx_users \
217             SET \
218               email          = COALESCE($2, email), \
219               email_verified = COALESCE($3, email_verified), \
220               username       = COALESCE($4, username), \
221               metadata       = COALESCE($5, metadata), \
222               updated_at     = NOW() \
223             WHERE id = $1 \
224             RETURNING id, email, email_verified, username, created_at, updated_at, metadata",
225        )
226        .bind(id)
227        .bind(data.email.as_deref())
228        .bind(data.email_verified)
229        .bind(data.username.as_deref())
230        .bind(data.metadata.as_ref())
231        .fetch_optional(&self.pool)
232        .await
233        .map_err(db_err)?
234        .ok_or(AuthError::UserNotFound)?;
235
236        tracing::debug!(user_id = %id, "user row updated");
237        Ok(map_user(&row))
238    }
239
240    async fn delete(&self, id: Uuid) -> Result<()> {
241        let result = sqlx::query("DELETE FROM authx_users WHERE id = $1")
242            .bind(id)
243            .execute(&self.pool)
244            .await
245            .map_err(db_err)?;
246
247        if result.rows_affected() == 0 {
248            return Err(AuthError::UserNotFound);
249        }
250        tracing::debug!(user_id = %id, "user row deleted");
251        Ok(())
252    }
253}
254
255// ── SessionRepository ─────────────────────────────────────────────────────────
256
257#[async_trait]
258impl SessionRepository for PostgresStore {
259    async fn create(&self, data: CreateSession) -> Result<Session> {
260        let row = sqlx::query(
261            "INSERT INTO authx_sessions \
262               (id, user_id, token_hash, device_info, ip_address, org_id, expires_at) \
263             VALUES ($1, $2, $3, $4, $5, $6, $7) \
264             RETURNING id, user_id, token_hash, device_info, ip_address, org_id, expires_at, created_at",
265        )
266        .bind(Uuid::new_v4())
267        .bind(data.user_id)
268        .bind(&data.token_hash)
269        .bind(&data.device_info)
270        .bind(&data.ip_address)
271        .bind(data.org_id)
272        .bind(data.expires_at)
273        .fetch_one(&self.pool)
274        .await
275        .map_err(db_err)?;
276
277        tracing::debug!(user_id = %data.user_id, "session row inserted");
278        Ok(map_session(&row))
279    }
280
281    async fn find_by_token_hash(&self, hash: &str) -> Result<Option<Session>> {
282        let row = sqlx::query(
283            "SELECT id, user_id, token_hash, device_info, ip_address, org_id, expires_at, created_at \
284             FROM authx_sessions WHERE token_hash = $1 AND expires_at > NOW()",
285        )
286        .bind(hash)
287        .fetch_optional(&self.pool)
288        .await
289        .map_err(db_err)?;
290        Ok(row.as_ref().map(map_session))
291    }
292
293    async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<Session>> {
294        let rows = sqlx::query(
295            "SELECT id, user_id, token_hash, device_info, ip_address, org_id, expires_at, created_at \
296             FROM authx_sessions WHERE user_id = $1",
297        )
298        .bind(user_id)
299        .fetch_all(&self.pool)
300        .await
301        .map_err(db_err)?;
302        Ok(rows.iter().map(map_session).collect())
303    }
304
305    async fn invalidate(&self, session_id: Uuid) -> Result<()> {
306        let result = sqlx::query("DELETE FROM authx_sessions WHERE id = $1")
307            .bind(session_id)
308            .execute(&self.pool)
309            .await
310            .map_err(db_err)?;
311
312        if result.rows_affected() == 0 {
313            return Err(AuthError::SessionNotFound);
314        }
315        tracing::debug!(session_id = %session_id, "session invalidated");
316        Ok(())
317    }
318
319    async fn invalidate_all_for_user(&self, user_id: Uuid) -> Result<()> {
320        sqlx::query("DELETE FROM authx_sessions WHERE user_id = $1")
321            .bind(user_id)
322            .execute(&self.pool)
323            .await
324            .map_err(db_err)?;
325        tracing::debug!(user_id = %user_id, "all user sessions invalidated");
326        Ok(())
327    }
328
329    async fn set_org(&self, session_id: Uuid, org_id: Option<Uuid>) -> Result<Session> {
330        let row = sqlx::query(
331            "UPDATE authx_sessions SET org_id = $2 WHERE id = $1 \
332             RETURNING id, user_id, token_hash, device_info, ip_address, org_id, expires_at, created_at",
333        )
334        .bind(session_id)
335        .bind(org_id)
336        .fetch_optional(&self.pool)
337        .await
338        .map_err(db_err)?
339        .ok_or(AuthError::SessionNotFound)?;
340
341        tracing::debug!(session_id = %session_id, "session org updated");
342        Ok(map_session(&row))
343    }
344}
345
346// ── CredentialRepository ──────────────────────────────────────────────────────
347
348#[async_trait]
349impl CredentialRepository for PostgresStore {
350    async fn create(&self, data: CreateCredential) -> Result<Credential> {
351        let kind_str = credential_kind_str(&data.kind);
352        let meta = data.metadata.unwrap_or(serde_json::json!({}));
353
354        let row = sqlx::query(
355            "INSERT INTO authx_credentials (id, user_id, kind, credential_hash, metadata) \
356             VALUES ($1, $2, $3::authx_credential_kind, $4, $5) \
357             RETURNING id, user_id, kind::text, credential_hash, metadata",
358        )
359        .bind(Uuid::new_v4())
360        .bind(data.user_id)
361        .bind(kind_str)
362        .bind(&data.credential_hash)
363        .bind(&meta)
364        .fetch_one(&self.pool)
365        .await
366        .map_err(db_err)?;
367
368        tracing::debug!(user_id = %data.user_id, kind = kind_str, "credential inserted");
369        Ok(Credential {
370            id: row.get("id"),
371            user_id: row.get("user_id"),
372            kind: credential_kind_from_str(row.get("kind")),
373            credential_hash: row.get("credential_hash"),
374            metadata: row.get::<serde_json::Value, _>("metadata"),
375        })
376    }
377
378    async fn find_password_hash(&self, user_id: Uuid) -> Result<Option<String>> {
379        let row = sqlx::query(
380            "SELECT credential_hash FROM authx_credentials \
381             WHERE user_id = $1 AND kind = 'password'::authx_credential_kind",
382        )
383        .bind(user_id)
384        .fetch_optional(&self.pool)
385        .await
386        .map_err(db_err)?;
387        Ok(row.map(|r| r.get("credential_hash")))
388    }
389
390    async fn find_by_user_and_kind(
391        &self,
392        user_id: Uuid,
393        kind: CredentialKind,
394    ) -> Result<Option<Credential>> {
395        let row = sqlx::query(
396            "SELECT id, user_id, kind::text, credential_hash, metadata \
397             FROM authx_credentials WHERE user_id = $1 AND kind = $2::authx_credential_kind",
398        )
399        .bind(user_id)
400        .bind(credential_kind_str(&kind))
401        .fetch_optional(&self.pool)
402        .await
403        .map_err(db_err)?;
404
405        Ok(row.map(|r| Credential {
406            id: r.get("id"),
407            user_id: r.get("user_id"),
408            kind: credential_kind_from_str(r.get("kind")),
409            credential_hash: r.get("credential_hash"),
410            metadata: r.get::<serde_json::Value, _>("metadata"),
411        }))
412    }
413
414    async fn delete_by_user_and_kind(&self, user_id: Uuid, kind: CredentialKind) -> Result<()> {
415        let result = sqlx::query(
416            "DELETE FROM authx_credentials WHERE user_id = $1 AND kind = $2::authx_credential_kind",
417        )
418        .bind(user_id)
419        .bind(credential_kind_str(&kind))
420        .execute(&self.pool)
421        .await
422        .map_err(db_err)?;
423
424        if result.rows_affected() == 0 {
425            return Err(AuthError::Storage(StorageError::NotFound));
426        }
427        Ok(())
428    }
429}
430
431// ── OrgRepository ─────────────────────────────────────────────────────────────
432
433fn map_org(r: &sqlx::postgres::PgRow) -> Organization {
434    Organization {
435        id: r.get("id"),
436        name: r.get("name"),
437        slug: r.get("slug"),
438        metadata: r.get::<serde_json::Value, _>("metadata"),
439        created_at: r.get("created_at"),
440    }
441}
442
443fn map_membership(r: &sqlx::postgres::PgRow) -> Membership {
444    Membership {
445        id: r.get("id"),
446        user_id: r.get("user_id"),
447        org_id: r.get("org_id"),
448        role: Role {
449            id: r.get("role_id"),
450            org_id: r.get("role_org_id"),
451            name: r.get("role_name"),
452            permissions: r.get::<Vec<String>, _>("permissions"),
453        },
454        created_at: r.get("created_at"),
455    }
456}
457
458#[async_trait]
459impl OrgRepository for PostgresStore {
460    async fn create(&self, data: CreateOrg) -> Result<Organization> {
461        let meta = data.metadata.unwrap_or(serde_json::json!({}));
462        let row = sqlx::query(
463            "INSERT INTO authx_orgs (id, name, slug, metadata) \
464             VALUES ($1, $2, $3, $4) \
465             RETURNING id, name, slug, metadata, created_at",
466        )
467        .bind(Uuid::new_v4())
468        .bind(&data.name)
469        .bind(&data.slug)
470        .bind(&meta)
471        .fetch_one(&self.pool)
472        .await
473        .map_err(|e| {
474            if let sqlx::Error::Database(ref dbe) = e
475                && dbe.constraint() == Some("authx_orgs_slug_key")
476            {
477                return AuthError::Storage(StorageError::Conflict(format!(
478                    "slug '{}' already taken",
479                    data.slug
480                )));
481            }
482            db_err(e)
483        })?;
484
485        tracing::debug!(slug = %data.slug, "org row inserted");
486        Ok(map_org(&row))
487    }
488
489    async fn find_by_id(&self, id: Uuid) -> Result<Option<Organization>> {
490        let row = sqlx::query(
491            "SELECT id, name, slug, metadata, created_at FROM authx_orgs WHERE id = $1",
492        )
493        .bind(id)
494        .fetch_optional(&self.pool)
495        .await
496        .map_err(db_err)?;
497        Ok(row.as_ref().map(map_org))
498    }
499
500    async fn find_by_slug(&self, slug: &str) -> Result<Option<Organization>> {
501        let row = sqlx::query(
502            "SELECT id, name, slug, metadata, created_at FROM authx_orgs WHERE slug = $1",
503        )
504        .bind(slug)
505        .fetch_optional(&self.pool)
506        .await
507        .map_err(db_err)?;
508        Ok(row.as_ref().map(map_org))
509    }
510
511    async fn add_member(&self, org_id: Uuid, user_id: Uuid, role_id: Uuid) -> Result<Membership> {
512        let role_row =
513            sqlx::query("SELECT id, org_id, name, permissions FROM authx_roles WHERE id = $1")
514                .bind(role_id)
515                .fetch_optional(&self.pool)
516                .await
517                .map_err(db_err)?
518                .ok_or(AuthError::Storage(StorageError::NotFound))?;
519
520        let role = Role {
521            id: role_row.get("id"),
522            org_id: role_row.get("org_id"),
523            name: role_row.get("name"),
524            permissions: role_row.get::<Vec<String>, _>("permissions"),
525        };
526
527        let row = sqlx::query(
528            "INSERT INTO authx_memberships (id, user_id, org_id, role_id) \
529             VALUES ($1, $2, $3, $4) \
530             RETURNING id, user_id, org_id, created_at",
531        )
532        .bind(Uuid::new_v4())
533        .bind(user_id)
534        .bind(org_id)
535        .bind(role_id)
536        .fetch_one(&self.pool)
537        .await
538        .map_err(db_err)?;
539
540        tracing::debug!(org_id = %org_id, user_id = %user_id, "member added");
541        Ok(Membership {
542            id: row.get("id"),
543            user_id: row.get("user_id"),
544            org_id: row.get("org_id"),
545            role,
546            created_at: row.get("created_at"),
547        })
548    }
549
550    async fn remove_member(&self, org_id: Uuid, user_id: Uuid) -> Result<()> {
551        let result =
552            sqlx::query("DELETE FROM authx_memberships WHERE org_id = $1 AND user_id = $2")
553                .bind(org_id)
554                .bind(user_id)
555                .execute(&self.pool)
556                .await
557                .map_err(db_err)?;
558
559        if result.rows_affected() == 0 {
560            return Err(AuthError::Storage(StorageError::NotFound));
561        }
562        Ok(())
563    }
564
565    async fn get_members(&self, org_id: Uuid) -> Result<Vec<Membership>> {
566        let rows = sqlx::query(
567            "SELECT m.id, m.user_id, m.org_id, m.created_at, \
568                    r.id AS role_id, r.org_id AS role_org_id, r.name AS role_name, r.permissions \
569             FROM authx_memberships m \
570             JOIN authx_roles r ON r.id = m.role_id \
571             WHERE m.org_id = $1",
572        )
573        .bind(org_id)
574        .fetch_all(&self.pool)
575        .await
576        .map_err(db_err)?;
577        Ok(rows.iter().map(map_membership).collect())
578    }
579
580    async fn find_roles(&self, org_id: Uuid) -> Result<Vec<Role>> {
581        let rows =
582            sqlx::query("SELECT id, org_id, name, permissions FROM authx_roles WHERE org_id = $1")
583                .bind(org_id)
584                .fetch_all(&self.pool)
585                .await
586                .map_err(db_err)?;
587        Ok(rows
588            .iter()
589            .map(|r| Role {
590                id: r.get("id"),
591                org_id: r.get("org_id"),
592                name: r.get("name"),
593                permissions: r.get::<Vec<String>, _>("permissions"),
594            })
595            .collect())
596    }
597
598    async fn create_role(
599        &self,
600        org_id: Uuid,
601        name: String,
602        permissions: Vec<String>,
603    ) -> Result<Role> {
604        let row = sqlx::query(
605            "INSERT INTO authx_roles (id, org_id, name, permissions) \
606             VALUES ($1, $2, $3, $4) \
607             RETURNING id, org_id, name, permissions",
608        )
609        .bind(Uuid::new_v4())
610        .bind(org_id)
611        .bind(&name)
612        .bind(&permissions)
613        .fetch_one(&self.pool)
614        .await
615        .map_err(db_err)?;
616
617        tracing::debug!(org_id = %org_id, name = %name, "role created");
618        Ok(Role {
619            id: row.get("id"),
620            org_id: row.get("org_id"),
621            name: row.get("name"),
622            permissions: row.get::<Vec<String>, _>("permissions"),
623        })
624    }
625
626    async fn update_member_role(
627        &self,
628        org_id: Uuid,
629        user_id: Uuid,
630        role_id: Uuid,
631    ) -> Result<Membership> {
632        sqlx::query("UPDATE authx_memberships SET role_id = $3 WHERE org_id = $1 AND user_id = $2")
633            .bind(org_id)
634            .bind(user_id)
635            .bind(role_id)
636            .execute(&self.pool)
637            .await
638            .map_err(db_err)?;
639
640        let rows = sqlx::query(
641            "SELECT m.id, m.user_id, m.org_id, m.created_at, \
642                    r.id AS role_id, r.org_id AS role_org_id, r.name AS role_name, r.permissions \
643             FROM authx_memberships m \
644             JOIN authx_roles r ON r.id = m.role_id \
645             WHERE m.org_id = $1 AND m.user_id = $2",
646        )
647        .bind(org_id)
648        .bind(user_id)
649        .fetch_optional(&self.pool)
650        .await
651        .map_err(db_err)?
652        .ok_or(AuthError::Storage(StorageError::NotFound))?;
653
654        Ok(map_membership(&rows))
655    }
656}
657
658// ── AuditLogRepository ────────────────────────────────────────────────────────
659
660#[async_trait]
661impl AuditLogRepository for PostgresStore {
662    async fn append(&self, entry: CreateAuditLog) -> Result<AuditLog> {
663        let meta = entry.metadata.unwrap_or(serde_json::json!({}));
664        let ip_address = entry.ip_address.unwrap_or_default();
665        let row = sqlx::query(
666            "INSERT INTO authx_audit_logs \
667               (id, user_id, org_id, action, resource_type, resource_id, ip_address, metadata) \
668             VALUES ($1, $2, $3, $4, $5, $6, $7, $8) \
669             RETURNING id, user_id, org_id, action, resource_type, resource_id, ip_address, metadata, created_at",
670        )
671        .bind(Uuid::new_v4())
672        .bind(entry.user_id)
673        .bind(entry.org_id)
674        .bind(&entry.action)
675        .bind(&entry.resource_type)
676        .bind(entry.resource_id.as_deref())
677        .bind(&ip_address)
678        .bind(&meta)
679        .fetch_one(&self.pool)
680        .await
681        .map_err(db_err)?;
682
683        tracing::debug!(action = %entry.action, "audit log appended");
684        Ok(map_audit_log(&row))
685    }
686
687    async fn find_by_user(&self, user_id: Uuid, limit: u32) -> Result<Vec<AuditLog>> {
688        let rows = sqlx::query(
689            "SELECT id, user_id, org_id, action, resource_type, resource_id, ip_address, metadata, created_at \
690             FROM authx_audit_logs WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2",
691        )
692        .bind(user_id)
693        .bind(limit as i64)
694        .fetch_all(&self.pool)
695        .await
696        .map_err(db_err)?;
697        Ok(rows.iter().map(map_audit_log).collect())
698    }
699
700    async fn find_by_org(&self, org_id: Uuid, limit: u32) -> Result<Vec<AuditLog>> {
701        let rows = sqlx::query(
702            "SELECT id, user_id, org_id, action, resource_type, resource_id, ip_address, metadata, created_at \
703             FROM authx_audit_logs WHERE org_id = $1 ORDER BY created_at DESC LIMIT $2",
704        )
705        .bind(org_id)
706        .bind(limit as i64)
707        .fetch_all(&self.pool)
708        .await
709        .map_err(db_err)?;
710        Ok(rows.iter().map(map_audit_log).collect())
711    }
712}
713
714// ── ApiKeyRepository ──────────────────────────────────────────────────────────
715
716fn map_api_key(r: &sqlx::postgres::PgRow) -> ApiKey {
717    ApiKey {
718        id: r.get("id"),
719        user_id: r.get("user_id"),
720        org_id: r.get("org_id"),
721        key_hash: r.get("key_hash"),
722        prefix: r.get("prefix"),
723        name: r.get("name"),
724        scopes: r.get::<Vec<String>, _>("scopes"),
725        expires_at: r.get("expires_at"),
726        last_used_at: r.get("last_used_at"),
727    }
728}
729
730#[async_trait]
731impl ApiKeyRepository for PostgresStore {
732    async fn create(&self, data: CreateApiKey) -> Result<ApiKey> {
733        let row = sqlx::query(
734            "INSERT INTO authx_api_keys \
735               (id, user_id, org_id, key_hash, prefix, name, scopes, expires_at) \
736             VALUES ($1, $2, $3, $4, $5, $6, $7, $8) \
737             RETURNING id, user_id, org_id, key_hash, prefix, name, scopes, expires_at, last_used_at",
738        )
739        .bind(Uuid::new_v4())
740        .bind(data.user_id)
741        .bind(data.org_id)
742        .bind(&data.key_hash)
743        .bind(&data.prefix)
744        .bind(&data.name)
745        .bind(&data.scopes)
746        .bind(data.expires_at)
747        .fetch_one(&self.pool)
748        .await
749        .map_err(db_err)?;
750
751        tracing::debug!(user_id = %data.user_id, "api key created");
752        Ok(map_api_key(&row))
753    }
754
755    async fn find_by_hash(&self, key_hash: &str) -> Result<Option<ApiKey>> {
756        let row = sqlx::query(
757            "SELECT id, user_id, org_id, key_hash, prefix, name, scopes, expires_at, last_used_at \
758             FROM authx_api_keys WHERE key_hash = $1",
759        )
760        .bind(key_hash)
761        .fetch_optional(&self.pool)
762        .await
763        .map_err(db_err)?;
764        Ok(row.as_ref().map(map_api_key))
765    }
766
767    async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<ApiKey>> {
768        let rows = sqlx::query(
769            "SELECT id, user_id, org_id, key_hash, prefix, name, scopes, expires_at, last_used_at \
770             FROM authx_api_keys WHERE user_id = $1 ORDER BY id",
771        )
772        .bind(user_id)
773        .fetch_all(&self.pool)
774        .await
775        .map_err(db_err)?;
776        Ok(rows.iter().map(map_api_key).collect())
777    }
778
779    async fn revoke(&self, key_id: Uuid, user_id: Uuid) -> Result<()> {
780        let result = sqlx::query("DELETE FROM authx_api_keys WHERE id = $1 AND user_id = $2")
781            .bind(key_id)
782            .bind(user_id)
783            .execute(&self.pool)
784            .await
785            .map_err(db_err)?;
786
787        if result.rows_affected() == 0 {
788            return Err(AuthError::Storage(StorageError::NotFound));
789        }
790        tracing::debug!(key_id = %key_id, "api key revoked");
791        Ok(())
792    }
793
794    async fn touch_last_used(&self, key_id: Uuid, at: DateTime<Utc>) -> Result<()> {
795        sqlx::query("UPDATE authx_api_keys SET last_used_at = $2 WHERE id = $1")
796            .bind(key_id)
797            .bind(at)
798            .execute(&self.pool)
799            .await
800            .map_err(db_err)?;
801        Ok(())
802    }
803}
804
805// ── OAuthAccountRepository ────────────────────────────────────────────────────
806
807fn map_oauth_account(r: &sqlx::postgres::PgRow) -> OAuthAccount {
808    OAuthAccount {
809        id: r.get("id"),
810        user_id: r.get("user_id"),
811        provider: r.get("provider"),
812        provider_user_id: r.get("provider_user_id"),
813        access_token_enc: r.get("access_token_enc"),
814        refresh_token_enc: r.get("refresh_token_enc"),
815        expires_at: r.get("expires_at"),
816    }
817}
818
819#[async_trait]
820impl OAuthAccountRepository for PostgresStore {
821    async fn upsert(&self, data: UpsertOAuthAccount) -> Result<OAuthAccount> {
822        let row = sqlx::query(
823            "INSERT INTO authx_oauth_accounts \
824               (id, user_id, provider, provider_user_id, access_token_enc, refresh_token_enc, expires_at) \
825             VALUES ($1, $2, $3, $4, $5, $6, $7) \
826             ON CONFLICT (provider, provider_user_id) DO UPDATE SET \
827               access_token_enc  = EXCLUDED.access_token_enc, \
828               refresh_token_enc = EXCLUDED.refresh_token_enc, \
829               expires_at        = EXCLUDED.expires_at \
830             RETURNING id, user_id, provider, provider_user_id, access_token_enc, refresh_token_enc, expires_at",
831        )
832        .bind(Uuid::new_v4())
833        .bind(data.user_id)
834        .bind(&data.provider)
835        .bind(&data.provider_user_id)
836        .bind(&data.access_token_enc)
837        .bind(data.refresh_token_enc.as_deref())
838        .bind(data.expires_at)
839        .fetch_one(&self.pool)
840        .await
841        .map_err(db_err)?;
842
843        tracing::debug!(provider = %data.provider, user_id = %data.user_id, "oauth account upserted");
844        Ok(map_oauth_account(&row))
845    }
846
847    async fn find_by_provider(
848        &self,
849        provider: &str,
850        provider_user_id: &str,
851    ) -> Result<Option<OAuthAccount>> {
852        let row = sqlx::query(
853            "SELECT id, user_id, provider, provider_user_id, access_token_enc, refresh_token_enc, expires_at \
854             FROM authx_oauth_accounts WHERE provider = $1 AND provider_user_id = $2",
855        )
856        .bind(provider)
857        .bind(provider_user_id)
858        .fetch_optional(&self.pool)
859        .await
860        .map_err(db_err)?;
861        Ok(row.as_ref().map(map_oauth_account))
862    }
863
864    async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<OAuthAccount>> {
865        let rows = sqlx::query(
866            "SELECT id, user_id, provider, provider_user_id, access_token_enc, refresh_token_enc, expires_at \
867             FROM authx_oauth_accounts WHERE user_id = $1",
868        )
869        .bind(user_id)
870        .fetch_all(&self.pool)
871        .await
872        .map_err(db_err)?;
873        Ok(rows.iter().map(map_oauth_account).collect())
874    }
875
876    async fn delete(&self, id: Uuid) -> Result<()> {
877        let result = sqlx::query("DELETE FROM authx_oauth_accounts WHERE id = $1")
878            .bind(id)
879            .execute(&self.pool)
880            .await
881            .map_err(db_err)?;
882
883        if result.rows_affected() == 0 {
884            return Err(AuthError::Storage(StorageError::NotFound));
885        }
886        Ok(())
887    }
888}
889
890// ── InviteRepository ──────────────────────────────────────────────────────────
891
892fn map_invite(r: &sqlx::postgres::PgRow) -> Invite {
893    Invite {
894        id: r.get("id"),
895        org_id: r.get("org_id"),
896        email: r.get("email"),
897        role_id: r.get("role_id"),
898        token_hash: r.get("token_hash"),
899        expires_at: r.get("expires_at"),
900        accepted_at: r.get("accepted_at"),
901    }
902}
903
904#[async_trait]
905impl InviteRepository for PostgresStore {
906    async fn create(&self, data: CreateInvite) -> Result<Invite> {
907        let row = sqlx::query(
908            "INSERT INTO authx_invites (id, org_id, email, role_id, token_hash, expires_at) \
909             VALUES ($1, $2, $3, $4, $5, $6) \
910             RETURNING id, org_id, email, role_id, token_hash, expires_at, accepted_at",
911        )
912        .bind(Uuid::new_v4())
913        .bind(data.org_id)
914        .bind(&data.email)
915        .bind(data.role_id)
916        .bind(&data.token_hash)
917        .bind(data.expires_at)
918        .fetch_one(&self.pool)
919        .await
920        .map_err(db_err)?;
921
922        tracing::debug!(org_id = %data.org_id, email = %data.email, "invite created");
923        Ok(map_invite(&row))
924    }
925
926    async fn find_by_token_hash(&self, hash: &str) -> Result<Option<Invite>> {
927        let row = sqlx::query(
928            "SELECT id, org_id, email, role_id, token_hash, expires_at, accepted_at \
929             FROM authx_invites WHERE token_hash = $1",
930        )
931        .bind(hash)
932        .fetch_optional(&self.pool)
933        .await
934        .map_err(db_err)?;
935        Ok(row.as_ref().map(map_invite))
936    }
937
938    async fn accept(&self, invite_id: Uuid) -> Result<Invite> {
939        let row = sqlx::query(
940            "UPDATE authx_invites SET accepted_at = NOW() WHERE id = $1 \
941             RETURNING id, org_id, email, role_id, token_hash, expires_at, accepted_at",
942        )
943        .bind(invite_id)
944        .fetch_optional(&self.pool)
945        .await
946        .map_err(db_err)?
947        .ok_or(AuthError::Storage(StorageError::NotFound))?;
948
949        Ok(map_invite(&row))
950    }
951
952    async fn delete_expired(&self) -> Result<u64> {
953        let result = sqlx::query(
954            "DELETE FROM authx_invites WHERE accepted_at IS NULL AND expires_at < NOW()",
955        )
956        .execute(&self.pool)
957        .await
958        .map_err(db_err)?;
959        Ok(result.rows_affected())
960    }
961}
962
963// ── OidcClientRepository ───────────────────────────────────────────────────────
964
965fn oidc_token_type_str(t: &OidcTokenType) -> &'static str {
966    match t {
967        OidcTokenType::Access => "access",
968        OidcTokenType::Refresh => "refresh",
969        OidcTokenType::DeviceAccess => "device_access",
970    }
971}
972
973fn oidc_token_type_from_str(s: &str) -> OidcTokenType {
974    match s {
975        "refresh" => OidcTokenType::Refresh,
976        "device_access" => OidcTokenType::DeviceAccess,
977        _ => OidcTokenType::Access,
978    }
979}
980
981fn map_oidc_client(r: &sqlx::postgres::PgRow) -> OidcClient {
982    OidcClient {
983        id: r.get("id"),
984        client_id: r.get("client_id"),
985        secret_hash: r.get("secret_hash"),
986        name: r.get("name"),
987        redirect_uris: r.get::<Vec<String>, _>("redirect_uris"),
988        grant_types: r.get::<Vec<String>, _>("grant_types"),
989        response_types: r.get::<Vec<String>, _>("response_types"),
990        allowed_scopes: r.get("allowed_scopes"),
991        created_at: r.get("created_at"),
992    }
993}
994
995#[async_trait]
996impl OidcClientRepository for PostgresStore {
997    async fn create(&self, data: CreateOidcClient) -> Result<OidcClient> {
998        let client_id = Uuid::new_v4().to_string();
999        let row = sqlx::query(
1000            "INSERT INTO authx_oidc_clients \
1001               (id, client_id, secret_hash, name, redirect_uris, grant_types, response_types, allowed_scopes) \
1002             VALUES ($1, $2, $3, $4, $5, $6, $7, $8) \
1003             RETURNING id, client_id, secret_hash, name, redirect_uris, grant_types, response_types, allowed_scopes, created_at",
1004        )
1005        .bind(Uuid::new_v4())
1006        .bind(&client_id)
1007        .bind(&data.secret_hash)
1008        .bind(&data.name)
1009        .bind(&data.redirect_uris)
1010        .bind(&data.grant_types)
1011        .bind(&data.response_types)
1012        .bind(&data.allowed_scopes)
1013        .fetch_one(&self.pool)
1014        .await
1015        .map_err(db_err)?;
1016
1017        tracing::debug!(client_id = %client_id, "oidc client created");
1018        Ok(map_oidc_client(&row))
1019    }
1020
1021    async fn find_by_client_id(&self, client_id: &str) -> Result<Option<OidcClient>> {
1022        let row = sqlx::query(
1023            "SELECT id, client_id, secret_hash, name, redirect_uris, grant_types, response_types, allowed_scopes, created_at \
1024             FROM authx_oidc_clients WHERE client_id = $1",
1025        )
1026        .bind(client_id)
1027        .fetch_optional(&self.pool)
1028        .await
1029        .map_err(db_err)?;
1030        Ok(row.as_ref().map(map_oidc_client))
1031    }
1032
1033    async fn list(&self, offset: u32, limit: u32) -> Result<Vec<OidcClient>> {
1034        let rows = sqlx::query(
1035            "SELECT id, client_id, secret_hash, name, redirect_uris, grant_types, response_types, allowed_scopes, created_at \
1036             FROM authx_oidc_clients ORDER BY created_at ASC LIMIT $1 OFFSET $2",
1037        )
1038        .bind(limit as i64)
1039        .bind(offset as i64)
1040        .fetch_all(&self.pool)
1041        .await
1042        .map_err(db_err)?;
1043        Ok(rows.iter().map(map_oidc_client).collect())
1044    }
1045}
1046
1047// ── AuthorizationCodeRepository ────────────────────────────────────────────────
1048
1049fn map_authorization_code(r: &sqlx::postgres::PgRow) -> AuthorizationCode {
1050    AuthorizationCode {
1051        id: r.get("id"),
1052        code_hash: r.get("code_hash"),
1053        client_id: r.get("client_id"),
1054        user_id: r.get("user_id"),
1055        redirect_uri: r.get("redirect_uri"),
1056        scope: r.get("scope"),
1057        nonce: r.get("nonce"),
1058        code_challenge: r.get("code_challenge"),
1059        expires_at: r.get("expires_at"),
1060        used: r.get("used"),
1061    }
1062}
1063
1064#[async_trait]
1065impl AuthorizationCodeRepository for PostgresStore {
1066    async fn create(&self, data: CreateAuthorizationCode) -> Result<AuthorizationCode> {
1067        let row = sqlx::query(
1068            "INSERT INTO authx_oidc_authorization_codes \
1069               (id, code_hash, client_id, user_id, redirect_uri, scope, nonce, code_challenge, expires_at) \
1070             VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
1071             RETURNING id, code_hash, client_id, user_id, redirect_uri, scope, nonce, code_challenge, expires_at, used",
1072        )
1073        .bind(Uuid::new_v4())
1074        .bind(&data.code_hash)
1075        .bind(&data.client_id)
1076        .bind(data.user_id)
1077        .bind(&data.redirect_uri)
1078        .bind(&data.scope)
1079        .bind(data.nonce.as_deref())
1080        .bind(data.code_challenge.as_deref())
1081        .bind(data.expires_at)
1082        .fetch_one(&self.pool)
1083        .await
1084        .map_err(db_err)?;
1085
1086        tracing::debug!(client_id = %data.client_id, "authorization code created");
1087        Ok(map_authorization_code(&row))
1088    }
1089
1090    async fn find_by_code_hash(&self, hash: &str) -> Result<Option<AuthorizationCode>> {
1091        let row = sqlx::query(
1092            "SELECT id, code_hash, client_id, user_id, redirect_uri, scope, nonce, code_challenge, expires_at, used \
1093             FROM authx_oidc_authorization_codes WHERE code_hash = $1 AND expires_at > NOW() AND used = false",
1094        )
1095        .bind(hash)
1096        .fetch_optional(&self.pool)
1097        .await
1098        .map_err(db_err)?;
1099        Ok(row.as_ref().map(map_authorization_code))
1100    }
1101
1102    async fn mark_used(&self, id: Uuid) -> Result<()> {
1103        let result =
1104            sqlx::query("UPDATE authx_oidc_authorization_codes SET used = true WHERE id = $1")
1105                .bind(id)
1106                .execute(&self.pool)
1107                .await
1108                .map_err(db_err)?;
1109
1110        if result.rows_affected() == 0 {
1111            return Err(AuthError::Storage(StorageError::NotFound));
1112        }
1113        Ok(())
1114    }
1115
1116    async fn delete_expired(&self) -> Result<u64> {
1117        let result =
1118            sqlx::query("DELETE FROM authx_oidc_authorization_codes WHERE expires_at < NOW()")
1119                .execute(&self.pool)
1120                .await
1121                .map_err(db_err)?;
1122        Ok(result.rows_affected())
1123    }
1124}
1125
1126// ── OidcTokenRepository ────────────────────────────────────────────────────────
1127
1128fn map_oidc_token(r: &sqlx::postgres::PgRow) -> OidcToken {
1129    OidcToken {
1130        id: r.get("id"),
1131        token_hash: r.get("token_hash"),
1132        client_id: r.get("client_id"),
1133        user_id: r.get("user_id"),
1134        scope: r.get("scope"),
1135        token_type: oidc_token_type_from_str(r.get::<&str, _>("token_type")),
1136        expires_at: r.get("expires_at"),
1137        revoked: r.get("revoked"),
1138        created_at: r.get("created_at"),
1139    }
1140}
1141
1142#[async_trait]
1143impl OidcTokenRepository for PostgresStore {
1144    async fn create(&self, data: CreateOidcToken) -> Result<OidcToken> {
1145        let row = sqlx::query(
1146            "INSERT INTO authx_oidc_tokens \
1147               (id, token_hash, client_id, user_id, scope, token_type, expires_at) \
1148             VALUES ($1, $2, $3, $4, $5, $6, $7) \
1149             RETURNING id, token_hash, client_id, user_id, scope, token_type, expires_at, revoked, created_at",
1150        )
1151        .bind(Uuid::new_v4())
1152        .bind(&data.token_hash)
1153        .bind(&data.client_id)
1154        .bind(data.user_id)
1155        .bind(&data.scope)
1156        .bind(oidc_token_type_str(&data.token_type))
1157        .bind(data.expires_at)
1158        .fetch_one(&self.pool)
1159        .await
1160        .map_err(db_err)?;
1161
1162        tracing::debug!(client_id = %data.client_id, "oidc token created");
1163        Ok(map_oidc_token(&row))
1164    }
1165
1166    async fn find_by_token_hash(&self, hash: &str) -> Result<Option<OidcToken>> {
1167        let row = sqlx::query(
1168            "SELECT id, token_hash, client_id, user_id, scope, token_type, expires_at, revoked, created_at \
1169             FROM authx_oidc_tokens WHERE token_hash = $1 AND revoked = false",
1170        )
1171        .bind(hash)
1172        .fetch_optional(&self.pool)
1173        .await
1174        .map_err(db_err)?;
1175
1176        if let Some(ref r) = row {
1177            let tok = map_oidc_token(r);
1178            if let Some(exp) = tok.expires_at
1179                && exp < Utc::now()
1180            {
1181                return Ok(None);
1182            }
1183        }
1184        Ok(row.as_ref().map(map_oidc_token))
1185    }
1186
1187    async fn revoke(&self, id: Uuid) -> Result<()> {
1188        let result = sqlx::query("UPDATE authx_oidc_tokens SET revoked = true WHERE id = $1")
1189            .bind(id)
1190            .execute(&self.pool)
1191            .await
1192            .map_err(db_err)?;
1193
1194        if result.rows_affected() == 0 {
1195            return Err(AuthError::Storage(StorageError::NotFound));
1196        }
1197        Ok(())
1198    }
1199
1200    async fn revoke_all_for_user_client(&self, user_id: Uuid, client_id: &str) -> Result<()> {
1201        sqlx::query(
1202            "UPDATE authx_oidc_tokens SET revoked = true WHERE user_id = $1 AND client_id = $2",
1203        )
1204        .bind(user_id)
1205        .bind(client_id)
1206        .execute(&self.pool)
1207        .await
1208        .map_err(db_err)?;
1209        Ok(())
1210    }
1211}
1212
1213// ── OidcFederationProviderRepository ──────────────────────────────────────────
1214
1215fn map_oidc_federation_provider(r: &sqlx::postgres::PgRow) -> OidcFederationProvider {
1216    let claim_mapping_json: serde_json::Value =
1217        r.try_get("claim_mapping").unwrap_or(serde_json::json!([]));
1218    let claim_mapping = serde_json::from_value(claim_mapping_json).unwrap_or_default();
1219    OidcFederationProvider {
1220        id: r.get("id"),
1221        name: r.get("name"),
1222        issuer: r.get("issuer"),
1223        client_id: r.get("client_id"),
1224        secret_enc: r.get("secret_enc"),
1225        scopes: r.get("scopes"),
1226        org_id: r.try_get("org_id").ok(),
1227        enabled: r.get("enabled"),
1228        created_at: r.get("created_at"),
1229        claim_mapping,
1230    }
1231}
1232
1233#[async_trait]
1234impl OidcFederationProviderRepository for PostgresStore {
1235    async fn create(&self, data: CreateOidcFederationProvider) -> Result<OidcFederationProvider> {
1236        let claim_mapping_json =
1237            serde_json::to_value(&data.claim_mapping).unwrap_or(serde_json::json!([]));
1238        let row = sqlx::query(
1239            "INSERT INTO authx_oidc_federation_providers \
1240               (id, name, issuer, client_id, secret_enc, scopes, org_id, claim_mapping, enabled) \
1241             VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true) \
1242             RETURNING id, name, issuer, client_id, secret_enc, scopes, org_id, claim_mapping, enabled, created_at",
1243        )
1244        .bind(Uuid::new_v4())
1245        .bind(&data.name)
1246        .bind(&data.issuer)
1247        .bind(&data.client_id)
1248        .bind(&data.secret_enc)
1249        .bind(&data.scopes)
1250        .bind(data.org_id)
1251        .bind(claim_mapping_json)
1252        .fetch_one(&self.pool)
1253        .await
1254        .map_err(db_err)?;
1255
1256        tracing::debug!(name = %data.name, "oidc federation provider created");
1257        Ok(map_oidc_federation_provider(&row))
1258    }
1259
1260    async fn find_by_id(&self, id: Uuid) -> Result<Option<OidcFederationProvider>> {
1261        let row = sqlx::query(
1262            "SELECT id, name, issuer, client_id, secret_enc, scopes, org_id, claim_mapping, enabled, created_at \
1263             FROM authx_oidc_federation_providers WHERE id = $1",
1264        )
1265        .bind(id)
1266        .fetch_optional(&self.pool)
1267        .await
1268        .map_err(db_err)?;
1269        Ok(row.as_ref().map(map_oidc_federation_provider))
1270    }
1271
1272    async fn find_by_name(&self, name: &str) -> Result<Option<OidcFederationProvider>> {
1273        let row = sqlx::query(
1274            "SELECT id, name, issuer, client_id, secret_enc, scopes, org_id, claim_mapping, enabled, created_at \
1275             FROM authx_oidc_federation_providers WHERE name = $1",
1276        )
1277        .bind(name)
1278        .fetch_optional(&self.pool)
1279        .await
1280        .map_err(db_err)?;
1281        Ok(row.as_ref().map(map_oidc_federation_provider))
1282    }
1283
1284    async fn list_enabled(&self) -> Result<Vec<OidcFederationProvider>> {
1285        let rows = sqlx::query(
1286            "SELECT id, name, issuer, client_id, secret_enc, scopes, org_id, claim_mapping, enabled, created_at \
1287             FROM authx_oidc_federation_providers WHERE enabled = true ORDER BY name",
1288        )
1289        .fetch_all(&self.pool)
1290        .await
1291        .map_err(db_err)?;
1292        Ok(rows.iter().map(map_oidc_federation_provider).collect())
1293    }
1294}
1295
1296// ── DeviceCodeRepository ─────────────────────────────────────────────────────
1297
1298const DEVICE_CODE_COLS: &str = "id, device_code_hash, user_code_hash, user_code, client_id, \
1299                                scope, expires_at, interval_secs, authorized, denied, user_id, \
1300                                last_polled_at";
1301
1302fn map_device_code(r: &sqlx::postgres::PgRow) -> DeviceCode {
1303    DeviceCode {
1304        id: r.get("id"),
1305        device_code_hash: r.get("device_code_hash"),
1306        user_code_hash: r.get("user_code_hash"),
1307        user_code: r.get("user_code"),
1308        client_id: r.get("client_id"),
1309        scope: r.get("scope"),
1310        expires_at: r.get("expires_at"),
1311        interval_secs: r.get::<i32, _>("interval_secs") as u32,
1312        authorized: r.get("authorized"),
1313        denied: r.get("denied"),
1314        user_id: r.get("user_id"),
1315        last_polled_at: r.get("last_polled_at"),
1316    }
1317}
1318
1319#[async_trait]
1320impl DeviceCodeRepository for PostgresStore {
1321    async fn create(&self, data: CreateDeviceCode) -> Result<DeviceCode> {
1322        let row = sqlx::query(&format!(
1323            "INSERT INTO authx_device_codes \
1324               (id, device_code_hash, user_code_hash, user_code, client_id, scope, expires_at, interval_secs) \
1325             VALUES ($1, $2, $3, $4, $5, $6, $7, $8) \
1326             RETURNING {DEVICE_CODE_COLS}"
1327        ))
1328        .bind(Uuid::new_v4())
1329        .bind(&data.device_code_hash)
1330        .bind(&data.user_code_hash)
1331        .bind(&data.user_code)
1332        .bind(&data.client_id)
1333        .bind(&data.scope)
1334        .bind(data.expires_at)
1335        .bind(data.interval_secs as i32)
1336        .fetch_one(&self.pool)
1337        .await
1338        .map_err(db_err)?;
1339
1340        tracing::debug!(client_id = %data.client_id, "device code created");
1341        Ok(map_device_code(&row))
1342    }
1343
1344    async fn find_by_device_code_hash(&self, hash: &str) -> Result<Option<DeviceCode>> {
1345        let row = sqlx::query(&format!(
1346            "SELECT {DEVICE_CODE_COLS} FROM authx_device_codes \
1347             WHERE device_code_hash = $1 AND expires_at > NOW()"
1348        ))
1349        .bind(hash)
1350        .fetch_optional(&self.pool)
1351        .await
1352        .map_err(db_err)?;
1353        Ok(row.as_ref().map(map_device_code))
1354    }
1355
1356    async fn find_by_user_code_hash(&self, hash: &str) -> Result<Option<DeviceCode>> {
1357        let row = sqlx::query(&format!(
1358            "SELECT {DEVICE_CODE_COLS} FROM authx_device_codes \
1359             WHERE user_code_hash = $1 AND expires_at > NOW() \
1360             AND authorized = false AND denied = false"
1361        ))
1362        .bind(hash)
1363        .fetch_optional(&self.pool)
1364        .await
1365        .map_err(db_err)?;
1366        Ok(row.as_ref().map(map_device_code))
1367    }
1368
1369    async fn authorize(&self, id: Uuid, user_id: Uuid) -> Result<()> {
1370        let result = sqlx::query(
1371            "UPDATE authx_device_codes SET authorized = true, user_id = $2 WHERE id = $1",
1372        )
1373        .bind(id)
1374        .bind(user_id)
1375        .execute(&self.pool)
1376        .await
1377        .map_err(db_err)?;
1378
1379        if result.rows_affected() == 0 {
1380            return Err(AuthError::Storage(StorageError::NotFound));
1381        }
1382        tracing::debug!(id = %id, "device code authorized");
1383        Ok(())
1384    }
1385
1386    async fn deny(&self, id: Uuid) -> Result<()> {
1387        let result = sqlx::query("UPDATE authx_device_codes SET denied = true WHERE id = $1")
1388            .bind(id)
1389            .execute(&self.pool)
1390            .await
1391            .map_err(db_err)?;
1392
1393        if result.rows_affected() == 0 {
1394            return Err(AuthError::Storage(StorageError::NotFound));
1395        }
1396        tracing::debug!(id = %id, "device code denied");
1397        Ok(())
1398    }
1399
1400    async fn update_last_polled(&self, id: Uuid, interval_secs: u32) -> Result<()> {
1401        sqlx::query(
1402            "UPDATE authx_device_codes SET last_polled_at = NOW(), interval_secs = $2 WHERE id = $1",
1403        )
1404        .bind(id)
1405        .bind(interval_secs as i32)
1406        .execute(&self.pool)
1407        .await
1408        .map_err(db_err)?;
1409        Ok(())
1410    }
1411
1412    async fn delete(&self, id: Uuid) -> Result<()> {
1413        sqlx::query("DELETE FROM authx_device_codes WHERE id = $1")
1414            .bind(id)
1415            .execute(&self.pool)
1416            .await
1417            .map_err(db_err)?;
1418        Ok(())
1419    }
1420
1421    async fn delete_expired(&self) -> Result<u64> {
1422        let result = sqlx::query("DELETE FROM authx_device_codes WHERE expires_at < NOW()")
1423            .execute(&self.pool)
1424            .await
1425            .map_err(db_err)?;
1426        Ok(result.rows_affected())
1427    }
1428
1429    async fn list_by_client(
1430        &self,
1431        client_id: &str,
1432        offset: u32,
1433        limit: u32,
1434    ) -> Result<Vec<DeviceCode>> {
1435        let rows = sqlx::query(&format!(
1436            "SELECT {DEVICE_CODE_COLS} FROM authx_device_codes \
1437             WHERE client_id = $1 ORDER BY expires_at DESC LIMIT $2 OFFSET $3"
1438        ))
1439        .bind(client_id)
1440        .bind(limit as i64)
1441        .bind(offset as i64)
1442        .fetch_all(&self.pool)
1443        .await
1444        .map_err(db_err)?;
1445        Ok(rows.iter().map(map_device_code).collect())
1446    }
1447}