1pub use super::traits::{
2 AccountOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, UserOps, VerificationOps,
3};
4
5pub trait DatabaseAdapter:
14 UserOps + SessionOps + AccountOps + VerificationOps + OrganizationOps + MemberOps + InvitationOps
15{
16}
17
18impl<T> DatabaseAdapter for T where
19 T: UserOps
20 + SessionOps
21 + AccountOps
22 + VerificationOps
23 + OrganizationOps
24 + MemberOps
25 + InvitationOps
26{
27}
28
29#[cfg(feature = "sqlx-postgres")]
30pub mod sqlx_adapter {
31 use super::*;
32 use async_trait::async_trait;
33 use chrono::{DateTime, Utc};
34
35 use crate::entity::{
36 AuthAccount, AuthInvitation, AuthMember, AuthOrganization, AuthSession, AuthUser,
37 AuthVerification,
38 };
39 use crate::error::{AuthError, AuthResult};
40 use crate::types::{
41 Account, CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
42 CreateUser, CreateVerification, Invitation, InvitationStatus, Member, Organization,
43 Session, UpdateOrganization, UpdateUser, User, Verification,
44 };
45 use sqlx::PgPool;
46 use sqlx::postgres::PgRow;
47 use std::marker::PhantomData;
48 use uuid::Uuid;
49
50 pub trait SqlxEntity:
57 for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
58 {
59 }
60
61 impl<T> SqlxEntity for T where
62 T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
63 {
64 }
65
66 pub struct SqlxAdapter<
71 U = User,
72 S = Session,
73 A = Account,
74 O = Organization,
75 M = Member,
76 I = Invitation,
77 V = Verification,
78 > {
79 pool: PgPool,
80 _phantom: PhantomData<(U, S, A, O, M, I, V)>,
81 }
82
83 impl SqlxAdapter {
85 pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
86 let pool = PgPool::connect(database_url).await?;
87 Ok(Self {
88 pool,
89 _phantom: PhantomData,
90 })
91 }
92
93 pub async fn with_config(
94 database_url: &str,
95 config: PoolConfig,
96 ) -> Result<Self, sqlx::Error> {
97 let pool = sqlx::postgres::PgPoolOptions::new()
98 .max_connections(config.max_connections)
99 .min_connections(config.min_connections)
100 .acquire_timeout(config.acquire_timeout)
101 .idle_timeout(config.idle_timeout)
102 .max_lifetime(config.max_lifetime)
103 .connect(database_url)
104 .await?;
105 Ok(Self {
106 pool,
107 _phantom: PhantomData,
108 })
109 }
110 }
111
112 impl<U, S, A, O, M, I, V> SqlxAdapter<U, S, A, O, M, I, V> {
114 pub fn from_pool(pool: PgPool) -> Self {
115 Self {
116 pool,
117 _phantom: PhantomData,
118 }
119 }
120
121 pub async fn test_connection(&self) -> Result<(), sqlx::Error> {
122 sqlx::query("SELECT 1").execute(&self.pool).await?;
123 Ok(())
124 }
125
126 pub fn pool_stats(&self) -> PoolStats {
127 PoolStats {
128 size: self.pool.size(),
129 idle: self.pool.num_idle(),
130 }
131 }
132
133 pub async fn close(&self) {
134 self.pool.close().await;
135 }
136 }
137
138 #[derive(Debug, Clone)]
139 pub struct PoolConfig {
140 pub max_connections: u32,
141 pub min_connections: u32,
142 pub acquire_timeout: std::time::Duration,
143 pub idle_timeout: Option<std::time::Duration>,
144 pub max_lifetime: Option<std::time::Duration>,
145 }
146
147 impl Default for PoolConfig {
148 fn default() -> Self {
149 Self {
150 max_connections: 10,
151 min_connections: 0,
152 acquire_timeout: std::time::Duration::from_secs(30),
153 idle_timeout: Some(std::time::Duration::from_secs(600)),
154 max_lifetime: Some(std::time::Duration::from_secs(1800)),
155 }
156 }
157 }
158
159 #[derive(Debug, Clone)]
160 pub struct PoolStats {
161 pub size: u32,
162 pub idle: usize,
163 }
164
165 #[async_trait]
168 impl<U, S, A, O, M, I, V> UserOps for SqlxAdapter<U, S, A, O, M, I, V>
169 where
170 U: AuthUser + SqlxEntity,
171 S: AuthSession + SqlxEntity,
172 A: AuthAccount + SqlxEntity,
173 O: AuthOrganization + SqlxEntity,
174 M: AuthMember + SqlxEntity,
175 I: AuthInvitation + SqlxEntity,
176 V: AuthVerification + SqlxEntity,
177 {
178 type User = U;
179
180 async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
181 let id = create_user.id.unwrap_or_else(|| Uuid::new_v4().to_string());
182 let now = Utc::now();
183
184 let user = sqlx::query_as::<_, U>(
185 r#"
186 INSERT INTO users (id, email, name, image, email_verified, created_at, updated_at, metadata)
187 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
188 RETURNING *
189 "#,
190 )
191 .bind(&id)
192 .bind(&create_user.email)
193 .bind(&create_user.name)
194 .bind(&create_user.image)
195 .bind(false)
196 .bind(&now)
197 .bind(&now)
198 .bind(sqlx::types::Json(create_user.metadata.unwrap_or(serde_json::json!({}))))
199 .fetch_one(&self.pool)
200 .await?;
201
202 Ok(user)
203 }
204
205 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
206 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE id = $1")
207 .bind(id)
208 .fetch_optional(&self.pool)
209 .await?;
210 Ok(user)
211 }
212
213 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
214 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE email = $1")
215 .bind(email)
216 .fetch_optional(&self.pool)
217 .await?;
218 Ok(user)
219 }
220
221 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
222 let user = sqlx::query_as::<_, U>("SELECT * FROM users WHERE username = $1")
223 .bind(username)
224 .fetch_optional(&self.pool)
225 .await?;
226 Ok(user)
227 }
228
229 async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
230 let mut query = sqlx::QueryBuilder::new("UPDATE users SET updated_at = NOW()");
231 let mut has_updates = false;
232
233 if let Some(email) = &update.email {
234 query.push(", email = ");
235 query.push_bind(email);
236 has_updates = true;
237 }
238 if let Some(name) = &update.name {
239 query.push(", name = ");
240 query.push_bind(name);
241 has_updates = true;
242 }
243 if let Some(image) = &update.image {
244 query.push(", image = ");
245 query.push_bind(image);
246 has_updates = true;
247 }
248 if let Some(email_verified) = update.email_verified {
249 query.push(", email_verified = ");
250 query.push_bind(email_verified);
251 has_updates = true;
252 }
253 if let Some(metadata) = &update.metadata {
254 query.push(", metadata = ");
255 query.push_bind(sqlx::types::Json(metadata.clone()));
256 has_updates = true;
257 }
258
259 if !has_updates {
260 return self
261 .get_user_by_id(id)
262 .await?
263 .ok_or(AuthError::UserNotFound);
264 }
265
266 query.push(" WHERE id = ");
267 query.push_bind(id);
268 query.push(" RETURNING *");
269
270 let user = query.build_query_as::<U>().fetch_one(&self.pool).await?;
271 Ok(user)
272 }
273
274 async fn delete_user(&self, id: &str) -> AuthResult<()> {
275 sqlx::query("DELETE FROM users WHERE id = $1")
276 .bind(id)
277 .execute(&self.pool)
278 .await?;
279 Ok(())
280 }
281 }
282
283 #[async_trait]
286 impl<U, S, A, O, M, I, V> SessionOps for SqlxAdapter<U, S, A, O, M, I, V>
287 where
288 U: AuthUser + SqlxEntity,
289 S: AuthSession + SqlxEntity,
290 A: AuthAccount + SqlxEntity,
291 O: AuthOrganization + SqlxEntity,
292 M: AuthMember + SqlxEntity,
293 I: AuthInvitation + SqlxEntity,
294 V: AuthVerification + SqlxEntity,
295 {
296 type Session = S;
297
298 async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
299 let id = Uuid::new_v4().to_string();
300 let token = format!("session_{}", Uuid::new_v4());
301 let now = Utc::now();
302
303 let session = sqlx::query_as::<_, S>(
304 r#"
305 INSERT INTO sessions (id, user_id, token, expires_at, created_at, ip_address, user_agent, active)
306 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
307 RETURNING *
308 "#,
309 )
310 .bind(&id)
311 .bind(&create_session.user_id)
312 .bind(&token)
313 .bind(&create_session.expires_at)
314 .bind(&now)
315 .bind(&create_session.ip_address)
316 .bind(&create_session.user_agent)
317 .bind(true)
318 .fetch_one(&self.pool)
319 .await?;
320
321 Ok(session)
322 }
323
324 async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
325 let session =
326 sqlx::query_as::<_, S>("SELECT * FROM sessions WHERE token = $1 AND active = true")
327 .bind(token)
328 .fetch_optional(&self.pool)
329 .await?;
330 Ok(session)
331 }
332
333 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
334 let sessions = sqlx::query_as::<_, S>(
335 "SELECT * FROM sessions WHERE user_id = $1 AND active = true ORDER BY created_at DESC",
336 )
337 .bind(user_id)
338 .fetch_all(&self.pool)
339 .await?;
340 Ok(sessions)
341 }
342
343 async fn update_session_expiry(
344 &self,
345 token: &str,
346 expires_at: DateTime<Utc>,
347 ) -> AuthResult<()> {
348 sqlx::query("UPDATE sessions SET expires_at = $1 WHERE token = $2 AND active = true")
349 .bind(&expires_at)
350 .bind(token)
351 .execute(&self.pool)
352 .await?;
353 Ok(())
354 }
355
356 async fn delete_session(&self, token: &str) -> AuthResult<()> {
357 sqlx::query("DELETE FROM sessions WHERE token = $1")
358 .bind(token)
359 .execute(&self.pool)
360 .await?;
361 Ok(())
362 }
363
364 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
365 sqlx::query("DELETE FROM sessions WHERE user_id = $1")
366 .bind(user_id)
367 .execute(&self.pool)
368 .await?;
369 Ok(())
370 }
371
372 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
373 let result =
374 sqlx::query("DELETE FROM sessions WHERE expires_at < NOW() OR active = false")
375 .execute(&self.pool)
376 .await?;
377 Ok(result.rows_affected() as usize)
378 }
379
380 async fn update_session_active_organization(
381 &self,
382 token: &str,
383 organization_id: Option<&str>,
384 ) -> AuthResult<S> {
385 let session = sqlx::query_as::<_, S>(
386 "UPDATE sessions SET active_organization_id = $1, updated_at = NOW() WHERE token = $2 AND active = true RETURNING *",
387 )
388 .bind(organization_id)
389 .bind(token)
390 .fetch_one(&self.pool)
391 .await?;
392 Ok(session)
393 }
394 }
395
396 #[async_trait]
399 impl<U, S, A, O, M, I, V> AccountOps for SqlxAdapter<U, S, A, O, M, I, V>
400 where
401 U: AuthUser + SqlxEntity,
402 S: AuthSession + SqlxEntity,
403 A: AuthAccount + SqlxEntity,
404 O: AuthOrganization + SqlxEntity,
405 M: AuthMember + SqlxEntity,
406 I: AuthInvitation + SqlxEntity,
407 V: AuthVerification + SqlxEntity,
408 {
409 type Account = A;
410
411 async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
412 let id = Uuid::new_v4().to_string();
413 let now = Utc::now();
414
415 let account = sqlx::query_as::<_, A>(
416 r#"
417 INSERT INTO accounts (id, account_id, provider_id, user_id, access_token, refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, scope, password, created_at, updated_at)
418 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
419 RETURNING *
420 "#,
421 )
422 .bind(&id)
423 .bind(&create_account.account_id)
424 .bind(&create_account.provider_id)
425 .bind(&create_account.user_id)
426 .bind(&create_account.access_token)
427 .bind(&create_account.refresh_token)
428 .bind(&create_account.id_token)
429 .bind(&create_account.access_token_expires_at)
430 .bind(&create_account.refresh_token_expires_at)
431 .bind(&create_account.scope)
432 .bind(&create_account.password)
433 .bind(&now)
434 .bind(&now)
435 .fetch_one(&self.pool)
436 .await?;
437
438 Ok(account)
439 }
440
441 async fn get_account(
442 &self,
443 provider: &str,
444 provider_account_id: &str,
445 ) -> AuthResult<Option<A>> {
446 let account = sqlx::query_as::<_, A>(
447 "SELECT * FROM accounts WHERE provider_id = $1 AND account_id = $2",
448 )
449 .bind(provider)
450 .bind(provider_account_id)
451 .fetch_optional(&self.pool)
452 .await?;
453 Ok(account)
454 }
455
456 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
457 let accounts = sqlx::query_as::<_, A>(
458 "SELECT * FROM accounts WHERE user_id = $1 ORDER BY created_at DESC",
459 )
460 .bind(user_id)
461 .fetch_all(&self.pool)
462 .await?;
463 Ok(accounts)
464 }
465
466 async fn delete_account(&self, id: &str) -> AuthResult<()> {
467 sqlx::query("DELETE FROM accounts WHERE id = $1")
468 .bind(id)
469 .execute(&self.pool)
470 .await?;
471 Ok(())
472 }
473 }
474
475 #[async_trait]
478 impl<U, S, A, O, M, I, V> VerificationOps for SqlxAdapter<U, S, A, O, M, I, V>
479 where
480 U: AuthUser + SqlxEntity,
481 S: AuthSession + SqlxEntity,
482 A: AuthAccount + SqlxEntity,
483 O: AuthOrganization + SqlxEntity,
484 M: AuthMember + SqlxEntity,
485 I: AuthInvitation + SqlxEntity,
486 V: AuthVerification + SqlxEntity,
487 {
488 type Verification = V;
489
490 async fn create_verification(
491 &self,
492 create_verification: CreateVerification,
493 ) -> AuthResult<V> {
494 let id = Uuid::new_v4().to_string();
495 let now = Utc::now();
496
497 let verification = sqlx::query_as::<_, V>(
498 r#"
499 INSERT INTO verifications (id, identifier, value, expires_at, created_at, updated_at)
500 VALUES ($1, $2, $3, $4, $5, $6)
501 RETURNING *
502 "#,
503 )
504 .bind(&id)
505 .bind(&create_verification.identifier)
506 .bind(&create_verification.value)
507 .bind(&create_verification.expires_at)
508 .bind(&now)
509 .bind(&now)
510 .fetch_one(&self.pool)
511 .await?;
512
513 Ok(verification)
514 }
515
516 async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
517 let verification = sqlx::query_as::<_, V>(
518 "SELECT * FROM verifications WHERE identifier = $1 AND value = $2 AND expires_at > NOW()",
519 )
520 .bind(identifier)
521 .bind(value)
522 .fetch_optional(&self.pool)
523 .await?;
524 Ok(verification)
525 }
526
527 async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
528 let verification = sqlx::query_as::<_, V>(
529 "SELECT * FROM verifications WHERE value = $1 AND expires_at > NOW()",
530 )
531 .bind(value)
532 .fetch_optional(&self.pool)
533 .await?;
534 Ok(verification)
535 }
536
537 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
538 sqlx::query("DELETE FROM verifications WHERE id = $1")
539 .bind(id)
540 .execute(&self.pool)
541 .await?;
542 Ok(())
543 }
544
545 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
546 let result = sqlx::query("DELETE FROM verifications WHERE expires_at < NOW()")
547 .execute(&self.pool)
548 .await?;
549 Ok(result.rows_affected() as usize)
550 }
551 }
552
553 #[async_trait]
556 impl<U, S, A, O, M, I, V> OrganizationOps for SqlxAdapter<U, S, A, O, M, I, V>
557 where
558 U: AuthUser + SqlxEntity,
559 S: AuthSession + SqlxEntity,
560 A: AuthAccount + SqlxEntity,
561 O: AuthOrganization + SqlxEntity,
562 M: AuthMember + SqlxEntity,
563 I: AuthInvitation + SqlxEntity,
564 V: AuthVerification + SqlxEntity,
565 {
566 type Organization = O;
567
568 async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
569 let id = create_org.id.unwrap_or_else(|| Uuid::new_v4().to_string());
570 let now = Utc::now();
571
572 let organization = sqlx::query_as::<_, O>(
573 r#"
574 INSERT INTO organization (id, name, slug, logo, metadata, created_at, updated_at)
575 VALUES ($1, $2, $3, $4, $5, $6, $7)
576 RETURNING *
577 "#,
578 )
579 .bind(&id)
580 .bind(&create_org.name)
581 .bind(&create_org.slug)
582 .bind(&create_org.logo)
583 .bind(sqlx::types::Json(
584 create_org.metadata.unwrap_or(serde_json::json!({})),
585 ))
586 .bind(&now)
587 .bind(&now)
588 .fetch_one(&self.pool)
589 .await?;
590
591 Ok(organization)
592 }
593
594 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
595 let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE id = $1")
596 .bind(id)
597 .fetch_optional(&self.pool)
598 .await?;
599 Ok(organization)
600 }
601
602 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
603 let organization = sqlx::query_as::<_, O>("SELECT * FROM organization WHERE slug = $1")
604 .bind(slug)
605 .fetch_optional(&self.pool)
606 .await?;
607 Ok(organization)
608 }
609
610 async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
611 let mut query = sqlx::QueryBuilder::new("UPDATE organization SET updated_at = NOW()");
612
613 if let Some(name) = &update.name {
614 query.push(", name = ");
615 query.push_bind(name);
616 }
617 if let Some(slug) = &update.slug {
618 query.push(", slug = ");
619 query.push_bind(slug);
620 }
621 if let Some(logo) = &update.logo {
622 query.push(", logo = ");
623 query.push_bind(logo);
624 }
625 if let Some(metadata) = &update.metadata {
626 query.push(", metadata = ");
627 query.push_bind(sqlx::types::Json(metadata.clone()));
628 }
629
630 query.push(" WHERE id = ");
631 query.push_bind(id);
632 query.push(" RETURNING *");
633
634 let organization = query.build_query_as::<O>().fetch_one(&self.pool).await?;
635 Ok(organization)
636 }
637
638 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
639 sqlx::query("DELETE FROM organization WHERE id = $1")
640 .bind(id)
641 .execute(&self.pool)
642 .await?;
643 Ok(())
644 }
645
646 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
647 let organizations = sqlx::query_as::<_, O>(
648 r#"
649 SELECT o.*
650 FROM organization o
651 INNER JOIN member m ON o.id = m.organization_id
652 WHERE m.user_id = $1
653 ORDER BY o.created_at DESC
654 "#,
655 )
656 .bind(user_id)
657 .fetch_all(&self.pool)
658 .await?;
659 Ok(organizations)
660 }
661 }
662
663 #[async_trait]
666 impl<U, S, A, O, M, I, V> MemberOps for SqlxAdapter<U, S, A, O, M, I, V>
667 where
668 U: AuthUser + SqlxEntity,
669 S: AuthSession + SqlxEntity,
670 A: AuthAccount + SqlxEntity,
671 O: AuthOrganization + SqlxEntity,
672 M: AuthMember + SqlxEntity,
673 I: AuthInvitation + SqlxEntity,
674 V: AuthVerification + SqlxEntity,
675 {
676 type Member = M;
677
678 async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
679 let id = Uuid::new_v4().to_string();
680 let now = Utc::now();
681
682 let member = sqlx::query_as::<_, M>(
683 r#"
684 INSERT INTO member (id, organization_id, user_id, role, created_at)
685 VALUES ($1, $2, $3, $4, $5)
686 RETURNING *
687 "#,
688 )
689 .bind(&id)
690 .bind(&create_member.organization_id)
691 .bind(&create_member.user_id)
692 .bind(&create_member.role)
693 .bind(&now)
694 .fetch_one(&self.pool)
695 .await?;
696
697 Ok(member)
698 }
699
700 async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
701 let member = sqlx::query_as::<_, M>(
702 "SELECT * FROM member WHERE organization_id = $1 AND user_id = $2",
703 )
704 .bind(organization_id)
705 .bind(user_id)
706 .fetch_optional(&self.pool)
707 .await?;
708 Ok(member)
709 }
710
711 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
712 let member = sqlx::query_as::<_, M>("SELECT * FROM member WHERE id = $1")
713 .bind(id)
714 .fetch_optional(&self.pool)
715 .await?;
716 Ok(member)
717 }
718
719 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
720 let member =
721 sqlx::query_as::<_, M>("UPDATE member SET role = $1 WHERE id = $2 RETURNING *")
722 .bind(role)
723 .bind(member_id)
724 .fetch_one(&self.pool)
725 .await?;
726 Ok(member)
727 }
728
729 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
730 sqlx::query("DELETE FROM member WHERE id = $1")
731 .bind(member_id)
732 .execute(&self.pool)
733 .await?;
734 Ok(())
735 }
736
737 async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
738 let members = sqlx::query_as::<_, M>(
739 "SELECT * FROM member WHERE organization_id = $1 ORDER BY created_at ASC",
740 )
741 .bind(organization_id)
742 .fetch_all(&self.pool)
743 .await?;
744 Ok(members)
745 }
746
747 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
748 let count: (i64,) =
749 sqlx::query_as("SELECT COUNT(*) FROM member WHERE organization_id = $1")
750 .bind(organization_id)
751 .fetch_one(&self.pool)
752 .await?;
753 Ok(count.0 as usize)
754 }
755
756 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
757 let count: (i64,) = sqlx::query_as(
758 "SELECT COUNT(*) FROM member WHERE organization_id = $1 AND role = 'owner'",
759 )
760 .bind(organization_id)
761 .fetch_one(&self.pool)
762 .await?;
763 Ok(count.0 as usize)
764 }
765 }
766
767 #[async_trait]
770 impl<U, S, A, O, M, I, V> InvitationOps for SqlxAdapter<U, S, A, O, M, I, V>
771 where
772 U: AuthUser + SqlxEntity,
773 S: AuthSession + SqlxEntity,
774 A: AuthAccount + SqlxEntity,
775 O: AuthOrganization + SqlxEntity,
776 M: AuthMember + SqlxEntity,
777 I: AuthInvitation + SqlxEntity,
778 V: AuthVerification + SqlxEntity,
779 {
780 type Invitation = I;
781
782 async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
783 let id = Uuid::new_v4().to_string();
784 let now = Utc::now();
785
786 let invitation = sqlx::query_as::<_, I>(
787 r#"
788 INSERT INTO invitation (id, organization_id, email, role, status, inviter_id, expires_at, created_at)
789 VALUES ($1, $2, $3, $4, 'pending', $5, $6, $7)
790 RETURNING *
791 "#,
792 )
793 .bind(&id)
794 .bind(&create_inv.organization_id)
795 .bind(&create_inv.email)
796 .bind(&create_inv.role)
797 .bind(&create_inv.inviter_id)
798 .bind(&create_inv.expires_at)
799 .bind(&now)
800 .fetch_one(&self.pool)
801 .await?;
802
803 Ok(invitation)
804 }
805
806 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
807 let invitation = sqlx::query_as::<_, I>("SELECT * FROM invitation WHERE id = $1")
808 .bind(id)
809 .fetch_optional(&self.pool)
810 .await?;
811 Ok(invitation)
812 }
813
814 async fn get_pending_invitation(
815 &self,
816 organization_id: &str,
817 email: &str,
818 ) -> AuthResult<Option<I>> {
819 let invitation = sqlx::query_as::<_, I>(
820 "SELECT * FROM invitation WHERE organization_id = $1 AND LOWER(email) = LOWER($2) AND status = 'pending'",
821 )
822 .bind(organization_id)
823 .bind(email)
824 .fetch_optional(&self.pool)
825 .await?;
826 Ok(invitation)
827 }
828
829 async fn update_invitation_status(
830 &self,
831 id: &str,
832 status: InvitationStatus,
833 ) -> AuthResult<I> {
834 let invitation = sqlx::query_as::<_, I>(
835 "UPDATE invitation SET status = $1 WHERE id = $2 RETURNING *",
836 )
837 .bind(status.to_string())
838 .bind(id)
839 .fetch_one(&self.pool)
840 .await?;
841 Ok(invitation)
842 }
843
844 async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
845 let invitations = sqlx::query_as::<_, I>(
846 "SELECT * FROM invitation WHERE organization_id = $1 ORDER BY created_at DESC",
847 )
848 .bind(organization_id)
849 .fetch_all(&self.pool)
850 .await?;
851 Ok(invitations)
852 }
853
854 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
855 let invitations = sqlx::query_as::<_, I>(
856 "SELECT * FROM invitation WHERE LOWER(email) = LOWER($1) AND status = 'pending' AND expires_at > NOW() ORDER BY created_at DESC",
857 )
858 .bind(email)
859 .fetch_all(&self.pool)
860 .await?;
861 Ok(invitations)
862 }
863 }
864}
865
866#[cfg(feature = "sqlx-postgres")]
867pub use sqlx_adapter::{SqlxAdapter, SqlxEntity};