Skip to main content

better_auth_core/
hooks.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7    AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, TwoFactorOps,
8    UserOps, VerificationOps,
9};
10use crate::error::AuthResult;
11use crate::types::{
12    CreateAccount, CreateApiKey, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
13    CreateTwoFactor, CreateUser, CreateVerification, InvitationStatus, ListUsersParams,
14    UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser,
15};
16
17/// Database lifecycle hooks for intercepting operations.
18///
19/// All methods have default no-op implementations. Override only the hooks
20/// you need. Returning `Err` from a `before_*` hook aborts the operation.
21///
22/// The `DB` type parameter determines the concrete entity types used in
23/// `after_*` hooks (e.g., `after_create_user` receives `&DB::User`).
24#[async_trait]
25pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
26    async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
27        let _ = user;
28        Ok(())
29    }
30
31    async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
32        let _ = user;
33        Ok(())
34    }
35
36    async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
37        let _ = (id, update);
38        Ok(())
39    }
40
41    async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
42        let _ = user;
43        Ok(())
44    }
45
46    async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
47        let _ = id;
48        Ok(())
49    }
50
51    async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52        let _ = id;
53        Ok(())
54    }
55
56    async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
57        let _ = session;
58        Ok(())
59    }
60
61    async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
62        let _ = session;
63        Ok(())
64    }
65
66    async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
67        let _ = token;
68        Ok(())
69    }
70
71    async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
72        let _ = token;
73        Ok(())
74    }
75}
76
77/// A database adapter wrapper that calls hooks around the inner adapter's operations.
78pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
79    inner: Arc<DB>,
80    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
81}
82
83impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
84    pub fn new(inner: Arc<DB>) -> Self {
85        Self {
86            inner,
87            hooks: Vec::new(),
88        }
89    }
90
91    pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
92        self.hooks.push(hook);
93        self
94    }
95
96    pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
97        self.hooks.push(hook);
98    }
99}
100
101#[async_trait]
102impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
103    type User = DB::User;
104
105    async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
106        for hook in &self.hooks {
107            hook.before_create_user(&mut user).await?;
108        }
109        let result = self.inner.create_user(user).await?;
110        for hook in &self.hooks {
111            hook.after_create_user(&result).await?;
112        }
113        Ok(result)
114    }
115
116    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
117        self.inner.get_user_by_id(id).await
118    }
119
120    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
121        self.inner.get_user_by_email(email).await
122    }
123
124    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
125        self.inner.get_user_by_username(username).await
126    }
127
128    async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
129        for hook in &self.hooks {
130            hook.before_update_user(id, &mut update).await?;
131        }
132        let result = self.inner.update_user(id, update).await?;
133        for hook in &self.hooks {
134            hook.after_update_user(&result).await?;
135        }
136        Ok(result)
137    }
138
139    async fn delete_user(&self, id: &str) -> AuthResult<()> {
140        for hook in &self.hooks {
141            hook.before_delete_user(id).await?;
142        }
143        self.inner.delete_user(id).await?;
144        for hook in &self.hooks {
145            hook.after_delete_user(id).await?;
146        }
147        Ok(())
148    }
149
150    async fn list_users(&self, params: ListUsersParams) -> AuthResult<(Vec<Self::User>, usize)> {
151        self.inner.list_users(params).await
152    }
153}
154
155#[async_trait]
156impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
157    type Session = DB::Session;
158
159    async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
160        for hook in &self.hooks {
161            hook.before_create_session(&mut session).await?;
162        }
163        let result = self.inner.create_session(session).await?;
164        for hook in &self.hooks {
165            hook.after_create_session(&result).await?;
166        }
167        Ok(result)
168    }
169
170    async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
171        self.inner.get_session(token).await
172    }
173
174    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
175        self.inner.get_user_sessions(user_id).await
176    }
177
178    async fn update_session_expiry(
179        &self,
180        token: &str,
181        expires_at: DateTime<Utc>,
182    ) -> AuthResult<()> {
183        self.inner.update_session_expiry(token, expires_at).await
184    }
185
186    async fn delete_session(&self, token: &str) -> AuthResult<()> {
187        for hook in &self.hooks {
188            hook.before_delete_session(token).await?;
189        }
190        self.inner.delete_session(token).await?;
191        for hook in &self.hooks {
192            hook.after_delete_session(token).await?;
193        }
194        Ok(())
195    }
196
197    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
198        self.inner.delete_user_sessions(user_id).await
199    }
200
201    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
202        self.inner.delete_expired_sessions().await
203    }
204
205    async fn update_session_active_organization(
206        &self,
207        token: &str,
208        organization_id: Option<&str>,
209    ) -> AuthResult<Self::Session> {
210        self.inner
211            .update_session_active_organization(token, organization_id)
212            .await
213    }
214}
215
216#[async_trait]
217impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
218    type Account = DB::Account;
219
220    async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
221        self.inner.create_account(account).await
222    }
223
224    async fn get_account(
225        &self,
226        provider: &str,
227        provider_account_id: &str,
228    ) -> AuthResult<Option<Self::Account>> {
229        self.inner.get_account(provider, provider_account_id).await
230    }
231
232    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
233        self.inner.get_user_accounts(user_id).await
234    }
235
236    async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<Self::Account> {
237        self.inner.update_account(id, update).await
238    }
239
240    async fn delete_account(&self, id: &str) -> AuthResult<()> {
241        self.inner.delete_account(id).await
242    }
243}
244
245#[async_trait]
246impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
247    type Verification = DB::Verification;
248
249    async fn create_verification(
250        &self,
251        verification: CreateVerification,
252    ) -> AuthResult<Self::Verification> {
253        self.inner.create_verification(verification).await
254    }
255
256    async fn get_verification(
257        &self,
258        identifier: &str,
259        value: &str,
260    ) -> AuthResult<Option<Self::Verification>> {
261        self.inner.get_verification(identifier, value).await
262    }
263
264    async fn get_verification_by_value(
265        &self,
266        value: &str,
267    ) -> AuthResult<Option<Self::Verification>> {
268        self.inner.get_verification_by_value(value).await
269    }
270
271    async fn get_verification_by_identifier(
272        &self,
273        identifier: &str,
274    ) -> AuthResult<Option<Self::Verification>> {
275        self.inner.get_verification_by_identifier(identifier).await
276    }
277
278    async fn consume_verification(
279        &self,
280        identifier: &str,
281        value: &str,
282    ) -> AuthResult<Option<Self::Verification>> {
283        self.inner.consume_verification(identifier, value).await
284    }
285
286    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
287        self.inner.delete_verification(id).await
288    }
289
290    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
291        self.inner.delete_expired_verifications().await
292    }
293}
294
295#[async_trait]
296impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
297    type Organization = DB::Organization;
298
299    async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
300        self.inner.create_organization(org).await
301    }
302
303    async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
304        self.inner.get_organization_by_id(id).await
305    }
306
307    async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
308        self.inner.get_organization_by_slug(slug).await
309    }
310
311    async fn update_organization(
312        &self,
313        id: &str,
314        update: UpdateOrganization,
315    ) -> AuthResult<Self::Organization> {
316        self.inner.update_organization(id, update).await
317    }
318
319    async fn delete_organization(&self, id: &str) -> AuthResult<()> {
320        self.inner.delete_organization(id).await
321    }
322
323    async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
324        self.inner.list_user_organizations(user_id).await
325    }
326}
327
328#[async_trait]
329impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
330    type Member = DB::Member;
331
332    async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
333        self.inner.create_member(member).await
334    }
335
336    async fn get_member(
337        &self,
338        organization_id: &str,
339        user_id: &str,
340    ) -> AuthResult<Option<Self::Member>> {
341        self.inner.get_member(organization_id, user_id).await
342    }
343
344    async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
345        self.inner.get_member_by_id(id).await
346    }
347
348    async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
349        self.inner.update_member_role(member_id, role).await
350    }
351
352    async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
353        self.inner.delete_member(member_id).await
354    }
355
356    async fn list_organization_members(
357        &self,
358        organization_id: &str,
359    ) -> AuthResult<Vec<Self::Member>> {
360        self.inner.list_organization_members(organization_id).await
361    }
362
363    async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
364        self.inner.count_organization_members(organization_id).await
365    }
366
367    async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
368        self.inner.count_organization_owners(organization_id).await
369    }
370}
371
372#[async_trait]
373impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
374    type Invitation = DB::Invitation;
375
376    async fn create_invitation(
377        &self,
378        invitation: CreateInvitation,
379    ) -> AuthResult<Self::Invitation> {
380        self.inner.create_invitation(invitation).await
381    }
382
383    async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
384        self.inner.get_invitation_by_id(id).await
385    }
386
387    async fn get_pending_invitation(
388        &self,
389        organization_id: &str,
390        email: &str,
391    ) -> AuthResult<Option<Self::Invitation>> {
392        self.inner
393            .get_pending_invitation(organization_id, email)
394            .await
395    }
396
397    async fn update_invitation_status(
398        &self,
399        id: &str,
400        status: InvitationStatus,
401    ) -> AuthResult<Self::Invitation> {
402        self.inner.update_invitation_status(id, status).await
403    }
404
405    async fn list_organization_invitations(
406        &self,
407        organization_id: &str,
408    ) -> AuthResult<Vec<Self::Invitation>> {
409        self.inner
410            .list_organization_invitations(organization_id)
411            .await
412    }
413
414    async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
415        self.inner.list_user_invitations(email).await
416    }
417}
418
419#[async_trait]
420impl<DB: DatabaseAdapter> TwoFactorOps for HookedDatabaseAdapter<DB> {
421    type TwoFactor = DB::TwoFactor;
422
423    async fn create_two_factor(&self, two_factor: CreateTwoFactor) -> AuthResult<Self::TwoFactor> {
424        self.inner.create_two_factor(two_factor).await
425    }
426
427    async fn get_two_factor_by_user_id(
428        &self,
429        user_id: &str,
430    ) -> AuthResult<Option<Self::TwoFactor>> {
431        self.inner.get_two_factor_by_user_id(user_id).await
432    }
433
434    async fn update_two_factor_backup_codes(
435        &self,
436        user_id: &str,
437        backup_codes: &str,
438    ) -> AuthResult<Self::TwoFactor> {
439        self.inner
440            .update_two_factor_backup_codes(user_id, backup_codes)
441            .await
442    }
443
444    async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
445        self.inner.delete_two_factor(user_id).await
446    }
447}
448
449#[async_trait]
450impl<DB: DatabaseAdapter> ApiKeyOps for HookedDatabaseAdapter<DB> {
451    type ApiKey = DB::ApiKey;
452
453    async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<Self::ApiKey> {
454        self.inner.create_api_key(input).await
455    }
456
457    async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<Self::ApiKey>> {
458        self.inner.get_api_key_by_id(id).await
459    }
460
461    async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<Self::ApiKey>> {
462        self.inner.get_api_key_by_hash(hash).await
463    }
464
465    async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<Self::ApiKey>> {
466        self.inner.list_api_keys_by_user(user_id).await
467    }
468
469    async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<Self::ApiKey> {
470        self.inner.update_api_key(id, update).await
471    }
472
473    async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
474        self.inner.delete_api_key(id).await
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::adapters::MemoryDatabaseAdapter;
482    use crate::types::{CreateUser, UpdateUser, User};
483    use std::sync::atomic::{AtomicU32, Ordering};
484
485    struct CountingHook {
486        before_create_count: AtomicU32,
487        after_create_count: AtomicU32,
488        before_update_count: AtomicU32,
489        after_update_count: AtomicU32,
490        before_delete_count: AtomicU32,
491        after_delete_count: AtomicU32,
492    }
493
494    impl CountingHook {
495        fn new() -> Self {
496            Self {
497                before_create_count: AtomicU32::new(0),
498                after_create_count: AtomicU32::new(0),
499                before_update_count: AtomicU32::new(0),
500                after_update_count: AtomicU32::new(0),
501                before_delete_count: AtomicU32::new(0),
502                after_delete_count: AtomicU32::new(0),
503            }
504        }
505    }
506
507    #[async_trait]
508    impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
509        async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
510            self.before_create_count.fetch_add(1, Ordering::SeqCst);
511            Ok(())
512        }
513        async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
514            self.after_create_count.fetch_add(1, Ordering::SeqCst);
515            Ok(())
516        }
517        async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
518            self.before_update_count.fetch_add(1, Ordering::SeqCst);
519            Ok(())
520        }
521        async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
522            self.after_update_count.fetch_add(1, Ordering::SeqCst);
523            Ok(())
524        }
525        async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
526            self.before_delete_count.fetch_add(1, Ordering::SeqCst);
527            Ok(())
528        }
529        async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
530            self.after_delete_count.fetch_add(1, Ordering::SeqCst);
531            Ok(())
532        }
533    }
534
535    #[tokio::test]
536    async fn test_hooks_called_on_create_user() {
537        let hook = Arc::new(CountingHook::new());
538        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
539            .with_hook(hook.clone());
540
541        let create = CreateUser::new()
542            .with_email("test@example.com")
543            .with_name("Test");
544        db.create_user(create).await.unwrap();
545
546        assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
547        assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
548    }
549
550    #[tokio::test]
551    async fn test_hooks_called_on_update_user() {
552        let hook = Arc::new(CountingHook::new());
553        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
554            .with_hook(hook.clone());
555
556        let create = CreateUser::new()
557            .with_email("test@example.com")
558            .with_name("Test");
559        let user = db.create_user(create).await.unwrap();
560
561        let update = UpdateUser {
562            name: Some("Updated".to_string()),
563            email: None,
564            image: None,
565            email_verified: None,
566            username: None,
567            display_username: None,
568            role: None,
569            banned: None,
570            ban_reason: None,
571            ban_expires: None,
572            two_factor_enabled: None,
573            metadata: None,
574        };
575        db.update_user(&user.id, update).await.unwrap();
576
577        assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
578        assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
579    }
580
581    #[tokio::test]
582    async fn test_hooks_called_on_delete_user() {
583        let hook = Arc::new(CountingHook::new());
584        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
585            .with_hook(hook.clone());
586
587        let create = CreateUser::new()
588            .with_email("test@example.com")
589            .with_name("Test");
590        let user = db.create_user(create).await.unwrap();
591
592        db.delete_user(&user.id).await.unwrap();
593
594        assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
595        assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
596    }
597
598    #[tokio::test]
599    async fn test_before_hook_can_reject() {
600        struct RejectHook;
601
602        #[async_trait]
603        impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
604            async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
605                Err(crate::error::AuthError::forbidden("Hook rejected"))
606            }
607        }
608
609        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
610            .with_hook(Arc::new(RejectHook));
611
612        let create = CreateUser::new()
613            .with_email("test@example.com")
614            .with_name("Test");
615        let result = db.create_user(create).await;
616
617        assert!(result.is_err());
618        assert_eq!(result.unwrap_err().status_code(), 403);
619    }
620
621    #[tokio::test]
622    async fn test_multiple_hooks() {
623        let hook1 = Arc::new(CountingHook::new());
624        let hook2 = Arc::new(CountingHook::new());
625        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
626            .with_hook(hook1.clone())
627            .with_hook(hook2.clone());
628
629        let create = CreateUser::new()
630            .with_email("test@example.com")
631            .with_name("Test");
632        db.create_user(create).await.unwrap();
633
634        assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
635        assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
636        assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
637        assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
638    }
639
640    #[tokio::test]
641    async fn test_passthrough_operations() {
642        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
643
644        let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
645        assert!(result.is_none());
646    }
647}