1pub use super::traits::{
2 AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, PasskeyOps, SessionOps,
3 TwoFactorOps, UserOps, VerificationOps,
4};
5
6pub trait DatabaseAdapter:
15 UserOps
16 + SessionOps
17 + AccountOps
18 + VerificationOps
19 + OrganizationOps
20 + MemberOps
21 + InvitationOps
22 + TwoFactorOps
23 + ApiKeyOps
24 + PasskeyOps
25{
26}
27
28impl<T> DatabaseAdapter for T where
29 T: UserOps
30 + SessionOps
31 + AccountOps
32 + VerificationOps
33 + OrganizationOps
34 + MemberOps
35 + InvitationOps
36 + TwoFactorOps
37 + ApiKeyOps
38 + PasskeyOps
39{
40}
41
42#[cfg(feature = "sqlx-postgres")]
43pub mod sqlx_adapter {
44 use super::*;
45 use async_trait::async_trait;
46 use chrono::{DateTime, Utc};
47
48 use crate::entity::{
49 AuthAccount, AuthAccountMeta, AuthApiKey, AuthApiKeyMeta, AuthInvitation,
50 AuthInvitationMeta, AuthMember, AuthMemberMeta, AuthOrganization, AuthOrganizationMeta,
51 AuthPasskey, AuthPasskeyMeta, AuthSession, AuthSessionMeta, AuthTwoFactor,
52 AuthTwoFactorMeta, AuthUser, AuthUserMeta, AuthVerification, AuthVerificationMeta,
53 };
54 use crate::error::{AuthError, AuthResult};
55 use crate::types::{
56 Account, ApiKey, CreateAccount, CreateApiKey, CreateInvitation, CreateMember,
57 CreateOrganization, CreatePasskey, CreateSession, CreateTwoFactor, CreateUser,
58 CreateVerification, Invitation, InvitationStatus, ListUsersParams, Member, Organization,
59 Passkey, Session, TwoFactor, UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser,
60 User, Verification,
61 };
62 use sqlx::PgPool;
63 use sqlx::postgres::PgRow;
64 use std::marker::PhantomData;
65 use uuid::Uuid;
66
67 #[inline]
73 fn qi(ident: &str) -> String {
74 format!("\"{}\"", ident.replace('"', "\"\""))
75 }
76
77 pub trait SqlxEntity:
84 for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
85 {
86 }
87
88 impl<T> SqlxEntity for T where
89 T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + Clone + 'static
90 {
91 }
92
93 type SqlxAdapterEntities<U, S, A, O, M, I, V, TF, AK, PK> = (U, S, A, O, M, I, V, TF, AK, PK);
94
95 pub struct SqlxAdapter<
100 U = User,
101 S = Session,
102 A = Account,
103 O = Organization,
104 M = Member,
105 I = Invitation,
106 V = Verification,
107 TF = TwoFactor,
108 AK = ApiKey,
109 PK = Passkey,
110 > {
111 pool: PgPool,
112 #[allow(clippy::type_complexity)]
113 _phantom: PhantomData<SqlxAdapterEntities<U, S, A, O, M, I, V, TF, AK, PK>>,
114 }
115
116 impl SqlxAdapter {
118 pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
119 let pool = PgPool::connect(database_url).await?;
120 Ok(Self {
121 pool,
122 _phantom: PhantomData,
123 })
124 }
125
126 pub async fn with_config(
127 database_url: &str,
128 config: PoolConfig,
129 ) -> Result<Self, sqlx::Error> {
130 let pool = sqlx::postgres::PgPoolOptions::new()
131 .max_connections(config.max_connections)
132 .min_connections(config.min_connections)
133 .acquire_timeout(config.acquire_timeout)
134 .idle_timeout(config.idle_timeout)
135 .max_lifetime(config.max_lifetime)
136 .connect(database_url)
137 .await?;
138 Ok(Self {
139 pool,
140 _phantom: PhantomData,
141 })
142 }
143 }
144
145 impl<U, S, A, O, M, I, V, TF, AK, PK> SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK> {
147 pub fn from_pool(pool: PgPool) -> Self {
148 Self {
149 pool,
150 _phantom: PhantomData,
151 }
152 }
153
154 pub async fn test_connection(&self) -> Result<(), sqlx::Error> {
155 sqlx::query("SELECT 1").execute(&self.pool).await?;
156 Ok(())
157 }
158
159 pub fn pool_stats(&self) -> PoolStats {
160 PoolStats {
161 size: self.pool.size(),
162 idle: self.pool.num_idle(),
163 }
164 }
165
166 pub async fn close(&self) {
167 self.pool.close().await;
168 }
169 }
170
171 #[derive(Debug, Clone)]
172 pub struct PoolConfig {
173 pub max_connections: u32,
174 pub min_connections: u32,
175 pub acquire_timeout: std::time::Duration,
176 pub idle_timeout: Option<std::time::Duration>,
177 pub max_lifetime: Option<std::time::Duration>,
178 }
179
180 impl Default for PoolConfig {
181 fn default() -> Self {
182 Self {
183 max_connections: 10,
184 min_connections: 0,
185 acquire_timeout: std::time::Duration::from_secs(30),
186 idle_timeout: Some(std::time::Duration::from_secs(600)),
187 max_lifetime: Some(std::time::Duration::from_secs(1800)),
188 }
189 }
190 }
191
192 #[derive(Debug, Clone)]
193 pub struct PoolStats {
194 pub size: u32,
195 pub idle: usize,
196 }
197
198 #[async_trait]
201 impl<U, S, A, O, M, I, V, TF, AK, PK> UserOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
202 where
203 U: AuthUser + AuthUserMeta + SqlxEntity,
204 S: AuthSession + AuthSessionMeta + SqlxEntity,
205 A: AuthAccount + AuthAccountMeta + SqlxEntity,
206 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
207 M: AuthMember + AuthMemberMeta + SqlxEntity,
208 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
209 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
210 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
211 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
212 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
213 {
214 type User = U;
215
216 async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
217 let id = create_user.id.unwrap_or_else(|| Uuid::new_v4().to_string());
218 let now = Utc::now();
219
220 let sql = format!(
221 "INSERT INTO {} ({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *",
222 qi(U::table()),
223 qi(U::col_id()),
224 qi(U::col_email()),
225 qi(U::col_name()),
226 qi(U::col_image()),
227 qi(U::col_email_verified()),
228 qi(U::col_username()),
229 qi(U::col_display_username()),
230 qi(U::col_role()),
231 qi(U::col_created_at()),
232 qi(U::col_updated_at()),
233 qi(U::col_metadata()),
234 );
235 let user = sqlx::query_as::<_, U>(&sql)
236 .bind(&id)
237 .bind(&create_user.email)
238 .bind(&create_user.name)
239 .bind(&create_user.image)
240 .bind(create_user.email_verified.unwrap_or(false))
241 .bind(&create_user.username)
242 .bind(&create_user.display_username)
243 .bind(&create_user.role)
244 .bind(now)
245 .bind(now)
246 .bind(sqlx::types::Json(
247 create_user.metadata.unwrap_or(serde_json::json!({})),
248 ))
249 .fetch_one(&self.pool)
250 .await?;
251
252 Ok(user)
253 }
254
255 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
256 let sql = format!(
257 "SELECT * FROM {} WHERE {} = $1",
258 qi(U::table()),
259 qi(U::col_id())
260 );
261 let user = sqlx::query_as::<_, U>(&sql)
262 .bind(id)
263 .fetch_optional(&self.pool)
264 .await?;
265 Ok(user)
266 }
267
268 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
269 let sql = format!(
270 "SELECT * FROM {} WHERE {} = $1",
271 qi(U::table()),
272 qi(U::col_email())
273 );
274 let user = sqlx::query_as::<_, U>(&sql)
275 .bind(email)
276 .fetch_optional(&self.pool)
277 .await?;
278 Ok(user)
279 }
280
281 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
282 let sql = format!(
283 "SELECT * FROM {} WHERE {} = $1",
284 qi(U::table()),
285 qi(U::col_username())
286 );
287 let user = sqlx::query_as::<_, U>(&sql)
288 .bind(username)
289 .fetch_optional(&self.pool)
290 .await?;
291 Ok(user)
292 }
293
294 async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
295 let mut query = sqlx::QueryBuilder::new(format!(
296 "UPDATE {} SET {} = NOW()",
297 qi(U::table()),
298 qi(U::col_updated_at())
299 ));
300 let mut has_updates = false;
301
302 if let Some(email) = &update.email {
303 query.push(format!(", {} = ", qi(U::col_email())));
304 query.push_bind(email);
305 has_updates = true;
306 }
307 if let Some(name) = &update.name {
308 query.push(format!(", {} = ", qi(U::col_name())));
309 query.push_bind(name);
310 has_updates = true;
311 }
312 if let Some(image) = &update.image {
313 query.push(format!(", {} = ", qi(U::col_image())));
314 query.push_bind(image);
315 has_updates = true;
316 }
317 if let Some(email_verified) = update.email_verified {
318 query.push(format!(", {} = ", qi(U::col_email_verified())));
319 query.push_bind(email_verified);
320 has_updates = true;
321 }
322 if let Some(username) = &update.username {
323 query.push(format!(", {} = ", qi(U::col_username())));
324 query.push_bind(username);
325 has_updates = true;
326 }
327 if let Some(display_username) = &update.display_username {
328 query.push(format!(", {} = ", qi(U::col_display_username())));
329 query.push_bind(display_username);
330 has_updates = true;
331 }
332 if let Some(role) = &update.role {
333 query.push(format!(", {} = ", qi(U::col_role())));
334 query.push_bind(role);
335 has_updates = true;
336 }
337 if let Some(banned) = update.banned {
338 query.push(format!(", {} = ", qi(U::col_banned())));
339 query.push_bind(banned);
340 has_updates = true;
341 if !banned {
343 query.push(format!(
344 ", {} = NULL, {} = NULL",
345 qi(U::col_ban_reason()),
346 qi(U::col_ban_expires())
347 ));
348 }
349 }
350 if update.banned != Some(false) {
355 if let Some(ban_reason) = &update.ban_reason {
356 query.push(format!(", {} = ", qi(U::col_ban_reason())));
357 query.push_bind(ban_reason);
358 has_updates = true;
359 }
360 if let Some(ban_expires) = update.ban_expires {
361 query.push(format!(", {} = ", qi(U::col_ban_expires())));
362 query.push_bind(ban_expires);
363 has_updates = true;
364 }
365 }
366 if let Some(two_factor_enabled) = update.two_factor_enabled {
367 query.push(format!(", {} = ", qi(U::col_two_factor_enabled())));
368 query.push_bind(two_factor_enabled);
369 has_updates = true;
370 }
371 if let Some(metadata) = &update.metadata {
372 query.push(format!(", {} = ", qi(U::col_metadata())));
373 query.push_bind(sqlx::types::Json(metadata.clone()));
374 has_updates = true;
375 }
376
377 if !has_updates {
378 return self
379 .get_user_by_id(id)
380 .await?
381 .ok_or(AuthError::UserNotFound);
382 }
383
384 query.push(format!(" WHERE {} = ", qi(U::col_id())));
385 query.push_bind(id);
386 query.push(" RETURNING *");
387
388 let user = query.build_query_as::<U>().fetch_one(&self.pool).await?;
389 Ok(user)
390 }
391
392 async fn delete_user(&self, id: &str) -> AuthResult<()> {
393 let sql = format!(
394 "DELETE FROM {} WHERE {} = $1",
395 qi(U::table()),
396 qi(U::col_id())
397 );
398 sqlx::query(&sql).bind(id).execute(&self.pool).await?;
399 Ok(())
400 }
401
402 async fn list_users(&self, params: ListUsersParams) -> AuthResult<(Vec<U>, usize)> {
403 let limit = params.limit.unwrap_or(100) as i64;
404 let offset = params.offset.unwrap_or(0) as i64;
405
406 let mut conditions: Vec<String> = Vec::new();
408 let mut bind_values: Vec<String> = Vec::new();
409
410 if let Some(search_value) = ¶ms.search_value {
411 let field = params.search_field.as_deref().unwrap_or("email");
412 let col = qi(match field {
413 "name" => U::col_name(),
414 _ => U::col_email(),
415 });
416 let op = params.search_operator.as_deref().unwrap_or("contains");
417 let escaped = search_value.replace('%', "\\%").replace('_', "\\_");
418 let pattern = match op {
419 "starts_with" => format!("{}%", escaped),
420 "ends_with" => format!("%{}", escaped),
421 _ => format!("%{}%", escaped),
422 };
423 let idx = bind_values.len() + 1;
424 conditions.push(format!("{} ILIKE ${}", col, idx));
425 bind_values.push(pattern);
426 }
427
428 if let Some(filter_value) = ¶ms.filter_value {
429 let field = params.filter_field.as_deref().unwrap_or("email");
430 let col = qi(match field {
431 "name" => U::col_name(),
432 "role" => U::col_role(),
433 _ => U::col_email(),
434 });
435 let op = params.filter_operator.as_deref().unwrap_or("eq");
436 let idx = bind_values.len() + 1;
437 match op {
438 "contains" => {
439 let escaped = filter_value.replace('%', "\\%").replace('_', "\\_");
440 conditions.push(format!("{} ILIKE ${}", col, idx));
441 bind_values.push(format!("%{}%", escaped));
442 }
443 "ne" => {
444 conditions.push(format!("{} != ${}", col, idx));
445 bind_values.push(filter_value.clone());
446 }
447 _ => {
448 conditions.push(format!("{} = ${}", col, idx));
449 bind_values.push(filter_value.clone());
450 }
451 }
452 }
453
454 let where_clause = if conditions.is_empty() {
455 String::new()
456 } else {
457 format!(" WHERE {}", conditions.join(" AND "))
458 };
459
460 let order_clause = if let Some(sort_by) = ¶ms.sort_by {
462 let col = qi(match sort_by.as_str() {
463 "name" => U::col_name(),
464 "createdAt" | "created_at" => U::col_created_at(),
465 _ => U::col_email(),
466 });
467 let dir = if params.sort_direction.as_deref() == Some("desc") {
468 "DESC"
469 } else {
470 "ASC"
471 };
472 format!(" ORDER BY {} {}", col, dir)
473 } else {
474 format!(" ORDER BY {} DESC", qi(U::col_created_at()))
475 };
476
477 let count_sql = format!(
479 "SELECT COUNT(*) as count FROM {}{}",
480 qi(U::table()),
481 where_clause
482 );
483 let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
484 for v in &bind_values {
485 count_query = count_query.bind(v);
486 }
487 let total = count_query.fetch_one(&self.pool).await? as usize;
488
489 let limit_idx = bind_values.len() + 1;
491 let offset_idx = bind_values.len() + 2;
492 let data_sql = format!(
493 "SELECT * FROM {}{}{} LIMIT ${} OFFSET ${}",
494 qi(U::table()),
495 where_clause,
496 order_clause,
497 limit_idx,
498 offset_idx
499 );
500 let mut data_query = sqlx::query_as::<_, U>(&data_sql);
501 for v in &bind_values {
502 data_query = data_query.bind(v);
503 }
504 data_query = data_query.bind(limit).bind(offset);
505 let users = data_query.fetch_all(&self.pool).await?;
506
507 Ok((users, total))
508 }
509 }
510
511 #[async_trait]
514 impl<U, S, A, O, M, I, V, TF, AK, PK> SessionOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
515 where
516 U: AuthUser + AuthUserMeta + SqlxEntity,
517 S: AuthSession + AuthSessionMeta + SqlxEntity,
518 A: AuthAccount + AuthAccountMeta + SqlxEntity,
519 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
520 M: AuthMember + AuthMemberMeta + SqlxEntity,
521 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
522 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
523 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
524 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
525 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
526 {
527 type Session = S;
528
529 async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
530 let id = Uuid::new_v4().to_string();
531 let token = format!("session_{}", Uuid::new_v4());
532 let now = Utc::now();
533
534 let sql = format!(
535 "INSERT INTO {} ({}, {}, {}, {}, {}, {}, {}, {}, {}, {}) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *",
536 qi(S::table()),
537 qi(S::col_id()),
538 qi(S::col_user_id()),
539 qi(S::col_token()),
540 qi(S::col_expires_at()),
541 qi(S::col_created_at()),
542 qi(S::col_ip_address()),
543 qi(S::col_user_agent()),
544 qi(S::col_impersonated_by()),
545 qi(S::col_active_organization_id()),
546 qi(S::col_active()),
547 );
548 let session = sqlx::query_as::<_, S>(&sql)
549 .bind(&id)
550 .bind(&create_session.user_id)
551 .bind(&token)
552 .bind(create_session.expires_at)
553 .bind(now)
554 .bind(&create_session.ip_address)
555 .bind(&create_session.user_agent)
556 .bind(&create_session.impersonated_by)
557 .bind(&create_session.active_organization_id)
558 .bind(true)
559 .fetch_one(&self.pool)
560 .await?;
561
562 Ok(session)
563 }
564
565 async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
566 let sql = format!(
567 "SELECT * FROM {} WHERE {} = $1 AND {} = true",
568 qi(S::table()),
569 qi(S::col_token()),
570 qi(S::col_active())
571 );
572 let session = sqlx::query_as::<_, S>(&sql)
573 .bind(token)
574 .fetch_optional(&self.pool)
575 .await?;
576 Ok(session)
577 }
578
579 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
580 let sql = format!(
581 "SELECT * FROM {} WHERE {} = $1 AND {} = true ORDER BY {} DESC",
582 qi(S::table()),
583 qi(S::col_user_id()),
584 qi(S::col_active()),
585 qi(S::col_created_at())
586 );
587 let sessions = sqlx::query_as::<_, S>(&sql)
588 .bind(user_id)
589 .fetch_all(&self.pool)
590 .await?;
591 Ok(sessions)
592 }
593
594 async fn update_session_expiry(
595 &self,
596 token: &str,
597 expires_at: DateTime<Utc>,
598 ) -> AuthResult<()> {
599 let sql = format!(
600 "UPDATE {} SET {} = $1, {} = $2 WHERE {} = $3 AND {} = true",
601 qi(S::table()),
602 qi(S::col_expires_at()),
603 qi(S::col_updated_at()),
604 qi(S::col_token()),
605 qi(S::col_active())
606 );
607 sqlx::query(&sql)
608 .bind(expires_at)
609 .bind(Utc::now())
610 .bind(token)
611 .execute(&self.pool)
612 .await?;
613 Ok(())
614 }
615
616 async fn delete_session(&self, token: &str) -> AuthResult<()> {
617 let sql = format!(
618 "DELETE FROM {} WHERE {} = $1",
619 qi(S::table()),
620 qi(S::col_token())
621 );
622 sqlx::query(&sql).bind(token).execute(&self.pool).await?;
623 Ok(())
624 }
625
626 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
627 let sql = format!(
628 "DELETE FROM {} WHERE {} = $1",
629 qi(S::table()),
630 qi(S::col_user_id())
631 );
632 sqlx::query(&sql).bind(user_id).execute(&self.pool).await?;
633 Ok(())
634 }
635
636 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
637 let sql = format!(
638 "DELETE FROM {} WHERE {} < NOW() OR {} = false",
639 qi(S::table()),
640 qi(S::col_expires_at()),
641 qi(S::col_active())
642 );
643 let result = sqlx::query(&sql).execute(&self.pool).await?;
644 Ok(result.rows_affected() as usize)
645 }
646
647 async fn update_session_active_organization(
648 &self,
649 token: &str,
650 organization_id: Option<&str>,
651 ) -> AuthResult<S> {
652 let sql = format!(
653 "UPDATE {} SET {} = $1, {} = NOW() WHERE {} = $2 AND {} = true RETURNING *",
654 qi(S::table()),
655 qi(S::col_active_organization_id()),
656 qi(S::col_updated_at()),
657 qi(S::col_token()),
658 qi(S::col_active())
659 );
660 let session = sqlx::query_as::<_, S>(&sql)
661 .bind(organization_id)
662 .bind(token)
663 .fetch_one(&self.pool)
664 .await?;
665 Ok(session)
666 }
667 }
668
669 #[async_trait]
672 impl<U, S, A, O, M, I, V, TF, AK, PK> AccountOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
673 where
674 U: AuthUser + AuthUserMeta + SqlxEntity,
675 S: AuthSession + AuthSessionMeta + SqlxEntity,
676 A: AuthAccount + AuthAccountMeta + SqlxEntity,
677 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
678 M: AuthMember + AuthMemberMeta + SqlxEntity,
679 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
680 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
681 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
682 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
683 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
684 {
685 type Account = A;
686
687 async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
688 let id = Uuid::new_v4().to_string();
689 let now = Utc::now();
690
691 let sql = format!(
692 "INSERT INTO {} ({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}) \
693 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING *",
694 qi(A::table()),
695 qi(A::col_id()),
696 qi(A::col_account_id()),
697 qi(A::col_provider_id()),
698 qi(A::col_user_id()),
699 qi(A::col_access_token()),
700 qi(A::col_refresh_token()),
701 qi(A::col_id_token()),
702 qi(A::col_access_token_expires_at()),
703 qi(A::col_refresh_token_expires_at()),
704 qi(A::col_scope()),
705 qi(A::col_password()),
706 qi(A::col_created_at()),
707 qi(A::col_updated_at()),
708 );
709 let account = sqlx::query_as::<_, A>(&sql)
710 .bind(&id)
711 .bind(&create_account.account_id)
712 .bind(&create_account.provider_id)
713 .bind(&create_account.user_id)
714 .bind(&create_account.access_token)
715 .bind(&create_account.refresh_token)
716 .bind(&create_account.id_token)
717 .bind(create_account.access_token_expires_at)
718 .bind(create_account.refresh_token_expires_at)
719 .bind(&create_account.scope)
720 .bind(&create_account.password)
721 .bind(now)
722 .bind(now)
723 .fetch_one(&self.pool)
724 .await?;
725
726 Ok(account)
727 }
728
729 async fn get_account(
730 &self,
731 provider: &str,
732 provider_account_id: &str,
733 ) -> AuthResult<Option<A>> {
734 let sql = format!(
735 "SELECT * FROM {} WHERE {} = $1 AND {} = $2",
736 qi(A::table()),
737 qi(A::col_provider_id()),
738 qi(A::col_account_id())
739 );
740 let account = sqlx::query_as::<_, A>(&sql)
741 .bind(provider)
742 .bind(provider_account_id)
743 .fetch_optional(&self.pool)
744 .await?;
745 Ok(account)
746 }
747
748 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
749 let sql = format!(
750 "SELECT * FROM {} WHERE {} = $1 ORDER BY {} DESC",
751 qi(A::table()),
752 qi(A::col_user_id()),
753 qi(A::col_created_at())
754 );
755 let accounts = sqlx::query_as::<_, A>(&sql)
756 .bind(user_id)
757 .fetch_all(&self.pool)
758 .await?;
759 Ok(accounts)
760 }
761
762 async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<A> {
763 let mut query = sqlx::QueryBuilder::new(format!(
764 "UPDATE {} SET {} = NOW()",
765 qi(A::table()),
766 qi(A::col_updated_at())
767 ));
768
769 if let Some(access_token) = &update.access_token {
770 query.push(format!(", {} = ", qi(A::col_access_token())));
771 query.push_bind(access_token);
772 }
773 if let Some(refresh_token) = &update.refresh_token {
774 query.push(format!(", {} = ", qi(A::col_refresh_token())));
775 query.push_bind(refresh_token);
776 }
777 if let Some(id_token) = &update.id_token {
778 query.push(format!(", {} = ", qi(A::col_id_token())));
779 query.push_bind(id_token);
780 }
781 if let Some(access_token_expires_at) = &update.access_token_expires_at {
782 query.push(format!(", {} = ", qi(A::col_access_token_expires_at())));
783 query.push_bind(access_token_expires_at);
784 }
785 if let Some(refresh_token_expires_at) = &update.refresh_token_expires_at {
786 query.push(format!(", {} = ", qi(A::col_refresh_token_expires_at())));
787 query.push_bind(refresh_token_expires_at);
788 }
789 if let Some(scope) = &update.scope {
790 query.push(format!(", {} = ", qi(A::col_scope())));
791 query.push_bind(scope);
792 }
793 if let Some(password) = &update.password {
794 query.push(format!(", {} = ", qi(A::col_password())));
795 query.push_bind(password);
796 }
797
798 query.push(format!(" WHERE {} = ", qi(A::col_id())));
799 query.push_bind(id);
800 query.push(" RETURNING *");
801
802 let account = query.build_query_as::<A>().fetch_one(&self.pool).await?;
803 Ok(account)
804 }
805
806 async fn delete_account(&self, id: &str) -> AuthResult<()> {
807 let sql = format!(
808 "DELETE FROM {} WHERE {} = $1",
809 qi(A::table()),
810 qi(A::col_id())
811 );
812 sqlx::query(&sql).bind(id).execute(&self.pool).await?;
813 Ok(())
814 }
815 }
816
817 #[async_trait]
820 impl<U, S, A, O, M, I, V, TF, AK, PK> VerificationOps
821 for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
822 where
823 U: AuthUser + AuthUserMeta + SqlxEntity,
824 S: AuthSession + AuthSessionMeta + SqlxEntity,
825 A: AuthAccount + AuthAccountMeta + SqlxEntity,
826 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
827 M: AuthMember + AuthMemberMeta + SqlxEntity,
828 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
829 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
830 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
831 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
832 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
833 {
834 type Verification = V;
835
836 async fn create_verification(
837 &self,
838 create_verification: CreateVerification,
839 ) -> AuthResult<V> {
840 let id = Uuid::new_v4().to_string();
841 let now = Utc::now();
842
843 let sql = format!(
844 "INSERT INTO {} ({}, {}, {}, {}, {}, {}) VALUES ($1, $2, $3, $4, $5, $6) RETURNING *",
845 qi(V::table()),
846 qi(V::col_id()),
847 qi(V::col_identifier()),
848 qi(V::col_value()),
849 qi(V::col_expires_at()),
850 qi(V::col_created_at()),
851 qi(V::col_updated_at()),
852 );
853 let verification = sqlx::query_as::<_, V>(&sql)
854 .bind(&id)
855 .bind(&create_verification.identifier)
856 .bind(&create_verification.value)
857 .bind(create_verification.expires_at)
858 .bind(now)
859 .bind(now)
860 .fetch_one(&self.pool)
861 .await?;
862
863 Ok(verification)
864 }
865
866 async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
867 let sql = format!(
868 "SELECT * FROM {} WHERE {} = $1 AND {} = $2 AND {} > NOW()",
869 qi(V::table()),
870 qi(V::col_identifier()),
871 qi(V::col_value()),
872 qi(V::col_expires_at())
873 );
874 let verification = sqlx::query_as::<_, V>(&sql)
875 .bind(identifier)
876 .bind(value)
877 .fetch_optional(&self.pool)
878 .await?;
879 Ok(verification)
880 }
881
882 async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
883 let sql = format!(
884 "SELECT * FROM {} WHERE {} = $1 AND {} > NOW()",
885 qi(V::table()),
886 qi(V::col_value()),
887 qi(V::col_expires_at())
888 );
889 let verification = sqlx::query_as::<_, V>(&sql)
890 .bind(value)
891 .fetch_optional(&self.pool)
892 .await?;
893 Ok(verification)
894 }
895
896 async fn get_verification_by_identifier(&self, identifier: &str) -> AuthResult<Option<V>> {
897 let sql = format!(
898 "SELECT * FROM {} WHERE {} = $1 AND {} > NOW()",
899 qi(V::table()),
900 qi(V::col_identifier()),
901 qi(V::col_expires_at())
902 );
903 let verification = sqlx::query_as::<_, V>(&sql)
904 .bind(identifier)
905 .fetch_optional(&self.pool)
906 .await?;
907 Ok(verification)
908 }
909
910 async fn consume_verification(
911 &self,
912 identifier: &str,
913 value: &str,
914 ) -> AuthResult<Option<V>> {
915 let sql = format!(
916 "DELETE FROM {tbl} WHERE {id} IN (\
917 SELECT {id} FROM {tbl} \
918 WHERE {ident} = $1 AND {val} = $2 AND {exp} > NOW() \
919 ORDER BY {ca} DESC \
920 LIMIT 1\
921 ) RETURNING *",
922 tbl = qi(V::table()),
923 id = qi(V::col_id()),
924 ident = qi(V::col_identifier()),
925 val = qi(V::col_value()),
926 exp = qi(V::col_expires_at()),
927 ca = qi(V::col_created_at()),
928 );
929 let verification = sqlx::query_as::<_, V>(&sql)
930 .bind(identifier)
931 .bind(value)
932 .fetch_optional(&self.pool)
933 .await?;
934 Ok(verification)
935 }
936
937 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
938 let sql = format!(
939 "DELETE FROM {} WHERE {} = $1",
940 qi(V::table()),
941 qi(V::col_id())
942 );
943 sqlx::query(&sql).bind(id).execute(&self.pool).await?;
944 Ok(())
945 }
946
947 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
948 let sql = format!(
949 "DELETE FROM {} WHERE {} < NOW()",
950 qi(V::table()),
951 qi(V::col_expires_at())
952 );
953 let result = sqlx::query(&sql).execute(&self.pool).await?;
954 Ok(result.rows_affected() as usize)
955 }
956 }
957
958 #[async_trait]
961 impl<U, S, A, O, M, I, V, TF, AK, PK> OrganizationOps
962 for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
963 where
964 U: AuthUser + AuthUserMeta + SqlxEntity,
965 S: AuthSession + AuthSessionMeta + SqlxEntity,
966 A: AuthAccount + AuthAccountMeta + SqlxEntity,
967 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
968 M: AuthMember + AuthMemberMeta + SqlxEntity,
969 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
970 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
971 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
972 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
973 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
974 {
975 type Organization = O;
976
977 async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
978 let id = create_org.id.unwrap_or_else(|| Uuid::new_v4().to_string());
979 let now = Utc::now();
980
981 let sql = format!(
982 "INSERT INTO {} ({}, {}, {}, {}, {}, {}, {}) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *",
983 qi(O::table()),
984 qi(O::col_id()),
985 qi(O::col_name()),
986 qi(O::col_slug()),
987 qi(O::col_logo()),
988 qi(O::col_metadata()),
989 qi(O::col_created_at()),
990 qi(O::col_updated_at()),
991 );
992 let organization = sqlx::query_as::<_, O>(&sql)
993 .bind(&id)
994 .bind(&create_org.name)
995 .bind(&create_org.slug)
996 .bind(&create_org.logo)
997 .bind(sqlx::types::Json(
998 create_org.metadata.unwrap_or(serde_json::json!({})),
999 ))
1000 .bind(now)
1001 .bind(now)
1002 .fetch_one(&self.pool)
1003 .await?;
1004
1005 Ok(organization)
1006 }
1007
1008 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
1009 let sql = format!(
1010 "SELECT * FROM {} WHERE {} = $1",
1011 qi(O::table()),
1012 qi(O::col_id())
1013 );
1014 let organization = sqlx::query_as::<_, O>(&sql)
1015 .bind(id)
1016 .fetch_optional(&self.pool)
1017 .await?;
1018 Ok(organization)
1019 }
1020
1021 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
1022 let sql = format!(
1023 "SELECT * FROM {} WHERE {} = $1",
1024 qi(O::table()),
1025 qi(O::col_slug())
1026 );
1027 let organization = sqlx::query_as::<_, O>(&sql)
1028 .bind(slug)
1029 .fetch_optional(&self.pool)
1030 .await?;
1031 Ok(organization)
1032 }
1033
1034 async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
1035 let mut query = sqlx::QueryBuilder::new(format!(
1036 "UPDATE {} SET {} = NOW()",
1037 qi(O::table()),
1038 qi(O::col_updated_at())
1039 ));
1040
1041 if let Some(name) = &update.name {
1042 query.push(format!(", {} = ", qi(O::col_name())));
1043 query.push_bind(name);
1044 }
1045 if let Some(slug) = &update.slug {
1046 query.push(format!(", {} = ", qi(O::col_slug())));
1047 query.push_bind(slug);
1048 }
1049 if let Some(logo) = &update.logo {
1050 query.push(format!(", {} = ", qi(O::col_logo())));
1051 query.push_bind(logo);
1052 }
1053 if let Some(metadata) = &update.metadata {
1054 query.push(format!(", {} = ", qi(O::col_metadata())));
1055 query.push_bind(sqlx::types::Json(metadata.clone()));
1056 }
1057
1058 query.push(format!(" WHERE {} = ", qi(O::col_id())));
1059 query.push_bind(id);
1060 query.push(" RETURNING *");
1061
1062 let organization = query.build_query_as::<O>().fetch_one(&self.pool).await?;
1063 Ok(organization)
1064 }
1065
1066 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
1067 let sql = format!(
1068 "DELETE FROM {} WHERE {} = $1",
1069 qi(O::table()),
1070 qi(O::col_id())
1071 );
1072 sqlx::query(&sql).bind(id).execute(&self.pool).await?;
1073 Ok(())
1074 }
1075
1076 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
1077 let sql = format!(
1078 "SELECT o.* FROM {} o INNER JOIN {} m ON o.{} = m.{} WHERE m.{} = $1 ORDER BY o.{} DESC",
1079 qi(O::table()),
1080 qi(M::table()),
1081 qi(O::col_id()),
1082 qi(M::col_organization_id()),
1083 qi(M::col_user_id()),
1084 qi(O::col_created_at()),
1085 );
1086 let organizations = sqlx::query_as::<_, O>(&sql)
1087 .bind(user_id)
1088 .fetch_all(&self.pool)
1089 .await?;
1090 Ok(organizations)
1091 }
1092 }
1093
1094 #[async_trait]
1097 impl<U, S, A, O, M, I, V, TF, AK, PK> MemberOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1098 where
1099 U: AuthUser + AuthUserMeta + SqlxEntity,
1100 S: AuthSession + AuthSessionMeta + SqlxEntity,
1101 A: AuthAccount + AuthAccountMeta + SqlxEntity,
1102 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
1103 M: AuthMember + AuthMemberMeta + SqlxEntity,
1104 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
1105 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
1106 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
1107 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
1108 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
1109 {
1110 type Member = M;
1111
1112 async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
1113 let id = Uuid::new_v4().to_string();
1114 let now = Utc::now();
1115
1116 let sql = format!(
1117 "INSERT INTO {} ({}, {}, {}, {}, {}) VALUES ($1, $2, $3, $4, $5) RETURNING *",
1118 qi(M::table()),
1119 qi(M::col_id()),
1120 qi(M::col_organization_id()),
1121 qi(M::col_user_id()),
1122 qi(M::col_role()),
1123 qi(M::col_created_at()),
1124 );
1125 let member = sqlx::query_as::<_, M>(&sql)
1126 .bind(&id)
1127 .bind(&create_member.organization_id)
1128 .bind(&create_member.user_id)
1129 .bind(&create_member.role)
1130 .bind(now)
1131 .fetch_one(&self.pool)
1132 .await?;
1133
1134 Ok(member)
1135 }
1136
1137 async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
1138 let sql = format!(
1139 "SELECT * FROM {} WHERE {} = $1 AND {} = $2",
1140 qi(M::table()),
1141 qi(M::col_organization_id()),
1142 qi(M::col_user_id())
1143 );
1144 let member = sqlx::query_as::<_, M>(&sql)
1145 .bind(organization_id)
1146 .bind(user_id)
1147 .fetch_optional(&self.pool)
1148 .await?;
1149 Ok(member)
1150 }
1151
1152 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
1153 let sql = format!(
1154 "SELECT * FROM {} WHERE {} = $1",
1155 qi(M::table()),
1156 qi(M::col_id())
1157 );
1158 let member = sqlx::query_as::<_, M>(&sql)
1159 .bind(id)
1160 .fetch_optional(&self.pool)
1161 .await?;
1162 Ok(member)
1163 }
1164
1165 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
1166 let sql = format!(
1167 "UPDATE {} SET {} = $1 WHERE {} = $2 RETURNING *",
1168 qi(M::table()),
1169 qi(M::col_role()),
1170 qi(M::col_id())
1171 );
1172 let member = sqlx::query_as::<_, M>(&sql)
1173 .bind(role)
1174 .bind(member_id)
1175 .fetch_one(&self.pool)
1176 .await?;
1177 Ok(member)
1178 }
1179
1180 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
1181 let sql = format!(
1182 "DELETE FROM {} WHERE {} = $1",
1183 qi(M::table()),
1184 qi(M::col_id())
1185 );
1186 sqlx::query(&sql)
1187 .bind(member_id)
1188 .execute(&self.pool)
1189 .await?;
1190 Ok(())
1191 }
1192
1193 async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
1194 let sql = format!(
1195 "SELECT * FROM {} WHERE {} = $1 ORDER BY {} ASC",
1196 qi(M::table()),
1197 qi(M::col_organization_id()),
1198 qi(M::col_created_at())
1199 );
1200 let members = sqlx::query_as::<_, M>(&sql)
1201 .bind(organization_id)
1202 .fetch_all(&self.pool)
1203 .await?;
1204 Ok(members)
1205 }
1206
1207 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
1208 let sql = format!(
1209 "SELECT COUNT(*) FROM {} WHERE {} = $1",
1210 qi(M::table()),
1211 qi(M::col_organization_id())
1212 );
1213 let count: (i64,) = sqlx::query_as(&sql)
1214 .bind(organization_id)
1215 .fetch_one(&self.pool)
1216 .await?;
1217 Ok(count.0 as usize)
1218 }
1219
1220 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
1221 let sql = format!(
1222 "SELECT COUNT(*) FROM {} WHERE {} = $1 AND {} = 'owner'",
1223 qi(M::table()),
1224 qi(M::col_organization_id()),
1225 qi(M::col_role())
1226 );
1227 let count: (i64,) = sqlx::query_as(&sql)
1228 .bind(organization_id)
1229 .fetch_one(&self.pool)
1230 .await?;
1231 Ok(count.0 as usize)
1232 }
1233 }
1234
1235 #[async_trait]
1238 impl<U, S, A, O, M, I, V, TF, AK, PK> InvitationOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1239 where
1240 U: AuthUser + AuthUserMeta + SqlxEntity,
1241 S: AuthSession + AuthSessionMeta + SqlxEntity,
1242 A: AuthAccount + AuthAccountMeta + SqlxEntity,
1243 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
1244 M: AuthMember + AuthMemberMeta + SqlxEntity,
1245 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
1246 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
1247 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
1248 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
1249 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
1250 {
1251 type Invitation = I;
1252
1253 async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
1254 let id = Uuid::new_v4().to_string();
1255 let now = Utc::now();
1256
1257 let sql = format!(
1258 "INSERT INTO {} ({}, {}, {}, {}, {}, {}, {}, {}) \
1259 VALUES ($1, $2, $3, $4, 'pending', $5, $6, $7) RETURNING *",
1260 qi(I::table()),
1261 qi(I::col_id()),
1262 qi(I::col_organization_id()),
1263 qi(I::col_email()),
1264 qi(I::col_role()),
1265 qi(I::col_status()),
1266 qi(I::col_inviter_id()),
1267 qi(I::col_expires_at()),
1268 qi(I::col_created_at()),
1269 );
1270 let invitation = sqlx::query_as::<_, I>(&sql)
1271 .bind(&id)
1272 .bind(&create_inv.organization_id)
1273 .bind(&create_inv.email)
1274 .bind(&create_inv.role)
1275 .bind(&create_inv.inviter_id)
1276 .bind(create_inv.expires_at)
1277 .bind(now)
1278 .fetch_one(&self.pool)
1279 .await?;
1280
1281 Ok(invitation)
1282 }
1283
1284 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
1285 let sql = format!(
1286 "SELECT * FROM {} WHERE {} = $1",
1287 qi(I::table()),
1288 qi(I::col_id())
1289 );
1290 let invitation = sqlx::query_as::<_, I>(&sql)
1291 .bind(id)
1292 .fetch_optional(&self.pool)
1293 .await?;
1294 Ok(invitation)
1295 }
1296
1297 async fn get_pending_invitation(
1298 &self,
1299 organization_id: &str,
1300 email: &str,
1301 ) -> AuthResult<Option<I>> {
1302 let sql = format!(
1303 "SELECT * FROM {} WHERE {} = $1 AND LOWER({}) = LOWER($2) AND {} = 'pending'",
1304 qi(I::table()),
1305 qi(I::col_organization_id()),
1306 qi(I::col_email()),
1307 qi(I::col_status())
1308 );
1309 let invitation = sqlx::query_as::<_, I>(&sql)
1310 .bind(organization_id)
1311 .bind(email)
1312 .fetch_optional(&self.pool)
1313 .await?;
1314 Ok(invitation)
1315 }
1316
1317 async fn update_invitation_status(
1318 &self,
1319 id: &str,
1320 status: InvitationStatus,
1321 ) -> AuthResult<I> {
1322 let sql = format!(
1323 "UPDATE {} SET {} = $1 WHERE {} = $2 RETURNING *",
1324 qi(I::table()),
1325 qi(I::col_status()),
1326 qi(I::col_id())
1327 );
1328 let invitation = sqlx::query_as::<_, I>(&sql)
1329 .bind(status.to_string())
1330 .bind(id)
1331 .fetch_one(&self.pool)
1332 .await?;
1333 Ok(invitation)
1334 }
1335
1336 async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
1337 let sql = format!(
1338 "SELECT * FROM {} WHERE {} = $1 ORDER BY {} DESC",
1339 qi(I::table()),
1340 qi(I::col_organization_id()),
1341 qi(I::col_created_at())
1342 );
1343 let invitations = sqlx::query_as::<_, I>(&sql)
1344 .bind(organization_id)
1345 .fetch_all(&self.pool)
1346 .await?;
1347 Ok(invitations)
1348 }
1349
1350 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
1351 let sql = format!(
1352 "SELECT * FROM {} WHERE LOWER({}) = LOWER($1) AND {} = 'pending' AND {} > NOW() ORDER BY {} DESC",
1353 qi(I::table()),
1354 qi(I::col_email()),
1355 qi(I::col_status()),
1356 qi(I::col_expires_at()),
1357 qi(I::col_created_at())
1358 );
1359 let invitations = sqlx::query_as::<_, I>(&sql)
1360 .bind(email)
1361 .fetch_all(&self.pool)
1362 .await?;
1363 Ok(invitations)
1364 }
1365 }
1366
1367 #[async_trait]
1370 impl<U, S, A, O, M, I, V, TF, AK, PK> TwoFactorOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1371 where
1372 U: AuthUser + AuthUserMeta + SqlxEntity,
1373 S: AuthSession + AuthSessionMeta + SqlxEntity,
1374 A: AuthAccount + AuthAccountMeta + SqlxEntity,
1375 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
1376 M: AuthMember + AuthMemberMeta + SqlxEntity,
1377 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
1378 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
1379 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
1380 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
1381 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
1382 {
1383 type TwoFactor = TF;
1384
1385 async fn create_two_factor(&self, create: CreateTwoFactor) -> AuthResult<TF> {
1386 let id = Uuid::new_v4().to_string();
1387 let now = Utc::now();
1388
1389 let sql = format!(
1390 "INSERT INTO {} ({}, {}, {}, {}, {}, {}) VALUES ($1, $2, $3, $4, $5, $6) RETURNING *",
1391 qi(TF::table()),
1392 qi(TF::col_id()),
1393 qi(TF::col_secret()),
1394 qi(TF::col_backup_codes()),
1395 qi(TF::col_user_id()),
1396 qi(TF::col_created_at()),
1397 qi(TF::col_updated_at()),
1398 );
1399 let two_factor = sqlx::query_as::<_, TF>(&sql)
1400 .bind(&id)
1401 .bind(&create.secret)
1402 .bind(&create.backup_codes)
1403 .bind(&create.user_id)
1404 .bind(now)
1405 .bind(now)
1406 .fetch_one(&self.pool)
1407 .await?;
1408
1409 Ok(two_factor)
1410 }
1411
1412 async fn get_two_factor_by_user_id(&self, user_id: &str) -> AuthResult<Option<TF>> {
1413 let sql = format!(
1414 "SELECT * FROM {} WHERE {} = $1",
1415 qi(TF::table()),
1416 qi(TF::col_user_id())
1417 );
1418 let two_factor = sqlx::query_as::<_, TF>(&sql)
1419 .bind(user_id)
1420 .fetch_optional(&self.pool)
1421 .await?;
1422 Ok(two_factor)
1423 }
1424
1425 async fn update_two_factor_backup_codes(
1426 &self,
1427 user_id: &str,
1428 backup_codes: &str,
1429 ) -> AuthResult<TF> {
1430 let sql = format!(
1431 "UPDATE {} SET {} = $1, {} = NOW() WHERE {} = $2 RETURNING *",
1432 qi(TF::table()),
1433 qi(TF::col_backup_codes()),
1434 qi(TF::col_updated_at()),
1435 qi(TF::col_user_id())
1436 );
1437 let two_factor = sqlx::query_as::<_, TF>(&sql)
1438 .bind(backup_codes)
1439 .bind(user_id)
1440 .fetch_one(&self.pool)
1441 .await?;
1442 Ok(two_factor)
1443 }
1444
1445 async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
1446 let sql = format!(
1447 "DELETE FROM {} WHERE {} = $1",
1448 qi(TF::table()),
1449 qi(TF::col_user_id())
1450 );
1451 sqlx::query(&sql).bind(user_id).execute(&self.pool).await?;
1452 Ok(())
1453 }
1454 }
1455
1456 #[async_trait]
1459 impl<U, S, A, O, M, I, V, TF, AK, PK> ApiKeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1460 where
1461 U: AuthUser + AuthUserMeta + SqlxEntity,
1462 S: AuthSession + AuthSessionMeta + SqlxEntity,
1463 A: AuthAccount + AuthAccountMeta + SqlxEntity,
1464 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
1465 M: AuthMember + AuthMemberMeta + SqlxEntity,
1466 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
1467 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
1468 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
1469 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
1470 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
1471 {
1472 type ApiKey = AK;
1473
1474 async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<AK> {
1475 let id = Uuid::new_v4().to_string();
1476 let now = Utc::now();
1477
1478 let sql = format!(
1479 "INSERT INTO {} ({}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}) \
1480 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14::timestamptz, $15, $16, $17, $18) RETURNING *",
1481 qi(AK::table()),
1482 qi(AK::col_id()),
1483 qi(AK::col_name()),
1484 qi(AK::col_start()),
1485 qi(AK::col_prefix()),
1486 qi(AK::col_key_hash()),
1487 qi(AK::col_user_id()),
1488 qi(AK::col_refill_interval()),
1489 qi(AK::col_refill_amount()),
1490 qi(AK::col_enabled()),
1491 qi(AK::col_rate_limit_enabled()),
1492 qi(AK::col_rate_limit_time_window()),
1493 qi(AK::col_rate_limit_max()),
1494 qi(AK::col_remaining()),
1495 qi(AK::col_expires_at()),
1496 qi(AK::col_created_at()),
1497 qi(AK::col_updated_at()),
1498 qi(AK::col_permissions()),
1499 qi(AK::col_metadata()),
1500 );
1501 let api_key = sqlx::query_as::<_, AK>(&sql)
1502 .bind(&id)
1503 .bind(&input.name)
1504 .bind(&input.start)
1505 .bind(&input.prefix)
1506 .bind(&input.key_hash)
1507 .bind(&input.user_id)
1508 .bind(input.refill_interval)
1509 .bind(input.refill_amount)
1510 .bind(input.enabled)
1511 .bind(input.rate_limit_enabled)
1512 .bind(input.rate_limit_time_window)
1513 .bind(input.rate_limit_max)
1514 .bind(input.remaining)
1515 .bind(&input.expires_at)
1516 .bind(now)
1517 .bind(now)
1518 .bind(&input.permissions)
1519 .bind(&input.metadata)
1520 .fetch_one(&self.pool)
1521 .await?;
1522
1523 Ok(api_key)
1524 }
1525
1526 async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<AK>> {
1527 let sql = format!(
1528 "SELECT * FROM {} WHERE {} = $1",
1529 qi(AK::table()),
1530 qi(AK::col_id())
1531 );
1532 let api_key = sqlx::query_as::<_, AK>(&sql)
1533 .bind(id)
1534 .fetch_optional(&self.pool)
1535 .await?;
1536 Ok(api_key)
1537 }
1538
1539 async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<AK>> {
1540 let sql = format!(
1541 "SELECT * FROM {} WHERE {} = $1",
1542 qi(AK::table()),
1543 qi(AK::col_key_hash())
1544 );
1545 let api_key = sqlx::query_as::<_, AK>(&sql)
1546 .bind(hash)
1547 .fetch_optional(&self.pool)
1548 .await?;
1549 Ok(api_key)
1550 }
1551
1552 async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<AK>> {
1553 let sql = format!(
1554 "SELECT * FROM {} WHERE {} = $1 ORDER BY {} DESC",
1555 qi(AK::table()),
1556 qi(AK::col_user_id()),
1557 qi(AK::col_created_at())
1558 );
1559 let keys = sqlx::query_as::<_, AK>(&sql)
1560 .bind(user_id)
1561 .fetch_all(&self.pool)
1562 .await?;
1563 Ok(keys)
1564 }
1565
1566 async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<AK> {
1567 let mut query = sqlx::QueryBuilder::new(format!(
1568 "UPDATE {} SET {} = NOW()",
1569 qi(AK::table()),
1570 qi(AK::col_updated_at())
1571 ));
1572
1573 if let Some(name) = &update.name {
1574 query.push(format!(", {} = ", qi(AK::col_name())));
1575 query.push_bind(name);
1576 }
1577 if let Some(enabled) = update.enabled {
1578 query.push(format!(", {} = ", qi(AK::col_enabled())));
1579 query.push_bind(enabled);
1580 }
1581 if let Some(remaining) = update.remaining {
1582 query.push(format!(", {} = ", qi(AK::col_remaining())));
1583 query.push_bind(remaining);
1584 }
1585 if let Some(rate_limit_enabled) = update.rate_limit_enabled {
1586 query.push(format!(", {} = ", qi(AK::col_rate_limit_enabled())));
1587 query.push_bind(rate_limit_enabled);
1588 }
1589 if let Some(rate_limit_time_window) = update.rate_limit_time_window {
1590 query.push(format!(", {} = ", qi(AK::col_rate_limit_time_window())));
1591 query.push_bind(rate_limit_time_window);
1592 }
1593 if let Some(rate_limit_max) = update.rate_limit_max {
1594 query.push(format!(", {} = ", qi(AK::col_rate_limit_max())));
1595 query.push_bind(rate_limit_max);
1596 }
1597 if let Some(refill_interval) = update.refill_interval {
1598 query.push(format!(", {} = ", qi(AK::col_refill_interval())));
1599 query.push_bind(refill_interval);
1600 }
1601 if let Some(refill_amount) = update.refill_amount {
1602 query.push(format!(", {} = ", qi(AK::col_refill_amount())));
1603 query.push_bind(refill_amount);
1604 }
1605 if let Some(permissions) = &update.permissions {
1606 query.push(format!(", {} = ", qi(AK::col_permissions())));
1607 query.push_bind(permissions);
1608 }
1609 if let Some(metadata) = &update.metadata {
1610 query.push(format!(", {} = ", qi(AK::col_metadata())));
1611 query.push_bind(metadata);
1612 }
1613 if let Some(expires_at) = &update.expires_at {
1614 query.push(format!(", {} = ", qi(AK::col_expires_at())));
1615 query.push_bind(expires_at.as_deref().map(|s| s.to_string()));
1616 }
1617 if let Some(last_request) = &update.last_request {
1618 query.push(format!(", {} = ", qi(AK::col_last_request())));
1619 query.push_bind(last_request.as_deref().map(|s| s.to_string()));
1620 }
1621 if let Some(request_count) = update.request_count {
1622 query.push(format!(", {} = ", qi(AK::col_request_count())));
1623 query.push_bind(request_count);
1624 }
1625 if let Some(last_refill_at) = &update.last_refill_at {
1626 query.push(format!(", {} = ", qi(AK::col_last_refill_at())));
1627 query.push_bind(last_refill_at.as_deref().map(|s| s.to_string()));
1628 }
1629
1630 query.push(format!(" WHERE {} = ", qi(AK::col_id())));
1631 query.push_bind(id);
1632 query.push(" RETURNING *");
1633
1634 let api_key = query
1635 .build_query_as::<AK>()
1636 .fetch_one(&self.pool)
1637 .await
1638 .map_err(|err| match err {
1639 sqlx::Error::RowNotFound => AuthError::not_found("API key not found"),
1640 other => AuthError::from(other),
1641 })?;
1642 Ok(api_key)
1643 }
1644
1645 async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
1646 let sql = format!(
1647 "DELETE FROM {} WHERE {} = $1",
1648 qi(AK::table()),
1649 qi(AK::col_id())
1650 );
1651 sqlx::query(&sql).bind(id).execute(&self.pool).await?;
1652 Ok(())
1653 }
1654
1655 async fn delete_expired_api_keys(&self) -> AuthResult<usize> {
1656 let now = Utc::now().to_rfc3339();
1660 let sql = format!(
1661 "DELETE FROM {} WHERE {} IS NOT NULL AND {} < $1",
1662 qi(AK::table()),
1663 qi(AK::col_expires_at()),
1664 qi(AK::col_expires_at()),
1665 );
1666 let result = sqlx::query(&sql).bind(&now).execute(&self.pool).await?;
1667 Ok(result.rows_affected() as usize)
1668 }
1669 }
1670
1671 #[async_trait]
1674 impl<U, S, A, O, M, I, V, TF, AK, PK> PasskeyOps for SqlxAdapter<U, S, A, O, M, I, V, TF, AK, PK>
1675 where
1676 U: AuthUser + AuthUserMeta + SqlxEntity,
1677 S: AuthSession + AuthSessionMeta + SqlxEntity,
1678 A: AuthAccount + AuthAccountMeta + SqlxEntity,
1679 O: AuthOrganization + AuthOrganizationMeta + SqlxEntity,
1680 M: AuthMember + AuthMemberMeta + SqlxEntity,
1681 I: AuthInvitation + AuthInvitationMeta + SqlxEntity,
1682 V: AuthVerification + AuthVerificationMeta + SqlxEntity,
1683 TF: AuthTwoFactor + AuthTwoFactorMeta + SqlxEntity,
1684 AK: AuthApiKey + AuthApiKeyMeta + SqlxEntity,
1685 PK: AuthPasskey + AuthPasskeyMeta + SqlxEntity,
1686 {
1687 type Passkey = PK;
1688
1689 async fn create_passkey(&self, input: CreatePasskey) -> AuthResult<PK> {
1690 let id = Uuid::new_v4().to_string();
1691 let now = Utc::now();
1692 let counter = i64::try_from(input.counter)
1693 .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1694
1695 let sql = format!(
1696 "INSERT INTO {} ({}, {}, {}, {}, {}, {}, {}, {}, {}, {}) \
1697 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *",
1698 qi(PK::table()),
1699 qi(PK::col_id()),
1700 qi(PK::col_name()),
1701 qi(PK::col_public_key()),
1702 qi(PK::col_user_id()),
1703 qi(PK::col_credential_id()),
1704 qi(PK::col_counter()),
1705 qi(PK::col_device_type()),
1706 qi(PK::col_backed_up()),
1707 qi(PK::col_transports()),
1708 qi(PK::col_created_at()),
1709 );
1710 let passkey = sqlx::query_as::<_, PK>(&sql)
1711 .bind(&id)
1712 .bind(&input.name)
1713 .bind(&input.public_key)
1714 .bind(&input.user_id)
1715 .bind(&input.credential_id)
1716 .bind(counter)
1717 .bind(&input.device_type)
1718 .bind(input.backed_up)
1719 .bind(&input.transports)
1720 .bind(now)
1721 .fetch_one(&self.pool)
1722 .await
1723 .map_err(|e| match e {
1724 sqlx::Error::Database(ref db_err) if db_err.is_unique_violation() => {
1725 AuthError::conflict("A passkey with this credential ID already exists")
1726 }
1727 other => AuthError::from(other),
1728 })?;
1729
1730 Ok(passkey)
1731 }
1732
1733 async fn get_passkey_by_id(&self, id: &str) -> AuthResult<Option<PK>> {
1734 let sql = format!(
1735 "SELECT * FROM {} WHERE {} = $1",
1736 qi(PK::table()),
1737 qi(PK::col_id())
1738 );
1739 let passkey = sqlx::query_as::<_, PK>(&sql)
1740 .bind(id)
1741 .fetch_optional(&self.pool)
1742 .await?;
1743 Ok(passkey)
1744 }
1745
1746 async fn get_passkey_by_credential_id(
1747 &self,
1748 credential_id: &str,
1749 ) -> AuthResult<Option<PK>> {
1750 let sql = format!(
1751 "SELECT * FROM {} WHERE {} = $1",
1752 qi(PK::table()),
1753 qi(PK::col_credential_id())
1754 );
1755 let passkey = sqlx::query_as::<_, PK>(&sql)
1756 .bind(credential_id)
1757 .fetch_optional(&self.pool)
1758 .await?;
1759 Ok(passkey)
1760 }
1761
1762 async fn list_passkeys_by_user(&self, user_id: &str) -> AuthResult<Vec<PK>> {
1763 let sql = format!(
1764 "SELECT * FROM {} WHERE {} = $1 ORDER BY {} DESC",
1765 qi(PK::table()),
1766 qi(PK::col_user_id()),
1767 qi(PK::col_created_at())
1768 );
1769 let passkeys = sqlx::query_as::<_, PK>(&sql)
1770 .bind(user_id)
1771 .fetch_all(&self.pool)
1772 .await?;
1773 Ok(passkeys)
1774 }
1775
1776 async fn update_passkey_counter(&self, id: &str, counter: u64) -> AuthResult<PK> {
1777 let counter = i64::try_from(counter)
1778 .map_err(|_| AuthError::bad_request("Passkey counter exceeds i64 range"))?;
1779 let sql = format!(
1780 "UPDATE {} SET {} = $2 WHERE {} = $1 RETURNING *",
1781 qi(PK::table()),
1782 qi(PK::col_counter()),
1783 qi(PK::col_id())
1784 );
1785 let passkey = sqlx::query_as::<_, PK>(&sql)
1786 .bind(id)
1787 .bind(counter)
1788 .fetch_one(&self.pool)
1789 .await
1790 .map_err(|err| match err {
1791 sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1792 other => AuthError::from(other),
1793 })?;
1794 Ok(passkey)
1795 }
1796
1797 async fn update_passkey_name(&self, id: &str, name: &str) -> AuthResult<PK> {
1798 let sql = format!(
1799 "UPDATE {} SET {} = $2 WHERE {} = $1 RETURNING *",
1800 qi(PK::table()),
1801 qi(PK::col_name()),
1802 qi(PK::col_id())
1803 );
1804 let passkey = sqlx::query_as::<_, PK>(&sql)
1805 .bind(id)
1806 .bind(name)
1807 .fetch_one(&self.pool)
1808 .await
1809 .map_err(|err| match err {
1810 sqlx::Error::RowNotFound => AuthError::not_found("Passkey not found"),
1811 other => AuthError::from(other),
1812 })?;
1813 Ok(passkey)
1814 }
1815
1816 async fn delete_passkey(&self, id: &str) -> AuthResult<()> {
1817 let sql = format!(
1818 "DELETE FROM {} WHERE {} = $1",
1819 qi(PK::table()),
1820 qi(PK::col_id())
1821 );
1822 sqlx::query(&sql).bind(id).execute(&self.pool).await?;
1823 Ok(())
1824 }
1825 }
1826}
1827
1828#[cfg(feature = "sqlx-postgres")]
1829pub use sqlx_adapter::{SqlxAdapter, SqlxEntity};