1pub use super::traits::{
2 AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, PasskeyOps, SessionOps,
3 TwoFactorOps, UserOps, VerificationOps,
4};
5
6pub trait DatabaseAdapter:
15 UserOps
16 + SessionOps
17 + AccountOps
18 + VerificationOps
19 + OrganizationOps
20 + MemberOps
21 + InvitationOps
22 + TwoFactorOps
23 + ApiKeyOps
24 + PasskeyOps
25{
26}
27
28impl<T> DatabaseAdapter for T where
29 T: UserOps
30 + SessionOps
31 + AccountOps
32 + VerificationOps
33 + OrganizationOps
34 + MemberOps
35 + InvitationOps
36 + TwoFactorOps
37 + ApiKeyOps
38 + PasskeyOps
39{
40}
41
42#[cfg(feature = "sqlx-postgres")]
43pub mod sqlx_adapter {
44 use super::*;
45 use async_trait::async_trait;
46 use chrono::{DateTime, Utc};
47
48 use crate::entity::{
49 AuthAccount, AuthApiKey, AuthInvitation, AuthMember, AuthOrganization, AuthPasskey,
50 AuthSession, AuthTwoFactor, AuthUser, AuthVerification,
51 };
52 use crate::error::{AuthError, AuthResult};
53 use crate::types::{
54 Account, ApiKey, CreateAccount, CreateApiKey, CreateInvitation, CreateMember,
55 CreateOrganization, CreatePasskey, CreateSession, CreateTwoFactor, CreateUser,
56 CreateVerification, Invitation, InvitationStatus, Member, Organization, Passkey, Session,
57 TwoFactor, UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser, User, Verification,
58 };
59 use sqlx::PgPool;
60 use sqlx::postgres::PgRow;
61 use std::marker::PhantomData;
62 use uuid::Uuid;
63
64 pub trait SqlxEntity:
71 for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
72 {
73 }
74
75 impl<T> SqlxEntity for T where
76 T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
77 {
78 }
79
80 type SqlxAdapterEntities<U, S, A, O, M, I, V, TF, AK, PK> = (U, S, A, O, M, I, V, TF, AK, PK);
81
82 pub struct SqlxAdapter<
87 U = User,
88 S = Session,
89 A = Account,
90 O = Organization,
91 M = Member,
92 I = Invitation,
93 V = Verification,
94 TF = TwoFactor,
95 AK = ApiKey,
96 PK = Passkey,
97 > {
98 pool: PgPool,
99 #[allow(clippy::type_complexity)]
100 _phantom: PhantomData<SqlxAdapterEntities<U, S, A, O, M, I, V, TF, AK, PK>>,
101 }
102
103 impl SqlxAdapter {
105 pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
106 let pool = PgPool::connect(database_url).await?;
107 Ok(Self {
108 pool,
109 _phantom: PhantomData,
110 })
111 }
112
113 pub async fn with_config(
114 database_url: &str,
115 config: PoolConfig,
116 ) -> Result<Self, sqlx::Error> {
117 let pool = sqlx::postgres::PgPoolOptions::new()
118 .max_connections(config.max_connections)
119 .min_connections(config.min_connections)
120 .acquire_timeout(config.acquire_timeout)
121 .idle_timeout(config.idle_timeout)
122 .max_lifetime(config.max_lifetime)
123 .connect(database_url)
124 .await?;
125 Ok(Self {
126 pool,
127 _phantom: PhantomData,
128 })
129 }
130 }
131
132 impl<U, S, A, O, M, I, V, TF, AK, PK> SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK> {
134 pub fn from_pool(pool: PgPool) -> Self {
135 Self {
136 pool,
137 _phantom: PhantomData,
138 }
139 }
140
141 pub async fn test_connection(&self) -> Result<(), sqlx::Error> {
142 sqlx::query("SELECT 1").execute(&self.pool).await?;
143 Ok(())
144 }
145
146 pub fn pool_stats(&self) -> PoolStats {
147 PoolStats {
148 size: self.pool.size(),
149 idle: self.pool.num_idle(),
150 }
151 }
152
153 pub async fn close(&self) {
154 self.pool.close().await;
155 }
156 }
157
158 #[derive(Debug, Clone)]
159 pub struct PoolConfig {
160 pub max_connections: u32,
161 pub min_connections: u32,
162 pub acquire_timeout: std::time::Duration,
163 pub idle_timeout: Option<std::time::Duration>,
164 pub max_lifetime: Option<std::time::Duration>,
165 }
166
167 impl Default for PoolConfig {
168 fn default() -> Self {
169 Self {
170 max_connections: 10,
171 min_connections: 0,
172 acquire_timeout: std::time::Duration::from_secs(30),
173 idle_timeout: Some(std::time::Duration::from_secs(600)),
174 max_lifetime: Some(std::time::Duration::from_secs(1800)),
175 }
176 }
177 }
178
179 #[derive(Debug, Clone)]
180 pub struct PoolStats {
181 pub size: u32,
182 pub idle: usize,
183 }
184
185 #[async_trait]
188 impl<U, S, A, O, M, I, V, TF, AK, PK> UserOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
189 where
190 U: AuthUser + SqlxEntity,
191 S: AuthSession + SqlxEntity,
192 A: AuthAccount + SqlxEntity,
193 O: AuthOrganization + SqlxEntity,
194 M: AuthMember + SqlxEntity,
195 I: AuthInvitation + SqlxEntity,
196 V: AuthVerification + SqlxEntity,
197 TF: AuthTwoFactor + SqlxEntity,
198 AK: AuthApiKey + SqlxEntity,
199 PK: AuthPasskey + SqlxEntity,
200 {
201 type User = U;
202
203 async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
204 let id = create_user.id.unwrap_or_else(|| Uuid::new_v4().to_string());
205 let now = Utc::now();
206
207 let user = sqlx::query_as::<_, U>(
208 r#"
209 INSERT INTO users (id, email, name, image, email_verified, created_at, updated_at, metadata)
210 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
211 RETURNING *
212 "#,
213 )
214 .bind(&id)
215 .bind(&create_user.email)
216 .bind(&create_user.name)
217 .bind(&create_user.image)
218 .bind(false)
219 .bind(now)
220 .bind(now)
221 .bind(sqlx::types::Json(create_user.metadata.unwrap_or(serde_json::json!({}))))
222 .fetch_one(&self.pool)
223 .await?;
224
225 Ok(user)
226 }
227
228 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
229 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE id = $1")
230 .bind(id)
231 .fetch_optional(&self.pool)
232 .await?;
233 Ok(user)
234 }
235
236 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
237 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE email = $1")
238 .bind(email)
239 .fetch_optional(&self.pool)
240 .await?;
241 Ok(user)
242 }
243
244 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
245 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE username = $1")
246 .bind(username)
247 .fetch_optional(&self.pool)
248 .await?;
249 Ok(user)
250 }
251
252 async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
253 let mut query = sqlx::QueryBuilder::new("UPDATE users SET updated_at = NOW()");
254 let mut has_updates = false;
255
256 if let Some(email) = &update.email {
257 query.push(", email = ");
258 query.push_bind(email);
259 has_updates = true;
260 }
261 if let Some(name) = &update.name {
262 query.push(", name = ");
263 query.push_bind(name);
264 has_updates = true;
265 }
266 if let Some(image) = &update.image {
267 query.push(", image = ");
268 query.push_bind(image);
269 has_updates = true;
270 }
271 if let Some(email_verified) = update.email_verified {
272 query.push(", email_verified = ");
273 query.push_bind(email_verified);
274 has_updates = true;
275 }
276 if let Some(metadata) = &update.metadata {
277 query.push(", metadata = ");
278 query.push_bind(sqlx::types::Json(metadata.clone()));
279 has_updates = true;
280 }
281
282 if !has_updates {
283 return self
284 .get_user_by_id(id)
285 .await?
286 .ok_or(AuthError::UserNotFound);
287 }
288
289 query.push(" WHERE id = ");
290 query.push_bind(id);
291 query.push(" RETURNING *");
292
293 let user = query.build_query_as::<U>().fetch_one(&self.pool).await?;
294 Ok(user)
295 }
296
297 async fn delete_user(&self, id: &str) -> AuthResult<()> {
298 sqlx::query("DELETE FROM users WHERE id = $1")
299 .bind(id)
300 .execute(&self.pool)
301 .await?;
302 Ok(())
303 }
304 }
305
306 #[async_trait]
309 impl<U, S, A, O, M, I, V, TF, AK, PK> SessionOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
310 where
311 U: AuthUser + SqlxEntity,
312 S: AuthSession + SqlxEntity,
313 A: AuthAccount + SqlxEntity,
314 O: AuthOrganization + SqlxEntity,
315 M: AuthMember + SqlxEntity,
316 I: AuthInvitation + SqlxEntity,
317 V: AuthVerification + SqlxEntity,
318 TF: AuthTwoFactor + SqlxEntity,
319 AK: AuthApiKey + SqlxEntity,
320 PK: AuthPasskey + SqlxEntity,
321 {
322 type Session = S;
323
324 async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
325 let id = Uuid::new_v4().to_string();
326 let token = format!("session_{}", Uuid::new_v4());
327 let now = Utc::now();
328
329 let session = sqlx::query_as::<_, S>(
330 r#"
331 INSERT INTO sessions (id, user_id, token, expires_at, created_at, ip_address, user_agent, active)
332 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
333 RETURNING *
334 "#,
335 )
336 .bind(&id)
337 .bind(&create_session.user_id)
338 .bind(&token)
339 .bind(create_session.expires_at)
340 .bind(now)
341 .bind(&create_session.ip_address)
342 .bind(&create_session.user_agent)
343 .bind(true)
344 .fetch_one(&self.pool)
345 .await?;
346
347 Ok(session)
348 }
349
350 async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
351 let session =
352 sqlx::query_as::<_, S>("SELECT * FROM sessions WHERE token = $1 AND active = true")
353 .bind(token)
354 .fetch_optional(&self.pool)
355 .await?;
356 Ok(session)
357 }
358
359 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
360 let sessions = sqlx::query_as::<_, S>(
361 "SELECT * FROM sessions WHERE user_id = $1 AND active = true ORDER BY created_at DESC",
362 )
363 .bind(user_id)
364 .fetch_all(&self.pool)
365 .await?;
366 Ok(sessions)
367 }
368
369 async fn update_session_expiry(
370 &self,
371 token: &str,
372 expires_at: DateTime<Utc>,
373 ) -> AuthResult<()> {
374 sqlx::query("UPDATE sessions SET expires_at = $1 WHERE token = $2 AND active = true")
375 .bind(expires_at)
376 .bind(token)
377 .execute(&self.pool)
378 .await?;
379 Ok(())
380 }
381
382 async fn delete_session(&self, token: &str) -> AuthResult<()> {
383 sqlx::query("DELETE FROM sessions WHERE token = $1")
384 .bind(token)
385 .execute(&self.pool)
386 .await?;
387 Ok(())
388 }
389
390 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
391 sqlx::query("DELETE FROM sessions WHERE user_id = $1")
392 .bind(user_id)
393 .execute(&self.pool)
394 .await?;
395 Ok(())
396 }
397
398 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
399 let result =
400 sqlx::query("DELETE FROM sessions WHERE expires_at < NOW() OR active = false")
401 .execute(&self.pool)
402 .await?;
403 Ok(result.rows_affected() as usize)
404 }
405
406 async fn update_session_active_organization(
407 &self,
408 token: &str,
409 organization_id: Option<&str>,
410 ) -> AuthResult<S> {
411 let session = sqlx::query_as::<_, S>(
412 "UPDATE sessions SET active_organization_id = $1, updated_at = NOW() WHERE token = $2 AND active = true RETURNING *",
413 )
414 .bind(organization_id)
415 .bind(token)
416 .fetch_one(&self.pool)
417 .await?;
418 Ok(session)
419 }
420 }
421
422 #[async_trait]
425 impl<U, S, A, O, M, I, V, TF, AK, PK> AccountOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
426 where
427 U: AuthUser + SqlxEntity,
428 S: AuthSession + SqlxEntity,
429 A: AuthAccount + SqlxEntity,
430 O: AuthOrganization + SqlxEntity,
431 M: AuthMember + SqlxEntity,
432 I: AuthInvitation + SqlxEntity,
433 V: AuthVerification + SqlxEntity,
434 TF: AuthTwoFactor + SqlxEntity,
435 AK: AuthApiKey + SqlxEntity,
436 PK: AuthPasskey + SqlxEntity,
437 {
438 type Account = A;
439
440 async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
441 let id = Uuid::new_v4().to_string();
442 let now = Utc::now();
443
444 let account = sqlx::query_as::<_, A>(
445 r#"
446 INSERT INTO accounts (id, account_id, provider_id, user_id, access_token, refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, scope, password, created_at, updated_at)
447 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
448 RETURNING *
449 "#,
450 )
451 .bind(&id)
452 .bind(&create_account.account_id)
453 .bind(&create_account.provider_id)
454 .bind(&create_account.user_id)
455 .bind(&create_account.access_token)
456 .bind(&create_account.refresh_token)
457 .bind(&create_account.id_token)
458 .bind(create_account.access_token_expires_at)
459 .bind(create_account.refresh_token_expires_at)
460 .bind(&create_account.scope)
461 .bind(&create_account.password)
462 .bind(now)
463 .bind(now)
464 .fetch_one(&self.pool)
465 .await?;
466
467 Ok(account)
468 }
469
470 async fn get_account(
471 &self,
472 provider: &str,
473 provider_account_id: &str,
474 ) -> AuthResult<Option<A>> {
475 let account = sqlx::query_as::<_, A>(
476 "SELECT * FROM accounts WHERE provider_id = $1 AND account_id = $2",
477 )
478 .bind(provider)
479 .bind(provider_account_id)
480 .fetch_optional(&self.pool)
481 .await?;
482 Ok(account)
483 }
484
485 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
486 let accounts = sqlx::query_as::<_, A>(
487 "SELECT * FROM accounts WHERE user_id = $1 ORDER BY created_at DESC",
488 )
489 .bind(user_id)
490 .fetch_all(&self.pool)
491 .await?;
492 Ok(accounts)
493 }
494
495 async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<A> {
496 let mut query = sqlx::QueryBuilder::new("UPDATE accounts SET updated_at = NOW()");
497
498 if let Some(access_token) = &update.access_token {
499 query.push(", access_token = ");
500 query.push_bind(access_token);
501 }
502 if let Some(refresh_token) = &update.refresh_token {
503 query.push(", refresh_token = ");
504 query.push_bind(refresh_token);
505 }
506 if let Some(id_token) = &update.id_token {
507 query.push(", id_token = ");
508 query.push_bind(id_token);
509 }
510 if let Some(access_token_expires_at) = &update.access_token_expires_at {
511 query.push(", access_token_expires_at = ");
512 query.push_bind(access_token_expires_at);
513 }
514 if let Some(refresh_token_expires_at) = &update.refresh_token_expires_at {
515 query.push(", refresh_token_expires_at = ");
516 query.push_bind(refresh_token_expires_at);
517 }
518 if let Some(scope) = &update.scope {
519 query.push(", scope = ");
520 query.push_bind(scope);
521 }
522
523 query.push(" WHERE id = ");
524 query.push_bind(id);
525 query.push(" RETURNING *");
526
527 let account = query.build_query_as::<A>().fetch_one(&self.pool).await?;
528 Ok(account)
529 }
530
531 async fn delete_account(&self, id: &str) -> AuthResult<()> {
532 sqlx::query("DELETE FROM accounts WHERE id = $1")
533 .bind(id)
534 .execute(&self.pool)
535 .await?;
536 Ok(())
537 }
538 }
539
540 #[async_trait]
543 impl<U, S, A, O, M, I, V, TF, AK, PK> VerificationOps
544 for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
545 where
546 U: AuthUser + SqlxEntity,
547 S: AuthSession + SqlxEntity,
548 A: AuthAccount + SqlxEntity,
549 O: AuthOrganization + SqlxEntity,
550 M: AuthMember + SqlxEntity,
551 I: AuthInvitation + SqlxEntity,
552 V: AuthVerification + SqlxEntity,
553 TF: AuthTwoFactor + SqlxEntity,
554 AK: AuthApiKey + SqlxEntity,
555 PK: AuthPasskey + SqlxEntity,
556 {
557 type Verification = V;
558
559 async fn create_verification(
560 &self,
561 create_verification: CreateVerification,
562 ) -> AuthResult<V> {
563 let id = Uuid::new_v4().to_string();
564 let now = Utc::now();
565
566 let verification = sqlx::query_as::<_, V>(
567 r#"
568 INSERT INTO verifications (id, identifier, value, expires_at, created_at, updated_at)
569 VALUES ($1, $2, $3, $4, $5, $6)
570 RETURNING *
571 "#,
572 )
573 .bind(&id)
574 .bind(&create_verification.identifier)
575 .bind(&create_verification.value)
576 .bind(create_verification.expires_at)
577 .bind(now)
578 .bind(now)
579 .fetch_one(&self.pool)
580 .await?;
581
582 Ok(verification)
583 }
584
585 async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
586 let verification = sqlx::query_as::<_, V>(
587 "SELECT * FROM verifications WHERE identifier = $1 AND value = $2 AND expires_at > NOW()",
588 )
589 .bind(identifier)
590 .bind(value)
591 .fetch_optional(&self.pool)
592 .await?;
593 Ok(verification)
594 }
595
596 async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
597 let verification = sqlx::query_as::<_, V>(
598 "SELECT * FROM verifications WHERE value = $1 AND expires_at > NOW()",
599 )
600 .bind(value)
601 .fetch_optional(&self.pool)
602 .await?;
603 Ok(verification)
604 }
605
606 async fn get_verification_by_identifier(&self, identifier: &str) -> AuthResult<Option<V>> {
607 let verification = sqlx::query_as::<_, V>(
608 "SELECT * FROM verifications WHERE identifier = $1 AND expires_at > NOW()",
609 )
610 .bind(identifier)
611 .fetch_optional(&self.pool)
612 .await?;
613 Ok(verification)
614 }
615
616 async fn consume_verification(
617 &self,
618 identifier: &str,
619 value: &str,
620 ) -> AuthResult<Option<V>> {
621 let verification = sqlx::query_as::<_, V>(
622 "DELETE FROM verifications WHERE id IN (
623 SELECT id FROM verifications
624 WHERE identifier = $1 AND value = $2 AND expires_at > NOW()
625 ORDER BY created_at DESC
626 LIMIT 1
627 ) RETURNING *",
628 )
629 .bind(identifier)
630 .bind(value)
631 .fetch_optional(&self.pool)
632 .await?;
633 Ok(verification)
634 }
635
636 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
637 sqlx::query("DELETE FROM verifications WHERE id = $1")
638 .bind(id)
639 .execute(&self.pool)
640 .await?;
641 Ok(())
642 }
643
644 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
645 let result = sqlx::query("DELETE FROM verifications WHERE expires_at < NOW()")
646 .execute(&self.pool)
647 .await?;
648 Ok(result.rows_affected() as usize)
649 }
650 }
651
652 #[async_trait]
655 impl<U, S, A, O, M, I, V, TF, AK, PK> OrganizationOps
656 for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
657 where
658 U: AuthUser + SqlxEntity,
659 S: AuthSession + SqlxEntity,
660 A: AuthAccount + SqlxEntity,
661 O: AuthOrganization + SqlxEntity,
662 M: AuthMember + SqlxEntity,
663 I: AuthInvitation + SqlxEntity,
664 V: AuthVerification + SqlxEntity,
665 TF: AuthTwoFactor + SqlxEntity,
666 AK: AuthApiKey + SqlxEntity,
667 PK: AuthPasskey + SqlxEntity,
668 {
669 type Organization = O;
670
671 async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
672 let id = create_org.id.unwrap_or_else(|| Uuid::new_v4().to_string());
673 let now = Utc::now();
674
675 let organization = sqlx::query_as::<_, O>(
676 r#"
677 INSERT INTO organization (id, name, slug, logo, metadata, created_at, updated_at)
678 VALUES ($1, $2, $3, $4, $5, $6, $7)
679 RETURNING *
680 "#,
681 )
682 .bind(&id)
683 .bind(&create_org.name)
684 .bind(&create_org.slug)
685 .bind(&create_org.logo)
686 .bind(sqlx::types::Json(
687 create_org.metadata.unwrap_or(serde_json::json!({})),
688 ))
689 .bind(now)
690 .bind(now)
691 .fetch_one(&self.pool)
692 .await?;
693
694 Ok(organization)
695 }
696
697 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
698 let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE id = $1")
699 .bind(id)
700 .fetch_optional(&self.pool)
701 .await?;
702 Ok(organization)
703 }
704
705 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
706 let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE slug = $1")
707 .bind(slug)
708 .fetch_optional(&self.pool)
709 .await?;
710 Ok(organization)
711 }
712
713 async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
714 let mut query = sqlx::QueryBuilder::new("UPDATE organization SET updated_at = NOW()");
715
716 if let Some(name) = &update.name {
717 query.push(", name = ");
718 query.push_bind(name);
719 }
720 if let Some(slug) = &update.slug {
721 query.push(", slug = ");
722 query.push_bind(slug);
723 }
724 if let Some(logo) = &update.logo {
725 query.push(", logo = ");
726 query.push_bind(logo);
727 }
728 if let Some(metadata) = &update.metadata {
729 query.push(", metadata = ");
730 query.push_bind(sqlx::types::Json(metadata.clone()));
731 }
732
733 query.push(" WHERE id = ");
734 query.push_bind(id);
735 query.push(" RETURNING *");
736
737 let organization = query.build_query_as::<O>().fetch_one(&self.pool).await?;
738 Ok(organization)
739 }
740
741 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
742 sqlx::query("DELETE FROM organization WHERE id = $1")
743 .bind(id)
744 .execute(&self.pool)
745 .await?;
746 Ok(())
747 }
748
749 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
750 let organizations = sqlx::query_as::<_, O>(
751 r#"
752 SELECT o.*
753 FROM organization o
754 INNER JOIN member m ON o.id = m.organization_id
755 WHERE m.user_id = $1
756 ORDER BY o.created_at DESC
757 "#,
758 )
759 .bind(user_id)
760 .fetch_all(&self.pool)
761 .await?;
762 Ok(organizations)
763 }
764 }
765
766 #[async_trait]
769 impl<U, S, A, O, M, I, V, TF, AK, PK> MemberOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
770 where
771 U: AuthUser + SqlxEntity,
772 S: AuthSession + SqlxEntity,
773 A: AuthAccount + SqlxEntity,
774 O: AuthOrganization + SqlxEntity,
775 M: AuthMember + SqlxEntity,
776 I: AuthInvitation + SqlxEntity,
777 V: AuthVerification + SqlxEntity,
778 TF: AuthTwoFactor + SqlxEntity,
779 AK: AuthApiKey + SqlxEntity,
780 PK: AuthPasskey + SqlxEntity,
781 {
782 type Member = M;
783
784 async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
785 let id = Uuid::new_v4().to_string();
786 let now = Utc::now();
787
788 let member = sqlx::query_as::<_, M>(
789 r#"
790 INSERT INTO member (id, organization_id, user_id, role, created_at)
791 VALUES ($1, $2, $3, $4, $5)
792 RETURNING *
793 "#,
794 )
795 .bind(&id)
796 .bind(&create_member.organization_id)
797 .bind(&create_member.user_id)
798 .bind(&create_member.role)
799 .bind(now)
800 .fetch_one(&self.pool)
801 .await?;
802
803 Ok(member)
804 }
805
806 async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
807 let member = sqlx::query_as::<_, M>(
808 "SELECT * FROM member WHERE organization_id = $1 AND user_id = $2",
809 )
810 .bind(organization_id)
811 .bind(user_id)
812 .fetch_optional(&self.pool)
813 .await?;
814 Ok(member)
815 }
816
817 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
818 let member = sqlx::query_as::<_, M>("SELECT * FROM member WHERE id = $1")
819 .bind(id)
820 .fetch_optional(&self.pool)
821 .await?;
822 Ok(member)
823 }
824
825 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
826 let member =
827 sqlx::query_as::<_, M>("UPDATE member SET role = $1 WHERE id = $2 RETURNING *")
828 .bind(role)
829 .bind(member_id)
830 .fetch_one(&self.pool)
831 .await?;
832 Ok(member)
833 }
834
835 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
836 sqlx::query("DELETE FROM member WHERE id = $1")
837 .bind(member_id)
838 .execute(&self.pool)
839 .await?;
840 Ok(())
841 }
842
843 async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
844 let members = sqlx::query_as::<_, M>(
845 "SELECT * FROM member WHERE organization_id = $1 ORDER BY created_at ASC",
846 )
847 .bind(organization_id)
848 .fetch_all(&self.pool)
849 .await?;
850 Ok(members)
851 }
852
853 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
854 let count: (i64,) =
855 sqlx::query_as("SELECT COUNT(*) FROM member WHERE organization_id = $1")
856 .bind(organization_id)
857 .fetch_one(&self.pool)
858 .await?;
859 Ok(count.0 as usize)
860 }
861
862 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
863 let count: (i64,) = sqlx::query_as(
864 "SELECT COUNT(*) FROM member WHERE organization_id = $1 AND role = 'owner'",
865 )
866 .bind(organization_id)
867 .fetch_one(&self.pool)
868 .await?;
869 Ok(count.0 as usize)
870 }
871 }
872
873 #[async_trait]
876 impl<U, S, A, O, M, I, V, TF, AK, PK> InvitationOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
877 where
878 U: AuthUser + SqlxEntity,
879 S: AuthSession + SqlxEntity,
880 A: AuthAccount + SqlxEntity,
881 O: AuthOrganization + SqlxEntity,
882 M: AuthMember + SqlxEntity,
883 I: AuthInvitation + SqlxEntity,
884 V: AuthVerification + SqlxEntity,
885 TF: AuthTwoFactor + SqlxEntity,
886 AK: AuthApiKey + SqlxEntity,
887 PK: AuthPasskey + SqlxEntity,
888 {
889 type Invitation = I;
890
891 async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
892 let id = Uuid::new_v4().to_string();
893 let now = Utc::now();
894
895 let invitation = sqlx::query_as::<_, I>(
896 r#"
897 INSERT INTO invitation (id, organization_id, email, role, status, inviter_id, expires_at, created_at)
898 VALUES ($1, $2, $3, $4, 'pending', $5, $6, $7)
899 RETURNING *
900 "#,
901 )
902 .bind(&id)
903 .bind(&create_inv.organization_id)
904 .bind(&create_inv.email)
905 .bind(&create_inv.role)
906 .bind(&create_inv.inviter_id)
907 .bind(create_inv.expires_at)
908 .bind(now)
909 .fetch_one(&self.pool)
910 .await?;
911
912 Ok(invitation)
913 }
914
915 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
916 let invitation = sqlx::query_as::<_, I>("SELECT * FROM invitation WHERE id = $1")
917 .bind(id)
918 .fetch_optional(&self.pool)
919 .await?;
920 Ok(invitation)
921 }
922
923 async fn get_pending_invitation(
924 &self,
925 organization_id: &str,
926 email: &str,
927 ) -> AuthResult<Option<I>> {
928 let invitation = sqlx::query_as::<_, I>(
929 "SELECT * FROM invitation WHERE organization_id = $1 AND LOWER(email) = LOWER($2) AND status = 'pending'",
930 )
931 .bind(organization_id)
932 .bind(email)
933 .fetch_optional(&self.pool)
934 .await?;
935 Ok(invitation)
936 }
937
938 async fn update_invitation_status(
939 &self,
940 id: &str,
941 status: InvitationStatus,
942 ) -> AuthResult<I> {
943 let invitation = sqlx::query_as::<_, I>(
944 "UPDATE invitation SET status = $1 WHERE id = $2 RETURNING *",
945 )
946 .bind(status.to_string())
947 .bind(id)
948 .fetch_one(&self.pool)
949 .await?;
950 Ok(invitation)
951 }
952
953 async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
954 let invitations = sqlx::query_as::<_, I>(
955 "SELECT * FROM invitation WHERE organization_id = $1 ORDER BY created_at DESC",
956 )
957 .bind(organization_id)
958 .fetch_all(&self.pool)
959 .await?;
960 Ok(invitations)
961 }
962
963 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
964 let invitations = sqlx::query_as::<_, I>(
965 "SELECT * FROM invitation WHERE LOWER(email) = LOWER($1) AND status = 'pending' AND expires_at > NOW() ORDER BY created_at DESC",
966 )
967 .bind(email)
968 .fetch_all(&self.pool)
969 .await?;
970 Ok(invitations)
971 }
972 }
973
974 #[async_trait]
977 impl<U, S, A, O, M, I, V, TF, AK, PK> TwoFactorOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
978 where
979 U: AuthUser + SqlxEntity,
980 S: AuthSession + SqlxEntity,
981 A: AuthAccount + SqlxEntity,
982 O: AuthOrganization + SqlxEntity,
983 M: AuthMember + SqlxEntity,
984 I: AuthInvitation + SqlxEntity,
985 V: AuthVerification + SqlxEntity,
986 TF: AuthTwoFactor + SqlxEntity,
987 AK: AuthApiKey + SqlxEntity,
988 PK: AuthPasskey + SqlxEntity,
989 {
990 type TwoFactor = TF;
991
992 async fn create_two_factor(&self, create: CreateTwoFactor) -> AuthResult<TF> {
993 let id = Uuid::new_v4().to_string();
994 let now = Utc::now();
995
996 let two_factor = sqlx::query_as::<_, TF>(
997 r#"
998 INSERT INTO two_factor (id, secret, backup_codes, user_id, created_at, updated_at)
999 VALUES ($1, $2, $3, $4, $5, $6)
1000 RETURNING *
1001 "#,
1002 )
1003 .bind(&id)
1004 .bind(&create.secret)
1005 .bind(&create.backup_codes)
1006 .bind(&create.user_id)
1007 .bind(now)
1008 .bind(now)
1009 .fetch_one(&self.pool)
1010 .await?;
1011
1012 Ok(two_factor)
1013 }
1014
1015 async fn get_two_factor_by_user_id(&self, user_id: &str) -> AuthResult<Option<TF>> {
1016 let two_factor = sqlx::query_as::<_, TF>("SELECT * FROM two_factor WHERE user_id = $1")
1017 .bind(user_id)
1018 .fetch_optional(&self.pool)
1019 .await?;
1020 Ok(two_factor)
1021 }
1022
1023 async fn update_two_factor_backup_codes(
1024 &self,
1025 user_id: &str,
1026 backup_codes: &str,
1027 ) -> AuthResult<TF> {
1028 let two_factor = sqlx::query_as::<_, TF>(
1029 "UPDATE two_factor SET backup_codes = $1, updated_at = NOW() WHERE user_id = $2 RETURNING *",
1030 )
1031 .bind(backup_codes)
1032 .bind(user_id)
1033 .fetch_one(&self.pool)
1034 .await?;
1035 Ok(two_factor)
1036 }
1037
1038 async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
1039 sqlx::query("DELETE FROM two_factor WHERE user_id = $1")
1040 .bind(user_id)
1041 .execute(&self.pool)
1042 .await?;
1043 Ok(())
1044 }
1045 }
1046
1047 #[async_trait]
1050 impl<U, S, A, O, M, I, V, TF, AK, PK> ApiKeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1051 where
1052 U: AuthUser + SqlxEntity,
1053 S: AuthSession + SqlxEntity,
1054 A: AuthAccount + SqlxEntity,
1055 O: AuthOrganization + SqlxEntity,
1056 M: AuthMember + SqlxEntity,
1057 I: AuthInvitation + SqlxEntity,
1058 V: AuthVerification + SqlxEntity,
1059 TF: AuthTwoFactor + SqlxEntity,
1060 AK: AuthApiKey + SqlxEntity,
1061 PK: AuthPasskey + SqlxEntity,
1062 {
1063 type ApiKey = AK;
1064
1065 async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<AK> {
1066 let id = Uuid::new_v4().to_string();
1067 let now = Utc::now();
1068
1069 let api_key = sqlx::query_as::<_, AK>(
1070 r#"
1071 INSERT INTO api_keys (id, name, start, prefix, key, user_id, refill_interval, refill_amount,
1072 enabled, rate_limit_enabled, rate_limit_time_window, rate_limit_max, remaining,
1073 expires_at, created_at, updated_at, permissions, metadata)
1074 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
1075 $14::timestamptz, $15, $16, $17, $18)
1076 RETURNING *
1077 "#,
1078 )
1079 .bind(&id)
1080 .bind(&input.name)
1081 .bind(&input.start)
1082 .bind(&input.prefix)
1083 .bind(&input.key_hash)
1084 .bind(&input.user_id)
1085 .bind(input.refill_interval)
1086 .bind(input.refill_amount)
1087 .bind(input.enabled)
1088 .bind(input.rate_limit_enabled)
1089 .bind(input.rate_limit_time_window)
1090 .bind(input.rate_limit_max)
1091 .bind(input.remaining)
1092 .bind(&input.expires_at)
1093 .bind(now)
1094 .bind(now)
1095 .bind(&input.permissions)
1096 .bind(&input.metadata)
1097 .fetch_one(&self.pool)
1098 .await?;
1099
1100 Ok(api_key)
1101 }
1102
1103 async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<AK>> {
1104 let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE id = $1")
1105 .bind(id)
1106 .fetch_optional(&self.pool)
1107 .await?;
1108 Ok(api_key)
1109 }
1110
1111 async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<AK>> {
1112 let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE key = $1")
1113 .bind(hash)
1114 .fetch_optional(&self.pool)
1115 .await?;
1116 Ok(api_key)
1117 }
1118
1119 async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<AK>> {
1120 let keys = sqlx::query_as::<_, AK>(
1121 "SELECT * FROM api_keys WHERE user_id = $1 ORDER BY created_at DESC",
1122 )
1123 .bind(user_id)
1124 .fetch_all(&self.pool)
1125 .await?;
1126 Ok(keys)
1127 }
1128
1129 async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<AK> {
1130 let mut query = sqlx::QueryBuilder::new("UPDATE api_keys SET updated_at = NOW()");
1131
1132 if let Some(name) = &update.name {
1133 query.push(", name = ");
1134 query.push_bind(name);
1135 }
1136 if let Some(enabled) = update.enabled {
1137 query.push(", enabled = ");
1138 query.push_bind(enabled);
1139 }
1140 if let Some(remaining) = update.remaining {
1141 query.push(", remaining = ");
1142 query.push_bind(remaining);
1143 }
1144 if let Some(rate_limit_enabled) = update.rate_limit_enabled {
1145 query.push(", rate_limit_enabled = ");
1146 query.push_bind(rate_limit_enabled);
1147 }
1148 if let Some(rate_limit_time_window) = update.rate_limit_time_window {
1149 query.push(", rate_limit_time_window = ");
1150 query.push_bind(rate_limit_time_window);
1151 }
1152 if let Some(rate_limit_max) = update.rate_limit_max {
1153 query.push(", rate_limit_max = ");
1154 query.push_bind(rate_limit_max);
1155 }
1156 if let Some(refill_interval) = update.refill_interval {
1157 query.push(", refill_interval = ");
1158 query.push_bind(refill_interval);
1159 }
1160 if let Some(refill_amount) = update.refill_amount {
1161 query.push(", refill_amount = ");
1162 query.push_bind(refill_amount);
1163 }
1164 if let Some(permissions) = &update.permissions {
1165 query.push(", permissions = ");
1166 query.push_bind(permissions);
1167 }
1168 if let Some(metadata) = &update.metadata {
1169 query.push(", metadata = ");
1170 query.push_bind(metadata);
1171 }
1172
1173 query.push(" WHERE id = ");
1174 query.push_bind(id);
1175 query.push(" RETURNING *");
1176
1177 let api_key = query
1178 .build_query_as::<AK>()
1179 .fetch_one(&self.pool)
1180 .await
1181 .map_err(|err| match err {
1182 sqlx::Error::RowNotFound => AuthError::not_found("API key not found"),
1183 other => AuthError::from(other),
1184 })?;
1185 Ok(api_key)
1186 }
1187
1188 async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
1189 sqlx::query("DELETE FROM api_keys WHERE id = $1")
1190 .bind(id)
1191 .execute(&self.pool)
1192 .await?;
1193 Ok(())
1194 }
1195 }
1196
1197 #[async_trait]
1200 impl<U, S, A, O, M, I, V, TF, AK, PK> PasskeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1201 where
1202 U: AuthUser + SqlxEntity,
1203 S: AuthSession + SqlxEntity,
1204 A: AuthAccount + SqlxEntity,
1205 O: AuthOrganization + SqlxEntity,
1206 M: AuthMember + SqlxEntity,
1207 I: AuthInvitation + SqlxEntity,
1208 V: AuthVerification + SqlxEntity,
1209 TF: AuthTwoFactor + SqlxEntity,
1210 AK: AuthApiKey + SqlxEntity,
1211 PK: AuthPasskey + SqlxEntity,
1212 {
1213 type Passkey = PK;
1214
1215 async fn create_passkey(&self, input: CreatePasskey) -> AuthResult<PK> {
1216 let id = Uuid::new_v4().to_string();
1217 let now = Utc::now();
1218 let counter = i64::try_from(input.counter)
1219 .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1220
1221 let passkey = sqlx::query_as::<_, PK>(
1222 r#"
1223 INSERT INTO passkeys (id, name, public_key, user_id, credential_id, counter, device_type, backed_up, transports, created_at)
1224 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
1225 RETURNING *
1226 "#,
1227 )
1228 .bind(&id)
1229 .bind(&input.name)
1230 .bind(&input.public_key)
1231 .bind(&input.user_id)
1232 .bind(&input.credential_id)
1233 .bind(counter)
1234 .bind(&input.device_type)
1235 .bind(input.backed_up)
1236 .bind(&input.transports)
1237 .bind(now)
1238 .fetch_one(&self.pool)
1239 .await
1240 .map_err(|e| match e {
1241 sqlx::Error::Database(ref db_err) if db_err.is_unique_violation() => {
1242 AuthError::conflict("A passkey with this credential ID already exists")
1243 }
1244 other => AuthError::from(other),
1245 })?;
1246
1247 Ok(passkey)
1248 }
1249
1250 async fn get_passkey_by_id(&self, id: &str) -> AuthResult<Option<PK>> {
1251 let passkey = sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE id = $1")
1252 .bind(id)
1253 .fetch_optional(&self.pool)
1254 .await?;
1255 Ok(passkey)
1256 }
1257
1258 async fn get_passkey_by_credential_id(
1259 &self,
1260 credential_id: &str,
1261 ) -> AuthResult<Option<PK>> {
1262 let passkey =
1263 sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE credential_id = $1")
1264 .bind(credential_id)
1265 .fetch_optional(&self.pool)
1266 .await?;
1267 Ok(passkey)
1268 }
1269
1270 async fn list_passkeys_by_user(&self, user_id: &str) -> AuthResult<Vec<PK>> {
1271 let passkeys = sqlx::query_as::<_, PK>(
1272 "SELECT * FROM passkeys WHERE user_id = $1 ORDER BY created_at DESC",
1273 )
1274 .bind(user_id)
1275 .fetch_all(&self.pool)
1276 .await?;
1277 Ok(passkeys)
1278 }
1279
1280 async fn update_passkey_counter(&self, id: &str, counter: u64) -> AuthResult<PK> {
1281 let counter = i64::try_from(counter)
1282 .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1283 let passkey = sqlx::query_as::<_, PK>(
1284 "UPDATE passkeys SET counter = $2 WHERE id = $1 RETURNING *",
1285 )
1286 .bind(id)
1287 .bind(counter)
1288 .fetch_one(&self.pool)
1289 .await
1290 .map_err(|err| match err {
1291 sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1292 other => AuthError::from(other),
1293 })?;
1294 Ok(passkey)
1295 }
1296
1297 async fn update_passkey_name(&self, id: &str, name: &str) -> AuthResult<PK> {
1298 let passkey =
1299 sqlx::query_as::<_, PK>("UPDATE passkeys SET name = $2 WHERE id = $1 RETURNING *")
1300 .bind(id)
1301 .bind(name)
1302 .fetch_one(&self.pool)
1303 .await
1304 .map_err(|err| match err {
1305 sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1306 other => AuthError::from(other),
1307 })?;
1308 Ok(passkey)
1309 }
1310
1311 async fn delete_passkey(&self, id: &str) -> AuthResult<()> {
1312 sqlx::query("DELETE FROM passkeys WHERE id = $1")
1313 .bind(id)
1314 .execute(&self.pool)
1315 .await?;
1316 Ok(())
1317 }
1318 }
1319}
1320
1321#[cfg(feature = "sqlx-postgres")]
1322pub use sqlx_adapter::{SqlxAdapter, SqlxEntity};