1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7 AccountOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, UserOps, VerificationOps,
8};
9use crate::error::AuthResult;
10use crate::types::{
11 CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession, CreateUser,
12 CreateVerification, InvitationStatus, UpdateOrganization, UpdateUser,
13};
14
15#[async_trait]
23pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
24 async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
25 let _ = user;
26 Ok(())
27 }
28
29 async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
30 let _ = user;
31 Ok(())
32 }
33
34 async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
35 let _ = (id, update);
36 Ok(())
37 }
38
39 async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
40 let _ = user;
41 Ok(())
42 }
43
44 async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
45 let _ = id;
46 Ok(())
47 }
48
49 async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
50 let _ = id;
51 Ok(())
52 }
53
54 async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
55 let _ = session;
56 Ok(())
57 }
58
59 async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
60 let _ = session;
61 Ok(())
62 }
63
64 async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
65 let _ = token;
66 Ok(())
67 }
68
69 async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
70 let _ = token;
71 Ok(())
72 }
73}
74
75pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
77 inner: Arc<DB>,
78 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
79}
80
81impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
82 pub fn new(inner: Arc<DB>) -> Self {
83 Self {
84 inner,
85 hooks: Vec::new(),
86 }
87 }
88
89 pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
90 self.hooks.push(hook);
91 self
92 }
93
94 pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
95 self.hooks.push(hook);
96 }
97}
98
99#[async_trait]
100impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
101 type User = DB::User;
102
103 async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
104 for hook in &self.hooks {
105 hook.before_create_user(&mut user).await?;
106 }
107 let result = self.inner.create_user(user).await?;
108 for hook in &self.hooks {
109 hook.after_create_user(&result).await?;
110 }
111 Ok(result)
112 }
113
114 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
115 self.inner.get_user_by_id(id).await
116 }
117
118 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
119 self.inner.get_user_by_email(email).await
120 }
121
122 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
123 self.inner.get_user_by_username(username).await
124 }
125
126 async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
127 for hook in &self.hooks {
128 hook.before_update_user(id, &mut update).await?;
129 }
130 let result = self.inner.update_user(id, update).await?;
131 for hook in &self.hooks {
132 hook.after_update_user(&result).await?;
133 }
134 Ok(result)
135 }
136
137 async fn delete_user(&self, id: &str) -> AuthResult<()> {
138 for hook in &self.hooks {
139 hook.before_delete_user(id).await?;
140 }
141 self.inner.delete_user(id).await?;
142 for hook in &self.hooks {
143 hook.after_delete_user(id).await?;
144 }
145 Ok(())
146 }
147}
148
149#[async_trait]
150impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
151 type Session = DB::Session;
152
153 async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
154 for hook in &self.hooks {
155 hook.before_create_session(&mut session).await?;
156 }
157 let result = self.inner.create_session(session).await?;
158 for hook in &self.hooks {
159 hook.after_create_session(&result).await?;
160 }
161 Ok(result)
162 }
163
164 async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
165 self.inner.get_session(token).await
166 }
167
168 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
169 self.inner.get_user_sessions(user_id).await
170 }
171
172 async fn update_session_expiry(
173 &self,
174 token: &str,
175 expires_at: DateTime<Utc>,
176 ) -> AuthResult<()> {
177 self.inner.update_session_expiry(token, expires_at).await
178 }
179
180 async fn delete_session(&self, token: &str) -> AuthResult<()> {
181 for hook in &self.hooks {
182 hook.before_delete_session(token).await?;
183 }
184 self.inner.delete_session(token).await?;
185 for hook in &self.hooks {
186 hook.after_delete_session(token).await?;
187 }
188 Ok(())
189 }
190
191 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
192 self.inner.delete_user_sessions(user_id).await
193 }
194
195 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
196 self.inner.delete_expired_sessions().await
197 }
198
199 async fn update_session_active_organization(
200 &self,
201 token: &str,
202 organization_id: Option<&str>,
203 ) -> AuthResult<Self::Session> {
204 self.inner
205 .update_session_active_organization(token, organization_id)
206 .await
207 }
208}
209
210#[async_trait]
211impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
212 type Account = DB::Account;
213
214 async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
215 self.inner.create_account(account).await
216 }
217
218 async fn get_account(
219 &self,
220 provider: &str,
221 provider_account_id: &str,
222 ) -> AuthResult<Option<Self::Account>> {
223 self.inner.get_account(provider, provider_account_id).await
224 }
225
226 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
227 self.inner.get_user_accounts(user_id).await
228 }
229
230 async fn delete_account(&self, id: &str) -> AuthResult<()> {
231 self.inner.delete_account(id).await
232 }
233}
234
235#[async_trait]
236impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
237 type Verification = DB::Verification;
238
239 async fn create_verification(
240 &self,
241 verification: CreateVerification,
242 ) -> AuthResult<Self::Verification> {
243 self.inner.create_verification(verification).await
244 }
245
246 async fn get_verification(
247 &self,
248 identifier: &str,
249 value: &str,
250 ) -> AuthResult<Option<Self::Verification>> {
251 self.inner.get_verification(identifier, value).await
252 }
253
254 async fn get_verification_by_value(
255 &self,
256 value: &str,
257 ) -> AuthResult<Option<Self::Verification>> {
258 self.inner.get_verification_by_value(value).await
259 }
260
261 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
262 self.inner.delete_verification(id).await
263 }
264
265 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
266 self.inner.delete_expired_verifications().await
267 }
268}
269
270#[async_trait]
271impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
272 type Organization = DB::Organization;
273
274 async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
275 self.inner.create_organization(org).await
276 }
277
278 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
279 self.inner.get_organization_by_id(id).await
280 }
281
282 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
283 self.inner.get_organization_by_slug(slug).await
284 }
285
286 async fn update_organization(
287 &self,
288 id: &str,
289 update: UpdateOrganization,
290 ) -> AuthResult<Self::Organization> {
291 self.inner.update_organization(id, update).await
292 }
293
294 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
295 self.inner.delete_organization(id).await
296 }
297
298 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
299 self.inner.list_user_organizations(user_id).await
300 }
301}
302
303#[async_trait]
304impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
305 type Member = DB::Member;
306
307 async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
308 self.inner.create_member(member).await
309 }
310
311 async fn get_member(
312 &self,
313 organization_id: &str,
314 user_id: &str,
315 ) -> AuthResult<Option<Self::Member>> {
316 self.inner.get_member(organization_id, user_id).await
317 }
318
319 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
320 self.inner.get_member_by_id(id).await
321 }
322
323 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
324 self.inner.update_member_role(member_id, role).await
325 }
326
327 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
328 self.inner.delete_member(member_id).await
329 }
330
331 async fn list_organization_members(
332 &self,
333 organization_id: &str,
334 ) -> AuthResult<Vec<Self::Member>> {
335 self.inner.list_organization_members(organization_id).await
336 }
337
338 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
339 self.inner.count_organization_members(organization_id).await
340 }
341
342 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
343 self.inner.count_organization_owners(organization_id).await
344 }
345}
346
347#[async_trait]
348impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
349 type Invitation = DB::Invitation;
350
351 async fn create_invitation(
352 &self,
353 invitation: CreateInvitation,
354 ) -> AuthResult<Self::Invitation> {
355 self.inner.create_invitation(invitation).await
356 }
357
358 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
359 self.inner.get_invitation_by_id(id).await
360 }
361
362 async fn get_pending_invitation(
363 &self,
364 organization_id: &str,
365 email: &str,
366 ) -> AuthResult<Option<Self::Invitation>> {
367 self.inner
368 .get_pending_invitation(organization_id, email)
369 .await
370 }
371
372 async fn update_invitation_status(
373 &self,
374 id: &str,
375 status: InvitationStatus,
376 ) -> AuthResult<Self::Invitation> {
377 self.inner.update_invitation_status(id, status).await
378 }
379
380 async fn list_organization_invitations(
381 &self,
382 organization_id: &str,
383 ) -> AuthResult<Vec<Self::Invitation>> {
384 self.inner
385 .list_organization_invitations(organization_id)
386 .await
387 }
388
389 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
390 self.inner.list_user_invitations(email).await
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::adapters::MemoryDatabaseAdapter;
398 use crate::types::{CreateUser, UpdateUser, User};
399 use std::sync::atomic::{AtomicU32, Ordering};
400
401 struct CountingHook {
402 before_create_count: AtomicU32,
403 after_create_count: AtomicU32,
404 before_update_count: AtomicU32,
405 after_update_count: AtomicU32,
406 before_delete_count: AtomicU32,
407 after_delete_count: AtomicU32,
408 }
409
410 impl CountingHook {
411 fn new() -> Self {
412 Self {
413 before_create_count: AtomicU32::new(0),
414 after_create_count: AtomicU32::new(0),
415 before_update_count: AtomicU32::new(0),
416 after_update_count: AtomicU32::new(0),
417 before_delete_count: AtomicU32::new(0),
418 after_delete_count: AtomicU32::new(0),
419 }
420 }
421 }
422
423 #[async_trait]
424 impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
425 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
426 self.before_create_count.fetch_add(1, Ordering::SeqCst);
427 Ok(())
428 }
429 async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
430 self.after_create_count.fetch_add(1, Ordering::SeqCst);
431 Ok(())
432 }
433 async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
434 self.before_update_count.fetch_add(1, Ordering::SeqCst);
435 Ok(())
436 }
437 async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
438 self.after_update_count.fetch_add(1, Ordering::SeqCst);
439 Ok(())
440 }
441 async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
442 self.before_delete_count.fetch_add(1, Ordering::SeqCst);
443 Ok(())
444 }
445 async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
446 self.after_delete_count.fetch_add(1, Ordering::SeqCst);
447 Ok(())
448 }
449 }
450
451 #[tokio::test]
452 async fn test_hooks_called_on_create_user() {
453 let hook = Arc::new(CountingHook::new());
454 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
455 .with_hook(hook.clone());
456
457 let create = CreateUser::new()
458 .with_email("test@example.com")
459 .with_name("Test");
460 db.create_user(create).await.unwrap();
461
462 assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
463 assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
464 }
465
466 #[tokio::test]
467 async fn test_hooks_called_on_update_user() {
468 let hook = Arc::new(CountingHook::new());
469 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
470 .with_hook(hook.clone());
471
472 let create = CreateUser::new()
473 .with_email("test@example.com")
474 .with_name("Test");
475 let user = db.create_user(create).await.unwrap();
476
477 let update = UpdateUser {
478 name: Some("Updated".to_string()),
479 email: None,
480 image: None,
481 email_verified: None,
482 username: None,
483 display_username: None,
484 role: None,
485 banned: None,
486 ban_reason: None,
487 ban_expires: None,
488 two_factor_enabled: None,
489 metadata: None,
490 };
491 db.update_user(&user.id, update).await.unwrap();
492
493 assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
494 assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
495 }
496
497 #[tokio::test]
498 async fn test_hooks_called_on_delete_user() {
499 let hook = Arc::new(CountingHook::new());
500 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
501 .with_hook(hook.clone());
502
503 let create = CreateUser::new()
504 .with_email("test@example.com")
505 .with_name("Test");
506 let user = db.create_user(create).await.unwrap();
507
508 db.delete_user(&user.id).await.unwrap();
509
510 assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
511 assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
512 }
513
514 #[tokio::test]
515 async fn test_before_hook_can_reject() {
516 struct RejectHook;
517
518 #[async_trait]
519 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
520 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
521 Err(crate::error::AuthError::forbidden("Hook rejected"))
522 }
523 }
524
525 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
526 .with_hook(Arc::new(RejectHook));
527
528 let create = CreateUser::new()
529 .with_email("test@example.com")
530 .with_name("Test");
531 let result = db.create_user(create).await;
532
533 assert!(result.is_err());
534 assert_eq!(result.unwrap_err().status_code(), 403);
535 }
536
537 #[tokio::test]
538 async fn test_multiple_hooks() {
539 let hook1 = Arc::new(CountingHook::new());
540 let hook2 = Arc::new(CountingHook::new());
541 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
542 .with_hook(hook1.clone())
543 .with_hook(hook2.clone());
544
545 let create = CreateUser::new()
546 .with_email("test@example.com")
547 .with_name("Test");
548 db.create_user(create).await.unwrap();
549
550 assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
551 assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
552 assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
553 assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
554 }
555
556 #[tokio::test]
557 async fn test_passthrough_operations() {
558 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
559
560 let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
561 assert!(result.is_none());
562 }
563}