1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::error::AuthResult;
7use crate::types::{
8 CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession, CreateUser,
9 CreateVerification, InvitationStatus, UpdateOrganization, UpdateUser,
10};
11
12#[async_trait]
20pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
21 async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
24 let _ = user;
25 Ok(())
26 }
27
28 async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
29 let _ = user;
30 Ok(())
31 }
32
33 async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
34 let _ = (id, update);
35 Ok(())
36 }
37
38 async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
39 let _ = user;
40 Ok(())
41 }
42
43 async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
44 let _ = id;
45 Ok(())
46 }
47
48 async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
49 let _ = id;
50 Ok(())
51 }
52
53 async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
56 let _ = session;
57 Ok(())
58 }
59
60 async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
61 let _ = session;
62 Ok(())
63 }
64
65 async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
66 let _ = token;
67 Ok(())
68 }
69
70 async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
71 let _ = token;
72 Ok(())
73 }
74}
75
76pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
78 inner: Arc<DB>,
79 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
80}
81
82impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
83 pub fn new(inner: Arc<DB>) -> Self {
84 Self {
85 inner,
86 hooks: Vec::new(),
87 }
88 }
89
90 pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
91 self.hooks.push(hook);
92 self
93 }
94
95 pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
96 self.hooks.push(hook);
97 }
98}
99
100#[async_trait]
101impl<DB: DatabaseAdapter> DatabaseAdapter for HookedDatabaseAdapter<DB> {
102 type User = DB::User;
103 type Session = DB::Session;
104 type Account = DB::Account;
105 type Organization = DB::Organization;
106 type Member = DB::Member;
107 type Invitation = DB::Invitation;
108 type Verification = DB::Verification;
109
110 async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
113 for hook in &self.hooks {
114 hook.before_create_user(&mut user).await?;
115 }
116 let result = self.inner.create_user(user).await?;
117 for hook in &self.hooks {
118 hook.after_create_user(&result).await?;
119 }
120 Ok(result)
121 }
122
123 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
124 self.inner.get_user_by_id(id).await
125 }
126
127 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
128 self.inner.get_user_by_email(email).await
129 }
130
131 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
132 self.inner.get_user_by_username(username).await
133 }
134
135 async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
136 for hook in &self.hooks {
137 hook.before_update_user(id, &mut update).await?;
138 }
139 let result = self.inner.update_user(id, update).await?;
140 for hook in &self.hooks {
141 hook.after_update_user(&result).await?;
142 }
143 Ok(result)
144 }
145
146 async fn delete_user(&self, id: &str) -> AuthResult<()> {
147 for hook in &self.hooks {
148 hook.before_delete_user(id).await?;
149 }
150 self.inner.delete_user(id).await?;
151 for hook in &self.hooks {
152 hook.after_delete_user(id).await?;
153 }
154 Ok(())
155 }
156
157 async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
160 for hook in &self.hooks {
161 hook.before_create_session(&mut session).await?;
162 }
163 let result = self.inner.create_session(session).await?;
164 for hook in &self.hooks {
165 hook.after_create_session(&result).await?;
166 }
167 Ok(result)
168 }
169
170 async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
171 self.inner.get_session(token).await
172 }
173
174 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
175 self.inner.get_user_sessions(user_id).await
176 }
177
178 async fn update_session_expiry(
179 &self,
180 token: &str,
181 expires_at: DateTime<Utc>,
182 ) -> AuthResult<()> {
183 self.inner.update_session_expiry(token, expires_at).await
184 }
185
186 async fn delete_session(&self, token: &str) -> AuthResult<()> {
187 for hook in &self.hooks {
188 hook.before_delete_session(token).await?;
189 }
190 self.inner.delete_session(token).await?;
191 for hook in &self.hooks {
192 hook.after_delete_session(token).await?;
193 }
194 Ok(())
195 }
196
197 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
198 self.inner.delete_user_sessions(user_id).await
199 }
200
201 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
202 self.inner.delete_expired_sessions().await
203 }
204
205 async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
208 self.inner.create_account(account).await
209 }
210
211 async fn get_account(
212 &self,
213 provider: &str,
214 provider_account_id: &str,
215 ) -> AuthResult<Option<Self::Account>> {
216 self.inner.get_account(provider, provider_account_id).await
217 }
218
219 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
220 self.inner.get_user_accounts(user_id).await
221 }
222
223 async fn delete_account(&self, id: &str) -> AuthResult<()> {
224 self.inner.delete_account(id).await
225 }
226
227 async fn create_verification(
230 &self,
231 verification: CreateVerification,
232 ) -> AuthResult<Self::Verification> {
233 self.inner.create_verification(verification).await
234 }
235
236 async fn get_verification(
237 &self,
238 identifier: &str,
239 value: &str,
240 ) -> AuthResult<Option<Self::Verification>> {
241 self.inner.get_verification(identifier, value).await
242 }
243
244 async fn get_verification_by_value(
245 &self,
246 value: &str,
247 ) -> AuthResult<Option<Self::Verification>> {
248 self.inner.get_verification_by_value(value).await
249 }
250
251 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
252 self.inner.delete_verification(id).await
253 }
254
255 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
256 self.inner.delete_expired_verifications().await
257 }
258
259 async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
262 self.inner.create_organization(org).await
263 }
264
265 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
266 self.inner.get_organization_by_id(id).await
267 }
268
269 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
270 self.inner.get_organization_by_slug(slug).await
271 }
272
273 async fn update_organization(
274 &self,
275 id: &str,
276 update: UpdateOrganization,
277 ) -> AuthResult<Self::Organization> {
278 self.inner.update_organization(id, update).await
279 }
280
281 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
282 self.inner.delete_organization(id).await
283 }
284
285 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
286 self.inner.list_user_organizations(user_id).await
287 }
288
289 async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
292 self.inner.create_member(member).await
293 }
294
295 async fn get_member(
296 &self,
297 organization_id: &str,
298 user_id: &str,
299 ) -> AuthResult<Option<Self::Member>> {
300 self.inner.get_member(organization_id, user_id).await
301 }
302
303 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
304 self.inner.get_member_by_id(id).await
305 }
306
307 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
308 self.inner.update_member_role(member_id, role).await
309 }
310
311 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
312 self.inner.delete_member(member_id).await
313 }
314
315 async fn list_organization_members(
316 &self,
317 organization_id: &str,
318 ) -> AuthResult<Vec<Self::Member>> {
319 self.inner.list_organization_members(organization_id).await
320 }
321
322 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
323 self.inner.count_organization_members(organization_id).await
324 }
325
326 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
327 self.inner.count_organization_owners(organization_id).await
328 }
329
330 async fn create_invitation(
333 &self,
334 invitation: CreateInvitation,
335 ) -> AuthResult<Self::Invitation> {
336 self.inner.create_invitation(invitation).await
337 }
338
339 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
340 self.inner.get_invitation_by_id(id).await
341 }
342
343 async fn get_pending_invitation(
344 &self,
345 organization_id: &str,
346 email: &str,
347 ) -> AuthResult<Option<Self::Invitation>> {
348 self.inner
349 .get_pending_invitation(organization_id, email)
350 .await
351 }
352
353 async fn update_invitation_status(
354 &self,
355 id: &str,
356 status: InvitationStatus,
357 ) -> AuthResult<Self::Invitation> {
358 self.inner.update_invitation_status(id, status).await
359 }
360
361 async fn list_organization_invitations(
362 &self,
363 organization_id: &str,
364 ) -> AuthResult<Vec<Self::Invitation>> {
365 self.inner
366 .list_organization_invitations(organization_id)
367 .await
368 }
369
370 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
371 self.inner.list_user_invitations(email).await
372 }
373
374 async fn update_session_active_organization(
377 &self,
378 token: &str,
379 organization_id: Option<&str>,
380 ) -> AuthResult<Self::Session> {
381 self.inner
382 .update_session_active_organization(token, organization_id)
383 .await
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::adapters::MemoryDatabaseAdapter;
391 use crate::types::{CreateUser, UpdateUser, User};
392 use std::sync::atomic::{AtomicU32, Ordering};
393
394 struct CountingHook {
395 before_create_count: AtomicU32,
396 after_create_count: AtomicU32,
397 before_update_count: AtomicU32,
398 after_update_count: AtomicU32,
399 before_delete_count: AtomicU32,
400 after_delete_count: AtomicU32,
401 }
402
403 impl CountingHook {
404 fn new() -> Self {
405 Self {
406 before_create_count: AtomicU32::new(0),
407 after_create_count: AtomicU32::new(0),
408 before_update_count: AtomicU32::new(0),
409 after_update_count: AtomicU32::new(0),
410 before_delete_count: AtomicU32::new(0),
411 after_delete_count: AtomicU32::new(0),
412 }
413 }
414 }
415
416 #[async_trait]
417 impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
418 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
419 self.before_create_count.fetch_add(1, Ordering::SeqCst);
420 Ok(())
421 }
422 async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
423 self.after_create_count.fetch_add(1, Ordering::SeqCst);
424 Ok(())
425 }
426 async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
427 self.before_update_count.fetch_add(1, Ordering::SeqCst);
428 Ok(())
429 }
430 async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
431 self.after_update_count.fetch_add(1, Ordering::SeqCst);
432 Ok(())
433 }
434 async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
435 self.before_delete_count.fetch_add(1, Ordering::SeqCst);
436 Ok(())
437 }
438 async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
439 self.after_delete_count.fetch_add(1, Ordering::SeqCst);
440 Ok(())
441 }
442 }
443
444 #[tokio::test]
445 async fn test_hooks_called_on_create_user() {
446 let hook = Arc::new(CountingHook::new());
447 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
448 .with_hook(hook.clone());
449
450 let create = CreateUser::new()
451 .with_email("test@example.com")
452 .with_name("Test");
453 db.create_user(create).await.unwrap();
454
455 assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
456 assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
457 }
458
459 #[tokio::test]
460 async fn test_hooks_called_on_update_user() {
461 let hook = Arc::new(CountingHook::new());
462 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
463 .with_hook(hook.clone());
464
465 let create = CreateUser::new()
466 .with_email("test@example.com")
467 .with_name("Test");
468 let user = db.create_user(create).await.unwrap();
469
470 let update = UpdateUser {
471 name: Some("Updated".to_string()),
472 email: None,
473 image: None,
474 email_verified: None,
475 username: None,
476 display_username: None,
477 role: None,
478 banned: None,
479 ban_reason: None,
480 ban_expires: None,
481 two_factor_enabled: None,
482 metadata: None,
483 };
484 db.update_user(&user.id, update).await.unwrap();
485
486 assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
487 assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
488 }
489
490 #[tokio::test]
491 async fn test_hooks_called_on_delete_user() {
492 let hook = Arc::new(CountingHook::new());
493 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
494 .with_hook(hook.clone());
495
496 let create = CreateUser::new()
497 .with_email("test@example.com")
498 .with_name("Test");
499 let user = db.create_user(create).await.unwrap();
500
501 db.delete_user(&user.id).await.unwrap();
502
503 assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
504 assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
505 }
506
507 #[tokio::test]
508 async fn test_before_hook_can_reject() {
509 struct RejectHook;
510
511 #[async_trait]
512 impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
513 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
514 Err(crate::error::AuthError::forbidden("Hook rejected"))
515 }
516 }
517
518 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
519 .with_hook(Arc::new(RejectHook));
520
521 let create = CreateUser::new()
522 .with_email("test@example.com")
523 .with_name("Test");
524 let result = db.create_user(create).await;
525
526 assert!(result.is_err());
527 assert_eq!(result.unwrap_err().status_code(), 403);
528 }
529
530 #[tokio::test]
531 async fn test_multiple_hooks() {
532 let hook1 = Arc::new(CountingHook::new());
533 let hook2 = Arc::new(CountingHook::new());
534 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
535 .with_hook(hook1.clone())
536 .with_hook(hook2.clone());
537
538 let create = CreateUser::new()
539 .with_email("test@example.com")
540 .with_name("Test");
541 db.create_user(create).await.unwrap();
542
543 assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
544 assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
545 assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
546 assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
547 }
548
549 #[tokio::test]
550 async fn test_passthrough_operations() {
551 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
552
553 let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
554 assert!(result.is_none());
555 }
556}