Skip to main content

better_auth_core/adapters/
database.rs

1pub use super::traits::{
2    AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, PasskeyOps, SessionOps,
3    TwoFactorOps, UserOps, VerificationOps,
4};
5
6/// Database adapter trait for persistence.
7///
8/// Combines all entity-specific operation traits. Any type that implements
9/// all sub-traits (`UserOps`, `SessionOps`, etc.) automatically implements
10/// `DatabaseAdapter` via the blanket impl.
11///
12/// Use the sub-traits directly when you only need a subset of operations
13/// (e.g., a plugin that only accesses users and sessions).
14pub trait DatabaseAdapter:
15    UserOps
16    + SessionOps
17    + AccountOps
18    + VerificationOps
19    + OrganizationOps
20    + MemberOps
21    + InvitationOps
22    + TwoFactorOps
23    + ApiKeyOps
24    + PasskeyOps
25{
26}
27
28impl<T> DatabaseAdapter for T where
29    T: UserOps
30        + SessionOps
31        + AccountOps
32        + VerificationOps
33        + OrganizationOps
34        + MemberOps
35        + InvitationOps
36        + TwoFactorOps
37        + ApiKeyOps
38        + PasskeyOps
39{
40}
41
42#[cfg(feature = "sqlx-postgres")]
43pub mod sqlx_adapter {
44    use super::*;
45    use async_trait::async_trait;
46    use chrono::{DateTime, Utc};
47
48    use crate::entity::{
49        AuthAccount, AuthApiKey, AuthInvitation, AuthMember, AuthOrganization, AuthPasskey,
50        AuthSession, AuthTwoFactor, AuthUser, AuthVerification,
51    };
52    use crate::error::{AuthError, AuthResult};
53    use crate::types::{
54        Account, ApiKey, CreateAccount, CreateApiKey, CreateInvitation, CreateMember,
55        CreateOrganization, CreatePasskey, CreateSession, CreateTwoFactor, CreateUser,
56        CreateVerification, Invitation, InvitationStatus, Member, Organization, Passkey, Session,
57        TwoFactor, UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser, User, Verification,
58    };
59    use sqlx::PgPool;
60    use sqlx::postgres::PgRow;
61    use std::marker::PhantomData;
62    use uuid::Uuid;
63
64    /// Blanket trait combining all bounds needed for SQLx-based entity types.
65    ///
66    /// Any type that implements `sqlx::FromRow` plus the standard marker traits
67    /// automatically satisfies this bound. Custom entity types just need
68    /// `#[derive(sqlx::FromRow)]` (or a manual `FromRow` impl) alongside
69    /// their `Auth*` derive.
70    pub trait SqlxEntity:
71        for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
72    {
73    }
74
75    impl<T> SqlxEntity for T where
76        T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
77    {
78    }
79
80    /// PostgreSQL database adapter via SQLx.
81    ///
82    /// Generic over entity types — use default type parameters for the built-in
83    /// types, or supply your own custom structs that implement `Auth*` + `sqlx::FromRow`.
84    pub struct SqlxAdapter<
85        U = User,
86        S = Session,
87        A = Account,
88        O = Organization,
89        M = Member,
90        I = Invitation,
91        V = Verification,
92        TF = TwoFactor,
93        AK = ApiKey,
94        PK = Passkey,
95    > {
96        pool: PgPool,
97        _phantom: PhantomData<(U, S, A, O, M, I, V, TF, AK, PK)>,
98    }
99
100    /// Constructors for the default (built-in) entity types.
101    impl SqlxAdapter {
102        pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
103            let pool = PgPool::connect(database_url).await?;
104            Ok(Self {
105                pool,
106                _phantom: PhantomData,
107            })
108        }
109
110        pub async fn with_config(
111            database_url: &str,
112            config: PoolConfig,
113        ) -> Result<Self, sqlx::Error> {
114            let pool = sqlx::postgres::PgPoolOptions::new()
115                .max_connections(config.max_connections)
116                .min_connections(config.min_connections)
117                .acquire_timeout(config.acquire_timeout)
118                .idle_timeout(config.idle_timeout)
119                .max_lifetime(config.max_lifetime)
120                .connect(database_url)
121                .await?;
122            Ok(Self {
123                pool,
124                _phantom: PhantomData,
125            })
126        }
127    }
128
129    /// Methods available for all type parameterizations.
130    impl<U, S, A, O, M, I, V, TF, AK, PK> SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK> {
131        pub fn from_pool(pool: PgPool) -> Self {
132            Self {
133                pool,
134                _phantom: PhantomData,
135            }
136        }
137
138        pub async fn test_connection(&self) -> Result<(), sqlx::Error> {
139            sqlx::query("SELECT 1").execute(&self.pool).await?;
140            Ok(())
141        }
142
143        pub fn pool_stats(&self) -> PoolStats {
144            PoolStats {
145                size: self.pool.size(),
146                idle: self.pool.num_idle(),
147            }
148        }
149
150        pub async fn close(&self) {
151            self.pool.close().await;
152        }
153    }
154
155    #[derive(Debug, Clone)]
156    pub struct PoolConfig {
157        pub max_connections: u32,
158        pub min_connections: u32,
159        pub acquire_timeout: std::time::Duration,
160        pub idle_timeout: Option<std::time::Duration>,
161        pub max_lifetime: Option<std::time::Duration>,
162    }
163
164    impl Default for PoolConfig {
165        fn default() -> Self {
166            Self {
167                max_connections: 10,
168                min_connections: 0,
169                acquire_timeout: std::time::Duration::from_secs(30),
170                idle_timeout: Some(std::time::Duration::from_secs(600)),
171                max_lifetime: Some(std::time::Duration::from_secs(1800)),
172            }
173        }
174    }
175
176    #[derive(Debug, Clone)]
177    pub struct PoolStats {
178        pub size: u32,
179        pub idle: usize,
180    }
181
182    // -- UserOps --
183
184    #[async_trait]
185    impl<U, S, A, O, M, I, V, TF, AK, PK> UserOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
186    where
187        U: AuthUser + SqlxEntity,
188        S: AuthSession + SqlxEntity,
189        A: AuthAccount + SqlxEntity,
190        O: AuthOrganization + SqlxEntity,
191        M: AuthMember + SqlxEntity,
192        I: AuthInvitation + SqlxEntity,
193        V: AuthVerification + SqlxEntity,
194        TF: AuthTwoFactor + SqlxEntity,
195        AK: AuthApiKey + SqlxEntity,
196        PK: AuthPasskey + SqlxEntity,
197    {
198        type User = U;
199
200        async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
201            let id = create_user.id.unwrap_or_else(|| Uuid::new_v4().to_string());
202            let now = Utc::now();
203
204            let user = sqlx::query_as::<_, U>(
205                r#"
206                INSERT INTO users (id, email, name, image, email_verified, created_at, updated_at, metadata)
207                VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
208                RETURNING *
209                "#,
210            )
211            .bind(&id)
212            .bind(&create_user.email)
213            .bind(&create_user.name)
214            .bind(&create_user.image)
215            .bind(false)
216            .bind(&now)
217            .bind(&now)
218            .bind(sqlx::types::Json(create_user.metadata.unwrap_or(serde_json::json!({}))))
219            .fetch_one(&self.pool)
220            .await?;
221
222            Ok(user)
223        }
224
225        async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
226            let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE id = $1")
227                .bind(id)
228                .fetch_optional(&self.pool)
229                .await?;
230            Ok(user)
231        }
232
233        async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
234            let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE email = $1")
235                .bind(email)
236                .fetch_optional(&self.pool)
237                .await?;
238            Ok(user)
239        }
240
241        async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
242            let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE username = $1")
243                .bind(username)
244                .fetch_optional(&self.pool)
245                .await?;
246            Ok(user)
247        }
248
249        async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
250            let mut query = sqlx::QueryBuilder::new("UPDATE users SET updated_at = NOW()");
251            let mut has_updates = false;
252
253            if let Some(email) = &update.email {
254                query.push(", email = ");
255                query.push_bind(email);
256                has_updates = true;
257            }
258            if let Some(name) = &update.name {
259                query.push(", name = ");
260                query.push_bind(name);
261                has_updates = true;
262            }
263            if let Some(image) = &update.image {
264                query.push(", image = ");
265                query.push_bind(image);
266                has_updates = true;
267            }
268            if let Some(email_verified) = update.email_verified {
269                query.push(", email_verified = ");
270                query.push_bind(email_verified);
271                has_updates = true;
272            }
273            if let Some(metadata) = &update.metadata {
274                query.push(", metadata = ");
275                query.push_bind(sqlx::types::Json(metadata.clone()));
276                has_updates = true;
277            }
278
279            if !has_updates {
280                return self
281                    .get_user_by_id(id)
282                    .await?
283                    .ok_or(AuthError::UserNotFound);
284            }
285
286            query.push(" WHERE id = ");
287            query.push_bind(id);
288            query.push(" RETURNING *");
289
290            let user = query.build_query_as::<U>().fetch_one(&self.pool).await?;
291            Ok(user)
292        }
293
294        async fn delete_user(&self, id: &str) -> AuthResult<()> {
295            sqlx::query("DELETE FROM users WHERE id = $1")
296                .bind(id)
297                .execute(&self.pool)
298                .await?;
299            Ok(())
300        }
301    }
302
303    // -- SessionOps --
304
305    #[async_trait]
306    impl<U, S, A, O, M, I, V, TF, AK, PK> SessionOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
307    where
308        U: AuthUser + SqlxEntity,
309        S: AuthSession + SqlxEntity,
310        A: AuthAccount + SqlxEntity,
311        O: AuthOrganization + SqlxEntity,
312        M: AuthMember + SqlxEntity,
313        I: AuthInvitation + SqlxEntity,
314        V: AuthVerification + SqlxEntity,
315        TF: AuthTwoFactor + SqlxEntity,
316        AK: AuthApiKey + SqlxEntity,
317        PK: AuthPasskey + SqlxEntity,
318    {
319        type Session = S;
320
321        async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
322            let id = Uuid::new_v4().to_string();
323            let token = format!("session_{}", Uuid::new_v4());
324            let now = Utc::now();
325
326            let session = sqlx::query_as::<_, S>(
327                r#"
328                INSERT INTO sessions (id, user_id, token, expires_at, created_at, ip_address, user_agent, active)
329                VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
330                RETURNING *
331                "#,
332            )
333            .bind(&id)
334            .bind(&create_session.user_id)
335            .bind(&token)
336            .bind(&create_session.expires_at)
337            .bind(&now)
338            .bind(&create_session.ip_address)
339            .bind(&create_session.user_agent)
340            .bind(true)
341            .fetch_one(&self.pool)
342            .await?;
343
344            Ok(session)
345        }
346
347        async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
348            let session =
349                sqlx::query_as::<_, S>("SELECT * FROM sessions WHERE token = $1 AND active = true")
350                    .bind(token)
351                    .fetch_optional(&self.pool)
352                    .await?;
353            Ok(session)
354        }
355
356        async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
357            let sessions = sqlx::query_as::<_, S>(
358                "SELECT * FROM sessions WHERE user_id = $1 AND active = true ORDER BY created_at DESC",
359            )
360            .bind(user_id)
361            .fetch_all(&self.pool)
362            .await?;
363            Ok(sessions)
364        }
365
366        async fn update_session_expiry(
367            &self,
368            token: &str,
369            expires_at: DateTime<Utc>,
370        ) -> AuthResult<()> {
371            sqlx::query("UPDATE sessions SET expires_at = $1 WHERE token = $2 AND active = true")
372                .bind(&expires_at)
373                .bind(token)
374                .execute(&self.pool)
375                .await?;
376            Ok(())
377        }
378
379        async fn delete_session(&self, token: &str) -> AuthResult<()> {
380            sqlx::query("DELETE FROM sessions WHERE token = $1")
381                .bind(token)
382                .execute(&self.pool)
383                .await?;
384            Ok(())
385        }
386
387        async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
388            sqlx::query("DELETE FROM sessions WHERE user_id = $1")
389                .bind(user_id)
390                .execute(&self.pool)
391                .await?;
392            Ok(())
393        }
394
395        async fn delete_expired_sessions(&self) -> AuthResult<usize> {
396            let result =
397                sqlx::query("DELETE FROM sessions WHERE expires_at < NOW() OR active = false")
398                    .execute(&self.pool)
399                    .await?;
400            Ok(result.rows_affected() as usize)
401        }
402
403        async fn update_session_active_organization(
404            &self,
405            token: &str,
406            organization_id: Option<&str>,
407        ) -> AuthResult<S> {
408            let session = sqlx::query_as::<_, S>(
409                "UPDATE sessions SET active_organization_id = $1, updated_at = NOW() WHERE token = $2 AND active = true RETURNING *",
410            )
411            .bind(organization_id)
412            .bind(token)
413            .fetch_one(&self.pool)
414            .await?;
415            Ok(session)
416        }
417    }
418
419    // -- AccountOps --
420
421    #[async_trait]
422    impl<U, S, A, O, M, I, V, TF, AK, PK> AccountOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
423    where
424        U: AuthUser + SqlxEntity,
425        S: AuthSession + SqlxEntity,
426        A: AuthAccount + SqlxEntity,
427        O: AuthOrganization + SqlxEntity,
428        M: AuthMember + SqlxEntity,
429        I: AuthInvitation + SqlxEntity,
430        V: AuthVerification + SqlxEntity,
431        TF: AuthTwoFactor + SqlxEntity,
432        AK: AuthApiKey + SqlxEntity,
433        PK: AuthPasskey + SqlxEntity,
434    {
435        type Account = A;
436
437        async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
438            let id = Uuid::new_v4().to_string();
439            let now = Utc::now();
440
441            let account = sqlx::query_as::<_, A>(
442                r#"
443                INSERT INTO accounts (id, account_id, provider_id, user_id, access_token, refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, scope, password, created_at, updated_at)
444                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
445                RETURNING *
446                "#,
447            )
448            .bind(&id)
449            .bind(&create_account.account_id)
450            .bind(&create_account.provider_id)
451            .bind(&create_account.user_id)
452            .bind(&create_account.access_token)
453            .bind(&create_account.refresh_token)
454            .bind(&create_account.id_token)
455            .bind(&create_account.access_token_expires_at)
456            .bind(&create_account.refresh_token_expires_at)
457            .bind(&create_account.scope)
458            .bind(&create_account.password)
459            .bind(&now)
460            .bind(&now)
461            .fetch_one(&self.pool)
462            .await?;
463
464            Ok(account)
465        }
466
467        async fn get_account(
468            &self,
469            provider: &str,
470            provider_account_id: &str,
471        ) -> AuthResult<Option<A>> {
472            let account = sqlx::query_as::<_, A>(
473                "SELECT * FROM accounts WHERE provider_id = $1 AND account_id = $2",
474            )
475            .bind(provider)
476            .bind(provider_account_id)
477            .fetch_optional(&self.pool)
478            .await?;
479            Ok(account)
480        }
481
482        async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
483            let accounts = sqlx::query_as::<_, A>(
484                "SELECT * FROM accounts WHERE user_id = $1 ORDER BY created_at DESC",
485            )
486            .bind(user_id)
487            .fetch_all(&self.pool)
488            .await?;
489            Ok(accounts)
490        }
491
492        async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<A> {
493            let mut query = sqlx::QueryBuilder::new("UPDATE accounts SET updated_at = NOW()");
494
495            if let Some(access_token) = &update.access_token {
496                query.push(", access_token = ");
497                query.push_bind(access_token);
498            }
499            if let Some(refresh_token) = &update.refresh_token {
500                query.push(", refresh_token = ");
501                query.push_bind(refresh_token);
502            }
503            if let Some(id_token) = &update.id_token {
504                query.push(", id_token = ");
505                query.push_bind(id_token);
506            }
507            if let Some(access_token_expires_at) = &update.access_token_expires_at {
508                query.push(", access_token_expires_at = ");
509                query.push_bind(access_token_expires_at);
510            }
511            if let Some(refresh_token_expires_at) = &update.refresh_token_expires_at {
512                query.push(", refresh_token_expires_at = ");
513                query.push_bind(refresh_token_expires_at);
514            }
515            if let Some(scope) = &update.scope {
516                query.push(", scope = ");
517                query.push_bind(scope);
518            }
519
520            query.push(" WHERE id = ");
521            query.push_bind(id);
522            query.push(" RETURNING *");
523
524            let account = query.build_query_as::<A>().fetch_one(&self.pool).await?;
525            Ok(account)
526        }
527
528        async fn delete_account(&self, id: &str) -> AuthResult<()> {
529            sqlx::query("DELETE FROM accounts WHERE id = $1")
530                .bind(id)
531                .execute(&self.pool)
532                .await?;
533            Ok(())
534        }
535    }
536
537    // -- VerificationOps --
538
539    #[async_trait]
540    impl<U, S, A, O, M, I, V, TF, AK, PK> VerificationOps
541        for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
542    where
543        U: AuthUser + SqlxEntity,
544        S: AuthSession + SqlxEntity,
545        A: AuthAccount + SqlxEntity,
546        O: AuthOrganization + SqlxEntity,
547        M: AuthMember + SqlxEntity,
548        I: AuthInvitation + SqlxEntity,
549        V: AuthVerification + SqlxEntity,
550        TF: AuthTwoFactor + SqlxEntity,
551        AK: AuthApiKey + SqlxEntity,
552        PK: AuthPasskey + SqlxEntity,
553    {
554        type Verification = V;
555
556        async fn create_verification(
557            &self,
558            create_verification: CreateVerification,
559        ) -> AuthResult<V> {
560            let id = Uuid::new_v4().to_string();
561            let now = Utc::now();
562
563            let verification = sqlx::query_as::<_, V>(
564                r#"
565                INSERT INTO verifications (id, identifier, value, expires_at, created_at, updated_at)
566                VALUES ($1, $2, $3, $4, $5, $6)
567                RETURNING *
568                "#,
569            )
570            .bind(&id)
571            .bind(&create_verification.identifier)
572            .bind(&create_verification.value)
573            .bind(&create_verification.expires_at)
574            .bind(&now)
575            .bind(&now)
576            .fetch_one(&self.pool)
577            .await?;
578
579            Ok(verification)
580        }
581
582        async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
583            let verification = sqlx::query_as::<_, V>(
584                "SELECT * FROM verifications WHERE identifier = $1 AND value = $2 AND expires_at > NOW()",
585            )
586            .bind(identifier)
587            .bind(value)
588            .fetch_optional(&self.pool)
589            .await?;
590            Ok(verification)
591        }
592
593        async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
594            let verification = sqlx::query_as::<_, V>(
595                "SELECT * FROM verifications WHERE value = $1 AND expires_at > NOW()",
596            )
597            .bind(value)
598            .fetch_optional(&self.pool)
599            .await?;
600            Ok(verification)
601        }
602
603        async fn get_verification_by_identifier(&self, identifier: &str) -> AuthResult<Option<V>> {
604            let verification = sqlx::query_as::<_, V>(
605                "SELECT * FROM verifications WHERE identifier = $1 AND expires_at > NOW()",
606            )
607            .bind(identifier)
608            .fetch_optional(&self.pool)
609            .await?;
610            Ok(verification)
611        }
612
613        async fn consume_verification(
614            &self,
615            identifier: &str,
616            value: &str,
617        ) -> AuthResult<Option<V>> {
618            let verification = sqlx::query_as::<_, V>(
619                "DELETE FROM verifications WHERE id IN (
620                    SELECT id FROM verifications
621                    WHERE identifier = $1 AND value = $2 AND expires_at > NOW()
622                    ORDER BY created_at DESC
623                    LIMIT 1
624                ) RETURNING *",
625            )
626            .bind(identifier)
627            .bind(value)
628            .fetch_optional(&self.pool)
629            .await?;
630            Ok(verification)
631        }
632
633        async fn delete_verification(&self, id: &str) -> AuthResult<()> {
634            sqlx::query("DELETE FROM verifications WHERE id = $1")
635                .bind(id)
636                .execute(&self.pool)
637                .await?;
638            Ok(())
639        }
640
641        async fn delete_expired_verifications(&self) -> AuthResult<usize> {
642            let result = sqlx::query("DELETE FROM verifications WHERE expires_at < NOW()")
643                .execute(&self.pool)
644                .await?;
645            Ok(result.rows_affected() as usize)
646        }
647    }
648
649    // -- OrganizationOps --
650
651    #[async_trait]
652    impl<U, S, A, O, M, I, V, TF, AK, PK> OrganizationOps
653        for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
654    where
655        U: AuthUser + SqlxEntity,
656        S: AuthSession + SqlxEntity,
657        A: AuthAccount + SqlxEntity,
658        O: AuthOrganization + SqlxEntity,
659        M: AuthMember + SqlxEntity,
660        I: AuthInvitation + SqlxEntity,
661        V: AuthVerification + SqlxEntity,
662        TF: AuthTwoFactor + SqlxEntity,
663        AK: AuthApiKey + SqlxEntity,
664        PK: AuthPasskey + SqlxEntity,
665    {
666        type Organization = O;
667
668        async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
669            let id = create_org.id.unwrap_or_else(|| Uuid::new_v4().to_string());
670            let now = Utc::now();
671
672            let organization = sqlx::query_as::<_, O>(
673                r#"
674                INSERT INTO organization (id, name, slug, logo, metadata, created_at, updated_at)
675                VALUES ($1, $2, $3, $4, $5, $6, $7)
676                RETURNING *
677                "#,
678            )
679            .bind(&id)
680            .bind(&create_org.name)
681            .bind(&create_org.slug)
682            .bind(&create_org.logo)
683            .bind(sqlx::types::Json(
684                create_org.metadata.unwrap_or(serde_json::json!({})),
685            ))
686            .bind(&now)
687            .bind(&now)
688            .fetch_one(&self.pool)
689            .await?;
690
691            Ok(organization)
692        }
693
694        async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
695            let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE id = $1")
696                .bind(id)
697                .fetch_optional(&self.pool)
698                .await?;
699            Ok(organization)
700        }
701
702        async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
703            let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE slug = $1")
704                .bind(slug)
705                .fetch_optional(&self.pool)
706                .await?;
707            Ok(organization)
708        }
709
710        async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
711            let mut query = sqlx::QueryBuilder::new("UPDATE organization SET updated_at = NOW()");
712
713            if let Some(name) = &update.name {
714                query.push(", name = ");
715                query.push_bind(name);
716            }
717            if let Some(slug) = &update.slug {
718                query.push(", slug = ");
719                query.push_bind(slug);
720            }
721            if let Some(logo) = &update.logo {
722                query.push(", logo = ");
723                query.push_bind(logo);
724            }
725            if let Some(metadata) = &update.metadata {
726                query.push(", metadata = ");
727                query.push_bind(sqlx::types::Json(metadata.clone()));
728            }
729
730            query.push(" WHERE id = ");
731            query.push_bind(id);
732            query.push(" RETURNING *");
733
734            let organization = query.build_query_as::<O>().fetch_one(&self.pool).await?;
735            Ok(organization)
736        }
737
738        async fn delete_organization(&self, id: &str) -> AuthResult<()> {
739            sqlx::query("DELETE FROM organization WHERE id = $1")
740                .bind(id)
741                .execute(&self.pool)
742                .await?;
743            Ok(())
744        }
745
746        async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
747            let organizations = sqlx::query_as::<_, O>(
748                r#"
749                SELECT o.*
750                FROM organization o
751                INNER JOIN member m ON o.id = m.organization_id
752                WHERE m.user_id = $1
753                ORDER BY o.created_at DESC
754                "#,
755            )
756            .bind(user_id)
757            .fetch_all(&self.pool)
758            .await?;
759            Ok(organizations)
760        }
761    }
762
763    // -- MemberOps --
764
765    #[async_trait]
766    impl<U, S, A, O, M, I, V, TF, AK, PK> MemberOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
767    where
768        U: AuthUser + SqlxEntity,
769        S: AuthSession + SqlxEntity,
770        A: AuthAccount + SqlxEntity,
771        O: AuthOrganization + SqlxEntity,
772        M: AuthMember + SqlxEntity,
773        I: AuthInvitation + SqlxEntity,
774        V: AuthVerification + SqlxEntity,
775        TF: AuthTwoFactor + SqlxEntity,
776        AK: AuthApiKey + SqlxEntity,
777        PK: AuthPasskey + SqlxEntity,
778    {
779        type Member = M;
780
781        async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
782            let id = Uuid::new_v4().to_string();
783            let now = Utc::now();
784
785            let member = sqlx::query_as::<_, M>(
786                r#"
787                INSERT INTO member (id, organization_id, user_id, role, created_at)
788                VALUES ($1, $2, $3, $4, $5)
789                RETURNING *
790                "#,
791            )
792            .bind(&id)
793            .bind(&create_member.organization_id)
794            .bind(&create_member.user_id)
795            .bind(&create_member.role)
796            .bind(&now)
797            .fetch_one(&self.pool)
798            .await?;
799
800            Ok(member)
801        }
802
803        async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
804            let member = sqlx::query_as::<_, M>(
805                "SELECT * FROM member WHERE organization_id = $1 AND user_id = $2",
806            )
807            .bind(organization_id)
808            .bind(user_id)
809            .fetch_optional(&self.pool)
810            .await?;
811            Ok(member)
812        }
813
814        async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
815            let member = sqlx::query_as::<_, M>("SELECT * FROM member WHERE id = $1")
816                .bind(id)
817                .fetch_optional(&self.pool)
818                .await?;
819            Ok(member)
820        }
821
822        async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
823            let member =
824                sqlx::query_as::<_, M>("UPDATE member SET role = $1 WHERE id = $2 RETURNING *")
825                    .bind(role)
826                    .bind(member_id)
827                    .fetch_one(&self.pool)
828                    .await?;
829            Ok(member)
830        }
831
832        async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
833            sqlx::query("DELETE FROM member WHERE id = $1")
834                .bind(member_id)
835                .execute(&self.pool)
836                .await?;
837            Ok(())
838        }
839
840        async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
841            let members = sqlx::query_as::<_, M>(
842                "SELECT * FROM member WHERE organization_id = $1 ORDER BY created_at ASC",
843            )
844            .bind(organization_id)
845            .fetch_all(&self.pool)
846            .await?;
847            Ok(members)
848        }
849
850        async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
851            let count: (i64,) =
852                sqlx::query_as("SELECT COUNT(*) FROM member WHERE organization_id = $1")
853                    .bind(organization_id)
854                    .fetch_one(&self.pool)
855                    .await?;
856            Ok(count.0 as usize)
857        }
858
859        async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
860            let count: (i64,) = sqlx::query_as(
861                "SELECT COUNT(*) FROM member WHERE organization_id = $1 AND role = 'owner'",
862            )
863            .bind(organization_id)
864            .fetch_one(&self.pool)
865            .await?;
866            Ok(count.0 as usize)
867        }
868    }
869
870    // -- InvitationOps --
871
872    #[async_trait]
873    impl<U, S, A, O, M, I, V, TF, AK, PK> InvitationOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
874    where
875        U: AuthUser + SqlxEntity,
876        S: AuthSession + SqlxEntity,
877        A: AuthAccount + SqlxEntity,
878        O: AuthOrganization + SqlxEntity,
879        M: AuthMember + SqlxEntity,
880        I: AuthInvitation + SqlxEntity,
881        V: AuthVerification + SqlxEntity,
882        TF: AuthTwoFactor + SqlxEntity,
883        AK: AuthApiKey + SqlxEntity,
884        PK: AuthPasskey + SqlxEntity,
885    {
886        type Invitation = I;
887
888        async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
889            let id = Uuid::new_v4().to_string();
890            let now = Utc::now();
891
892            let invitation = sqlx::query_as::<_, I>(
893                r#"
894                INSERT INTO invitation (id, organization_id, email, role, status, inviter_id, expires_at, created_at)
895                VALUES ($1, $2, $3, $4, 'pending', $5, $6, $7)
896                RETURNING *
897                "#,
898            )
899            .bind(&id)
900            .bind(&create_inv.organization_id)
901            .bind(&create_inv.email)
902            .bind(&create_inv.role)
903            .bind(&create_inv.inviter_id)
904            .bind(&create_inv.expires_at)
905            .bind(&now)
906            .fetch_one(&self.pool)
907            .await?;
908
909            Ok(invitation)
910        }
911
912        async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
913            let invitation = sqlx::query_as::<_, I>("SELECT * FROM invitation WHERE id = $1")
914                .bind(id)
915                .fetch_optional(&self.pool)
916                .await?;
917            Ok(invitation)
918        }
919
920        async fn get_pending_invitation(
921            &self,
922            organization_id: &str,
923            email: &str,
924        ) -> AuthResult<Option<I>> {
925            let invitation = sqlx::query_as::<_, I>(
926                "SELECT * FROM invitation WHERE organization_id = $1 AND LOWER(email) = LOWER($2) AND status = 'pending'",
927            )
928            .bind(organization_id)
929            .bind(email)
930            .fetch_optional(&self.pool)
931            .await?;
932            Ok(invitation)
933        }
934
935        async fn update_invitation_status(
936            &self,
937            id: &str,
938            status: InvitationStatus,
939        ) -> AuthResult<I> {
940            let invitation = sqlx::query_as::<_, I>(
941                "UPDATE invitation SET status = $1 WHERE id = $2 RETURNING *",
942            )
943            .bind(status.to_string())
944            .bind(id)
945            .fetch_one(&self.pool)
946            .await?;
947            Ok(invitation)
948        }
949
950        async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
951            let invitations = sqlx::query_as::<_, I>(
952                "SELECT * FROM invitation WHERE organization_id = $1 ORDER BY created_at DESC",
953            )
954            .bind(organization_id)
955            .fetch_all(&self.pool)
956            .await?;
957            Ok(invitations)
958        }
959
960        async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
961            let invitations = sqlx::query_as::<_, I>(
962                "SELECT * FROM invitation WHERE LOWER(email) = LOWER($1) AND status = 'pending' AND expires_at > NOW() ORDER BY created_at DESC",
963            )
964            .bind(email)
965            .fetch_all(&self.pool)
966            .await?;
967            Ok(invitations)
968        }
969    }
970
971    // -- TwoFactorOps --
972
973    #[async_trait]
974    impl<U, S, A, O, M, I, V, TF, AK, PK> TwoFactorOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
975    where
976        U: AuthUser + SqlxEntity,
977        S: AuthSession + SqlxEntity,
978        A: AuthAccount + SqlxEntity,
979        O: AuthOrganization + SqlxEntity,
980        M: AuthMember + SqlxEntity,
981        I: AuthInvitation + SqlxEntity,
982        V: AuthVerification + SqlxEntity,
983        TF: AuthTwoFactor + SqlxEntity,
984        AK: AuthApiKey + SqlxEntity,
985        PK: AuthPasskey + SqlxEntity,
986    {
987        type TwoFactor = TF;
988
989        async fn create_two_factor(&self, create: CreateTwoFactor) -> AuthResult<TF> {
990            let id = Uuid::new_v4().to_string();
991            let now = Utc::now();
992
993            let two_factor = sqlx::query_as::<_, TF>(
994                r#"
995                INSERT INTO two_factor (id, secret, backup_codes, user_id, created_at, updated_at)
996                VALUES ($1, $2, $3, $4, $5, $6)
997                RETURNING *
998                "#,
999            )
1000            .bind(&id)
1001            .bind(&create.secret)
1002            .bind(&create.backup_codes)
1003            .bind(&create.user_id)
1004            .bind(&now)
1005            .bind(&now)
1006            .fetch_one(&self.pool)
1007            .await?;
1008
1009            Ok(two_factor)
1010        }
1011
1012        async fn get_two_factor_by_user_id(&self, user_id: &str) -> AuthResult<Option<TF>> {
1013            let two_factor = sqlx::query_as::<_, TF>("SELECT * FROM two_factor WHERE user_id = $1")
1014                .bind(user_id)
1015                .fetch_optional(&self.pool)
1016                .await?;
1017            Ok(two_factor)
1018        }
1019
1020        async fn update_two_factor_backup_codes(
1021            &self,
1022            user_id: &str,
1023            backup_codes: &str,
1024        ) -> AuthResult<TF> {
1025            let two_factor = sqlx::query_as::<_, TF>(
1026                "UPDATE two_factor SET backup_codes = $1, updated_at = NOW() WHERE user_id = $2 RETURNING *",
1027            )
1028            .bind(backup_codes)
1029            .bind(user_id)
1030            .fetch_one(&self.pool)
1031            .await?;
1032            Ok(two_factor)
1033        }
1034
1035        async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
1036            sqlx::query("DELETE FROM two_factor WHERE user_id = $1")
1037                .bind(user_id)
1038                .execute(&self.pool)
1039                .await?;
1040            Ok(())
1041        }
1042    }
1043
1044    // -- ApiKeyOps --
1045
1046    #[async_trait]
1047    impl<U, S, A, O, M, I, V, TF, AK, PK> ApiKeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1048    where
1049        U: AuthUser + SqlxEntity,
1050        S: AuthSession + SqlxEntity,
1051        A: AuthAccount + SqlxEntity,
1052        O: AuthOrganization + SqlxEntity,
1053        M: AuthMember + SqlxEntity,
1054        I: AuthInvitation + SqlxEntity,
1055        V: AuthVerification + SqlxEntity,
1056        TF: AuthTwoFactor + SqlxEntity,
1057        AK: AuthApiKey + SqlxEntity,
1058        PK: AuthPasskey + SqlxEntity,
1059    {
1060        type ApiKey = AK;
1061
1062        async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<AK> {
1063            let id = Uuid::new_v4().to_string();
1064            let now = Utc::now();
1065
1066            let api_key = sqlx::query_as::<_, AK>(
1067                r#"
1068                INSERT INTO api_keys (id, name, start, prefix, key, user_id, refill_interval, refill_amount,
1069                    enabled, rate_limit_enabled, rate_limit_time_window, rate_limit_max, remaining,
1070                    expires_at, created_at, updated_at, permissions, metadata)
1071                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
1072                    $14::timestamptz, $15, $16, $17, $18)
1073                RETURNING *
1074                "#,
1075            )
1076            .bind(&id)
1077            .bind(&input.name)
1078            .bind(&input.start)
1079            .bind(&input.prefix)
1080            .bind(&input.key_hash)
1081            .bind(&input.user_id)
1082            .bind(&input.refill_interval)
1083            .bind(&input.refill_amount)
1084            .bind(input.enabled)
1085            .bind(input.rate_limit_enabled)
1086            .bind(&input.rate_limit_time_window)
1087            .bind(&input.rate_limit_max)
1088            .bind(&input.remaining)
1089            .bind(&input.expires_at)
1090            .bind(&now)
1091            .bind(&now)
1092            .bind(&input.permissions)
1093            .bind(&input.metadata)
1094            .fetch_one(&self.pool)
1095            .await?;
1096
1097            Ok(api_key)
1098        }
1099
1100        async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<AK>> {
1101            let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE id = $1")
1102                .bind(id)
1103                .fetch_optional(&self.pool)
1104                .await?;
1105            Ok(api_key)
1106        }
1107
1108        async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<AK>> {
1109            let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE key = $1")
1110                .bind(hash)
1111                .fetch_optional(&self.pool)
1112                .await?;
1113            Ok(api_key)
1114        }
1115
1116        async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<AK>> {
1117            let keys = sqlx::query_as::<_, AK>(
1118                "SELECT * FROM api_keys WHERE user_id = $1 ORDER BY created_at DESC",
1119            )
1120            .bind(user_id)
1121            .fetch_all(&self.pool)
1122            .await?;
1123            Ok(keys)
1124        }
1125
1126        async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<AK> {
1127            let mut query = sqlx::QueryBuilder::new("UPDATE api_keys SET updated_at = NOW()");
1128
1129            if let Some(name) = &update.name {
1130                query.push(", name = ");
1131                query.push_bind(name);
1132            }
1133            if let Some(enabled) = update.enabled {
1134                query.push(", enabled = ");
1135                query.push_bind(enabled);
1136            }
1137            if let Some(remaining) = update.remaining {
1138                query.push(", remaining = ");
1139                query.push_bind(remaining);
1140            }
1141            if let Some(rate_limit_enabled) = update.rate_limit_enabled {
1142                query.push(", rate_limit_enabled = ");
1143                query.push_bind(rate_limit_enabled);
1144            }
1145            if let Some(rate_limit_time_window) = update.rate_limit_time_window {
1146                query.push(", rate_limit_time_window = ");
1147                query.push_bind(rate_limit_time_window);
1148            }
1149            if let Some(rate_limit_max) = update.rate_limit_max {
1150                query.push(", rate_limit_max = ");
1151                query.push_bind(rate_limit_max);
1152            }
1153            if let Some(refill_interval) = update.refill_interval {
1154                query.push(", refill_interval = ");
1155                query.push_bind(refill_interval);
1156            }
1157            if let Some(refill_amount) = update.refill_amount {
1158                query.push(", refill_amount = ");
1159                query.push_bind(refill_amount);
1160            }
1161            if let Some(permissions) = &update.permissions {
1162                query.push(", permissions = ");
1163                query.push_bind(permissions);
1164            }
1165            if let Some(metadata) = &update.metadata {
1166                query.push(", metadata = ");
1167                query.push_bind(metadata);
1168            }
1169
1170            query.push(" WHERE id = ");
1171            query.push_bind(id);
1172            query.push(" RETURNING *");
1173
1174            let api_key = query
1175                .build_query_as::<AK>()
1176                .fetch_one(&self.pool)
1177                .await
1178                .map_err(|err| match err {
1179                    sqlx::Error::RowNotFound => AuthError::not_found("API key not found"),
1180                    other => AuthError::from(other),
1181                })?;
1182            Ok(api_key)
1183        }
1184
1185        async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
1186            sqlx::query("DELETE FROM api_keys WHERE id = $1")
1187                .bind(id)
1188                .execute(&self.pool)
1189                .await?;
1190            Ok(())
1191        }
1192    }
1193
1194    // -- PasskeyOps --
1195
1196    #[async_trait]
1197    impl<U, S, A, O, M, I, V, TF, AK, PK> PasskeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1198    where
1199        U: AuthUser + SqlxEntity,
1200        S: AuthSession + SqlxEntity,
1201        A: AuthAccount + SqlxEntity,
1202        O: AuthOrganization + SqlxEntity,
1203        M: AuthMember + SqlxEntity,
1204        I: AuthInvitation + SqlxEntity,
1205        V: AuthVerification + SqlxEntity,
1206        TF: AuthTwoFactor + SqlxEntity,
1207        AK: AuthApiKey + SqlxEntity,
1208        PK: AuthPasskey + SqlxEntity,
1209    {
1210        type Passkey = PK;
1211
1212        async fn create_passkey(&self, input: CreatePasskey) -> AuthResult<PK> {
1213            let id = Uuid::new_v4().to_string();
1214            let now = Utc::now();
1215            let counter = i64::try_from(input.counter)
1216                .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1217
1218            let passkey = sqlx::query_as::<_, PK>(
1219                r#"
1220                INSERT INTO passkeys (id, name, public_key, user_id, credential_id, counter, device_type, backed_up, transports, created_at)
1221                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
1222                RETURNING *
1223                "#,
1224            )
1225            .bind(&id)
1226            .bind(&input.name)
1227            .bind(&input.public_key)
1228            .bind(&input.user_id)
1229            .bind(&input.credential_id)
1230            .bind(counter)
1231            .bind(&input.device_type)
1232            .bind(input.backed_up)
1233            .bind(&input.transports)
1234            .bind(now)
1235            .fetch_one(&self.pool)
1236            .await
1237            .map_err(|e| match e {
1238                sqlx::Error::Database(ref db_err) if db_err.is_unique_violation() => {
1239                    AuthError::conflict("A passkey with this credential ID already exists")
1240                }
1241                other => AuthError::from(other),
1242            })?;
1243
1244            Ok(passkey)
1245        }
1246
1247        async fn get_passkey_by_id(&self, id: &str) -> AuthResult<Option<PK>> {
1248            let passkey = sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE id = $1")
1249                .bind(id)
1250                .fetch_optional(&self.pool)
1251                .await?;
1252            Ok(passkey)
1253        }
1254
1255        async fn get_passkey_by_credential_id(
1256            &self,
1257            credential_id: &str,
1258        ) -> AuthResult<Option<PK>> {
1259            let passkey =
1260                sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE credential_id = $1")
1261                    .bind(credential_id)
1262                    .fetch_optional(&self.pool)
1263                    .await?;
1264            Ok(passkey)
1265        }
1266
1267        async fn list_passkeys_by_user(&self, user_id: &str) -> AuthResult<Vec<PK>> {
1268            let passkeys = sqlx::query_as::<_, PK>(
1269                "SELECT * FROM passkeys WHERE user_id = $1 ORDER BY created_at DESC",
1270            )
1271            .bind(user_id)
1272            .fetch_all(&self.pool)
1273            .await?;
1274            Ok(passkeys)
1275        }
1276
1277        async fn update_passkey_counter(&self, id: &str, counter: u64) -> AuthResult<PK> {
1278            let counter = i64::try_from(counter)
1279                .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1280            let passkey = sqlx::query_as::<_, PK>(
1281                "UPDATE passkeys SET counter = $2 WHERE id = $1 RETURNING *",
1282            )
1283            .bind(id)
1284            .bind(counter)
1285            .fetch_one(&self.pool)
1286            .await
1287            .map_err(|err| match err {
1288                sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1289                other => AuthError::from(other),
1290            })?;
1291            Ok(passkey)
1292        }
1293
1294        async fn update_passkey_name(&self, id: &str, name: &str) -> AuthResult<PK> {
1295            let passkey =
1296                sqlx::query_as::<_, PK>("UPDATE passkeys SET name = $2 WHERE id = $1 RETURNING *")
1297                    .bind(id)
1298                    .bind(name)
1299                    .fetch_one(&self.pool)
1300                    .await
1301                    .map_err(|err| match err {
1302                        sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1303                        other => AuthError::from(other),
1304                    })?;
1305            Ok(passkey)
1306        }
1307
1308        async fn delete_passkey(&self, id: &str) -> AuthResult<()> {
1309            sqlx::query("DELETE FROM passkeys WHERE id = $1")
1310                .bind(id)
1311                .execute(&self.pool)
1312                .await?;
1313            Ok(())
1314        }
1315    }
1316}
1317
1318#[cfg(feature = "sqlx-postgres")]
1319pub use sqlx_adapter::{SqlxAdapter, SqlxEntity};