1pub use super::traits::{
2 AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, PasskeyOps, SessionOps,
3 TwoFactorOps, UserOps, VerificationOps,
4};
5
6pub trait DatabaseAdapter:
15 UserOps
16 + SessionOps
17 + AccountOps
18 + VerificationOps
19 + OrganizationOps
20 + MemberOps
21 + InvitationOps
22 + TwoFactorOps
23 + ApiKeyOps
24 + PasskeyOps
25{
26}
27
28impl<T> DatabaseAdapter for T where
29 T: UserOps
30 + SessionOps
31 + AccountOps
32 + VerificationOps
33 + OrganizationOps
34 + MemberOps
35 + InvitationOps
36 + TwoFactorOps
37 + ApiKeyOps
38 + PasskeyOps
39{
40}
41
42#[cfg(feature = "sqlx-postgres")]
43pub mod sqlx_adapter {
44 use super::*;
45 use async_trait::async_trait;
46 use chrono::{DateTime, Utc};
47
48 use crate::entity::{
49 AuthAccount, AuthApiKey, AuthInvitation, AuthMember, AuthOrganization, AuthPasskey,
50 AuthSession, AuthTwoFactor, AuthUser, AuthVerification,
51 };
52 use crate::error::{AuthError, AuthResult};
53 use crate::types::{
54 Account, ApiKey, CreateAccount, CreateApiKey, CreateInvitation, CreateMember,
55 CreateOrganization, CreatePasskey, CreateSession, CreateTwoFactor, CreateUser,
56 CreateVerification, Invitation, InvitationStatus, Member, Organization, Passkey, Session,
57 TwoFactor, UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser, User, Verification,
58 };
59 use sqlx::PgPool;
60 use sqlx::postgres::PgRow;
61 use std::marker::PhantomData;
62 use uuid::Uuid;
63
64 pub trait SqlxEntity:
71 for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
72 {
73 }
74
75 impl<T> SqlxEntity for T where
76 T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
77 {
78 }
79
80 pub struct SqlxAdapter<
85 U = User,
86 S = Session,
87 A = Account,
88 O = Organization,
89 M = Member,
90 I = Invitation,
91 V = Verification,
92 TF = TwoFactor,
93 AK = ApiKey,
94 PK = Passkey,
95 > {
96 pool: PgPool,
97 _phantom: PhantomData<(U, S, A, O, M, I, V, TF, AK, PK)>,
98 }
99
100 impl SqlxAdapter {
102 pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
103 let pool = PgPool::connect(database_url).await?;
104 Ok(Self {
105 pool,
106 _phantom: PhantomData,
107 })
108 }
109
110 pub async fn with_config(
111 database_url: &str,
112 config: PoolConfig,
113 ) -> Result<Self, sqlx::Error> {
114 let pool = sqlx::postgres::PgPoolOptions::new()
115 .max_connections(config.max_connections)
116 .min_connections(config.min_connections)
117 .acquire_timeout(config.acquire_timeout)
118 .idle_timeout(config.idle_timeout)
119 .max_lifetime(config.max_lifetime)
120 .connect(database_url)
121 .await?;
122 Ok(Self {
123 pool,
124 _phantom: PhantomData,
125 })
126 }
127 }
128
129 impl<U, S, A, O, M, I, V, TF, AK, PK> SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK> {
131 pub fn from_pool(pool: PgPool) -> Self {
132 Self {
133 pool,
134 _phantom: PhantomData,
135 }
136 }
137
138 pub async fn test_connection(&self) -> Result<(), sqlx::Error> {
139 sqlx::query("SELECT 1").execute(&self.pool).await?;
140 Ok(())
141 }
142
143 pub fn pool_stats(&self) -> PoolStats {
144 PoolStats {
145 size: self.pool.size(),
146 idle: self.pool.num_idle(),
147 }
148 }
149
150 pub async fn close(&self) {
151 self.pool.close().await;
152 }
153 }
154
155 #[derive(Debug, Clone)]
156 pub struct PoolConfig {
157 pub max_connections: u32,
158 pub min_connections: u32,
159 pub acquire_timeout: std::time::Duration,
160 pub idle_timeout: Option<std::time::Duration>,
161 pub max_lifetime: Option<std::time::Duration>,
162 }
163
164 impl Default for PoolConfig {
165 fn default() -> Self {
166 Self {
167 max_connections: 10,
168 min_connections: 0,
169 acquire_timeout: std::time::Duration::from_secs(30),
170 idle_timeout: Some(std::time::Duration::from_secs(600)),
171 max_lifetime: Some(std::time::Duration::from_secs(1800)),
172 }
173 }
174 }
175
176 #[derive(Debug, Clone)]
177 pub struct PoolStats {
178 pub size: u32,
179 pub idle: usize,
180 }
181
182 #[async_trait]
185 impl<U, S, A, O, M, I, V, TF, AK, PK> UserOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
186 where
187 U: AuthUser + SqlxEntity,
188 S: AuthSession + SqlxEntity,
189 A: AuthAccount + SqlxEntity,
190 O: AuthOrganization + SqlxEntity,
191 M: AuthMember + SqlxEntity,
192 I: AuthInvitation + SqlxEntity,
193 V: AuthVerification + SqlxEntity,
194 TF: AuthTwoFactor + SqlxEntity,
195 AK: AuthApiKey + SqlxEntity,
196 PK: AuthPasskey + SqlxEntity,
197 {
198 type User = U;
199
200 async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
201 let id = create_user.id.unwrap_or_else(|| Uuid::new_v4().to_string());
202 let now = Utc::now();
203
204 let user = sqlx::query_as::<_, U>(
205 r#"
206 INSERT INTO users (id, email, name, image, email_verified, created_at, updated_at, metadata)
207 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
208 RETURNING *
209 "#,
210 )
211 .bind(&id)
212 .bind(&create_user.email)
213 .bind(&create_user.name)
214 .bind(&create_user.image)
215 .bind(false)
216 .bind(&now)
217 .bind(&now)
218 .bind(sqlx::types::Json(create_user.metadata.unwrap_or(serde_json::json!({}))))
219 .fetch_one(&self.pool)
220 .await?;
221
222 Ok(user)
223 }
224
225 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
226 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE id = $1")
227 .bind(id)
228 .fetch_optional(&self.pool)
229 .await?;
230 Ok(user)
231 }
232
233 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
234 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE email = $1")
235 .bind(email)
236 .fetch_optional(&self.pool)
237 .await?;
238 Ok(user)
239 }
240
241 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
242 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE username = $1")
243 .bind(username)
244 .fetch_optional(&self.pool)
245 .await?;
246 Ok(user)
247 }
248
249 async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
250 let mut query = sqlx::QueryBuilder::new("UPDATE users SET updated_at = NOW()");
251 let mut has_updates = false;
252
253 if let Some(email) = &update.email {
254 query.push(", email = ");
255 query.push_bind(email);
256 has_updates = true;
257 }
258 if let Some(name) = &update.name {
259 query.push(", name = ");
260 query.push_bind(name);
261 has_updates = true;
262 }
263 if let Some(image) = &update.image {
264 query.push(", image = ");
265 query.push_bind(image);
266 has_updates = true;
267 }
268 if let Some(email_verified) = update.email_verified {
269 query.push(", email_verified = ");
270 query.push_bind(email_verified);
271 has_updates = true;
272 }
273 if let Some(metadata) = &update.metadata {
274 query.push(", metadata = ");
275 query.push_bind(sqlx::types::Json(metadata.clone()));
276 has_updates = true;
277 }
278
279 if !has_updates {
280 return self
281 .get_user_by_id(id)
282 .await?
283 .ok_or(AuthError::UserNotFound);
284 }
285
286 query.push(" WHERE id = ");
287 query.push_bind(id);
288 query.push(" RETURNING *");
289
290 let user = query.build_query_as::<U>().fetch_one(&self.pool).await?;
291 Ok(user)
292 }
293
294 async fn delete_user(&self, id: &str) -> AuthResult<()> {
295 sqlx::query("DELETE FROM users WHERE id = $1")
296 .bind(id)
297 .execute(&self.pool)
298 .await?;
299 Ok(())
300 }
301 }
302
303 #[async_trait]
306 impl<U, S, A, O, M, I, V, TF, AK, PK> SessionOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
307 where
308 U: AuthUser + SqlxEntity,
309 S: AuthSession + SqlxEntity,
310 A: AuthAccount + SqlxEntity,
311 O: AuthOrganization + SqlxEntity,
312 M: AuthMember + SqlxEntity,
313 I: AuthInvitation + SqlxEntity,
314 V: AuthVerification + SqlxEntity,
315 TF: AuthTwoFactor + SqlxEntity,
316 AK: AuthApiKey + SqlxEntity,
317 PK: AuthPasskey + SqlxEntity,
318 {
319 type Session = S;
320
321 async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
322 let id = Uuid::new_v4().to_string();
323 let token = format!("session_{}", Uuid::new_v4());
324 let now = Utc::now();
325
326 let session = sqlx::query_as::<_, S>(
327 r#"
328 INSERT INTO sessions (id, user_id, token, expires_at, created_at, ip_address, user_agent, active)
329 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
330 RETURNING *
331 "#,
332 )
333 .bind(&id)
334 .bind(&create_session.user_id)
335 .bind(&token)
336 .bind(&create_session.expires_at)
337 .bind(&now)
338 .bind(&create_session.ip_address)
339 .bind(&create_session.user_agent)
340 .bind(true)
341 .fetch_one(&self.pool)
342 .await?;
343
344 Ok(session)
345 }
346
347 async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
348 let session =
349 sqlx::query_as::<_, S>("SELECT * FROM sessions WHERE token = $1 AND active = true")
350 .bind(token)
351 .fetch_optional(&self.pool)
352 .await?;
353 Ok(session)
354 }
355
356 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
357 let sessions = sqlx::query_as::<_, S>(
358 "SELECT * FROM sessions WHERE user_id = $1 AND active = true ORDER BY created_at DESC",
359 )
360 .bind(user_id)
361 .fetch_all(&self.pool)
362 .await?;
363 Ok(sessions)
364 }
365
366 async fn update_session_expiry(
367 &self,
368 token: &str,
369 expires_at: DateTime<Utc>,
370 ) -> AuthResult<()> {
371 sqlx::query("UPDATE sessions SET expires_at = $1 WHERE token = $2 AND active = true")
372 .bind(&expires_at)
373 .bind(token)
374 .execute(&self.pool)
375 .await?;
376 Ok(())
377 }
378
379 async fn delete_session(&self, token: &str) -> AuthResult<()> {
380 sqlx::query("DELETE FROM sessions WHERE token = $1")
381 .bind(token)
382 .execute(&self.pool)
383 .await?;
384 Ok(())
385 }
386
387 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
388 sqlx::query("DELETE FROM sessions WHERE user_id = $1")
389 .bind(user_id)
390 .execute(&self.pool)
391 .await?;
392 Ok(())
393 }
394
395 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
396 let result =
397 sqlx::query("DELETE FROM sessions WHERE expires_at < NOW() OR active = false")
398 .execute(&self.pool)
399 .await?;
400 Ok(result.rows_affected() as usize)
401 }
402
403 async fn update_session_active_organization(
404 &self,
405 token: &str,
406 organization_id: Option<&str>,
407 ) -> AuthResult<S> {
408 let session = sqlx::query_as::<_, S>(
409 "UPDATE sessions SET active_organization_id = $1, updated_at = NOW() WHERE token = $2 AND active = true RETURNING *",
410 )
411 .bind(organization_id)
412 .bind(token)
413 .fetch_one(&self.pool)
414 .await?;
415 Ok(session)
416 }
417 }
418
419 #[async_trait]
422 impl<U, S, A, O, M, I, V, TF, AK, PK> AccountOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
423 where
424 U: AuthUser + SqlxEntity,
425 S: AuthSession + SqlxEntity,
426 A: AuthAccount + SqlxEntity,
427 O: AuthOrganization + SqlxEntity,
428 M: AuthMember + SqlxEntity,
429 I: AuthInvitation + SqlxEntity,
430 V: AuthVerification + SqlxEntity,
431 TF: AuthTwoFactor + SqlxEntity,
432 AK: AuthApiKey + SqlxEntity,
433 PK: AuthPasskey + SqlxEntity,
434 {
435 type Account = A;
436
437 async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
438 let id = Uuid::new_v4().to_string();
439 let now = Utc::now();
440
441 let account = sqlx::query_as::<_, A>(
442 r#"
443 INSERT INTO accounts (id, account_id, provider_id, user_id, access_token, refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, scope, password, created_at, updated_at)
444 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
445 RETURNING *
446 "#,
447 )
448 .bind(&id)
449 .bind(&create_account.account_id)
450 .bind(&create_account.provider_id)
451 .bind(&create_account.user_id)
452 .bind(&create_account.access_token)
453 .bind(&create_account.refresh_token)
454 .bind(&create_account.id_token)
455 .bind(&create_account.access_token_expires_at)
456 .bind(&create_account.refresh_token_expires_at)
457 .bind(&create_account.scope)
458 .bind(&create_account.password)
459 .bind(&now)
460 .bind(&now)
461 .fetch_one(&self.pool)
462 .await?;
463
464 Ok(account)
465 }
466
467 async fn get_account(
468 &self,
469 provider: &str,
470 provider_account_id: &str,
471 ) -> AuthResult<Option<A>> {
472 let account = sqlx::query_as::<_, A>(
473 "SELECT * FROM accounts WHERE provider_id = $1 AND account_id = $2",
474 )
475 .bind(provider)
476 .bind(provider_account_id)
477 .fetch_optional(&self.pool)
478 .await?;
479 Ok(account)
480 }
481
482 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
483 let accounts = sqlx::query_as::<_, A>(
484 "SELECT * FROM accounts WHERE user_id = $1 ORDER BY created_at DESC",
485 )
486 .bind(user_id)
487 .fetch_all(&self.pool)
488 .await?;
489 Ok(accounts)
490 }
491
492 async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<A> {
493 let mut query = sqlx::QueryBuilder::new("UPDATE accounts SET updated_at = NOW()");
494
495 if let Some(access_token) = &update.access_token {
496 query.push(", access_token = ");
497 query.push_bind(access_token);
498 }
499 if let Some(refresh_token) = &update.refresh_token {
500 query.push(", refresh_token = ");
501 query.push_bind(refresh_token);
502 }
503 if let Some(id_token) = &update.id_token {
504 query.push(", id_token = ");
505 query.push_bind(id_token);
506 }
507 if let Some(access_token_expires_at) = &update.access_token_expires_at {
508 query.push(", access_token_expires_at = ");
509 query.push_bind(access_token_expires_at);
510 }
511 if let Some(refresh_token_expires_at) = &update.refresh_token_expires_at {
512 query.push(", refresh_token_expires_at = ");
513 query.push_bind(refresh_token_expires_at);
514 }
515 if let Some(scope) = &update.scope {
516 query.push(", scope = ");
517 query.push_bind(scope);
518 }
519
520 query.push(" WHERE id = ");
521 query.push_bind(id);
522 query.push(" RETURNING *");
523
524 let account = query.build_query_as::<A>().fetch_one(&self.pool).await?;
525 Ok(account)
526 }
527
528 async fn delete_account(&self, id: &str) -> AuthResult<()> {
529 sqlx::query("DELETE FROM accounts WHERE id = $1")
530 .bind(id)
531 .execute(&self.pool)
532 .await?;
533 Ok(())
534 }
535 }
536
537 #[async_trait]
540 impl<U, S, A, O, M, I, V, TF, AK, PK> VerificationOps
541 for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
542 where
543 U: AuthUser + SqlxEntity,
544 S: AuthSession + SqlxEntity,
545 A: AuthAccount + SqlxEntity,
546 O: AuthOrganization + SqlxEntity,
547 M: AuthMember + SqlxEntity,
548 I: AuthInvitation + SqlxEntity,
549 V: AuthVerification + SqlxEntity,
550 TF: AuthTwoFactor + SqlxEntity,
551 AK: AuthApiKey + SqlxEntity,
552 PK: AuthPasskey + SqlxEntity,
553 {
554 type Verification = V;
555
556 async fn create_verification(
557 &self,
558 create_verification: CreateVerification,
559 ) -> AuthResult<V> {
560 let id = Uuid::new_v4().to_string();
561 let now = Utc::now();
562
563 let verification = sqlx::query_as::<_, V>(
564 r#"
565 INSERT INTO verifications (id, identifier, value, expires_at, created_at, updated_at)
566 VALUES ($1, $2, $3, $4, $5, $6)
567 RETURNING *
568 "#,
569 )
570 .bind(&id)
571 .bind(&create_verification.identifier)
572 .bind(&create_verification.value)
573 .bind(&create_verification.expires_at)
574 .bind(&now)
575 .bind(&now)
576 .fetch_one(&self.pool)
577 .await?;
578
579 Ok(verification)
580 }
581
582 async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
583 let verification = sqlx::query_as::<_, V>(
584 "SELECT * FROM verifications WHERE identifier = $1 AND value = $2 AND expires_at > NOW()",
585 )
586 .bind(identifier)
587 .bind(value)
588 .fetch_optional(&self.pool)
589 .await?;
590 Ok(verification)
591 }
592
593 async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
594 let verification = sqlx::query_as::<_, V>(
595 "SELECT * FROM verifications WHERE value = $1 AND expires_at > NOW()",
596 )
597 .bind(value)
598 .fetch_optional(&self.pool)
599 .await?;
600 Ok(verification)
601 }
602
603 async fn get_verification_by_identifier(&self, identifier: &str) -> AuthResult<Option<V>> {
604 let verification = sqlx::query_as::<_, V>(
605 "SELECT * FROM verifications WHERE identifier = $1 AND expires_at > NOW()",
606 )
607 .bind(identifier)
608 .fetch_optional(&self.pool)
609 .await?;
610 Ok(verification)
611 }
612
613 async fn consume_verification(
614 &self,
615 identifier: &str,
616 value: &str,
617 ) -> AuthResult<Option<V>> {
618 let verification = sqlx::query_as::<_, V>(
619 "DELETE FROM verifications WHERE id IN (
620 SELECT id FROM verifications
621 WHERE identifier = $1 AND value = $2 AND expires_at > NOW()
622 ORDER BY created_at DESC
623 LIMIT 1
624 ) RETURNING *",
625 )
626 .bind(identifier)
627 .bind(value)
628 .fetch_optional(&self.pool)
629 .await?;
630 Ok(verification)
631 }
632
633 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
634 sqlx::query("DELETE FROM verifications WHERE id = $1")
635 .bind(id)
636 .execute(&self.pool)
637 .await?;
638 Ok(())
639 }
640
641 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
642 let result = sqlx::query("DELETE FROM verifications WHERE expires_at < NOW()")
643 .execute(&self.pool)
644 .await?;
645 Ok(result.rows_affected() as usize)
646 }
647 }
648
649 #[async_trait]
652 impl<U, S, A, O, M, I, V, TF, AK, PK> OrganizationOps
653 for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
654 where
655 U: AuthUser + SqlxEntity,
656 S: AuthSession + SqlxEntity,
657 A: AuthAccount + SqlxEntity,
658 O: AuthOrganization + SqlxEntity,
659 M: AuthMember + SqlxEntity,
660 I: AuthInvitation + SqlxEntity,
661 V: AuthVerification + SqlxEntity,
662 TF: AuthTwoFactor + SqlxEntity,
663 AK: AuthApiKey + SqlxEntity,
664 PK: AuthPasskey + SqlxEntity,
665 {
666 type Organization = O;
667
668 async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
669 let id = create_org.id.unwrap_or_else(|| Uuid::new_v4().to_string());
670 let now = Utc::now();
671
672 let organization = sqlx::query_as::<_, O>(
673 r#"
674 INSERT INTO organization (id, name, slug, logo, metadata, created_at, updated_at)
675 VALUES ($1, $2, $3, $4, $5, $6, $7)
676 RETURNING *
677 "#,
678 )
679 .bind(&id)
680 .bind(&create_org.name)
681 .bind(&create_org.slug)
682 .bind(&create_org.logo)
683 .bind(sqlx::types::Json(
684 create_org.metadata.unwrap_or(serde_json::json!({})),
685 ))
686 .bind(&now)
687 .bind(&now)
688 .fetch_one(&self.pool)
689 .await?;
690
691 Ok(organization)
692 }
693
694 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
695 let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE id = $1")
696 .bind(id)
697 .fetch_optional(&self.pool)
698 .await?;
699 Ok(organization)
700 }
701
702 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
703 let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE slug = $1")
704 .bind(slug)
705 .fetch_optional(&self.pool)
706 .await?;
707 Ok(organization)
708 }
709
710 async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
711 let mut query = sqlx::QueryBuilder::new("UPDATE organization SET updated_at = NOW()");
712
713 if let Some(name) = &update.name {
714 query.push(", name = ");
715 query.push_bind(name);
716 }
717 if let Some(slug) = &update.slug {
718 query.push(", slug = ");
719 query.push_bind(slug);
720 }
721 if let Some(logo) = &update.logo {
722 query.push(", logo = ");
723 query.push_bind(logo);
724 }
725 if let Some(metadata) = &update.metadata {
726 query.push(", metadata = ");
727 query.push_bind(sqlx::types::Json(metadata.clone()));
728 }
729
730 query.push(" WHERE id = ");
731 query.push_bind(id);
732 query.push(" RETURNING *");
733
734 let organization = query.build_query_as::<O>().fetch_one(&self.pool).await?;
735 Ok(organization)
736 }
737
738 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
739 sqlx::query("DELETE FROM organization WHERE id = $1")
740 .bind(id)
741 .execute(&self.pool)
742 .await?;
743 Ok(())
744 }
745
746 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
747 let organizations = sqlx::query_as::<_, O>(
748 r#"
749 SELECT o.*
750 FROM organization o
751 INNER JOIN member m ON o.id = m.organization_id
752 WHERE m.user_id = $1
753 ORDER BY o.created_at DESC
754 "#,
755 )
756 .bind(user_id)
757 .fetch_all(&self.pool)
758 .await?;
759 Ok(organizations)
760 }
761 }
762
763 #[async_trait]
766 impl<U, S, A, O, M, I, V, TF, AK, PK> MemberOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
767 where
768 U: AuthUser + SqlxEntity,
769 S: AuthSession + SqlxEntity,
770 A: AuthAccount + SqlxEntity,
771 O: AuthOrganization + SqlxEntity,
772 M: AuthMember + SqlxEntity,
773 I: AuthInvitation + SqlxEntity,
774 V: AuthVerification + SqlxEntity,
775 TF: AuthTwoFactor + SqlxEntity,
776 AK: AuthApiKey + SqlxEntity,
777 PK: AuthPasskey + SqlxEntity,
778 {
779 type Member = M;
780
781 async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
782 let id = Uuid::new_v4().to_string();
783 let now = Utc::now();
784
785 let member = sqlx::query_as::<_, M>(
786 r#"
787 INSERT INTO member (id, organization_id, user_id, role, created_at)
788 VALUES ($1, $2, $3, $4, $5)
789 RETURNING *
790 "#,
791 )
792 .bind(&id)
793 .bind(&create_member.organization_id)
794 .bind(&create_member.user_id)
795 .bind(&create_member.role)
796 .bind(&now)
797 .fetch_one(&self.pool)
798 .await?;
799
800 Ok(member)
801 }
802
803 async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
804 let member = sqlx::query_as::<_, M>(
805 "SELECT * FROM member WHERE organization_id = $1 AND user_id = $2",
806 )
807 .bind(organization_id)
808 .bind(user_id)
809 .fetch_optional(&self.pool)
810 .await?;
811 Ok(member)
812 }
813
814 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
815 let member = sqlx::query_as::<_, M>("SELECT * FROM member WHERE id = $1")
816 .bind(id)
817 .fetch_optional(&self.pool)
818 .await?;
819 Ok(member)
820 }
821
822 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
823 let member =
824 sqlx::query_as::<_, M>("UPDATE member SET role = $1 WHERE id = $2 RETURNING *")
825 .bind(role)
826 .bind(member_id)
827 .fetch_one(&self.pool)
828 .await?;
829 Ok(member)
830 }
831
832 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
833 sqlx::query("DELETE FROM member WHERE id = $1")
834 .bind(member_id)
835 .execute(&self.pool)
836 .await?;
837 Ok(())
838 }
839
840 async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
841 let members = sqlx::query_as::<_, M>(
842 "SELECT * FROM member WHERE organization_id = $1 ORDER BY created_at ASC",
843 )
844 .bind(organization_id)
845 .fetch_all(&self.pool)
846 .await?;
847 Ok(members)
848 }
849
850 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
851 let count: (i64,) =
852 sqlx::query_as("SELECT COUNT(*) FROM member WHERE organization_id = $1")
853 .bind(organization_id)
854 .fetch_one(&self.pool)
855 .await?;
856 Ok(count.0 as usize)
857 }
858
859 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
860 let count: (i64,) = sqlx::query_as(
861 "SELECT COUNT(*) FROM member WHERE organization_id = $1 AND role = 'owner'",
862 )
863 .bind(organization_id)
864 .fetch_one(&self.pool)
865 .await?;
866 Ok(count.0 as usize)
867 }
868 }
869
870 #[async_trait]
873 impl<U, S, A, O, M, I, V, TF, AK, PK> InvitationOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
874 where
875 U: AuthUser + SqlxEntity,
876 S: AuthSession + SqlxEntity,
877 A: AuthAccount + SqlxEntity,
878 O: AuthOrganization + SqlxEntity,
879 M: AuthMember + SqlxEntity,
880 I: AuthInvitation + SqlxEntity,
881 V: AuthVerification + SqlxEntity,
882 TF: AuthTwoFactor + SqlxEntity,
883 AK: AuthApiKey + SqlxEntity,
884 PK: AuthPasskey + SqlxEntity,
885 {
886 type Invitation = I;
887
888 async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
889 let id = Uuid::new_v4().to_string();
890 let now = Utc::now();
891
892 let invitation = sqlx::query_as::<_, I>(
893 r#"
894 INSERT INTO invitation (id, organization_id, email, role, status, inviter_id, expires_at, created_at)
895 VALUES ($1, $2, $3, $4, 'pending', $5, $6, $7)
896 RETURNING *
897 "#,
898 )
899 .bind(&id)
900 .bind(&create_inv.organization_id)
901 .bind(&create_inv.email)
902 .bind(&create_inv.role)
903 .bind(&create_inv.inviter_id)
904 .bind(&create_inv.expires_at)
905 .bind(&now)
906 .fetch_one(&self.pool)
907 .await?;
908
909 Ok(invitation)
910 }
911
912 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
913 let invitation = sqlx::query_as::<_, I>("SELECT * FROM invitation WHERE id = $1")
914 .bind(id)
915 .fetch_optional(&self.pool)
916 .await?;
917 Ok(invitation)
918 }
919
920 async fn get_pending_invitation(
921 &self,
922 organization_id: &str,
923 email: &str,
924 ) -> AuthResult<Option<I>> {
925 let invitation = sqlx::query_as::<_, I>(
926 "SELECT * FROM invitation WHERE organization_id = $1 AND LOWER(email) = LOWER($2) AND status = 'pending'",
927 )
928 .bind(organization_id)
929 .bind(email)
930 .fetch_optional(&self.pool)
931 .await?;
932 Ok(invitation)
933 }
934
935 async fn update_invitation_status(
936 &self,
937 id: &str,
938 status: InvitationStatus,
939 ) -> AuthResult<I> {
940 let invitation = sqlx::query_as::<_, I>(
941 "UPDATE invitation SET status = $1 WHERE id = $2 RETURNING *",
942 )
943 .bind(status.to_string())
944 .bind(id)
945 .fetch_one(&self.pool)
946 .await?;
947 Ok(invitation)
948 }
949
950 async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
951 let invitations = sqlx::query_as::<_, I>(
952 "SELECT * FROM invitation WHERE organization_id = $1 ORDER BY created_at DESC",
953 )
954 .bind(organization_id)
955 .fetch_all(&self.pool)
956 .await?;
957 Ok(invitations)
958 }
959
960 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
961 let invitations = sqlx::query_as::<_, I>(
962 "SELECT * FROM invitation WHERE LOWER(email) = LOWER($1) AND status = 'pending' AND expires_at > NOW() ORDER BY created_at DESC",
963 )
964 .bind(email)
965 .fetch_all(&self.pool)
966 .await?;
967 Ok(invitations)
968 }
969 }
970
971 #[async_trait]
974 impl<U, S, A, O, M, I, V, TF, AK, PK> TwoFactorOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
975 where
976 U: AuthUser + SqlxEntity,
977 S: AuthSession + SqlxEntity,
978 A: AuthAccount + SqlxEntity,
979 O: AuthOrganization + SqlxEntity,
980 M: AuthMember + SqlxEntity,
981 I: AuthInvitation + SqlxEntity,
982 V: AuthVerification + SqlxEntity,
983 TF: AuthTwoFactor + SqlxEntity,
984 AK: AuthApiKey + SqlxEntity,
985 PK: AuthPasskey + SqlxEntity,
986 {
987 type TwoFactor = TF;
988
989 async fn create_two_factor(&self, create: CreateTwoFactor) -> AuthResult<TF> {
990 let id = Uuid::new_v4().to_string();
991 let now = Utc::now();
992
993 let two_factor = sqlx::query_as::<_, TF>(
994 r#"
995 INSERT INTO two_factor (id, secret, backup_codes, user_id, created_at, updated_at)
996 VALUES ($1, $2, $3, $4, $5, $6)
997 RETURNING *
998 "#,
999 )
1000 .bind(&id)
1001 .bind(&create.secret)
1002 .bind(&create.backup_codes)
1003 .bind(&create.user_id)
1004 .bind(&now)
1005 .bind(&now)
1006 .fetch_one(&self.pool)
1007 .await?;
1008
1009 Ok(two_factor)
1010 }
1011
1012 async fn get_two_factor_by_user_id(&self, user_id: &str) -> AuthResult<Option<TF>> {
1013 let two_factor = sqlx::query_as::<_, TF>("SELECT * FROM two_factor WHERE user_id = $1")
1014 .bind(user_id)
1015 .fetch_optional(&self.pool)
1016 .await?;
1017 Ok(two_factor)
1018 }
1019
1020 async fn update_two_factor_backup_codes(
1021 &self,
1022 user_id: &str,
1023 backup_codes: &str,
1024 ) -> AuthResult<TF> {
1025 let two_factor = sqlx::query_as::<_, TF>(
1026 "UPDATE two_factor SET backup_codes = $1, updated_at = NOW() WHERE user_id = $2 RETURNING *",
1027 )
1028 .bind(backup_codes)
1029 .bind(user_id)
1030 .fetch_one(&self.pool)
1031 .await?;
1032 Ok(two_factor)
1033 }
1034
1035 async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
1036 sqlx::query("DELETE FROM two_factor WHERE user_id = $1")
1037 .bind(user_id)
1038 .execute(&self.pool)
1039 .await?;
1040 Ok(())
1041 }
1042 }
1043
1044 #[async_trait]
1047 impl<U, S, A, O, M, I, V, TF, AK, PK> ApiKeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1048 where
1049 U: AuthUser + SqlxEntity,
1050 S: AuthSession + SqlxEntity,
1051 A: AuthAccount + SqlxEntity,
1052 O: AuthOrganization + SqlxEntity,
1053 M: AuthMember + SqlxEntity,
1054 I: AuthInvitation + SqlxEntity,
1055 V: AuthVerification + SqlxEntity,
1056 TF: AuthTwoFactor + SqlxEntity,
1057 AK: AuthApiKey + SqlxEntity,
1058 PK: AuthPasskey + SqlxEntity,
1059 {
1060 type ApiKey = AK;
1061
1062 async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<AK> {
1063 let id = Uuid::new_v4().to_string();
1064 let now = Utc::now();
1065
1066 let api_key = sqlx::query_as::<_, AK>(
1067 r#"
1068 INSERT INTO api_keys (id, name, start, prefix, key, user_id, refill_interval, refill_amount,
1069 enabled, rate_limit_enabled, rate_limit_time_window, rate_limit_max, remaining,
1070 expires_at, created_at, updated_at, permissions, metadata)
1071 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
1072 $14::timestamptz, $15, $16, $17, $18)
1073 RETURNING *
1074 "#,
1075 )
1076 .bind(&id)
1077 .bind(&input.name)
1078 .bind(&input.start)
1079 .bind(&input.prefix)
1080 .bind(&input.key_hash)
1081 .bind(&input.user_id)
1082 .bind(&input.refill_interval)
1083 .bind(&input.refill_amount)
1084 .bind(input.enabled)
1085 .bind(input.rate_limit_enabled)
1086 .bind(&input.rate_limit_time_window)
1087 .bind(&input.rate_limit_max)
1088 .bind(&input.remaining)
1089 .bind(&input.expires_at)
1090 .bind(&now)
1091 .bind(&now)
1092 .bind(&input.permissions)
1093 .bind(&input.metadata)
1094 .fetch_one(&self.pool)
1095 .await?;
1096
1097 Ok(api_key)
1098 }
1099
1100 async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<AK>> {
1101 let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE id = $1")
1102 .bind(id)
1103 .fetch_optional(&self.pool)
1104 .await?;
1105 Ok(api_key)
1106 }
1107
1108 async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<AK>> {
1109 let api_key = sqlx::query_as::<_, AK>("SELECT * FROM api_keys WHERE key = $1")
1110 .bind(hash)
1111 .fetch_optional(&self.pool)
1112 .await?;
1113 Ok(api_key)
1114 }
1115
1116 async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<AK>> {
1117 let keys = sqlx::query_as::<_, AK>(
1118 "SELECT * FROM api_keys WHERE user_id = $1 ORDER BY created_at DESC",
1119 )
1120 .bind(user_id)
1121 .fetch_all(&self.pool)
1122 .await?;
1123 Ok(keys)
1124 }
1125
1126 async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<AK> {
1127 let mut query = sqlx::QueryBuilder::new("UPDATE api_keys SET updated_at = NOW()");
1128
1129 if let Some(name) = &update.name {
1130 query.push(", name = ");
1131 query.push_bind(name);
1132 }
1133 if let Some(enabled) = update.enabled {
1134 query.push(", enabled = ");
1135 query.push_bind(enabled);
1136 }
1137 if let Some(remaining) = update.remaining {
1138 query.push(", remaining = ");
1139 query.push_bind(remaining);
1140 }
1141 if let Some(rate_limit_enabled) = update.rate_limit_enabled {
1142 query.push(", rate_limit_enabled = ");
1143 query.push_bind(rate_limit_enabled);
1144 }
1145 if let Some(rate_limit_time_window) = update.rate_limit_time_window {
1146 query.push(", rate_limit_time_window = ");
1147 query.push_bind(rate_limit_time_window);
1148 }
1149 if let Some(rate_limit_max) = update.rate_limit_max {
1150 query.push(", rate_limit_max = ");
1151 query.push_bind(rate_limit_max);
1152 }
1153 if let Some(refill_interval) = update.refill_interval {
1154 query.push(", refill_interval = ");
1155 query.push_bind(refill_interval);
1156 }
1157 if let Some(refill_amount) = update.refill_amount {
1158 query.push(", refill_amount = ");
1159 query.push_bind(refill_amount);
1160 }
1161 if let Some(permissions) = &update.permissions {
1162 query.push(", permissions = ");
1163 query.push_bind(permissions);
1164 }
1165 if let Some(metadata) = &update.metadata {
1166 query.push(", metadata = ");
1167 query.push_bind(metadata);
1168 }
1169
1170 query.push(" WHERE id = ");
1171 query.push_bind(id);
1172 query.push(" RETURNING *");
1173
1174 let api_key = query
1175 .build_query_as::<AK>()
1176 .fetch_one(&self.pool)
1177 .await
1178 .map_err(|err| match err {
1179 sqlx::Error::RowNotFound => AuthError::not_found("API key not found"),
1180 other => AuthError::from(other),
1181 })?;
1182 Ok(api_key)
1183 }
1184
1185 async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
1186 sqlx::query("DELETE FROM api_keys WHERE id = $1")
1187 .bind(id)
1188 .execute(&self.pool)
1189 .await?;
1190 Ok(())
1191 }
1192 }
1193
1194 #[async_trait]
1197 impl<U, S, A, O, M, I, V, TF, AK, PK> PasskeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1198 where
1199 U: AuthUser + SqlxEntity,
1200 S: AuthSession + SqlxEntity,
1201 A: AuthAccount + SqlxEntity,
1202 O: AuthOrganization + SqlxEntity,
1203 M: AuthMember + SqlxEntity,
1204 I: AuthInvitation + SqlxEntity,
1205 V: AuthVerification + SqlxEntity,
1206 TF: AuthTwoFactor + SqlxEntity,
1207 AK: AuthApiKey + SqlxEntity,
1208 PK: AuthPasskey + SqlxEntity,
1209 {
1210 type Passkey = PK;
1211
1212 async fn create_passkey(&self, input: CreatePasskey) -> AuthResult<PK> {
1213 let id = Uuid::new_v4().to_string();
1214 let now = Utc::now();
1215 let counter = i64::try_from(input.counter)
1216 .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1217
1218 let passkey = sqlx::query_as::<_, PK>(
1219 r#"
1220 INSERT INTO passkeys (id, name, public_key, user_id, credential_id, counter, device_type, backed_up, transports, created_at)
1221 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
1222 RETURNING *
1223 "#,
1224 )
1225 .bind(&id)
1226 .bind(&input.name)
1227 .bind(&input.public_key)
1228 .bind(&input.user_id)
1229 .bind(&input.credential_id)
1230 .bind(counter)
1231 .bind(&input.device_type)
1232 .bind(input.backed_up)
1233 .bind(&input.transports)
1234 .bind(now)
1235 .fetch_one(&self.pool)
1236 .await
1237 .map_err(|e| match e {
1238 sqlx::Error::Database(ref db_err) if db_err.is_unique_violation() => {
1239 AuthError::conflict("A passkey with this credential ID already exists")
1240 }
1241 other => AuthError::from(other),
1242 })?;
1243
1244 Ok(passkey)
1245 }
1246
1247 async fn get_passkey_by_id(&self, id: &str) -> AuthResult<Option<PK>> {
1248 let passkey = sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE id = $1")
1249 .bind(id)
1250 .fetch_optional(&self.pool)
1251 .await?;
1252 Ok(passkey)
1253 }
1254
1255 async fn get_passkey_by_credential_id(
1256 &self,
1257 credential_id: &str,
1258 ) -> AuthResult<Option<PK>> {
1259 let passkey =
1260 sqlx::query_as::<_, PK>("SELECT * FROM passkeys WHERE credential_id = $1")
1261 .bind(credential_id)
1262 .fetch_optional(&self.pool)
1263 .await?;
1264 Ok(passkey)
1265 }
1266
1267 async fn list_passkeys_by_user(&self, user_id: &str) -> AuthResult<Vec<PK>> {
1268 let passkeys = sqlx::query_as::<_, PK>(
1269 "SELECT * FROM passkeys WHERE user_id = $1 ORDER BY created_at DESC",
1270 )
1271 .bind(user_id)
1272 .fetch_all(&self.pool)
1273 .await?;
1274 Ok(passkeys)
1275 }
1276
1277 async fn update_passkey_counter(&self, id: &str, counter: u64) -> AuthResult<PK> {
1278 let counter = i64::try_from(counter)
1279 .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1280 let passkey = sqlx::query_as::<_, PK>(
1281 "UPDATE passkeys SET counter = $2 WHERE id = $1 RETURNING *",
1282 )
1283 .bind(id)
1284 .bind(counter)
1285 .fetch_one(&self.pool)
1286 .await
1287 .map_err(|err| match err {
1288 sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1289 other => AuthError::from(other),
1290 })?;
1291 Ok(passkey)
1292 }
1293
1294 async fn update_passkey_name(&self, id: &str, name: &str) -> AuthResult<PK> {
1295 let passkey =
1296 sqlx::query_as::<_, PK>("UPDATE passkeys SET name = $2 WHERE id = $1 RETURNING *")
1297 .bind(id)
1298 .bind(name)
1299 .fetch_one(&self.pool)
1300 .await
1301 .map_err(|err| match err {
1302 sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1303 other => AuthError::from(other),
1304 })?;
1305 Ok(passkey)
1306 }
1307
1308 async fn delete_passkey(&self, id: &str) -> AuthResult<()> {
1309 sqlx::query("DELETE FROM passkeys WHERE id = $1")
1310 .bind(id)
1311 .execute(&self.pool)
1312 .await?;
1313 Ok(())
1314 }
1315 }
1316}
1317
1318#[cfg(feature = "sqlx-postgres")]
1319pub use sqlx_adapter::{SqlxAdapter, SqlxEntity};