1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7 AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, TwoFactorOps,
8 UserOps, VerificationOps,
9};
10use crate::error::AuthResult;
11use crate::types::{
12 CreateAccount, CreateApiKey, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
13 CreateTwoFactor, CreateUser, CreateVerification, InvitationStatus, ListUsersParams,
14 UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser,
15};
16
17#[async_trait]
25pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
26 async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
27 let _ = user;
28 Ok(())
29 }
30
31 async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
32 let _ = user;
33 Ok(())
34 }
35
36 async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
37 let _ = (id, update);
38 Ok(())
39 }
40
41 async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
42 let _ = user;
43 Ok(())
44 }
45
46 async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
47 let _ = id;
48 Ok(())
49 }
50
51 async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52 let _ = id;
53 Ok(())
54 }
55
56 async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
57 let _ = session;
58 Ok(())
59 }
60
61 async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
62 let _ = session;
63 Ok(())
64 }
65
66 async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
67 let _ = token;
68 Ok(())
69 }
70
71 async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
72 let _ = token;
73 Ok(())
74 }
75
76 async fn before_create_account(&self, account: &mut CreateAccount) -> AuthResult<()> {
79 let _ = account;
80 Ok(())
81 }
82
83 async fn after_create_account(&self, account: &DB::Account) -> AuthResult<()> {
84 let _ = account;
85 Ok(())
86 }
87
88 async fn before_update_account(&self, id: &str, update: &mut UpdateAccount) -> AuthResult<()> {
89 let _ = (id, update);
90 Ok(())
91 }
92
93 async fn after_update_account(&self, account: &DB::Account) -> AuthResult<()> {
94 let _ = account;
95 Ok(())
96 }
97
98 async fn before_delete_account(&self, id: &str) -> AuthResult<()> {
99 let _ = id;
100 Ok(())
101 }
102
103 async fn after_delete_account(&self, id: &str) -> AuthResult<()> {
104 let _ = id;
105 Ok(())
106 }
107
108 async fn before_create_verification(
111 &self,
112 verification: &mut CreateVerification,
113 ) -> AuthResult<()> {
114 let _ = verification;
115 Ok(())
116 }
117
118 async fn after_create_verification(&self, verification: &DB::Verification) -> AuthResult<()> {
119 let _ = verification;
120 Ok(())
121 }
122
123 async fn before_delete_verification(&self, id: &str) -> AuthResult<()> {
124 let _ = id;
125 Ok(())
126 }
127
128 async fn after_delete_verification(&self, id: &str) -> AuthResult<()> {
129 let _ = id;
130 Ok(())
131 }
132}
133
134pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
136 inner: Arc<DB>,
137 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
138}
139
140impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
141 pub fn new(inner: Arc<DB>) -> Self {
142 Self {
143 inner,
144 hooks: Vec::new(),
145 }
146 }
147
148 pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
149 self.hooks.push(hook);
150 self
151 }
152
153 pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
154 self.hooks.push(hook);
155 }
156}
157
158macro_rules! hooked_create {
165 ($self:ident, $before:ident, $after:ident, $inner_method:ident, $input:ident) => {{
166 for hook in &$self.hooks {
167 hook.$before(&mut $input).await?;
168 }
169 let result = $self.inner.$inner_method($input).await?;
170 for hook in &$self.hooks {
171 hook.$after(&result).await?;
172 }
173 Ok(result)
174 }};
175}
176
177macro_rules! hooked_update {
179 ($self:ident, $before:ident, $after:ident, $inner_method:ident, $id:expr, $input:ident) => {{
180 for hook in &$self.hooks {
181 hook.$before($id, &mut $input).await?;
182 }
183 let result = $self.inner.$inner_method($id, $input).await?;
184 for hook in &$self.hooks {
185 hook.$after(&result).await?;
186 }
187 Ok(result)
188 }};
189}
190
191macro_rules! hooked_delete {
194 ($self:ident, $before:ident, $after:ident, $inner_method:ident, $key:expr) => {{
195 for hook in &$self.hooks {
196 hook.$before($key).await?;
197 }
198 $self.inner.$inner_method($key).await?;
199 for hook in &$self.hooks {
200 hook.$after($key).await?;
201 }
202 Ok(())
203 }};
204}
205
206#[async_trait]
207impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
208 type User = DB::User;
209
210 async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
211 hooked_create!(
212 self,
213 before_create_user,
214 after_create_user,
215 create_user,
216 user
217 )
218 }
219
220 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
221 self.inner.get_user_by_id(id).await
222 }
223
224 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
225 self.inner.get_user_by_email(email).await
226 }
227
228 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
229 self.inner.get_user_by_username(username).await
230 }
231
232 async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
233 hooked_update!(
234 self,
235 before_update_user,
236 after_update_user,
237 update_user,
238 id,
239 update
240 )
241 }
242
243 async fn delete_user(&self, id: &str) -> AuthResult<()> {
244 hooked_delete!(self, before_delete_user, after_delete_user, delete_user, id)
245 }
246
247 async fn list_users(&self, params: ListUsersParams) -> AuthResult<(Vec<Self::User>, usize)> {
248 self.inner.list_users(params).await
249 }
250}
251
252#[async_trait]
253impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
254 type Session = DB::Session;
255
256 async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
257 hooked_create!(
258 self,
259 before_create_session,
260 after_create_session,
261 create_session,
262 session
263 )
264 }
265
266 async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
267 self.inner.get_session(token).await
268 }
269
270 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
271 self.inner.get_user_sessions(user_id).await
272 }
273
274 async fn update_session_expiry(
275 &self,
276 token: &str,
277 expires_at: DateTime<Utc>,
278 ) -> AuthResult<()> {
279 self.inner.update_session_expiry(token, expires_at).await
280 }
281
282 async fn delete_session(&self, token: &str) -> AuthResult<()> {
283 hooked_delete!(
284 self,
285 before_delete_session,
286 after_delete_session,
287 delete_session,
288 token
289 )
290 }
291
292 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
293 self.inner.delete_user_sessions(user_id).await
294 }
295
296 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
297 self.inner.delete_expired_sessions().await
298 }
299
300 async fn update_session_active_organization(
301 &self,
302 token: &str,
303 organization_id: Option<&str>,
304 ) -> AuthResult<Self::Session> {
305 self.inner
306 .update_session_active_organization(token, organization_id)
307 .await
308 }
309}
310
311#[async_trait]
312impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
313 type Account = DB::Account;
314
315 async fn create_account(&self, mut account: CreateAccount) -> AuthResult<Self::Account> {
316 hooked_create!(
317 self,
318 before_create_account,
319 after_create_account,
320 create_account,
321 account
322 )
323 }
324
325 async fn get_account(
326 &self,
327 provider: &str,
328 provider_account_id: &str,
329 ) -> AuthResult<Option<Self::Account>> {
330 self.inner.get_account(provider, provider_account_id).await
331 }
332
333 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
334 self.inner.get_user_accounts(user_id).await
335 }
336
337 async fn update_account(
338 &self,
339 id: &str,
340 mut update: UpdateAccount,
341 ) -> AuthResult<Self::Account> {
342 hooked_update!(
343 self,
344 before_update_account,
345 after_update_account,
346 update_account,
347 id,
348 update
349 )
350 }
351
352 async fn delete_account(&self, id: &str) -> AuthResult<()> {
353 hooked_delete!(
354 self,
355 before_delete_account,
356 after_delete_account,
357 delete_account,
358 id
359 )
360 }
361}
362
363#[async_trait]
364impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
365 type Verification = DB::Verification;
366
367 async fn create_verification(
368 &self,
369 mut verification: CreateVerification,
370 ) -> AuthResult<Self::Verification> {
371 hooked_create!(
372 self,
373 before_create_verification,
374 after_create_verification,
375 create_verification,
376 verification
377 )
378 }
379
380 async fn get_verification(
381 &self,
382 identifier: &str,
383 value: &str,
384 ) -> AuthResult<Option<Self::Verification>> {
385 self.inner.get_verification(identifier, value).await
386 }
387
388 async fn get_verification_by_value(
389 &self,
390 value: &str,
391 ) -> AuthResult<Option<Self::Verification>> {
392 self.inner.get_verification_by_value(value).await
393 }
394
395 async fn get_verification_by_identifier(
396 &self,
397 identifier: &str,
398 ) -> AuthResult<Option<Self::Verification>> {
399 self.inner.get_verification_by_identifier(identifier).await
400 }
401
402 async fn consume_verification(
403 &self,
404 identifier: &str,
405 value: &str,
406 ) -> AuthResult<Option<Self::Verification>> {
407 self.inner.consume_verification(identifier, value).await
408 }
409
410 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
411 hooked_delete!(
412 self,
413 before_delete_verification,
414 after_delete_verification,
415 delete_verification,
416 id
417 )
418 }
419
420 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
421 self.inner.delete_expired_verifications().await
422 }
423}
424
425#[async_trait]
426impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
427 type Organization = DB::Organization;
428
429 async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
430 self.inner.create_organization(org).await
431 }
432
433 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
434 self.inner.get_organization_by_id(id).await
435 }
436
437 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
438 self.inner.get_organization_by_slug(slug).await
439 }
440
441 async fn update_organization(
442 &self,
443 id: &str,
444 update: UpdateOrganization,
445 ) -> AuthResult<Self::Organization> {
446 self.inner.update_organization(id, update).await
447 }
448
449 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
450 self.inner.delete_organization(id).await
451 }
452
453 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
454 self.inner.list_user_organizations(user_id).await
455 }
456}
457
458#[async_trait]
459impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
460 type Member = DB::Member;
461
462 async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
463 self.inner.create_member(member).await
464 }
465
466 async fn get_member(
467 &self,
468 organization_id: &str,
469 user_id: &str,
470 ) -> AuthResult<Option<Self::Member>> {
471 self.inner.get_member(organization_id, user_id).await
472 }
473
474 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
475 self.inner.get_member_by_id(id).await
476 }
477
478 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
479 self.inner.update_member_role(member_id, role).await
480 }
481
482 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
483 self.inner.delete_member(member_id).await
484 }
485
486 async fn list_organization_members(
487 &self,
488 organization_id: &str,
489 ) -> AuthResult<Vec<Self::Member>> {
490 self.inner.list_organization_members(organization_id).await
491 }
492
493 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
494 self.inner.count_organization_members(organization_id).await
495 }
496
497 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
498 self.inner.count_organization_owners(organization_id).await
499 }
500}
501
502#[async_trait]
503impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
504 type Invitation = DB::Invitation;
505
506 async fn create_invitation(
507 &self,
508 invitation: CreateInvitation,
509 ) -> AuthResult<Self::Invitation> {
510 self.inner.create_invitation(invitation).await
511 }
512
513 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
514 self.inner.get_invitation_by_id(id).await
515 }
516
517 async fn get_pending_invitation(
518 &self,
519 organization_id: &str,
520 email: &str,
521 ) -> AuthResult<Option<Self::Invitation>> {
522 self.inner
523 .get_pending_invitation(organization_id, email)
524 .await
525 }
526
527 async fn update_invitation_status(
528 &self,
529 id: &str,
530 status: InvitationStatus,
531 ) -> AuthResult<Self::Invitation> {
532 self.inner.update_invitation_status(id, status).await
533 }
534
535 async fn list_organization_invitations(
536 &self,
537 organization_id: &str,
538 ) -> AuthResult<Vec<Self::Invitation>> {
539 self.inner
540 .list_organization_invitations(organization_id)
541 .await
542 }
543
544 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
545 self.inner.list_user_invitations(email).await
546 }
547}
548
549#[async_trait]
550impl<DB: DatabaseAdapter> TwoFactorOps for HookedDatabaseAdapter<DB> {
551 type TwoFactor = DB::TwoFactor;
552
553 async fn create_two_factor(&self, two_factor: CreateTwoFactor) -> AuthResult<Self::TwoFactor> {
554 self.inner.create_two_factor(two_factor).await
555 }
556
557 async fn get_two_factor_by_user_id(
558 &self,
559 user_id: &str,
560 ) -> AuthResult<Option<Self::TwoFactor>> {
561 self.inner.get_two_factor_by_user_id(user_id).await
562 }
563
564 async fn update_two_factor_backup_codes(
565 &self,
566 user_id: &str,
567 backup_codes: &str,
568 ) -> AuthResult<Self::TwoFactor> {
569 self.inner
570 .update_two_factor_backup_codes(user_id, backup_codes)
571 .await
572 }
573
574 async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
575 self.inner.delete_two_factor(user_id).await
576 }
577}
578
579#[async_trait]
580impl<DB: DatabaseAdapter> ApiKeyOps for HookedDatabaseAdapter<DB> {
581 type ApiKey = DB::ApiKey;
582
583 async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<Self::ApiKey> {
584 self.inner.create_api_key(input).await
585 }
586
587 async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<Self::ApiKey>> {
588 self.inner.get_api_key_by_id(id).await
589 }
590
591 async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<Self::ApiKey>> {
592 self.inner.get_api_key_by_hash(hash).await
593 }
594
595 async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<Self::ApiKey>> {
596 self.inner.list_api_keys_by_user(user_id).await
597 }
598
599 async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<Self::ApiKey> {
600 self.inner.update_api_key(id, update).await
601 }
602
603 async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
604 self.inner.delete_api_key(id).await
605 }
606
607 async fn delete_expired_api_keys(&self) -> AuthResult<usize> {
608 self.inner.delete_expired_api_keys().await
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use crate::adapters::MemoryDatabaseAdapter;
616 use crate::types::{
617 Account, CreateAccount, CreateUser, CreateVerification, UpdateAccount, UpdateUser, User,
618 Verification,
619 };
620 use std::sync::atomic::{AtomicU32, Ordering};
621
622 struct CountingHook {
629 before_create_user_count: AtomicU32,
631 after_create_user_count: AtomicU32,
632 before_update_user_count: AtomicU32,
633 after_update_user_count: AtomicU32,
634 before_delete_user_count: AtomicU32,
635 after_delete_user_count: AtomicU32,
636 before_create_account_count: AtomicU32,
638 after_create_account_count: AtomicU32,
639 before_update_account_count: AtomicU32,
640 after_update_account_count: AtomicU32,
641 before_delete_account_count: AtomicU32,
642 after_delete_account_count: AtomicU32,
643 before_create_verification_count: AtomicU32,
645 after_create_verification_count: AtomicU32,
646 before_delete_verification_count: AtomicU32,
647 after_delete_verification_count: AtomicU32,
648 }
649
650 impl CountingHook {
651 fn new() -> Self {
652 Self {
653 before_create_user_count: AtomicU32::new(0),
654 after_create_user_count: AtomicU32::new(0),
655 before_update_user_count: AtomicU32::new(0),
656 after_update_user_count: AtomicU32::new(0),
657 before_delete_user_count: AtomicU32::new(0),
658 after_delete_user_count: AtomicU32::new(0),
659 before_create_account_count: AtomicU32::new(0),
660 after_create_account_count: AtomicU32::new(0),
661 before_update_account_count: AtomicU32::new(0),
662 after_update_account_count: AtomicU32::new(0),
663 before_delete_account_count: AtomicU32::new(0),
664 after_delete_account_count: AtomicU32::new(0),
665 before_create_verification_count: AtomicU32::new(0),
666 after_create_verification_count: AtomicU32::new(0),
667 before_delete_verification_count: AtomicU32::new(0),
668 after_delete_verification_count: AtomicU32::new(0),
669 }
670 }
671 }
672
673 #[async_trait]
674 impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
675 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
677 self.before_create_user_count.fetch_add(1, Ordering::SeqCst);
678 Ok(())
679 }
680 async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
681 self.after_create_user_count.fetch_add(1, Ordering::SeqCst);
682 Ok(())
683 }
684 async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
685 self.before_update_user_count.fetch_add(1, Ordering::SeqCst);
686 Ok(())
687 }
688 async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
689 self.after_update_user_count.fetch_add(1, Ordering::SeqCst);
690 Ok(())
691 }
692 async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
693 self.before_delete_user_count.fetch_add(1, Ordering::SeqCst);
694 Ok(())
695 }
696 async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
697 self.after_delete_user_count.fetch_add(1, Ordering::SeqCst);
698 Ok(())
699 }
700 async fn before_create_account(&self, _account: &mut CreateAccount) -> AuthResult<()> {
702 self.before_create_account_count
703 .fetch_add(1, Ordering::SeqCst);
704 Ok(())
705 }
706 async fn after_create_account(&self, _account: &Account) -> AuthResult<()> {
707 self.after_create_account_count
708 .fetch_add(1, Ordering::SeqCst);
709 Ok(())
710 }
711 async fn before_update_account(
712 &self,
713 _id: &str,
714 _update: &mut UpdateAccount,
715 ) -> AuthResult<()> {
716 self.before_update_account_count
717 .fetch_add(1, Ordering::SeqCst);
718 Ok(())
719 }
720 async fn after_update_account(&self, _account: &Account) -> AuthResult<()> {
721 self.after_update_account_count
722 .fetch_add(1, Ordering::SeqCst);
723 Ok(())
724 }
725 async fn before_delete_account(&self, _id: &str) -> AuthResult<()> {
726 self.before_delete_account_count
727 .fetch_add(1, Ordering::SeqCst);
728 Ok(())
729 }
730 async fn after_delete_account(&self, _id: &str) -> AuthResult<()> {
731 self.after_delete_account_count
732 .fetch_add(1, Ordering::SeqCst);
733 Ok(())
734 }
735 async fn before_create_verification(&self, _v: &mut CreateVerification) -> AuthResult<()> {
737 self.before_create_verification_count
738 .fetch_add(1, Ordering::SeqCst);
739 Ok(())
740 }
741 async fn after_create_verification(&self, _v: &Verification) -> AuthResult<()> {
742 self.after_create_verification_count
743 .fetch_add(1, Ordering::SeqCst);
744 Ok(())
745 }
746 async fn before_delete_verification(&self, _id: &str) -> AuthResult<()> {
747 self.before_delete_verification_count
748 .fetch_add(1, Ordering::SeqCst);
749 Ok(())
750 }
751 async fn after_delete_verification(&self, _id: &str) -> AuthResult<()> {
752 self.after_delete_verification_count
753 .fetch_add(1, Ordering::SeqCst);
754 Ok(())
755 }
756 }
757
758 fn test_create_account(user_id: &str) -> CreateAccount {
762 CreateAccount {
763 user_id: user_id.to_string(),
764 account_id: "provider_123".to_string(),
765 provider_id: "google".to_string(),
766 access_token: None,
767 refresh_token: None,
768 id_token: None,
769 access_token_expires_at: None,
770 refresh_token_expires_at: None,
771 scope: None,
772 password: None,
773 }
774 }
775
776 #[tokio::test]
777 async fn test_hooks_called_on_create_user() {
778 let hook = Arc::new(CountingHook::new());
779 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
780 .with_hook(hook.clone());
781
782 let create = CreateUser::new()
783 .with_email("test@example.com")
784 .with_name("Test");
785 db.create_user(create).await.unwrap();
786
787 assert_eq!(hook.before_create_user_count.load(Ordering::SeqCst), 1);
788 assert_eq!(hook.after_create_user_count.load(Ordering::SeqCst), 1);
789 }
790
791 #[tokio::test]
792 async fn test_hooks_called_on_update_user() {
793 let hook = Arc::new(CountingHook::new());
794 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
795 .with_hook(hook.clone());
796
797 let create = CreateUser::new()
798 .with_email("test@example.com")
799 .with_name("Test");
800 let user = db.create_user(create).await.unwrap();
801
802 let update = UpdateUser {
803 name: Some("Updated".to_string()),
804 email: None,
805 image: None,
806 email_verified: None,
807 username: None,
808 display_username: None,
809 role: None,
810 banned: None,
811 ban_reason: None,
812 ban_expires: None,
813 two_factor_enabled: None,
814 metadata: None,
815 };
816 db.update_user(&user.id, update).await.unwrap();
817
818 assert_eq!(hook.before_update_user_count.load(Ordering::SeqCst), 1);
819 assert_eq!(hook.after_update_user_count.load(Ordering::SeqCst), 1);
820 }
821
822 #[tokio::test]
823 async fn test_hooks_called_on_delete_user() {
824 let hook = Arc::new(CountingHook::new());
825 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
826 .with_hook(hook.clone());
827
828 let create = CreateUser::new()
829 .with_email("test@example.com")
830 .with_name("Test");
831 let user = db.create_user(create).await.unwrap();
832
833 db.delete_user(&user.id).await.unwrap();
834
835 assert_eq!(hook.before_delete_user_count.load(Ordering::SeqCst), 1);
836 assert_eq!(hook.after_delete_user_count.load(Ordering::SeqCst), 1);
837 }
838
839 #[tokio::test]
840 async fn test_before_hook_can_reject() {
841 struct RejectHook;
842
843 #[async_trait]
844 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
845 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
846 Err(crate::error::AuthError::forbidden("Hook rejected"))
847 }
848 }
849
850 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
851 .with_hook(Arc::new(RejectHook));
852
853 let create = CreateUser::new()
854 .with_email("test@example.com")
855 .with_name("Test");
856 let result = db.create_user(create).await;
857
858 assert!(result.is_err());
859 assert_eq!(result.unwrap_err().status_code(), 403);
860 }
861
862 #[tokio::test]
863 async fn test_multiple_hooks() {
864 let hook1 = Arc::new(CountingHook::new());
865 let hook2 = Arc::new(CountingHook::new());
866 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
867 .with_hook(hook1.clone())
868 .with_hook(hook2.clone());
869
870 let create = CreateUser::new()
871 .with_email("test@example.com")
872 .with_name("Test");
873 db.create_user(create).await.unwrap();
874
875 assert_eq!(hook1.before_create_user_count.load(Ordering::SeqCst), 1);
876 assert_eq!(hook2.before_create_user_count.load(Ordering::SeqCst), 1);
877 assert_eq!(hook1.after_create_user_count.load(Ordering::SeqCst), 1);
878 assert_eq!(hook2.after_create_user_count.load(Ordering::SeqCst), 1);
879 }
880
881 #[tokio::test]
882 async fn test_passthrough_operations() {
883 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
884
885 let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
886 assert!(result.is_none());
887 }
888
889 #[tokio::test]
892 async fn test_account_hooks_create() {
893 let hook = Arc::new(CountingHook::new());
894 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
895 .with_hook(hook.clone());
896
897 let user = db
898 .create_user(CreateUser::new().with_email("test@example.com"))
899 .await
900 .unwrap();
901
902 db.create_account(test_create_account(&user.id))
903 .await
904 .unwrap();
905
906 assert_eq!(hook.before_create_account_count.load(Ordering::SeqCst), 1);
907 assert_eq!(hook.after_create_account_count.load(Ordering::SeqCst), 1);
908 }
909
910 #[tokio::test]
911 async fn test_account_hooks_update() {
912 let hook = Arc::new(CountingHook::new());
913 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
914 .with_hook(hook.clone());
915
916 let user = db
917 .create_user(CreateUser::new().with_email("test@example.com"))
918 .await
919 .unwrap();
920
921 let created = db
922 .create_account(test_create_account(&user.id))
923 .await
924 .unwrap();
925
926 let update = UpdateAccount {
927 access_token: Some("new_tok".to_string()),
928 ..Default::default()
929 };
930 db.update_account(&created.id, update).await.unwrap();
931
932 assert_eq!(hook.before_update_account_count.load(Ordering::SeqCst), 1);
933 assert_eq!(hook.after_update_account_count.load(Ordering::SeqCst), 1);
934 }
935
936 #[tokio::test]
937 async fn test_account_hooks_delete() {
938 let hook = Arc::new(CountingHook::new());
939 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
940 .with_hook(hook.clone());
941
942 let user = db
943 .create_user(CreateUser::new().with_email("test@example.com"))
944 .await
945 .unwrap();
946
947 let created = db
948 .create_account(test_create_account(&user.id))
949 .await
950 .unwrap();
951
952 db.delete_account(&created.id).await.unwrap();
953
954 assert_eq!(hook.before_delete_account_count.load(Ordering::SeqCst), 1);
955 assert_eq!(hook.after_delete_account_count.load(Ordering::SeqCst), 1);
956 }
957
958 #[tokio::test]
959 async fn test_account_before_hook_can_reject() {
960 struct RejectAccountHook;
961
962 #[async_trait]
963 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectAccountHook {
964 async fn before_create_account(&self, _account: &mut CreateAccount) -> AuthResult<()> {
965 Err(crate::error::AuthError::forbidden("Account hook rejected"))
966 }
967 }
968
969 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
970 .with_hook(Arc::new(RejectAccountHook));
971
972 let user = db
973 .create_user(CreateUser::new().with_email("test@example.com"))
974 .await
975 .unwrap();
976
977 let result = db.create_account(test_create_account(&user.id)).await;
978
979 assert!(result.is_err());
980 assert_eq!(result.unwrap_err().status_code(), 403);
981 }
982
983 #[tokio::test]
986 async fn test_verification_hooks_create() {
987 let hook = Arc::new(CountingHook::new());
988 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
989 .with_hook(hook.clone());
990
991 let verification = CreateVerification {
992 identifier: "email_verification:test@example.com".to_string(),
993 value: "token_abc".to_string(),
994 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
995 };
996 db.create_verification(verification).await.unwrap();
997
998 assert_eq!(
999 hook.before_create_verification_count.load(Ordering::SeqCst),
1000 1
1001 );
1002 assert_eq!(
1003 hook.after_create_verification_count.load(Ordering::SeqCst),
1004 1
1005 );
1006 }
1007
1008 #[tokio::test]
1009 async fn test_verification_hooks_delete() {
1010 let hook = Arc::new(CountingHook::new());
1011 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
1012 .with_hook(hook.clone());
1013
1014 let verification = CreateVerification {
1015 identifier: "email_verification:test@example.com".to_string(),
1016 value: "token_abc".to_string(),
1017 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
1018 };
1019 let created = db.create_verification(verification).await.unwrap();
1020
1021 db.delete_verification(&created.id).await.unwrap();
1022
1023 assert_eq!(
1024 hook.before_delete_verification_count.load(Ordering::SeqCst),
1025 1
1026 );
1027 assert_eq!(
1028 hook.after_delete_verification_count.load(Ordering::SeqCst),
1029 1
1030 );
1031 }
1032
1033 #[tokio::test]
1034 async fn test_verification_before_hook_can_reject() {
1035 struct RejectVerificationHook;
1036
1037 #[async_trait]
1038 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectVerificationHook {
1039 async fn before_create_verification(
1040 &self,
1041 _v: &mut CreateVerification,
1042 ) -> AuthResult<()> {
1043 Err(crate::error::AuthError::forbidden(
1044 "Verification hook rejected",
1045 ))
1046 }
1047 }
1048
1049 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
1050 .with_hook(Arc::new(RejectVerificationHook));
1051
1052 let verification = CreateVerification {
1053 identifier: "test".to_string(),
1054 value: "val".to_string(),
1055 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
1056 };
1057 let result = db.create_verification(verification).await;
1058
1059 assert!(result.is_err());
1060 assert_eq!(result.unwrap_err().status_code(), 403);
1061 }
1062}