Skip to main content

better_auth_core/adapters/
database.rs

1pub use super::traits::{
2    AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, PasskeyOps, SessionOps,
3    TwoFactorOps, UserOps, VerificationOps,
4};
5
6/// Database adapter trait for persistence.
7///
8/// Combines all entity-specific operation traits. Any type that implements
9/// all sub-traits (`UserOps`, `SessionOps`, etc.) automatically implements
10/// `DatabaseAdapter` via the blanket impl.
11///
12/// Use the sub-traits directly when you only need a subset of operations
13/// (e.g., a plugin that only accesses users and sessions).
14pub trait DatabaseAdapter:
15    UserOps
16    + SessionOps
17    + AccountOps
18    + VerificationOps
19    + OrganizationOps
20    + MemberOps
21    + InvitationOps
22    + TwoFactorOps
23    + ApiKeyOps
24    + PasskeyOps
25{
26}
27
28impl<T> DatabaseAdapter for T where
29    T: UserOps
30        + SessionOps
31        + AccountOps
32        + VerificationOps
33        + OrganizationOps
34        + MemberOps
35        + InvitationOps
36        + TwoFactorOps
37        + ApiKeyOps
38        + PasskeyOps
39{
40}
41
42#[cfg(feature = "sqlx-postgres")]
43pub mod sqlx_adapter {
44    use super::*;
45    use async_trait::async_trait;
46    use chrono::{DateTime, Utc};
47
48    use crate::entity::{
49        AuthAccount, AuthApiKey, AuthInvitation, AuthMember, AuthOrganization, AuthPasskey,
50        AuthSession, AuthTwoFactor, AuthUser, AuthVerification,
51    };
52    use crate::error::{AuthError, AuthResult};
53    use crate::types::{
54        Account, ApiKey, CreateAccount, CreateApiKey, CreateInvitation, CreateMember,
55        CreateOrganization, CreatePasskey, CreateSession, CreateTwoFactor, CreateUser,
56        CreateVerification, Invitation, InvitationStatus, Member, Organization, Passkey, Session,
57        TwoFactor, UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser, User, Verification,
58    };
59    use sqlx::PgPool;
60    use sqlx::postgres::PgRow;
61    use std::marker::PhantomData;
62    use uuid::Uuid;
63
64    /// Blanket trait combining all bounds needed for SQLx-based entity types.
65    ///
66    /// Any type that implements `sqlx::FromRow` plus the standard marker traits
67    /// automatically satisfies this bound. Custom entity types just need
68    /// `#[derive(sqlx::FromRow)]` (or a manual `FromRow` impl) alongside
69    /// their `Auth*` derive.
70    pub trait SqlxEntity:
71        for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
72    {
73    }
74
75    impl<T> SqlxEntity for T where
76        T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
77    {
78    }
79
80    type SqlxAdapterEntities<U, S, A, O, M, I, V, TF, AK, PK> = (U, S, A, O, M, I, V, TF, AK, PK);
81
82    /// PostgreSQL database adapter via SQLx.
83    ///
84    /// Generic over entity types — use default type parameters for the built-in
85    /// types, or supply your own custom structs that implement `Auth*` + `sqlx::FromRow`.
86    pub struct SqlxAdapter<
87        U = User,
88        S = Session,
89        A = Account,
90        O = Organization,
91        M = Member,
92        I = Invitation,
93        V = Verification,
94        TF = TwoFactor,
95        AK = ApiKey,
96        PK = Passkey,
97    > {
98        pool: PgPool,
99        #[allow(clippy::type_complexity)]
100        _phantom: PhantomData<SqlxAdapterEntities<U, S, A, O, M, I, V, TF, AK, PK>>,
101    }
102
103    /// Constructors for the default (built-in) entity types.
104    impl SqlxAdapter {
105        pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
106            let pool = PgPool::connect(database_url).await?;
107            Ok(Self {
108                pool,
109                _phantom: PhantomData,
110            })
111        }
112
113        pub async fn with_config(
114            database_url: &str,
115            config: PoolConfig,
116        ) -> Result<Self, sqlx::Error> {
117            let pool = sqlx::postgres::PgPoolOptions::new()
118                .max_connections(config.max_connections)
119                .min_connections(config.min_connections)
120                .acquire_timeout(config.acquire_timeout)
121                .idle_timeout(config.idle_timeout)
122                .max_lifetime(config.max_lifetime)
123                .connect(database_url)
124                .await?;
125            Ok(Self {
126                pool,
127                _phantom: PhantomData,
128            })
129        }
130    }
131
132    /// Methods available for all type parameterizations.
133    impl<U, S, A, O, M, I, V, TF, AK, PK> SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK> {
134        pub fn from_pool(pool: PgPool) -> Self {
135            Self {
136                pool,
137                _phantom: PhantomData,
138            }
139        }
140
141        pub async fn test_connection(&self) -> Result<(), sqlx::Error> {
142            sqlx::query("SELECT 1").execute(&self.pool).await?;
143            Ok(())
144        }
145
146        pub fn pool_stats(&self) -> PoolStats {
147            PoolStats {
148                size: self.pool.size(),
149                idle: self.pool.num_idle(),
150            }
151        }
152
153        pub async fn close(&self) {
154            self.pool.close().await;
155        }
156    }
157
158    #[derive(Debug, Clone)]
159    pub struct PoolConfig {
160        pub max_connections: u32,
161        pub min_connections: u32,
162        pub acquire_timeout: std::time::Duration,
163        pub idle_timeout: Option<std::time::Duration>,
164        pub max_lifetime: Option<std::time::Duration>,
165    }
166
167    impl Default for PoolConfig {
168        fn default() -> Self {
169            Self {
170                max_connections: 10,
171                min_connections: 0,
172                acquire_timeout: std::time::Duration::from_secs(30),
173                idle_timeout: Some(std::time::Duration::from_secs(600)),
174                max_lifetime: Some(std::time::Duration::from_secs(1800)),
175            }
176        }
177    }
178
179    #[derive(Debug, Clone)]
180    pub struct PoolStats {
181        pub size: u32,
182        pub idle: usize,
183    }
184
185    // -- UserOps --
186
187    #[async_trait]
188    impl<U, S, A, O, M, I, V, TF, AK, PK> UserOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
189    where
190        U: AuthUser + SqlxEntity,
191        S: AuthSession + SqlxEntity,
192        A: AuthAccount + SqlxEntity,
193        O: AuthOrganization + SqlxEntity,
194        M: AuthMember + SqlxEntity,
195        I: AuthInvitation + SqlxEntity,
196        V: AuthVerification + SqlxEntity,
197        TF: AuthTwoFactor + SqlxEntity,
198        AK: AuthApiKey + SqlxEntity,
199        PK: AuthPasskey + SqlxEntity,
200    {
201        type User = U;
202
203        async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
204            let id = create_user.id.unwrap_or_else(|| Uuid::new_v4().to_string());
205            let now = Utc::now();
206
207            let user = sqlx::query_as::<_, U>(
208                r#"
209                INSERT INTO users (id, email, name, image, email_verified, created_at, updated_at, metadata)
210                VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
211                RETURNING *
212                "#,
213            )
214            .bind(&id)
215            .bind(&create_user.email)
216            .bind(&create_user.name)
217            .bind(&create_user.image)
218            .bind(false)
219            .bind(now)
220            .bind(now)
221            .bind(sqlx::types::Json(create_user.metadata.unwrap_or(serde_json::json!({}))))
222            .fetch_one(&self.pool)
223            .await?;
224
225            Ok(user)
226        }
227
228        async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
229            let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE id = $1")
230                .bind(id)
231                .fetch_optional(&self.pool)
232                .await?;
233            Ok(user)
234        }
235
236        async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
237            let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE email = $1")
238                .bind(email)
239                .fetch_optional(&self.pool)
240                .await?;
241            Ok(user)
242        }
243
244        async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
245            let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE username = $1")
246                .bind(username)
247                .fetch_optional(&self.pool)
248                .await?;
249            Ok(user)
250        }
251
252        async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
253            let mut query = sqlx::QueryBuilder::new("UPDATE users SET updated_at = NOW()");
254            let mut has_updates = false;
255
256            if let Some(email) = &update.email {
257                query.push(", email = ");
258                query.push_bind(email);
259                has_updates = true;
260            }
261            if let Some(name) = &update.name {
262                query.push(", name = ");
263                query.push_bind(name);
264                has_updates = true;
265            }
266            if let Some(image) = &update.image {
267                query.push(", image = ");
268                query.push_bind(image);
269                has_updates = true;
270            }
271            if let Some(email_verified) = update.email_verified {
272                query.push(", email_verified = ");
273                query.push_bind(email_verified);
274                has_updates = true;
275            }
276            if let Some(metadata) = &update.metadata {
277                query.push(", metadata = ");
278                query.push_bind(sqlx::types::Json(metadata.clone()));
279                has_updates = true;
280            }
281
282            if !has_updates {
283                return self
284                    .get_user_by_id(id)
285                    .await?
286                    .ok_or(AuthError::UserNotFound);
287            }
288
289            query.push(" WHERE id = ");
290            query.push_bind(id);
291            query.push(" RETURNING *");
292
293            let user = query.build_query_as::<U>().fetch_one(&self.pool).await?;
294            Ok(user)
295        }
296
297        async fn delete_user(&self, id: &str) -> AuthResult<()> {
298            sqlx::query("DELETE FROM users WHERE id = $1")
299                .bind(id)
300                .execute(&self.pool)
301                .await?;
302            Ok(())
303        }
304    }
305
306    // -- SessionOps --
307
308    #[async_trait]
309    impl<U, S, A, O, M, I, V, TF, AK, PK> SessionOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
310    where
311        U: AuthUser + SqlxEntity,
312        S: AuthSession + SqlxEntity,
313        A: AuthAccount + SqlxEntity,
314        O: AuthOrganization + SqlxEntity,
315        M: AuthMember + SqlxEntity,
316        I: AuthInvitation + SqlxEntity,
317        V: AuthVerification + SqlxEntity,
318        TF: AuthTwoFactor + SqlxEntity,
319        AK: AuthApiKey + SqlxEntity,
320        PK: AuthPasskey + SqlxEntity,
321    {
322        type Session = S;
323
324        async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
325            let id = Uuid::new_v4().to_string();
326            let token = format!("session_{}", Uuid::new_v4());
327            let now = Utc::now();
328
329            let session = sqlx::query_as::<_, S>(
330                r#"
331                INSERT INTO sessions (id, user_id, token, expires_at, created_at, ip_address, user_agent, active)
332                VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
333                RETURNING *
334                "#,
335            )
336            .bind(&id)
337            .bind(&create_session.user_id)
338            .bind(&token)
339            .bind(create_session.expires_at)
340            .bind(now)
341            .bind(&create_session.ip_address)
342            .bind(&create_session.user_agent)
343            .bind(true)
344            .fetch_one(&self.pool)
345            .await?;
346
347            Ok(session)
348        }
349
350        async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
351            let session =
352                sqlx::query_as::<_, S>("SELECT * FROM sessions WHERE token = $1 AND active = true")
353                    .bind(token)
354                    .fetch_optional(&self.pool)
355                    .await?;
356            Ok(session)
357        }
358
359        async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
360            let sessions = sqlx::query_as::<_, S>(
361                "SELECT * FROM sessions WHERE user_id = $1 AND active = true ORDER BY created_at DESC",
362            )
363            .bind(user_id)
364            .fetch_all(&self.pool)
365            .await?;
366            Ok(sessions)
367        }
368
369        async fn update_session_expiry(
370            &self,
371            token: &str,
372            expires_at: DateTime<Utc>,
373        ) -> AuthResult<()> {
374            sqlx::query("UPDATE sessions SET expires_at = $1 WHERE token = $2 AND active = true")
375                .bind(expires_at)
376                .bind(token)
377                .execute(&self.pool)
378                .await?;
379            Ok(())
380        }
381
382        async fn delete_session(&self, token: &str) -> AuthResult<()> {
383            sqlx::query("DELETE FROM sessions WHERE token = $1")
384                .bind(token)
385                .execute(&self.pool)
386                .await?;
387            Ok(())
388        }
389
390        async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
391            sqlx::query("DELETE FROM sessions WHERE user_id = $1")
392                .bind(user_id)
393                .execute(&self.pool)
394                .await?;
395            Ok(())
396        }
397
398        async fn delete_expired_sessions(&self) -> AuthResult<usize> {
399            let result =
400                sqlx::query("DELETE FROM sessions WHERE expires_at < NOW() OR active = false")
401                    .execute(&self.pool)
402                    .await?;
403            Ok(result.rows_affected() as usize)
404        }
405
406        async fn update_session_active_organization(
407            &self,
408            token: &str,
409            organization_id: Option<&str>,
410        ) -> AuthResult<S> {
411            let session = sqlx::query_as::<_, S>(
412                "UPDATE sessions SET active_organization_id = $1, updated_at = NOW() WHERE token = $2 AND active = true RETURNING *",
413            )
414            .bind(organization_id)
415            .bind(token)
416            .fetch_one(&self.pool)
417            .await?;
418            Ok(session)
419        }
420    }
421
422    // -- AccountOps --
423
424    #[async_trait]
425    impl<U, S, A, O, M, I, V, TF, AK, PK> AccountOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
426    where
427        U: AuthUser + SqlxEntity,
428        S: AuthSession + SqlxEntity,
429        A: AuthAccount + SqlxEntity,
430        O: AuthOrganization + SqlxEntity,
431        M: AuthMember + SqlxEntity,
432        I: AuthInvitation + SqlxEntity,
433        V: AuthVerification + SqlxEntity,
434        TF: AuthTwoFactor + SqlxEntity,
435        AK: AuthApiKey + SqlxEntity,
436        PK: AuthPasskey + SqlxEntity,
437    {
438        type Account = A;
439
440        async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
441            let id = Uuid::new_v4().to_string();
442            let now = Utc::now();
443
444            let account = sqlx::query_as::<_, A>(
445                r#"
446                INSERT INTO accounts (id, account_id, provider_id, user_id, access_token, refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, scope, password, created_at, updated_at)
447                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
448                RETURNING *
449                "#,
450            )
451            .bind(&id)
452            .bind(&create_account.account_id)
453            .bind(&create_account.provider_id)
454            .bind(&create_account.user_id)
455            .bind(&create_account.access_token)
456            .bind(&create_account.refresh_token)
457            .bind(&create_account.id_token)
458            .bind(create_account.access_token_expires_at)
459            .bind(create_account.refresh_token_expires_at)
460            .bind(&create_account.scope)
461            .bind(&create_account.password)
462            .bind(now)
463            .bind(now)
464            .fetch_one(&self.pool)
465            .await?;
466
467            Ok(account)
468        }
469
470        async fn get_account(
471            &self,
472            provider: &str,
473            provider_account_id: &str,
474        ) -> AuthResult<Option<A>> {
475            let account = sqlx::query_as::<_, A>(
476                "SELECT * FROM accounts WHERE provider_id = $1 AND account_id = $2",
477            )
478            .bind(provider)
479            .bind(provider_account_id)
480            .fetch_optional(&self.pool)
481            .await?;
482            Ok(account)
483        }
484
485        async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
486            let accounts = sqlx::query_as::<_, A>(
487                "SELECT * FROM accounts WHERE user_id = $1 ORDER BY created_at DESC",
488            )
489            .bind(user_id)
490            .fetch_all(&self.pool)
491            .await?;
492            Ok(accounts)
493        }
494
495        async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<A> {
496            let mut query = sqlx::QueryBuilder::new("UPDATE accounts SET updated_at = NOW()");
497
498            if let Some(access_token) = &update.access_token {
499                query.push(", access_token = ");
500                query.push_bind(access_token);
501            }
502            if let Some(refresh_token) = &update.refresh_token {
503                query.push(", refresh_token = ");
504                query.push_bind(refresh_token);
505            }
506            if let Some(id_token) = &update.id_token {
507                query.push(", id_token = ");
508                query.push_bind(id_token);
509            }
510            if let Some(access_token_expires_at) = &update.access_token_expires_at {
511                query.push(", access_token_expires_at = ");
512                query.push_bind(access_token_expires_at);
513            }
514            if let Some(refresh_token_expires_at) = &update.refresh_token_expires_at {
515                query.push(", refresh_token_expires_at = ");
516                query.push_bind(refresh_token_expires_at);
517            }
518            if let Some(scope) = &update.scope {
519                query.push(", scope = ");
520                query.push_bind(scope);
521            }
522
523            query.push(" WHERE id = ");
524            query.push_bind(id);
525            query.push(" RETURNING *");
526
527            let account = query.build_query_as::<A>().fetch_one(&self.pool).await?;
528            Ok(account)
529        }
530
531        async fn delete_account(&self, id: &str) -> AuthResult<()> {
532            sqlx::query("DELETE FROM accounts WHERE id = $1")
533                .bind(id)
534                .execute(&self.pool)
535                .await?;
536            Ok(())
537        }
538    }
539
540    // -- VerificationOps --
541
542    #[async_trait]
543    impl<U, S, A, O, M, I, V, TF, AK, PK> VerificationOps
544        for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
545    where
546        U: AuthUser + SqlxEntity,
547        S: AuthSession + SqlxEntity,
548        A: AuthAccount + SqlxEntity,
549        O: AuthOrganization + SqlxEntity,
550        M: AuthMember + SqlxEntity,
551        I: AuthInvitation + SqlxEntity,
552        V: AuthVerification + SqlxEntity,
553        TF: AuthTwoFactor + SqlxEntity,
554        AK: AuthApiKey + SqlxEntity,
555        PK: AuthPasskey + SqlxEntity,
556    {
557        type Verification = V;
558
559        async fn create_verification(
560            &self,
561            create_verification: CreateVerification,
562        ) -> AuthResult<V> {
563            let id = Uuid::new_v4().to_string();
564            let now = Utc::now();
565
566            let verification = sqlx::query_as::<_, V>(
567                r#"
568                INSERT INTO verifications (id, identifier, value, expires_at, created_at, updated_at)
569                VALUES ($1, $2, $3, $4, $5, $6)
570                RETURNING *
571                "#,
572            )
573            .bind(&id)
574            .bind(&create_verification.identifier)
575            .bind(&create_verification.value)
576            .bind(create_verification.expires_at)
577            .bind(now)
578            .bind(now)
579            .fetch_one(&self.pool)
580            .await?;
581
582            Ok(verification)
583        }
584
585        async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
586            let verification = sqlx::query_as::<_, V>(
587                "SELECT * FROM verifications WHERE identifier = $1 AND value = $2 AND expires_at > NOW()",
588            )
589            .bind(identifier)
590            .bind(value)
591            .fetch_optional(&self.pool)
592            .await?;
593            Ok(verification)
594        }
595
596        async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
597            let verification = sqlx::query_as::<_, V>(
598                "SELECT * FROM verifications WHERE value = $1 AND expires_at > NOW()",
599            )
600            .bind(value)
601            .fetch_optional(&self.pool)
602            .await?;
603            Ok(verification)
604        }
605
606        async fn get_verification_by_identifier(&self, identifier: &str) -> AuthResult<Option<V>> {
607            let verification = sqlx::query_as::<_, V>(
608                "SELECT * FROM verifications WHERE identifier = $1 AND expires_at > NOW()",
609            )
610            .bind(identifier)
611            .fetch_optional(&self.pool)
612            .await?;
613            Ok(verification)
614        }
615
616        async fn consume_verification(
617            &self,
618            identifier: &str,
619            value: &str,
620        ) -> AuthResult<Option<V>> {
621            let verification = sqlx::query_as::<_, V>(
622                "DELETE FROM verifications WHERE id IN (
623                    SELECT id FROM verifications
624                    WHERE identifier = $1 AND value = $2 AND expires_at > NOW()
625                    ORDER BY created_at DESC
626                    LIMIT 1
627                ) RETURNING *",
628            )
629            .bind(identifier)
630            .bind(value)
631            .fetch_optional(&self.pool)
632            .await?;
633            Ok(verification)
634        }
635
636        async fn delete_verification(&self, id: &str) -> AuthResult<()> {
637            sqlx::query("DELETE FROM verifications WHERE id = $1")
638                .bind(id)
639                .execute(&self.pool)
640                .await?;
641            Ok(())
642        }
643
644        async fn delete_expired_verifications(&self) -> AuthResult<usize> {
645            let result = sqlx::query("DELETE FROM verifications WHERE expires_at < NOW()")
646                .execute(&self.pool)
647                .await?;
648            Ok(result.rows_affected() as usize)
649        }
650    }
651
652    // -- OrganizationOps --
653
654    #[async_trait]
655    impl<U, S, A, O, M, I, V, TF, AK, PK> OrganizationOps
656        for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
657    where
658        U: AuthUser + SqlxEntity,
659        S: AuthSession + SqlxEntity,
660        A: AuthAccount + SqlxEntity,
661        O: AuthOrganization + SqlxEntity,
662        M: AuthMember + SqlxEntity,
663        I: AuthInvitation + SqlxEntity,
664        V: AuthVerification + SqlxEntity,
665        TF: AuthTwoFactor + SqlxEntity,
666        AK: AuthApiKey + SqlxEntity,
667        PK: AuthPasskey + SqlxEntity,
668    {
669        type Organization = O;
670
671        async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
672            let id = create_org.id.unwrap_or_else(|| Uuid::new_v4().to_string());
673            let now = Utc::now();
674
675            let organization = sqlx::query_as::<_, O>(
676                r#"
677                INSERT INTO organization (id, name, slug, logo, metadata, created_at, updated_at)
678                VALUES ($1, $2, $3, $4, $5, $6, $7)
679                RETURNING *
680                "#,
681            )
682            .bind(&id)
683            .bind(&create_org.name)
684            .bind(&create_org.slug)
685            .bind(&create_org.logo)
686            .bind(sqlx::types::Json(
687                create_org.metadata.unwrap_or(serde_json::json!({})),
688            ))
689            .bind(now)
690            .bind(now)
691            .fetch_one(&self.pool)
692            .await?;
693
694            Ok(organization)
695        }
696
697        async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
698            let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE id = $1")
699                .bind(id)
700                .fetch_optional(&self.pool)
701                .await?;
702            Ok(organization)
703        }
704
705        async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
706            let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE slug = $1")
707                .bind(slug)
708                .fetch_optional(&self.pool)
709                .await?;
710            Ok(organization)
711        }
712
713        async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
714            let mut query = sqlx::QueryBuilder::new("UPDATE organization SET updated_at = NOW()");
715
716            if let Some(name) = &update.name {
717                query.push(", name = ");
718                query.push_bind(name);
719            }
720            if let Some(slug) = &update.slug {
721                query.push(", slug = ");
722                query.push_bind(slug);
723            }
724            if let Some(logo) = &update.logo {
725                query.push(", logo = ");
726                query.push_bind(logo);
727            }
728            if let Some(metadata) = &update.metadata {
729                query.push(", metadata = ");
730                query.push_bind(sqlx::types::Json(metadata.clone()));
731            }
732
733            query.push(" WHERE id = ");
734            query.push_bind(id);
735            query.push(" RETURNING *");
736
737            let organization = query.build_query_as::<O>().fetch_one(&self.pool).await?;
738            Ok(organization)
739        }
740
741        async fn delete_organization(&self, id: &str) -> AuthResult<()> {
742            sqlx::query("DELETE FROM organization WHERE id = $1")
743                .bind(id)
744                .execute(&self.pool)
745                .await?;
746            Ok(())
747        }
748
749        async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
750            let organizations = sqlx::query_as::<_, O>(
751                r#"
752                SELECT o.*
753                FROM organization o
754                INNER JOIN member m ON o.id = m.organization_id
755                WHERE m.user_id = $1
756                ORDER BY o.created_at DESC
757                "#,
758            )
759            .bind(user_id)
760            .fetch_all(&self.pool)
761            .await?;
762            Ok(organizations)
763        }
764    }
765
766    // -- MemberOps --
767
768    #[async_trait]
769    impl<U, S, A, O, M, I, V, TF, AK, PK> MemberOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
770    where
771        U: AuthUser + SqlxEntity,
772        S: AuthSession + SqlxEntity,
773        A: AuthAccount + SqlxEntity,
774        O: AuthOrganization + SqlxEntity,
775        M: AuthMember + SqlxEntity,
776        I: AuthInvitation + SqlxEntity,
777        V: AuthVerification + SqlxEntity,
778        TF: AuthTwoFactor + SqlxEntity,
779        AK: AuthApiKey + SqlxEntity,
780        PK: AuthPasskey + SqlxEntity,
781    {
782        type Member = M;
783
784        async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
785            let id = Uuid::new_v4().to_string();
786            let now = Utc::now();
787
788            let member = sqlx::query_as::<_, M>(
789                r#"
790                INSERT INTO member (id, organization_id, user_id, role, created_at)
791                VALUES ($1, $2, $3, $4, $5)
792                RETURNING *
793                "#,
794            )
795            .bind(&id)
796            .bind(&create_member.organization_id)
797            .bind(&create_member.user_id)
798            .bind(&create_member.role)
799            .bind(now)
800            .fetch_one(&self.pool)
801            .await?;
802
803            Ok(member)
804        }
805
806        async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
807            let member = sqlx::query_as::<_, M>(
808                "SELECT * FROM member WHERE organization_id = $1 AND user_id = $2",
809            )
810            .bind(organization_id)
811            .bind(user_id)
812            .fetch_optional(&self.pool)
813            .await?;
814            Ok(member)
815        }
816
817        async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
818            let member = sqlx::query_as::<_, M>("SELECT * FROM member WHERE id = $1")
819                .bind(id)
820                .fetch_optional(&self.pool)
821                .await?;
822            Ok(member)
823        }
824
825        async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
826            let member =
827                sqlx::query_as::<_, M>("UPDATE member SET role = $1 WHERE id = $2 RETURNING *")
828                    .bind(role)
829                    .bind(member_id)
830                    .fetch_one(&self.pool)
831                    .await?;
832            Ok(member)
833        }
834
835        async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
836            sqlx::query("DELETE FROM member WHERE id = $1")
837                .bind(member_id)
838                .execute(&self.pool)
839                .await?;
840            Ok(())
841        }
842
843        async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
844            let members = sqlx::query_as::<_, M>(
845                "SELECT * FROM member WHERE organization_id = $1 ORDER BY created_at ASC",
846            )
847            .bind(organization_id)
848            .fetch_all(&self.pool)
849            .await?;
850            Ok(members)
851        }
852
853        async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
854            let count: (i64,) =
855                sqlx::query_as("SELECT COUNT(*) FROM member WHERE organization_id = $1")
856                    .bind(organization_id)
857                    .fetch_one(&self.pool)
858                    .await?;
859            Ok(count.0 as usize)
860        }
861
862        async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
863            let count: (i64,) = sqlx::query_as(
864                "SELECT COUNT(*) FROM member WHERE organization_id = $1 AND role = 'owner'",
865            )
866            .bind(organization_id)
867            .fetch_one(&self.pool)
868            .await?;
869            Ok(count.0 as usize)
870        }
871    }
872
873    // -- InvitationOps --
874
875    #[async_trait]
876    impl<U, S, A, O, M, I, V, TF, AK, PK> InvitationOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
877    where
878        U: AuthUser + SqlxEntity,
879        S: AuthSession + SqlxEntity,
880        A: AuthAccount + SqlxEntity,
881        O: AuthOrganization + SqlxEntity,
882        M: AuthMember + SqlxEntity,
883        I: AuthInvitation + SqlxEntity,
884        V: AuthVerification + SqlxEntity,
885        TF: AuthTwoFactor + SqlxEntity,
886        AK: AuthApiKey + SqlxEntity,
887        PK: AuthPasskey + SqlxEntity,
888    {
889        type Invitation = I;
890
891        async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
892            let id = Uuid::new_v4().to_string();
893            let now = Utc::now();
894
895            let invitation = sqlx::query_as::<_, I>(
896                r#"
897                INSERT INTO invitation (id, organization_id, email, role, status, inviter_id, expires_at, created_at)
898                VALUES ($1, $2, $3, $4, 'pending', $5, $6, $7)
899                RETURNING *
900                "#,
901            )
902            .bind(&id)
903            .bind(&create_inv.organization_id)
904            .bind(&create_inv.email)
905            .bind(&create_inv.role)
906            .bind(&create_inv.inviter_id)
907            .bind(create_inv.expires_at)
908            .bind(now)
909            .fetch_one(&self.pool)
910            .await?;
911
912            Ok(invitation)
913        }
914
915        async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
916            let invitation = sqlx::query_as::<_, I>("SELECT * FROM invitation WHERE id = $1")
917                .bind(id)
918                .fetch_optional(&self.pool)
919                .await?;
920            Ok(invitation)
921        }
922
923        async fn get_pending_invitation(
924            &self,
925            organization_id: &str,
926            email: &str,
927        ) -> AuthResult<Option<I>> {
928            let invitation = sqlx::query_as::<_, I>(
929                "SELECT * FROM invitation WHERE organization_id = $1 AND LOWER(email) = LOWER($2) AND status = 'pending'",
930            )
931            .bind(organization_id)
932            .bind(email)
933            .fetch_optional(&self.pool)
934            .await?;
935            Ok(invitation)
936        }
937
938        async fn update_invitation_status(
939            &self,
940            id: &str,
941            status: InvitationStatus,
942        ) -> AuthResult<I> {
943            let invitation = sqlx::query_as::<_, I>(
944                "UPDATE invitation SET status = $1 WHERE id = $2 RETURNING *",
945            )
946            .bind(status.to_string())
947            .bind(id)
948            .fetch_one(&self.pool)
949            .await?;
950            Ok(invitation)
951        }
952
953        async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
954            let invitations = sqlx::query_as::<_, I>(
955                "SELECT * FROM invitation WHERE organization_id = $1 ORDER BY created_at DESC",
956            )
957            .bind(organization_id)
958            .fetch_all(&self.pool)
959            .await?;
960            Ok(invitations)
961        }
962
963        async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
964            let invitations = sqlx::query_as::<_, I>(
965                "SELECT * FROM invitation WHERE LOWER(email) = LOWER($1) AND status = 'pending' AND expires_at > NOW() ORDER BY created_at DESC",
966            )
967            .bind(email)
968            .fetch_all(&self.pool)
969            .await?;
970            Ok(invitations)
971        }
972    }
973
974    // -- TwoFactorOps --
975
976    #[async_trait]
977    impl<U, S, A, O, M, I, V, TF, AK, PK> TwoFactorOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
978    where
979        U: AuthUser + SqlxEntity,
980        S: AuthSession + SqlxEntity,
981        A: AuthAccount + SqlxEntity,
982        O: AuthOrganization + SqlxEntity,
983        M: AuthMember + SqlxEntity,
984        I: AuthInvitation + SqlxEntity,
985        V: AuthVerification + SqlxEntity,
986        TF: AuthTwoFactor + SqlxEntity,
987        AK: AuthApiKey + SqlxEntity,
988        PK: AuthPasskey + SqlxEntity,
989    {
990        type TwoFactor = TF;
991
992        async fn create_two_factor(&self, create: CreateTwoFactor) -> AuthResult<TF> {
993            let id = Uuid::new_v4().to_string();
994            let now = Utc::now();
995
996            let two_factor = sqlx::query_as::<_, TF>(
997                r#"
998                INSERT INTO two_factor (id, secret, backup_codes, user_id, created_at, updated_at)
999                VALUES ($1, $2, $3, $4, $5, $6)
1000                RETURNING *
1001                "#,
1002            )
1003            .bind(&id)
1004            .bind(&create.secret)
1005            .bind(&create.backup_codes)
1006            .bind(&create.user_id)
1007            .bind(now)
1008            .bind(now)
1009            .fetch_one(&self.pool)
1010            .await?;
1011
1012            Ok(two_factor)
1013        }
1014
1015        async fn get_two_factor_by_user_id(&self, user_id: &str) -> AuthResult<Option<TF>> {
1016            let two_factor = sqlx::query_as::<_, TF>("SELECT * FROM two_factor WHERE user_id = $1")
1017                .bind(user_id)
1018                .fetch_optional(&self.pool)
1019                .await?;
1020            Ok(two_factor)
1021        }
1022
1023        async fn update_two_factor_backup_codes(
1024            &self,
1025            user_id: &str,
1026            backup_codes: &str,
1027        ) -> AuthResult<TF> {
1028            let two_factor = sqlx::query_as::<_, TF>(
1029                "UPDATE two_factor SET backup_codes = $1, updated_at = NOW() WHERE user_id = $2 RETURNING *",
1030            )
1031            .bind(backup_codes)
1032            .bind(user_id)
1033            .fetch_one(&self.pool)
1034            .await?;
1035            Ok(two_factor)
1036        }
1037
1038        async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
1039            sqlx::query("DELETE FROM two_factor WHERE user_id = $1")
1040                .bind(user_id)
1041                .execute(&self.pool)
1042                .await?;
1043            Ok(())
1044        }
1045    }
1046
1047    // -- ApiKeyOps --
1048
1049    #[async_trait]
1050    impl<U, S, A, O, M, I, V, TF, AK, PK> ApiKeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1051    where
1052        U: AuthUser + SqlxEntity,
1053        S: AuthSession + SqlxEntity,
1054        A: AuthAccount + SqlxEntity,
1055        O: AuthOrganization + SqlxEntity,
1056        M: AuthMember + SqlxEntity,
1057        I: AuthInvitation + SqlxEntity,
1058        V: AuthVerification + SqlxEntity,
1059        TF: AuthTwoFactor + SqlxEntity,
1060        AK: AuthApiKey + SqlxEntity,
1061        PK: AuthPasskey + SqlxEntity,
1062    {
1063        type ApiKey = AK;
1064
1065        async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<AK> {
1066            let id = Uuid::new_v4().to_string();
1067            let now = Utc::now();
1068
1069            let api_key = sqlx::query_as::<_, AK>(
1070                r#"
1071                INSERT INTO api_keys (id, name, start, prefix, key, user_id, refill_interval, refill_amount,
1072                    enabled, rate_limit_enabled, rate_limit_time_window, rate_limit_max, remaining,
1073                    expires_at, created_at, updated_at, permissions, metadata)
1074                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
1075                    $14::timestamptz, $15, $16, $17, $18)
1076                RETURNING *
1077                "#,
1078            )
1079            .bind(&id)
1080            .bind(&input.name)
1081            .bind(&input.start)
1082            .bind(&input.prefix)
1083            .bind(&input.key_hash)
1084            .bind(&input.user_id)
1085            .bind(input.refill_interval)
1086            .bind(input.refill_amount)
1087            .bind(input.enabled)
1088            .bind(input.rate_limit_enabled)
1089            .bind(input.rate_limit_time_window)
1090            .bind(input.rate_limit_max)
1091            .bind(input.remaining)
1092            .bind(&input.expires_at)
1093            .bind(now)
1094            .bind(now)
1095            .bind(&input.permissions)
1096            .bind(&input.metadata)
1097            .fetch_one(&self.pool)
1098            .await?;
1099
1100            Ok(api_key)
1101        }
1102
1103        async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<AK>> {
1104            let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE id = $1")
1105                .bind(id)
1106                .fetch_optional(&self.pool)
1107                .await?;
1108            Ok(api_key)
1109        }
1110
1111        async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<AK>> {
1112            let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE key = $1")
1113                .bind(hash)
1114                .fetch_optional(&self.pool)
1115                .await?;
1116            Ok(api_key)
1117        }
1118
1119        async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<AK>> {
1120            let keys = sqlx::query_as::<_, AK>(
1121                "SELECT * FROM api_keys WHERE user_id = $1 ORDER BY created_at DESC",
1122            )
1123            .bind(user_id)
1124            .fetch_all(&self.pool)
1125            .await?;
1126            Ok(keys)
1127        }
1128
1129        async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<AK> {
1130            let mut query = sqlx::QueryBuilder::new("UPDATE api_keys SET updated_at = NOW()");
1131
1132            if let Some(name) = &update.name {
1133                query.push(", name = ");
1134                query.push_bind(name);
1135            }
1136            if let Some(enabled) = update.enabled {
1137                query.push(", enabled = ");
1138                query.push_bind(enabled);
1139            }
1140            if let Some(remaining) = update.remaining {
1141                query.push(", remaining = ");
1142                query.push_bind(remaining);
1143            }
1144            if let Some(rate_limit_enabled) = update.rate_limit_enabled {
1145                query.push(", rate_limit_enabled = ");
1146                query.push_bind(rate_limit_enabled);
1147            }
1148            if let Some(rate_limit_time_window) = update.rate_limit_time_window {
1149                query.push(", rate_limit_time_window = ");
1150                query.push_bind(rate_limit_time_window);
1151            }
1152            if let Some(rate_limit_max) = update.rate_limit_max {
1153                query.push(", rate_limit_max = ");
1154                query.push_bind(rate_limit_max);
1155            }
1156            if let Some(refill_interval) = update.refill_interval {
1157                query.push(", refill_interval = ");
1158                query.push_bind(refill_interval);
1159            }
1160            if let Some(refill_amount) = update.refill_amount {
1161                query.push(", refill_amount = ");
1162                query.push_bind(refill_amount);
1163            }
1164            if let Some(permissions) = &update.permissions {
1165                query.push(", permissions = ");
1166                query.push_bind(permissions);
1167            }
1168            if let Some(metadata) = &update.metadata {
1169                query.push(", metadata = ");
1170                query.push_bind(metadata);
1171            }
1172
1173            query.push(" WHERE id = ");
1174            query.push_bind(id);
1175            query.push(" RETURNING *");
1176
1177            let api_key = query
1178                .build_query_as::<AK>()
1179                .fetch_one(&self.pool)
1180                .await
1181                .map_err(|err| match err {
1182                    sqlx::Error::RowNotFound => AuthError::not_found("API key not found"),
1183                    other => AuthError::from(other),
1184                })?;
1185            Ok(api_key)
1186        }
1187
1188        async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
1189            sqlx::query("DELETE FROM api_keys WHERE id = $1")
1190                .bind(id)
1191                .execute(&self.pool)
1192                .await?;
1193            Ok(())
1194        }
1195    }
1196
1197    // -- PasskeyOps --
1198
1199    #[async_trait]
1200    impl<U, S, A, O, M, I, V, TF, AK, PK> PasskeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1201    where
1202        U: AuthUser + SqlxEntity,
1203        S: AuthSession + SqlxEntity,
1204        A: AuthAccount + SqlxEntity,
1205        O: AuthOrganization + SqlxEntity,
1206        M: AuthMember + SqlxEntity,
1207        I: AuthInvitation + SqlxEntity,
1208        V: AuthVerification + SqlxEntity,
1209        TF: AuthTwoFactor + SqlxEntity,
1210        AK: AuthApiKey + SqlxEntity,
1211        PK: AuthPasskey + SqlxEntity,
1212    {
1213        type Passkey = PK;
1214
1215        async fn create_passkey(&self, input: CreatePasskey) -> AuthResult<PK> {
1216            let id = Uuid::new_v4().to_string();
1217            let now = Utc::now();
1218            let counter = i64::try_from(input.counter)
1219                .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1220
1221            let passkey = sqlx::query_as::<_, PK>(
1222                r#"
1223                INSERT INTO passkeys (id, name, public_key, user_id, credential_id, counter, device_type, backed_up, transports, created_at)
1224                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
1225                RETURNING *
1226                "#,
1227            )
1228            .bind(&id)
1229            .bind(&input.name)
1230            .bind(&input.public_key)
1231            .bind(&input.user_id)
1232            .bind(&input.credential_id)
1233            .bind(counter)
1234            .bind(&input.device_type)
1235            .bind(input.backed_up)
1236            .bind(&input.transports)
1237            .bind(now)
1238            .fetch_one(&self.pool)
1239            .await
1240            .map_err(|e| match e {
1241                sqlx::Error::Database(ref db_err) if db_err.is_unique_violation() => {
1242                    AuthError::conflict("A passkey with this credential ID already exists")
1243                }
1244                other => AuthError::from(other),
1245            })?;
1246
1247            Ok(passkey)
1248        }
1249
1250        async fn get_passkey_by_id(&self, id: &str) -> AuthResult<Option<PK>> {
1251            let passkey = sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE id = $1")
1252                .bind(id)
1253                .fetch_optional(&self.pool)
1254                .await?;
1255            Ok(passkey)
1256        }
1257
1258        async fn get_passkey_by_credential_id(
1259            &self,
1260            credential_id: &str,
1261        ) -> AuthResult<Option<PK>> {
1262            let passkey =
1263                sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE credential_id = $1")
1264                    .bind(credential_id)
1265                    .fetch_optional(&self.pool)
1266                    .await?;
1267            Ok(passkey)
1268        }
1269
1270        async fn list_passkeys_by_user(&self, user_id: &str) -> AuthResult<Vec<PK>> {
1271            let passkeys = sqlx::query_as::<_, PK>(
1272                "SELECT * FROM passkeys WHERE user_id = $1 ORDER BY created_at DESC",
1273            )
1274            .bind(user_id)
1275            .fetch_all(&self.pool)
1276            .await?;
1277            Ok(passkeys)
1278        }
1279
1280        async fn update_passkey_counter(&self, id: &str, counter: u64) -> AuthResult<PK> {
1281            let counter = i64::try_from(counter)
1282                .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1283            let passkey = sqlx::query_as::<_, PK>(
1284                "UPDATE passkeys SET counter = $2 WHERE id = $1 RETURNING *",
1285            )
1286            .bind(id)
1287            .bind(counter)
1288            .fetch_one(&self.pool)
1289            .await
1290            .map_err(|err| match err {
1291                sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1292                other => AuthError::from(other),
1293            })?;
1294            Ok(passkey)
1295        }
1296
1297        async fn update_passkey_name(&self, id: &str, name: &str) -> AuthResult<PK> {
1298            let passkey =
1299                sqlx::query_as::<_, PK>("UPDATE passkeys SET name = $2 WHERE id = $1 RETURNING *")
1300                    .bind(id)
1301                    .bind(name)
1302                    .fetch_one(&self.pool)
1303                    .await
1304                    .map_err(|err| match err {
1305                        sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1306                        other => AuthError::from(other),
1307                    })?;
1308            Ok(passkey)
1309        }
1310
1311        async fn delete_passkey(&self, id: &str) -> AuthResult<()> {
1312            sqlx::query("DELETE FROM passkeys WHERE id = $1")
1313                .bind(id)
1314                .execute(&self.pool)
1315                .await?;
1316            Ok(())
1317        }
1318    }
1319}
1320
1321#[cfg(feature = "sqlx-postgres")]
1322pub use sqlx_adapter::{SqlxAdapter, SqlxEntity};