1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7 AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, TwoFactorOps,
8 UserOps, VerificationOps,
9};
10use crate::error::AuthResult;
11use crate::types::{
12 CreateAccount, CreateApiKey, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
13 CreateTwoFactor, CreateUser, CreateVerification, InvitationStatus, UpdateAccount, UpdateApiKey,
14 UpdateOrganization, UpdateUser,
15};
16
17#[async_trait]
25pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
26 async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
27 let _ = user;
28 Ok(())
29 }
30
31 async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
32 let _ = user;
33 Ok(())
34 }
35
36 async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
37 let _ = (id, update);
38 Ok(())
39 }
40
41 async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
42 let _ = user;
43 Ok(())
44 }
45
46 async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
47 let _ = id;
48 Ok(())
49 }
50
51 async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52 let _ = id;
53 Ok(())
54 }
55
56 async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
57 let _ = session;
58 Ok(())
59 }
60
61 async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
62 let _ = session;
63 Ok(())
64 }
65
66 async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
67 let _ = token;
68 Ok(())
69 }
70
71 async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
72 let _ = token;
73 Ok(())
74 }
75}
76
77pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
79 inner: Arc<DB>,
80 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
81}
82
83impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
84 pub fn new(inner: Arc<DB>) -> Self {
85 Self {
86 inner,
87 hooks: Vec::new(),
88 }
89 }
90
91 pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
92 self.hooks.push(hook);
93 self
94 }
95
96 pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
97 self.hooks.push(hook);
98 }
99}
100
101#[async_trait]
102impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
103 type User = DB::User;
104
105 async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
106 for hook in &self.hooks {
107 hook.before_create_user(&mut user).await?;
108 }
109 let result = self.inner.create_user(user).await?;
110 for hook in &self.hooks {
111 hook.after_create_user(&result).await?;
112 }
113 Ok(result)
114 }
115
116 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
117 self.inner.get_user_by_id(id).await
118 }
119
120 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
121 self.inner.get_user_by_email(email).await
122 }
123
124 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
125 self.inner.get_user_by_username(username).await
126 }
127
128 async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
129 for hook in &self.hooks {
130 hook.before_update_user(id, &mut update).await?;
131 }
132 let result = self.inner.update_user(id, update).await?;
133 for hook in &self.hooks {
134 hook.after_update_user(&result).await?;
135 }
136 Ok(result)
137 }
138
139 async fn delete_user(&self, id: &str) -> AuthResult<()> {
140 for hook in &self.hooks {
141 hook.before_delete_user(id).await?;
142 }
143 self.inner.delete_user(id).await?;
144 for hook in &self.hooks {
145 hook.after_delete_user(id).await?;
146 }
147 Ok(())
148 }
149}
150
151#[async_trait]
152impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
153 type Session = DB::Session;
154
155 async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
156 for hook in &self.hooks {
157 hook.before_create_session(&mut session).await?;
158 }
159 let result = self.inner.create_session(session).await?;
160 for hook in &self.hooks {
161 hook.after_create_session(&result).await?;
162 }
163 Ok(result)
164 }
165
166 async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
167 self.inner.get_session(token).await
168 }
169
170 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
171 self.inner.get_user_sessions(user_id).await
172 }
173
174 async fn update_session_expiry(
175 &self,
176 token: &str,
177 expires_at: DateTime<Utc>,
178 ) -> AuthResult<()> {
179 self.inner.update_session_expiry(token, expires_at).await
180 }
181
182 async fn delete_session(&self, token: &str) -> AuthResult<()> {
183 for hook in &self.hooks {
184 hook.before_delete_session(token).await?;
185 }
186 self.inner.delete_session(token).await?;
187 for hook in &self.hooks {
188 hook.after_delete_session(token).await?;
189 }
190 Ok(())
191 }
192
193 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
194 self.inner.delete_user_sessions(user_id).await
195 }
196
197 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
198 self.inner.delete_expired_sessions().await
199 }
200
201 async fn update_session_active_organization(
202 &self,
203 token: &str,
204 organization_id: Option<&str>,
205 ) -> AuthResult<Self::Session> {
206 self.inner
207 .update_session_active_organization(token, organization_id)
208 .await
209 }
210}
211
212#[async_trait]
213impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
214 type Account = DB::Account;
215
216 async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
217 self.inner.create_account(account).await
218 }
219
220 async fn get_account(
221 &self,
222 provider: &str,
223 provider_account_id: &str,
224 ) -> AuthResult<Option<Self::Account>> {
225 self.inner.get_account(provider, provider_account_id).await
226 }
227
228 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
229 self.inner.get_user_accounts(user_id).await
230 }
231
232 async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<Self::Account> {
233 self.inner.update_account(id, update).await
234 }
235
236 async fn delete_account(&self, id: &str) -> AuthResult<()> {
237 self.inner.delete_account(id).await
238 }
239}
240
241#[async_trait]
242impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
243 type Verification = DB::Verification;
244
245 async fn create_verification(
246 &self,
247 verification: CreateVerification,
248 ) -> AuthResult<Self::Verification> {
249 self.inner.create_verification(verification).await
250 }
251
252 async fn get_verification(
253 &self,
254 identifier: &str,
255 value: &str,
256 ) -> AuthResult<Option<Self::Verification>> {
257 self.inner.get_verification(identifier, value).await
258 }
259
260 async fn get_verification_by_value(
261 &self,
262 value: &str,
263 ) -> AuthResult<Option<Self::Verification>> {
264 self.inner.get_verification_by_value(value).await
265 }
266
267 async fn get_verification_by_identifier(
268 &self,
269 identifier: &str,
270 ) -> AuthResult<Option<Self::Verification>> {
271 self.inner.get_verification_by_identifier(identifier).await
272 }
273
274 async fn consume_verification(
275 &self,
276 identifier: &str,
277 value: &str,
278 ) -> AuthResult<Option<Self::Verification>> {
279 self.inner.consume_verification(identifier, value).await
280 }
281
282 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
283 self.inner.delete_verification(id).await
284 }
285
286 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
287 self.inner.delete_expired_verifications().await
288 }
289}
290
291#[async_trait]
292impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
293 type Organization = DB::Organization;
294
295 async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
296 self.inner.create_organization(org).await
297 }
298
299 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
300 self.inner.get_organization_by_id(id).await
301 }
302
303 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
304 self.inner.get_organization_by_slug(slug).await
305 }
306
307 async fn update_organization(
308 &self,
309 id: &str,
310 update: UpdateOrganization,
311 ) -> AuthResult<Self::Organization> {
312 self.inner.update_organization(id, update).await
313 }
314
315 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
316 self.inner.delete_organization(id).await
317 }
318
319 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
320 self.inner.list_user_organizations(user_id).await
321 }
322}
323
324#[async_trait]
325impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
326 type Member = DB::Member;
327
328 async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
329 self.inner.create_member(member).await
330 }
331
332 async fn get_member(
333 &self,
334 organization_id: &str,
335 user_id: &str,
336 ) -> AuthResult<Option<Self::Member>> {
337 self.inner.get_member(organization_id, user_id).await
338 }
339
340 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
341 self.inner.get_member_by_id(id).await
342 }
343
344 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
345 self.inner.update_member_role(member_id, role).await
346 }
347
348 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
349 self.inner.delete_member(member_id).await
350 }
351
352 async fn list_organization_members(
353 &self,
354 organization_id: &str,
355 ) -> AuthResult<Vec<Self::Member>> {
356 self.inner.list_organization_members(organization_id).await
357 }
358
359 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
360 self.inner.count_organization_members(organization_id).await
361 }
362
363 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
364 self.inner.count_organization_owners(organization_id).await
365 }
366}
367
368#[async_trait]
369impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
370 type Invitation = DB::Invitation;
371
372 async fn create_invitation(
373 &self,
374 invitation: CreateInvitation,
375 ) -> AuthResult<Self::Invitation> {
376 self.inner.create_invitation(invitation).await
377 }
378
379 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
380 self.inner.get_invitation_by_id(id).await
381 }
382
383 async fn get_pending_invitation(
384 &self,
385 organization_id: &str,
386 email: &str,
387 ) -> AuthResult<Option<Self::Invitation>> {
388 self.inner
389 .get_pending_invitation(organization_id, email)
390 .await
391 }
392
393 async fn update_invitation_status(
394 &self,
395 id: &str,
396 status: InvitationStatus,
397 ) -> AuthResult<Self::Invitation> {
398 self.inner.update_invitation_status(id, status).await
399 }
400
401 async fn list_organization_invitations(
402 &self,
403 organization_id: &str,
404 ) -> AuthResult<Vec<Self::Invitation>> {
405 self.inner
406 .list_organization_invitations(organization_id)
407 .await
408 }
409
410 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
411 self.inner.list_user_invitations(email).await
412 }
413}
414
415#[async_trait]
416impl<DB: DatabaseAdapter> TwoFactorOps for HookedDatabaseAdapter<DB> {
417 type TwoFactor = DB::TwoFactor;
418
419 async fn create_two_factor(&self, two_factor: CreateTwoFactor) -> AuthResult<Self::TwoFactor> {
420 self.inner.create_two_factor(two_factor).await
421 }
422
423 async fn get_two_factor_by_user_id(
424 &self,
425 user_id: &str,
426 ) -> AuthResult<Option<Self::TwoFactor>> {
427 self.inner.get_two_factor_by_user_id(user_id).await
428 }
429
430 async fn update_two_factor_backup_codes(
431 &self,
432 user_id: &str,
433 backup_codes: &str,
434 ) -> AuthResult<Self::TwoFactor> {
435 self.inner
436 .update_two_factor_backup_codes(user_id, backup_codes)
437 .await
438 }
439
440 async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
441 self.inner.delete_two_factor(user_id).await
442 }
443}
444
445#[async_trait]
446impl<DB: DatabaseAdapter> ApiKeyOps for HookedDatabaseAdapter<DB> {
447 type ApiKey = DB::ApiKey;
448
449 async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<Self::ApiKey> {
450 self.inner.create_api_key(input).await
451 }
452
453 async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<Self::ApiKey>> {
454 self.inner.get_api_key_by_id(id).await
455 }
456
457 async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<Self::ApiKey>> {
458 self.inner.get_api_key_by_hash(hash).await
459 }
460
461 async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<Self::ApiKey>> {
462 self.inner.list_api_keys_by_user(user_id).await
463 }
464
465 async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<Self::ApiKey> {
466 self.inner.update_api_key(id, update).await
467 }
468
469 async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
470 self.inner.delete_api_key(id).await
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::adapters::MemoryDatabaseAdapter;
478 use crate::types::{CreateUser, UpdateUser, User};
479 use std::sync::atomic::{AtomicU32, Ordering};
480
481 struct CountingHook {
482 before_create_count: AtomicU32,
483 after_create_count: AtomicU32,
484 before_update_count: AtomicU32,
485 after_update_count: AtomicU32,
486 before_delete_count: AtomicU32,
487 after_delete_count: AtomicU32,
488 }
489
490 impl CountingHook {
491 fn new() -> Self {
492 Self {
493 before_create_count: AtomicU32::new(0),
494 after_create_count: AtomicU32::new(0),
495 before_update_count: AtomicU32::new(0),
496 after_update_count: AtomicU32::new(0),
497 before_delete_count: AtomicU32::new(0),
498 after_delete_count: AtomicU32::new(0),
499 }
500 }
501 }
502
503 #[async_trait]
504 impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
505 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
506 self.before_create_count.fetch_add(1, Ordering::SeqCst);
507 Ok(())
508 }
509 async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
510 self.after_create_count.fetch_add(1, Ordering::SeqCst);
511 Ok(())
512 }
513 async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
514 self.before_update_count.fetch_add(1, Ordering::SeqCst);
515 Ok(())
516 }
517 async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
518 self.after_update_count.fetch_add(1, Ordering::SeqCst);
519 Ok(())
520 }
521 async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
522 self.before_delete_count.fetch_add(1, Ordering::SeqCst);
523 Ok(())
524 }
525 async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
526 self.after_delete_count.fetch_add(1, Ordering::SeqCst);
527 Ok(())
528 }
529 }
530
531 #[tokio::test]
532 async fn test_hooks_called_on_create_user() {
533 let hook = Arc::new(CountingHook::new());
534 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
535 .with_hook(hook.clone());
536
537 let create = CreateUser::new()
538 .with_email("test@example.com")
539 .with_name("Test");
540 db.create_user(create).await.unwrap();
541
542 assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
543 assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
544 }
545
546 #[tokio::test]
547 async fn test_hooks_called_on_update_user() {
548 let hook = Arc::new(CountingHook::new());
549 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
550 .with_hook(hook.clone());
551
552 let create = CreateUser::new()
553 .with_email("test@example.com")
554 .with_name("Test");
555 let user = db.create_user(create).await.unwrap();
556
557 let update = UpdateUser {
558 name: Some("Updated".to_string()),
559 email: None,
560 image: None,
561 email_verified: None,
562 username: None,
563 display_username: None,
564 role: None,
565 banned: None,
566 ban_reason: None,
567 ban_expires: None,
568 two_factor_enabled: None,
569 metadata: None,
570 };
571 db.update_user(&user.id, update).await.unwrap();
572
573 assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
574 assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
575 }
576
577 #[tokio::test]
578 async fn test_hooks_called_on_delete_user() {
579 let hook = Arc::new(CountingHook::new());
580 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
581 .with_hook(hook.clone());
582
583 let create = CreateUser::new()
584 .with_email("test@example.com")
585 .with_name("Test");
586 let user = db.create_user(create).await.unwrap();
587
588 db.delete_user(&user.id).await.unwrap();
589
590 assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
591 assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
592 }
593
594 #[tokio::test]
595 async fn test_before_hook_can_reject() {
596 struct RejectHook;
597
598 #[async_trait]
599 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
600 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
601 Err(crate::error::AuthError::forbidden("Hook rejected"))
602 }
603 }
604
605 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
606 .with_hook(Arc::new(RejectHook));
607
608 let create = CreateUser::new()
609 .with_email("test@example.com")
610 .with_name("Test");
611 let result = db.create_user(create).await;
612
613 assert!(result.is_err());
614 assert_eq!(result.unwrap_err().status_code(), 403);
615 }
616
617 #[tokio::test]
618 async fn test_multiple_hooks() {
619 let hook1 = Arc::new(CountingHook::new());
620 let hook2 = Arc::new(CountingHook::new());
621 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
622 .with_hook(hook1.clone())
623 .with_hook(hook2.clone());
624
625 let create = CreateUser::new()
626 .with_email("test@example.com")
627 .with_name("Test");
628 db.create_user(create).await.unwrap();
629
630 assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
631 assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
632 assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
633 assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
634 }
635
636 #[tokio::test]
637 async fn test_passthrough_operations() {
638 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
639
640 let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
641 assert!(result.is_none());
642 }
643}