1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7 AccountOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, TwoFactorOps, UserOps,
8 VerificationOps,
9};
10use crate::error::AuthResult;
11use crate::types::{
12 CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
13 CreateTwoFactor, CreateUser, CreateVerification, InvitationStatus, UpdateAccount,
14 UpdateOrganization, UpdateUser,
15};
16
17#[async_trait]
25pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
26 async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
27 let _ = user;
28 Ok(())
29 }
30
31 async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
32 let _ = user;
33 Ok(())
34 }
35
36 async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
37 let _ = (id, update);
38 Ok(())
39 }
40
41 async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
42 let _ = user;
43 Ok(())
44 }
45
46 async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
47 let _ = id;
48 Ok(())
49 }
50
51 async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52 let _ = id;
53 Ok(())
54 }
55
56 async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
57 let _ = session;
58 Ok(())
59 }
60
61 async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
62 let _ = session;
63 Ok(())
64 }
65
66 async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
67 let _ = token;
68 Ok(())
69 }
70
71 async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
72 let _ = token;
73 Ok(())
74 }
75}
76
77pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
79 inner: Arc<DB>,
80 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
81}
82
83impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
84 pub fn new(inner: Arc<DB>) -> Self {
85 Self {
86 inner,
87 hooks: Vec::new(),
88 }
89 }
90
91 pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
92 self.hooks.push(hook);
93 self
94 }
95
96 pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
97 self.hooks.push(hook);
98 }
99}
100
101#[async_trait]
102impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
103 type User = DB::User;
104
105 async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
106 for hook in &self.hooks {
107 hook.before_create_user(&mut user).await?;
108 }
109 let result = self.inner.create_user(user).await?;
110 for hook in &self.hooks {
111 hook.after_create_user(&result).await?;
112 }
113 Ok(result)
114 }
115
116 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
117 self.inner.get_user_by_id(id).await
118 }
119
120 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
121 self.inner.get_user_by_email(email).await
122 }
123
124 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
125 self.inner.get_user_by_username(username).await
126 }
127
128 async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
129 for hook in &self.hooks {
130 hook.before_update_user(id, &mut update).await?;
131 }
132 let result = self.inner.update_user(id, update).await?;
133 for hook in &self.hooks {
134 hook.after_update_user(&result).await?;
135 }
136 Ok(result)
137 }
138
139 async fn delete_user(&self, id: &str) -> AuthResult<()> {
140 for hook in &self.hooks {
141 hook.before_delete_user(id).await?;
142 }
143 self.inner.delete_user(id).await?;
144 for hook in &self.hooks {
145 hook.after_delete_user(id).await?;
146 }
147 Ok(())
148 }
149}
150
151#[async_trait]
152impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
153 type Session = DB::Session;
154
155 async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
156 for hook in &self.hooks {
157 hook.before_create_session(&mut session).await?;
158 }
159 let result = self.inner.create_session(session).await?;
160 for hook in &self.hooks {
161 hook.after_create_session(&result).await?;
162 }
163 Ok(result)
164 }
165
166 async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
167 self.inner.get_session(token).await
168 }
169
170 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
171 self.inner.get_user_sessions(user_id).await
172 }
173
174 async fn update_session_expiry(
175 &self,
176 token: &str,
177 expires_at: DateTime<Utc>,
178 ) -> AuthResult<()> {
179 self.inner.update_session_expiry(token, expires_at).await
180 }
181
182 async fn delete_session(&self, token: &str) -> AuthResult<()> {
183 for hook in &self.hooks {
184 hook.before_delete_session(token).await?;
185 }
186 self.inner.delete_session(token).await?;
187 for hook in &self.hooks {
188 hook.after_delete_session(token).await?;
189 }
190 Ok(())
191 }
192
193 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
194 self.inner.delete_user_sessions(user_id).await
195 }
196
197 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
198 self.inner.delete_expired_sessions().await
199 }
200
201 async fn update_session_active_organization(
202 &self,
203 token: &str,
204 organization_id: Option<&str>,
205 ) -> AuthResult<Self::Session> {
206 self.inner
207 .update_session_active_organization(token, organization_id)
208 .await
209 }
210}
211
212#[async_trait]
213impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
214 type Account = DB::Account;
215
216 async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
217 self.inner.create_account(account).await
218 }
219
220 async fn get_account(
221 &self,
222 provider: &str,
223 provider_account_id: &str,
224 ) -> AuthResult<Option<Self::Account>> {
225 self.inner.get_account(provider, provider_account_id).await
226 }
227
228 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
229 self.inner.get_user_accounts(user_id).await
230 }
231
232 async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<Self::Account> {
233 self.inner.update_account(id, update).await
234 }
235
236 async fn delete_account(&self, id: &str) -> AuthResult<()> {
237 self.inner.delete_account(id).await
238 }
239}
240
241#[async_trait]
242impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
243 type Verification = DB::Verification;
244
245 async fn create_verification(
246 &self,
247 verification: CreateVerification,
248 ) -> AuthResult<Self::Verification> {
249 self.inner.create_verification(verification).await
250 }
251
252 async fn get_verification(
253 &self,
254 identifier: &str,
255 value: &str,
256 ) -> AuthResult<Option<Self::Verification>> {
257 self.inner.get_verification(identifier, value).await
258 }
259
260 async fn get_verification_by_value(
261 &self,
262 value: &str,
263 ) -> AuthResult<Option<Self::Verification>> {
264 self.inner.get_verification_by_value(value).await
265 }
266
267 async fn get_verification_by_identifier(
268 &self,
269 identifier: &str,
270 ) -> AuthResult<Option<Self::Verification>> {
271 self.inner.get_verification_by_identifier(identifier).await
272 }
273
274 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
275 self.inner.delete_verification(id).await
276 }
277
278 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
279 self.inner.delete_expired_verifications().await
280 }
281}
282
283#[async_trait]
284impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
285 type Organization = DB::Organization;
286
287 async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
288 self.inner.create_organization(org).await
289 }
290
291 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
292 self.inner.get_organization_by_id(id).await
293 }
294
295 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
296 self.inner.get_organization_by_slug(slug).await
297 }
298
299 async fn update_organization(
300 &self,
301 id: &str,
302 update: UpdateOrganization,
303 ) -> AuthResult<Self::Organization> {
304 self.inner.update_organization(id, update).await
305 }
306
307 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
308 self.inner.delete_organization(id).await
309 }
310
311 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
312 self.inner.list_user_organizations(user_id).await
313 }
314}
315
316#[async_trait]
317impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
318 type Member = DB::Member;
319
320 async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
321 self.inner.create_member(member).await
322 }
323
324 async fn get_member(
325 &self,
326 organization_id: &str,
327 user_id: &str,
328 ) -> AuthResult<Option<Self::Member>> {
329 self.inner.get_member(organization_id, user_id).await
330 }
331
332 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
333 self.inner.get_member_by_id(id).await
334 }
335
336 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
337 self.inner.update_member_role(member_id, role).await
338 }
339
340 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
341 self.inner.delete_member(member_id).await
342 }
343
344 async fn list_organization_members(
345 &self,
346 organization_id: &str,
347 ) -> AuthResult<Vec<Self::Member>> {
348 self.inner.list_organization_members(organization_id).await
349 }
350
351 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
352 self.inner.count_organization_members(organization_id).await
353 }
354
355 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
356 self.inner.count_organization_owners(organization_id).await
357 }
358}
359
360#[async_trait]
361impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
362 type Invitation = DB::Invitation;
363
364 async fn create_invitation(
365 &self,
366 invitation: CreateInvitation,
367 ) -> AuthResult<Self::Invitation> {
368 self.inner.create_invitation(invitation).await
369 }
370
371 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
372 self.inner.get_invitation_by_id(id).await
373 }
374
375 async fn get_pending_invitation(
376 &self,
377 organization_id: &str,
378 email: &str,
379 ) -> AuthResult<Option<Self::Invitation>> {
380 self.inner
381 .get_pending_invitation(organization_id, email)
382 .await
383 }
384
385 async fn update_invitation_status(
386 &self,
387 id: &str,
388 status: InvitationStatus,
389 ) -> AuthResult<Self::Invitation> {
390 self.inner.update_invitation_status(id, status).await
391 }
392
393 async fn list_organization_invitations(
394 &self,
395 organization_id: &str,
396 ) -> AuthResult<Vec<Self::Invitation>> {
397 self.inner
398 .list_organization_invitations(organization_id)
399 .await
400 }
401
402 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
403 self.inner.list_user_invitations(email).await
404 }
405}
406
407#[async_trait]
408impl<DB: DatabaseAdapter> TwoFactorOps for HookedDatabaseAdapter<DB> {
409 type TwoFactor = DB::TwoFactor;
410
411 async fn create_two_factor(&self, two_factor: CreateTwoFactor) -> AuthResult<Self::TwoFactor> {
412 self.inner.create_two_factor(two_factor).await
413 }
414
415 async fn get_two_factor_by_user_id(
416 &self,
417 user_id: &str,
418 ) -> AuthResult<Option<Self::TwoFactor>> {
419 self.inner.get_two_factor_by_user_id(user_id).await
420 }
421
422 async fn update_two_factor_backup_codes(
423 &self,
424 user_id: &str,
425 backup_codes: &str,
426 ) -> AuthResult<Self::TwoFactor> {
427 self.inner
428 .update_two_factor_backup_codes(user_id, backup_codes)
429 .await
430 }
431
432 async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
433 self.inner.delete_two_factor(user_id).await
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::adapters::MemoryDatabaseAdapter;
441 use crate::types::{CreateUser, UpdateUser, User};
442 use std::sync::atomic::{AtomicU32, Ordering};
443
444 struct CountingHook {
445 before_create_count: AtomicU32,
446 after_create_count: AtomicU32,
447 before_update_count: AtomicU32,
448 after_update_count: AtomicU32,
449 before_delete_count: AtomicU32,
450 after_delete_count: AtomicU32,
451 }
452
453 impl CountingHook {
454 fn new() -> Self {
455 Self {
456 before_create_count: AtomicU32::new(0),
457 after_create_count: AtomicU32::new(0),
458 before_update_count: AtomicU32::new(0),
459 after_update_count: AtomicU32::new(0),
460 before_delete_count: AtomicU32::new(0),
461 after_delete_count: AtomicU32::new(0),
462 }
463 }
464 }
465
466 #[async_trait]
467 impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
468 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
469 self.before_create_count.fetch_add(1, Ordering::SeqCst);
470 Ok(())
471 }
472 async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
473 self.after_create_count.fetch_add(1, Ordering::SeqCst);
474 Ok(())
475 }
476 async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
477 self.before_update_count.fetch_add(1, Ordering::SeqCst);
478 Ok(())
479 }
480 async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
481 self.after_update_count.fetch_add(1, Ordering::SeqCst);
482 Ok(())
483 }
484 async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
485 self.before_delete_count.fetch_add(1, Ordering::SeqCst);
486 Ok(())
487 }
488 async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
489 self.after_delete_count.fetch_add(1, Ordering::SeqCst);
490 Ok(())
491 }
492 }
493
494 #[tokio::test]
495 async fn test_hooks_called_on_create_user() {
496 let hook = Arc::new(CountingHook::new());
497 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
498 .with_hook(hook.clone());
499
500 let create = CreateUser::new()
501 .with_email("test@example.com")
502 .with_name("Test");
503 db.create_user(create).await.unwrap();
504
505 assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
506 assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
507 }
508
509 #[tokio::test]
510 async fn test_hooks_called_on_update_user() {
511 let hook = Arc::new(CountingHook::new());
512 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
513 .with_hook(hook.clone());
514
515 let create = CreateUser::new()
516 .with_email("test@example.com")
517 .with_name("Test");
518 let user = db.create_user(create).await.unwrap();
519
520 let update = UpdateUser {
521 name: Some("Updated".to_string()),
522 email: None,
523 image: None,
524 email_verified: None,
525 username: None,
526 display_username: None,
527 role: None,
528 banned: None,
529 ban_reason: None,
530 ban_expires: None,
531 two_factor_enabled: None,
532 metadata: None,
533 };
534 db.update_user(&user.id, update).await.unwrap();
535
536 assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
537 assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
538 }
539
540 #[tokio::test]
541 async fn test_hooks_called_on_delete_user() {
542 let hook = Arc::new(CountingHook::new());
543 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
544 .with_hook(hook.clone());
545
546 let create = CreateUser::new()
547 .with_email("test@example.com")
548 .with_name("Test");
549 let user = db.create_user(create).await.unwrap();
550
551 db.delete_user(&user.id).await.unwrap();
552
553 assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
554 assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
555 }
556
557 #[tokio::test]
558 async fn test_before_hook_can_reject() {
559 struct RejectHook;
560
561 #[async_trait]
562 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
563 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
564 Err(crate::error::AuthError::forbidden("Hook rejected"))
565 }
566 }
567
568 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
569 .with_hook(Arc::new(RejectHook));
570
571 let create = CreateUser::new()
572 .with_email("test@example.com")
573 .with_name("Test");
574 let result = db.create_user(create).await;
575
576 assert!(result.is_err());
577 assert_eq!(result.unwrap_err().status_code(), 403);
578 }
579
580 #[tokio::test]
581 async fn test_multiple_hooks() {
582 let hook1 = Arc::new(CountingHook::new());
583 let hook2 = Arc::new(CountingHook::new());
584 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
585 .with_hook(hook1.clone())
586 .with_hook(hook2.clone());
587
588 let create = CreateUser::new()
589 .with_email("test@example.com")
590 .with_name("Test");
591 db.create_user(create).await.unwrap();
592
593 assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
594 assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
595 assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
596 assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
597 }
598
599 #[tokio::test]
600 async fn test_passthrough_operations() {
601 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
602
603 let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
604 assert!(result.is_none());
605 }
606}