Skip to main content

better_auth_core/
hooks.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7    AccountOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, UserOps, VerificationOps,
8};
9use crate::error::AuthResult;
10use crate::types::{
11    CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession, CreateUser,
12    CreateVerification, InvitationStatus, UpdateOrganization, UpdateUser,
13};
14
15/// Database lifecycle hooks for intercepting operations.
16///
17/// All methods have default no-op implementations. Override only the hooks
18/// you need. Returning `Err` from a `before_*` hook aborts the operation.
19///
20/// The `DB` type parameter determines the concrete entity types used in
21/// `after_*` hooks (e.g., `after_create_user` receives `&DB::User`).
22#[async_trait]
23pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
24    async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
25        let _ = user;
26        Ok(())
27    }
28
29    async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
30        let _ = user;
31        Ok(())
32    }
33
34    async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
35        let _ = (id, update);
36        Ok(())
37    }
38
39    async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
40        let _ = user;
41        Ok(())
42    }
43
44    async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
45        let _ = id;
46        Ok(())
47    }
48
49    async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
50        let _ = id;
51        Ok(())
52    }
53
54    async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
55        let _ = session;
56        Ok(())
57    }
58
59    async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
60        let _ = session;
61        Ok(())
62    }
63
64    async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
65        let _ = token;
66        Ok(())
67    }
68
69    async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
70        let _ = token;
71        Ok(())
72    }
73}
74
75/// A database adapter wrapper that calls hooks around the inner adapter's operations.
76pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
77    inner: Arc<DB>,
78    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
79}
80
81impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
82    pub fn new(inner: Arc<DB>) -> Self {
83        Self {
84            inner,
85            hooks: Vec::new(),
86        }
87    }
88
89    pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
90        self.hooks.push(hook);
91        self
92    }
93
94    pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
95        self.hooks.push(hook);
96    }
97}
98
99#[async_trait]
100impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
101    type User = DB::User;
102
103    async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
104        for hook in &self.hooks {
105            hook.before_create_user(&mut user).await?;
106        }
107        let result = self.inner.create_user(user).await?;
108        for hook in &self.hooks {
109            hook.after_create_user(&result).await?;
110        }
111        Ok(result)
112    }
113
114    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
115        self.inner.get_user_by_id(id).await
116    }
117
118    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
119        self.inner.get_user_by_email(email).await
120    }
121
122    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
123        self.inner.get_user_by_username(username).await
124    }
125
126    async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
127        for hook in &self.hooks {
128            hook.before_update_user(id, &mut update).await?;
129        }
130        let result = self.inner.update_user(id, update).await?;
131        for hook in &self.hooks {
132            hook.after_update_user(&result).await?;
133        }
134        Ok(result)
135    }
136
137    async fn delete_user(&self, id: &str) -> AuthResult<()> {
138        for hook in &self.hooks {
139            hook.before_delete_user(id).await?;
140        }
141        self.inner.delete_user(id).await?;
142        for hook in &self.hooks {
143            hook.after_delete_user(id).await?;
144        }
145        Ok(())
146    }
147}
148
149#[async_trait]
150impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
151    type Session = DB::Session;
152
153    async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
154        for hook in &self.hooks {
155            hook.before_create_session(&mut session).await?;
156        }
157        let result = self.inner.create_session(session).await?;
158        for hook in &self.hooks {
159            hook.after_create_session(&result).await?;
160        }
161        Ok(result)
162    }
163
164    async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
165        self.inner.get_session(token).await
166    }
167
168    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
169        self.inner.get_user_sessions(user_id).await
170    }
171
172    async fn update_session_expiry(
173        &self,
174        token: &str,
175        expires_at: DateTime<Utc>,
176    ) -> AuthResult<()> {
177        self.inner.update_session_expiry(token, expires_at).await
178    }
179
180    async fn delete_session(&self, token: &str) -> AuthResult<()> {
181        for hook in &self.hooks {
182            hook.before_delete_session(token).await?;
183        }
184        self.inner.delete_session(token).await?;
185        for hook in &self.hooks {
186            hook.after_delete_session(token).await?;
187        }
188        Ok(())
189    }
190
191    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
192        self.inner.delete_user_sessions(user_id).await
193    }
194
195    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
196        self.inner.delete_expired_sessions().await
197    }
198
199    async fn update_session_active_organization(
200        &self,
201        token: &str,
202        organization_id: Option<&str>,
203    ) -> AuthResult<Self::Session> {
204        self.inner
205            .update_session_active_organization(token, organization_id)
206            .await
207    }
208}
209
210#[async_trait]
211impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
212    type Account = DB::Account;
213
214    async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
215        self.inner.create_account(account).await
216    }
217
218    async fn get_account(
219        &self,
220        provider: &str,
221        provider_account_id: &str,
222    ) -> AuthResult<Option<Self::Account>> {
223        self.inner.get_account(provider, provider_account_id).await
224    }
225
226    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
227        self.inner.get_user_accounts(user_id).await
228    }
229
230    async fn delete_account(&self, id: &str) -> AuthResult<()> {
231        self.inner.delete_account(id).await
232    }
233}
234
235#[async_trait]
236impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
237    type Verification = DB::Verification;
238
239    async fn create_verification(
240        &self,
241        verification: CreateVerification,
242    ) -> AuthResult<Self::Verification> {
243        self.inner.create_verification(verification).await
244    }
245
246    async fn get_verification(
247        &self,
248        identifier: &str,
249        value: &str,
250    ) -> AuthResult<Option<Self::Verification>> {
251        self.inner.get_verification(identifier, value).await
252    }
253
254    async fn get_verification_by_value(
255        &self,
256        value: &str,
257    ) -> AuthResult<Option<Self::Verification>> {
258        self.inner.get_verification_by_value(value).await
259    }
260
261    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
262        self.inner.delete_verification(id).await
263    }
264
265    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
266        self.inner.delete_expired_verifications().await
267    }
268}
269
270#[async_trait]
271impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
272    type Organization = DB::Organization;
273
274    async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
275        self.inner.create_organization(org).await
276    }
277
278    async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
279        self.inner.get_organization_by_id(id).await
280    }
281
282    async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
283        self.inner.get_organization_by_slug(slug).await
284    }
285
286    async fn update_organization(
287        &self,
288        id: &str,
289        update: UpdateOrganization,
290    ) -> AuthResult<Self::Organization> {
291        self.inner.update_organization(id, update).await
292    }
293
294    async fn delete_organization(&self, id: &str) -> AuthResult<()> {
295        self.inner.delete_organization(id).await
296    }
297
298    async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
299        self.inner.list_user_organizations(user_id).await
300    }
301}
302
303#[async_trait]
304impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
305    type Member = DB::Member;
306
307    async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
308        self.inner.create_member(member).await
309    }
310
311    async fn get_member(
312        &self,
313        organization_id: &str,
314        user_id: &str,
315    ) -> AuthResult<Option<Self::Member>> {
316        self.inner.get_member(organization_id, user_id).await
317    }
318
319    async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
320        self.inner.get_member_by_id(id).await
321    }
322
323    async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
324        self.inner.update_member_role(member_id, role).await
325    }
326
327    async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
328        self.inner.delete_member(member_id).await
329    }
330
331    async fn list_organization_members(
332        &self,
333        organization_id: &str,
334    ) -> AuthResult<Vec<Self::Member>> {
335        self.inner.list_organization_members(organization_id).await
336    }
337
338    async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
339        self.inner.count_organization_members(organization_id).await
340    }
341
342    async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
343        self.inner.count_organization_owners(organization_id).await
344    }
345}
346
347#[async_trait]
348impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
349    type Invitation = DB::Invitation;
350
351    async fn create_invitation(
352        &self,
353        invitation: CreateInvitation,
354    ) -> AuthResult<Self::Invitation> {
355        self.inner.create_invitation(invitation).await
356    }
357
358    async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
359        self.inner.get_invitation_by_id(id).await
360    }
361
362    async fn get_pending_invitation(
363        &self,
364        organization_id: &str,
365        email: &str,
366    ) -> AuthResult<Option<Self::Invitation>> {
367        self.inner
368            .get_pending_invitation(organization_id, email)
369            .await
370    }
371
372    async fn update_invitation_status(
373        &self,
374        id: &str,
375        status: InvitationStatus,
376    ) -> AuthResult<Self::Invitation> {
377        self.inner.update_invitation_status(id, status).await
378    }
379
380    async fn list_organization_invitations(
381        &self,
382        organization_id: &str,
383    ) -> AuthResult<Vec<Self::Invitation>> {
384        self.inner
385            .list_organization_invitations(organization_id)
386            .await
387    }
388
389    async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
390        self.inner.list_user_invitations(email).await
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::adapters::MemoryDatabaseAdapter;
398    use crate::types::{CreateUser, UpdateUser, User};
399    use std::sync::atomic::{AtomicU32, Ordering};
400
401    struct CountingHook {
402        before_create_count: AtomicU32,
403        after_create_count: AtomicU32,
404        before_update_count: AtomicU32,
405        after_update_count: AtomicU32,
406        before_delete_count: AtomicU32,
407        after_delete_count: AtomicU32,
408    }
409
410    impl CountingHook {
411        fn new() -> Self {
412            Self {
413                before_create_count: AtomicU32::new(0),
414                after_create_count: AtomicU32::new(0),
415                before_update_count: AtomicU32::new(0),
416                after_update_count: AtomicU32::new(0),
417                before_delete_count: AtomicU32::new(0),
418                after_delete_count: AtomicU32::new(0),
419            }
420        }
421    }
422
423    #[async_trait]
424    impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
425        async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
426            self.before_create_count.fetch_add(1, Ordering::SeqCst);
427            Ok(())
428        }
429        async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
430            self.after_create_count.fetch_add(1, Ordering::SeqCst);
431            Ok(())
432        }
433        async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
434            self.before_update_count.fetch_add(1, Ordering::SeqCst);
435            Ok(())
436        }
437        async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
438            self.after_update_count.fetch_add(1, Ordering::SeqCst);
439            Ok(())
440        }
441        async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
442            self.before_delete_count.fetch_add(1, Ordering::SeqCst);
443            Ok(())
444        }
445        async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
446            self.after_delete_count.fetch_add(1, Ordering::SeqCst);
447            Ok(())
448        }
449    }
450
451    #[tokio::test]
452    async fn test_hooks_called_on_create_user() {
453        let hook = Arc::new(CountingHook::new());
454        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
455            .with_hook(hook.clone());
456
457        let create = CreateUser::new()
458            .with_email("test@example.com")
459            .with_name("Test");
460        db.create_user(create).await.unwrap();
461
462        assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
463        assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
464    }
465
466    #[tokio::test]
467    async fn test_hooks_called_on_update_user() {
468        let hook = Arc::new(CountingHook::new());
469        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
470            .with_hook(hook.clone());
471
472        let create = CreateUser::new()
473            .with_email("test@example.com")
474            .with_name("Test");
475        let user = db.create_user(create).await.unwrap();
476
477        let update = UpdateUser {
478            name: Some("Updated".to_string()),
479            email: None,
480            image: None,
481            email_verified: None,
482            username: None,
483            display_username: None,
484            role: None,
485            banned: None,
486            ban_reason: None,
487            ban_expires: None,
488            two_factor_enabled: None,
489            metadata: None,
490        };
491        db.update_user(&user.id, update).await.unwrap();
492
493        assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
494        assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
495    }
496
497    #[tokio::test]
498    async fn test_hooks_called_on_delete_user() {
499        let hook = Arc::new(CountingHook::new());
500        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
501            .with_hook(hook.clone());
502
503        let create = CreateUser::new()
504            .with_email("test@example.com")
505            .with_name("Test");
506        let user = db.create_user(create).await.unwrap();
507
508        db.delete_user(&user.id).await.unwrap();
509
510        assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
511        assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
512    }
513
514    #[tokio::test]
515    async fn test_before_hook_can_reject() {
516        struct RejectHook;
517
518        #[async_trait]
519        impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
520            async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
521                Err(crate::error::AuthError::forbidden("Hook rejected"))
522            }
523        }
524
525        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
526            .with_hook(Arc::new(RejectHook));
527
528        let create = CreateUser::new()
529            .with_email("test@example.com")
530            .with_name("Test");
531        let result = db.create_user(create).await;
532
533        assert!(result.is_err());
534        assert_eq!(result.unwrap_err().status_code(), 403);
535    }
536
537    #[tokio::test]
538    async fn test_multiple_hooks() {
539        let hook1 = Arc::new(CountingHook::new());
540        let hook2 = Arc::new(CountingHook::new());
541        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
542            .with_hook(hook1.clone())
543            .with_hook(hook2.clone());
544
545        let create = CreateUser::new()
546            .with_email("test@example.com")
547            .with_name("Test");
548        db.create_user(create).await.unwrap();
549
550        assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
551        assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
552        assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
553        assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
554    }
555
556    #[tokio::test]
557    async fn test_passthrough_operations() {
558        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
559
560        let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
561        assert!(result.is_none());
562    }
563}