1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7 AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, TwoFactorOps,
8 UserOps, VerificationOps,
9};
10use crate::error::AuthResult;
11use crate::types::{
12 CreateAccount, CreateApiKey, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
13 CreateTwoFactor, CreateUser, CreateVerification, InvitationStatus, ListUsersParams,
14 UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser,
15};
16
17#[async_trait]
25pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
26 async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
27 let _ = user;
28 Ok(())
29 }
30
31 async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
32 let _ = user;
33 Ok(())
34 }
35
36 async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
37 let _ = (id, update);
38 Ok(())
39 }
40
41 async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
42 let _ = user;
43 Ok(())
44 }
45
46 async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
47 let _ = id;
48 Ok(())
49 }
50
51 async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52 let _ = id;
53 Ok(())
54 }
55
56 async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
57 let _ = session;
58 Ok(())
59 }
60
61 async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
62 let _ = session;
63 Ok(())
64 }
65
66 async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
67 let _ = token;
68 Ok(())
69 }
70
71 async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
72 let _ = token;
73 Ok(())
74 }
75}
76
77pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
79 inner: Arc<DB>,
80 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
81}
82
83impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
84 pub fn new(inner: Arc<DB>) -> Self {
85 Self {
86 inner,
87 hooks: Vec::new(),
88 }
89 }
90
91 pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
92 self.hooks.push(hook);
93 self
94 }
95
96 pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
97 self.hooks.push(hook);
98 }
99}
100
101#[async_trait]
102impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
103 type User = DB::User;
104
105 async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
106 for hook in &self.hooks {
107 hook.before_create_user(&mut user).await?;
108 }
109 let result = self.inner.create_user(user).await?;
110 for hook in &self.hooks {
111 hook.after_create_user(&result).await?;
112 }
113 Ok(result)
114 }
115
116 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
117 self.inner.get_user_by_id(id).await
118 }
119
120 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
121 self.inner.get_user_by_email(email).await
122 }
123
124 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
125 self.inner.get_user_by_username(username).await
126 }
127
128 async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
129 for hook in &self.hooks {
130 hook.before_update_user(id, &mut update).await?;
131 }
132 let result = self.inner.update_user(id, update).await?;
133 for hook in &self.hooks {
134 hook.after_update_user(&result).await?;
135 }
136 Ok(result)
137 }
138
139 async fn delete_user(&self, id: &str) -> AuthResult<()> {
140 for hook in &self.hooks {
141 hook.before_delete_user(id).await?;
142 }
143 self.inner.delete_user(id).await?;
144 for hook in &self.hooks {
145 hook.after_delete_user(id).await?;
146 }
147 Ok(())
148 }
149
150 async fn list_users(&self, params: ListUsersParams) -> AuthResult<(Vec<Self::User>, usize)> {
151 self.inner.list_users(params).await
152 }
153}
154
155#[async_trait]
156impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
157 type Session = DB::Session;
158
159 async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
160 for hook in &self.hooks {
161 hook.before_create_session(&mut session).await?;
162 }
163 let result = self.inner.create_session(session).await?;
164 for hook in &self.hooks {
165 hook.after_create_session(&result).await?;
166 }
167 Ok(result)
168 }
169
170 async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
171 self.inner.get_session(token).await
172 }
173
174 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
175 self.inner.get_user_sessions(user_id).await
176 }
177
178 async fn update_session_expiry(
179 &self,
180 token: &str,
181 expires_at: DateTime<Utc>,
182 ) -> AuthResult<()> {
183 self.inner.update_session_expiry(token, expires_at).await
184 }
185
186 async fn delete_session(&self, token: &str) -> AuthResult<()> {
187 for hook in &self.hooks {
188 hook.before_delete_session(token).await?;
189 }
190 self.inner.delete_session(token).await?;
191 for hook in &self.hooks {
192 hook.after_delete_session(token).await?;
193 }
194 Ok(())
195 }
196
197 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
198 self.inner.delete_user_sessions(user_id).await
199 }
200
201 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
202 self.inner.delete_expired_sessions().await
203 }
204
205 async fn update_session_active_organization(
206 &self,
207 token: &str,
208 organization_id: Option<&str>,
209 ) -> AuthResult<Self::Session> {
210 self.inner
211 .update_session_active_organization(token, organization_id)
212 .await
213 }
214}
215
216#[async_trait]
217impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
218 type Account = DB::Account;
219
220 async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
221 self.inner.create_account(account).await
222 }
223
224 async fn get_account(
225 &self,
226 provider: &str,
227 provider_account_id: &str,
228 ) -> AuthResult<Option<Self::Account>> {
229 self.inner.get_account(provider, provider_account_id).await
230 }
231
232 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
233 self.inner.get_user_accounts(user_id).await
234 }
235
236 async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<Self::Account> {
237 self.inner.update_account(id, update).await
238 }
239
240 async fn delete_account(&self, id: &str) -> AuthResult<()> {
241 self.inner.delete_account(id).await
242 }
243}
244
245#[async_trait]
246impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
247 type Verification = DB::Verification;
248
249 async fn create_verification(
250 &self,
251 verification: CreateVerification,
252 ) -> AuthResult<Self::Verification> {
253 self.inner.create_verification(verification).await
254 }
255
256 async fn get_verification(
257 &self,
258 identifier: &str,
259 value: &str,
260 ) -> AuthResult<Option<Self::Verification>> {
261 self.inner.get_verification(identifier, value).await
262 }
263
264 async fn get_verification_by_value(
265 &self,
266 value: &str,
267 ) -> AuthResult<Option<Self::Verification>> {
268 self.inner.get_verification_by_value(value).await
269 }
270
271 async fn get_verification_by_identifier(
272 &self,
273 identifier: &str,
274 ) -> AuthResult<Option<Self::Verification>> {
275 self.inner.get_verification_by_identifier(identifier).await
276 }
277
278 async fn consume_verification(
279 &self,
280 identifier: &str,
281 value: &str,
282 ) -> AuthResult<Option<Self::Verification>> {
283 self.inner.consume_verification(identifier, value).await
284 }
285
286 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
287 self.inner.delete_verification(id).await
288 }
289
290 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
291 self.inner.delete_expired_verifications().await
292 }
293}
294
295#[async_trait]
296impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
297 type Organization = DB::Organization;
298
299 async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
300 self.inner.create_organization(org).await
301 }
302
303 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
304 self.inner.get_organization_by_id(id).await
305 }
306
307 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
308 self.inner.get_organization_by_slug(slug).await
309 }
310
311 async fn update_organization(
312 &self,
313 id: &str,
314 update: UpdateOrganization,
315 ) -> AuthResult<Self::Organization> {
316 self.inner.update_organization(id, update).await
317 }
318
319 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
320 self.inner.delete_organization(id).await
321 }
322
323 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
324 self.inner.list_user_organizations(user_id).await
325 }
326}
327
328#[async_trait]
329impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
330 type Member = DB::Member;
331
332 async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
333 self.inner.create_member(member).await
334 }
335
336 async fn get_member(
337 &self,
338 organization_id: &str,
339 user_id: &str,
340 ) -> AuthResult<Option<Self::Member>> {
341 self.inner.get_member(organization_id, user_id).await
342 }
343
344 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
345 self.inner.get_member_by_id(id).await
346 }
347
348 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
349 self.inner.update_member_role(member_id, role).await
350 }
351
352 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
353 self.inner.delete_member(member_id).await
354 }
355
356 async fn list_organization_members(
357 &self,
358 organization_id: &str,
359 ) -> AuthResult<Vec<Self::Member>> {
360 self.inner.list_organization_members(organization_id).await
361 }
362
363 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
364 self.inner.count_organization_members(organization_id).await
365 }
366
367 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
368 self.inner.count_organization_owners(organization_id).await
369 }
370}
371
372#[async_trait]
373impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
374 type Invitation = DB::Invitation;
375
376 async fn create_invitation(
377 &self,
378 invitation: CreateInvitation,
379 ) -> AuthResult<Self::Invitation> {
380 self.inner.create_invitation(invitation).await
381 }
382
383 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
384 self.inner.get_invitation_by_id(id).await
385 }
386
387 async fn get_pending_invitation(
388 &self,
389 organization_id: &str,
390 email: &str,
391 ) -> AuthResult<Option<Self::Invitation>> {
392 self.inner
393 .get_pending_invitation(organization_id, email)
394 .await
395 }
396
397 async fn update_invitation_status(
398 &self,
399 id: &str,
400 status: InvitationStatus,
401 ) -> AuthResult<Self::Invitation> {
402 self.inner.update_invitation_status(id, status).await
403 }
404
405 async fn list_organization_invitations(
406 &self,
407 organization_id: &str,
408 ) -> AuthResult<Vec<Self::Invitation>> {
409 self.inner
410 .list_organization_invitations(organization_id)
411 .await
412 }
413
414 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
415 self.inner.list_user_invitations(email).await
416 }
417}
418
419#[async_trait]
420impl<DB: DatabaseAdapter> TwoFactorOps for HookedDatabaseAdapter<DB> {
421 type TwoFactor = DB::TwoFactor;
422
423 async fn create_two_factor(&self, two_factor: CreateTwoFactor) -> AuthResult<Self::TwoFactor> {
424 self.inner.create_two_factor(two_factor).await
425 }
426
427 async fn get_two_factor_by_user_id(
428 &self,
429 user_id: &str,
430 ) -> AuthResult<Option<Self::TwoFactor>> {
431 self.inner.get_two_factor_by_user_id(user_id).await
432 }
433
434 async fn update_two_factor_backup_codes(
435 &self,
436 user_id: &str,
437 backup_codes: &str,
438 ) -> AuthResult<Self::TwoFactor> {
439 self.inner
440 .update_two_factor_backup_codes(user_id, backup_codes)
441 .await
442 }
443
444 async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
445 self.inner.delete_two_factor(user_id).await
446 }
447}
448
449#[async_trait]
450impl<DB: DatabaseAdapter> ApiKeyOps for HookedDatabaseAdapter<DB> {
451 type ApiKey = DB::ApiKey;
452
453 async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<Self::ApiKey> {
454 self.inner.create_api_key(input).await
455 }
456
457 async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<Self::ApiKey>> {
458 self.inner.get_api_key_by_id(id).await
459 }
460
461 async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<Self::ApiKey>> {
462 self.inner.get_api_key_by_hash(hash).await
463 }
464
465 async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<Self::ApiKey>> {
466 self.inner.list_api_keys_by_user(user_id).await
467 }
468
469 async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<Self::ApiKey> {
470 self.inner.update_api_key(id, update).await
471 }
472
473 async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
474 self.inner.delete_api_key(id).await
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::adapters::MemoryDatabaseAdapter;
482 use crate::types::{CreateUser, UpdateUser, User};
483 use std::sync::atomic::{AtomicU32, Ordering};
484
485 struct CountingHook {
486 before_create_count: AtomicU32,
487 after_create_count: AtomicU32,
488 before_update_count: AtomicU32,
489 after_update_count: AtomicU32,
490 before_delete_count: AtomicU32,
491 after_delete_count: AtomicU32,
492 }
493
494 impl CountingHook {
495 fn new() -> Self {
496 Self {
497 before_create_count: AtomicU32::new(0),
498 after_create_count: AtomicU32::new(0),
499 before_update_count: AtomicU32::new(0),
500 after_update_count: AtomicU32::new(0),
501 before_delete_count: AtomicU32::new(0),
502 after_delete_count: AtomicU32::new(0),
503 }
504 }
505 }
506
507 #[async_trait]
508 impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
509 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
510 self.before_create_count.fetch_add(1, Ordering::SeqCst);
511 Ok(())
512 }
513 async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
514 self.after_create_count.fetch_add(1, Ordering::SeqCst);
515 Ok(())
516 }
517 async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
518 self.before_update_count.fetch_add(1, Ordering::SeqCst);
519 Ok(())
520 }
521 async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
522 self.after_update_count.fetch_add(1, Ordering::SeqCst);
523 Ok(())
524 }
525 async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
526 self.before_delete_count.fetch_add(1, Ordering::SeqCst);
527 Ok(())
528 }
529 async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
530 self.after_delete_count.fetch_add(1, Ordering::SeqCst);
531 Ok(())
532 }
533 }
534
535 #[tokio::test]
536 async fn test_hooks_called_on_create_user() {
537 let hook = Arc::new(CountingHook::new());
538 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
539 .with_hook(hook.clone());
540
541 let create = CreateUser::new()
542 .with_email("test@example.com")
543 .with_name("Test");
544 db.create_user(create).await.unwrap();
545
546 assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
547 assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
548 }
549
550 #[tokio::test]
551 async fn test_hooks_called_on_update_user() {
552 let hook = Arc::new(CountingHook::new());
553 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
554 .with_hook(hook.clone());
555
556 let create = CreateUser::new()
557 .with_email("test@example.com")
558 .with_name("Test");
559 let user = db.create_user(create).await.unwrap();
560
561 let update = UpdateUser {
562 name: Some("Updated".to_string()),
563 email: None,
564 image: None,
565 email_verified: None,
566 username: None,
567 display_username: None,
568 role: None,
569 banned: None,
570 ban_reason: None,
571 ban_expires: None,
572 two_factor_enabled: None,
573 metadata: None,
574 };
575 db.update_user(&user.id, update).await.unwrap();
576
577 assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
578 assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
579 }
580
581 #[tokio::test]
582 async fn test_hooks_called_on_delete_user() {
583 let hook = Arc::new(CountingHook::new());
584 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
585 .with_hook(hook.clone());
586
587 let create = CreateUser::new()
588 .with_email("test@example.com")
589 .with_name("Test");
590 let user = db.create_user(create).await.unwrap();
591
592 db.delete_user(&user.id).await.unwrap();
593
594 assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
595 assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
596 }
597
598 #[tokio::test]
599 async fn test_before_hook_can_reject() {
600 struct RejectHook;
601
602 #[async_trait]
603 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
604 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
605 Err(crate::error::AuthError::forbidden("Hook rejected"))
606 }
607 }
608
609 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
610 .with_hook(Arc::new(RejectHook));
611
612 let create = CreateUser::new()
613 .with_email("test@example.com")
614 .with_name("Test");
615 let result = db.create_user(create).await;
616
617 assert!(result.is_err());
618 assert_eq!(result.unwrap_err().status_code(), 403);
619 }
620
621 #[tokio::test]
622 async fn test_multiple_hooks() {
623 let hook1 = Arc::new(CountingHook::new());
624 let hook2 = Arc::new(CountingHook::new());
625 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
626 .with_hook(hook1.clone())
627 .with_hook(hook2.clone());
628
629 let create = CreateUser::new()
630 .with_email("test@example.com")
631 .with_name("Test");
632 db.create_user(create).await.unwrap();
633
634 assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
635 assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
636 assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
637 assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
638 }
639
640 #[tokio::test]
641 async fn test_passthrough_operations() {
642 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
643
644 let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
645 assert!(result.is_none());
646 }
647}