Skip to main content

better_auth_core/
hooks.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7    AccountOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, TwoFactorOps, UserOps,
8    VerificationOps,
9};
10use crate::error::AuthResult;
11use crate::types::{
12    CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
13    CreateTwoFactor, CreateUser, CreateVerification, InvitationStatus, UpdateAccount,
14    UpdateOrganization, UpdateUser,
15};
16
17/// Database lifecycle hooks for intercepting operations.
18///
19/// All methods have default no-op implementations. Override only the hooks
20/// you need. Returning `Err` from a `before_*` hook aborts the operation.
21///
22/// The `DB` type parameter determines the concrete entity types used in
23/// `after_*` hooks (e.g., `after_create_user` receives `&DB::User`).
24#[async_trait]
25pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
26    async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
27        let _ = user;
28        Ok(())
29    }
30
31    async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
32        let _ = user;
33        Ok(())
34    }
35
36    async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
37        let _ = (id, update);
38        Ok(())
39    }
40
41    async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
42        let _ = user;
43        Ok(())
44    }
45
46    async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
47        let _ = id;
48        Ok(())
49    }
50
51    async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52        let _ = id;
53        Ok(())
54    }
55
56    async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
57        let _ = session;
58        Ok(())
59    }
60
61    async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
62        let _ = session;
63        Ok(())
64    }
65
66    async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
67        let _ = token;
68        Ok(())
69    }
70
71    async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
72        let _ = token;
73        Ok(())
74    }
75}
76
77/// A database adapter wrapper that calls hooks around the inner adapter's operations.
78pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
79    inner: Arc<DB>,
80    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
81}
82
83impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
84    pub fn new(inner: Arc<DB>) -> Self {
85        Self {
86            inner,
87            hooks: Vec::new(),
88        }
89    }
90
91    pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
92        self.hooks.push(hook);
93        self
94    }
95
96    pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
97        self.hooks.push(hook);
98    }
99}
100
101#[async_trait]
102impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
103    type User = DB::User;
104
105    async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
106        for hook in &self.hooks {
107            hook.before_create_user(&mut user).await?;
108        }
109        let result = self.inner.create_user(user).await?;
110        for hook in &self.hooks {
111            hook.after_create_user(&result).await?;
112        }
113        Ok(result)
114    }
115
116    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
117        self.inner.get_user_by_id(id).await
118    }
119
120    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
121        self.inner.get_user_by_email(email).await
122    }
123
124    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
125        self.inner.get_user_by_username(username).await
126    }
127
128    async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
129        for hook in &self.hooks {
130            hook.before_update_user(id, &mut update).await?;
131        }
132        let result = self.inner.update_user(id, update).await?;
133        for hook in &self.hooks {
134            hook.after_update_user(&result).await?;
135        }
136        Ok(result)
137    }
138
139    async fn delete_user(&self, id: &str) -> AuthResult<()> {
140        for hook in &self.hooks {
141            hook.before_delete_user(id).await?;
142        }
143        self.inner.delete_user(id).await?;
144        for hook in &self.hooks {
145            hook.after_delete_user(id).await?;
146        }
147        Ok(())
148    }
149}
150
151#[async_trait]
152impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
153    type Session = DB::Session;
154
155    async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
156        for hook in &self.hooks {
157            hook.before_create_session(&mut session).await?;
158        }
159        let result = self.inner.create_session(session).await?;
160        for hook in &self.hooks {
161            hook.after_create_session(&result).await?;
162        }
163        Ok(result)
164    }
165
166    async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
167        self.inner.get_session(token).await
168    }
169
170    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
171        self.inner.get_user_sessions(user_id).await
172    }
173
174    async fn update_session_expiry(
175        &self,
176        token: &str,
177        expires_at: DateTime<Utc>,
178    ) -> AuthResult<()> {
179        self.inner.update_session_expiry(token, expires_at).await
180    }
181
182    async fn delete_session(&self, token: &str) -> AuthResult<()> {
183        for hook in &self.hooks {
184            hook.before_delete_session(token).await?;
185        }
186        self.inner.delete_session(token).await?;
187        for hook in &self.hooks {
188            hook.after_delete_session(token).await?;
189        }
190        Ok(())
191    }
192
193    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
194        self.inner.delete_user_sessions(user_id).await
195    }
196
197    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
198        self.inner.delete_expired_sessions().await
199    }
200
201    async fn update_session_active_organization(
202        &self,
203        token: &str,
204        organization_id: Option<&str>,
205    ) -> AuthResult<Self::Session> {
206        self.inner
207            .update_session_active_organization(token, organization_id)
208            .await
209    }
210}
211
212#[async_trait]
213impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
214    type Account = DB::Account;
215
216    async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
217        self.inner.create_account(account).await
218    }
219
220    async fn get_account(
221        &self,
222        provider: &str,
223        provider_account_id: &str,
224    ) -> AuthResult<Option<Self::Account>> {
225        self.inner.get_account(provider, provider_account_id).await
226    }
227
228    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
229        self.inner.get_user_accounts(user_id).await
230    }
231
232    async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<Self::Account> {
233        self.inner.update_account(id, update).await
234    }
235
236    async fn delete_account(&self, id: &str) -> AuthResult<()> {
237        self.inner.delete_account(id).await
238    }
239}
240
241#[async_trait]
242impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
243    type Verification = DB::Verification;
244
245    async fn create_verification(
246        &self,
247        verification: CreateVerification,
248    ) -> AuthResult<Self::Verification> {
249        self.inner.create_verification(verification).await
250    }
251
252    async fn get_verification(
253        &self,
254        identifier: &str,
255        value: &str,
256    ) -> AuthResult<Option<Self::Verification>> {
257        self.inner.get_verification(identifier, value).await
258    }
259
260    async fn get_verification_by_value(
261        &self,
262        value: &str,
263    ) -> AuthResult<Option<Self::Verification>> {
264        self.inner.get_verification_by_value(value).await
265    }
266
267    async fn get_verification_by_identifier(
268        &self,
269        identifier: &str,
270    ) -> AuthResult<Option<Self::Verification>> {
271        self.inner.get_verification_by_identifier(identifier).await
272    }
273
274    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
275        self.inner.delete_verification(id).await
276    }
277
278    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
279        self.inner.delete_expired_verifications().await
280    }
281}
282
283#[async_trait]
284impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
285    type Organization = DB::Organization;
286
287    async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
288        self.inner.create_organization(org).await
289    }
290
291    async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
292        self.inner.get_organization_by_id(id).await
293    }
294
295    async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
296        self.inner.get_organization_by_slug(slug).await
297    }
298
299    async fn update_organization(
300        &self,
301        id: &str,
302        update: UpdateOrganization,
303    ) -> AuthResult<Self::Organization> {
304        self.inner.update_organization(id, update).await
305    }
306
307    async fn delete_organization(&self, id: &str) -> AuthResult<()> {
308        self.inner.delete_organization(id).await
309    }
310
311    async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
312        self.inner.list_user_organizations(user_id).await
313    }
314}
315
316#[async_trait]
317impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
318    type Member = DB::Member;
319
320    async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
321        self.inner.create_member(member).await
322    }
323
324    async fn get_member(
325        &self,
326        organization_id: &str,
327        user_id: &str,
328    ) -> AuthResult<Option<Self::Member>> {
329        self.inner.get_member(organization_id, user_id).await
330    }
331
332    async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
333        self.inner.get_member_by_id(id).await
334    }
335
336    async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
337        self.inner.update_member_role(member_id, role).await
338    }
339
340    async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
341        self.inner.delete_member(member_id).await
342    }
343
344    async fn list_organization_members(
345        &self,
346        organization_id: &str,
347    ) -> AuthResult<Vec<Self::Member>> {
348        self.inner.list_organization_members(organization_id).await
349    }
350
351    async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
352        self.inner.count_organization_members(organization_id).await
353    }
354
355    async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
356        self.inner.count_organization_owners(organization_id).await
357    }
358}
359
360#[async_trait]
361impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
362    type Invitation = DB::Invitation;
363
364    async fn create_invitation(
365        &self,
366        invitation: CreateInvitation,
367    ) -> AuthResult<Self::Invitation> {
368        self.inner.create_invitation(invitation).await
369    }
370
371    async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
372        self.inner.get_invitation_by_id(id).await
373    }
374
375    async fn get_pending_invitation(
376        &self,
377        organization_id: &str,
378        email: &str,
379    ) -> AuthResult<Option<Self::Invitation>> {
380        self.inner
381            .get_pending_invitation(organization_id, email)
382            .await
383    }
384
385    async fn update_invitation_status(
386        &self,
387        id: &str,
388        status: InvitationStatus,
389    ) -> AuthResult<Self::Invitation> {
390        self.inner.update_invitation_status(id, status).await
391    }
392
393    async fn list_organization_invitations(
394        &self,
395        organization_id: &str,
396    ) -> AuthResult<Vec<Self::Invitation>> {
397        self.inner
398            .list_organization_invitations(organization_id)
399            .await
400    }
401
402    async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
403        self.inner.list_user_invitations(email).await
404    }
405}
406
407#[async_trait]
408impl<DB: DatabaseAdapter> TwoFactorOps for HookedDatabaseAdapter<DB> {
409    type TwoFactor = DB::TwoFactor;
410
411    async fn create_two_factor(&self, two_factor: CreateTwoFactor) -> AuthResult<Self::TwoFactor> {
412        self.inner.create_two_factor(two_factor).await
413    }
414
415    async fn get_two_factor_by_user_id(
416        &self,
417        user_id: &str,
418    ) -> AuthResult<Option<Self::TwoFactor>> {
419        self.inner.get_two_factor_by_user_id(user_id).await
420    }
421
422    async fn update_two_factor_backup_codes(
423        &self,
424        user_id: &str,
425        backup_codes: &str,
426    ) -> AuthResult<Self::TwoFactor> {
427        self.inner
428            .update_two_factor_backup_codes(user_id, backup_codes)
429            .await
430    }
431
432    async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
433        self.inner.delete_two_factor(user_id).await
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use crate::adapters::MemoryDatabaseAdapter;
441    use crate::types::{CreateUser, UpdateUser, User};
442    use std::sync::atomic::{AtomicU32, Ordering};
443
444    struct CountingHook {
445        before_create_count: AtomicU32,
446        after_create_count: AtomicU32,
447        before_update_count: AtomicU32,
448        after_update_count: AtomicU32,
449        before_delete_count: AtomicU32,
450        after_delete_count: AtomicU32,
451    }
452
453    impl CountingHook {
454        fn new() -> Self {
455            Self {
456                before_create_count: AtomicU32::new(0),
457                after_create_count: AtomicU32::new(0),
458                before_update_count: AtomicU32::new(0),
459                after_update_count: AtomicU32::new(0),
460                before_delete_count: AtomicU32::new(0),
461                after_delete_count: AtomicU32::new(0),
462            }
463        }
464    }
465
466    #[async_trait]
467    impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
468        async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
469            self.before_create_count.fetch_add(1, Ordering::SeqCst);
470            Ok(())
471        }
472        async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
473            self.after_create_count.fetch_add(1, Ordering::SeqCst);
474            Ok(())
475        }
476        async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
477            self.before_update_count.fetch_add(1, Ordering::SeqCst);
478            Ok(())
479        }
480        async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
481            self.after_update_count.fetch_add(1, Ordering::SeqCst);
482            Ok(())
483        }
484        async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
485            self.before_delete_count.fetch_add(1, Ordering::SeqCst);
486            Ok(())
487        }
488        async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
489            self.after_delete_count.fetch_add(1, Ordering::SeqCst);
490            Ok(())
491        }
492    }
493
494    #[tokio::test]
495    async fn test_hooks_called_on_create_user() {
496        let hook = Arc::new(CountingHook::new());
497        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
498            .with_hook(hook.clone());
499
500        let create = CreateUser::new()
501            .with_email("test@example.com")
502            .with_name("Test");
503        db.create_user(create).await.unwrap();
504
505        assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
506        assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
507    }
508
509    #[tokio::test]
510    async fn test_hooks_called_on_update_user() {
511        let hook = Arc::new(CountingHook::new());
512        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
513            .with_hook(hook.clone());
514
515        let create = CreateUser::new()
516            .with_email("test@example.com")
517            .with_name("Test");
518        let user = db.create_user(create).await.unwrap();
519
520        let update = UpdateUser {
521            name: Some("Updated".to_string()),
522            email: None,
523            image: None,
524            email_verified: None,
525            username: None,
526            display_username: None,
527            role: None,
528            banned: None,
529            ban_reason: None,
530            ban_expires: None,
531            two_factor_enabled: None,
532            metadata: None,
533        };
534        db.update_user(&user.id, update).await.unwrap();
535
536        assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
537        assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
538    }
539
540    #[tokio::test]
541    async fn test_hooks_called_on_delete_user() {
542        let hook = Arc::new(CountingHook::new());
543        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
544            .with_hook(hook.clone());
545
546        let create = CreateUser::new()
547            .with_email("test@example.com")
548            .with_name("Test");
549        let user = db.create_user(create).await.unwrap();
550
551        db.delete_user(&user.id).await.unwrap();
552
553        assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
554        assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
555    }
556
557    #[tokio::test]
558    async fn test_before_hook_can_reject() {
559        struct RejectHook;
560
561        #[async_trait]
562        impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
563            async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
564                Err(crate::error::AuthError::forbidden("Hook rejected"))
565            }
566        }
567
568        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
569            .with_hook(Arc::new(RejectHook));
570
571        let create = CreateUser::new()
572            .with_email("test@example.com")
573            .with_name("Test");
574        let result = db.create_user(create).await;
575
576        assert!(result.is_err());
577        assert_eq!(result.unwrap_err().status_code(), 403);
578    }
579
580    #[tokio::test]
581    async fn test_multiple_hooks() {
582        let hook1 = Arc::new(CountingHook::new());
583        let hook2 = Arc::new(CountingHook::new());
584        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
585            .with_hook(hook1.clone())
586            .with_hook(hook2.clone());
587
588        let create = CreateUser::new()
589            .with_email("test@example.com")
590            .with_name("Test");
591        db.create_user(create).await.unwrap();
592
593        assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
594        assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
595        assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
596        assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
597    }
598
599    #[tokio::test]
600    async fn test_passthrough_operations() {
601        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
602
603        let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
604        assert!(result.is_none());
605    }
606}