Skip to main content

better_auth_core/
hooks.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::adapters::database::{
7    AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, TwoFactorOps,
8    UserOps, VerificationOps,
9};
10use crate::error::AuthResult;
11use crate::types::{
12    CreateAccount, CreateApiKey, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
13    CreateTwoFactor, CreateUser, CreateVerification, InvitationStatus, UpdateAccount, UpdateApiKey,
14    UpdateOrganization, UpdateUser,
15};
16
17/// Database lifecycle hooks for intercepting operations.
18///
19/// All methods have default no-op implementations. Override only the hooks
20/// you need. Returning `Err` from a `before_*` hook aborts the operation.
21///
22/// The `DB` type parameter determines the concrete entity types used in
23/// `after_*` hooks (e.g., `after_create_user` receives `&DB::User`).
24#[async_trait]
25pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
26    async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
27        let _ = user;
28        Ok(())
29    }
30
31    async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
32        let _ = user;
33        Ok(())
34    }
35
36    async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
37        let _ = (id, update);
38        Ok(())
39    }
40
41    async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
42        let _ = user;
43        Ok(())
44    }
45
46    async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
47        let _ = id;
48        Ok(())
49    }
50
51    async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52        let _ = id;
53        Ok(())
54    }
55
56    async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
57        let _ = session;
58        Ok(())
59    }
60
61    async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
62        let _ = session;
63        Ok(())
64    }
65
66    async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
67        let _ = token;
68        Ok(())
69    }
70
71    async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
72        let _ = token;
73        Ok(())
74    }
75}
76
77/// A database adapter wrapper that calls hooks around the inner adapter's operations.
78pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
79    inner: Arc<DB>,
80    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
81}
82
83impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
84    pub fn new(inner: Arc<DB>) -> Self {
85        Self {
86            inner,
87            hooks: Vec::new(),
88        }
89    }
90
91    pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
92        self.hooks.push(hook);
93        self
94    }
95
96    pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
97        self.hooks.push(hook);
98    }
99}
100
101#[async_trait]
102impl<DB: DatabaseAdapter> UserOps for HookedDatabaseAdapter<DB> {
103    type User = DB::User;
104
105    async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
106        for hook in &self.hooks {
107            hook.before_create_user(&mut user).await?;
108        }
109        let result = self.inner.create_user(user).await?;
110        for hook in &self.hooks {
111            hook.after_create_user(&result).await?;
112        }
113        Ok(result)
114    }
115
116    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
117        self.inner.get_user_by_id(id).await
118    }
119
120    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
121        self.inner.get_user_by_email(email).await
122    }
123
124    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
125        self.inner.get_user_by_username(username).await
126    }
127
128    async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
129        for hook in &self.hooks {
130            hook.before_update_user(id, &mut update).await?;
131        }
132        let result = self.inner.update_user(id, update).await?;
133        for hook in &self.hooks {
134            hook.after_update_user(&result).await?;
135        }
136        Ok(result)
137    }
138
139    async fn delete_user(&self, id: &str) -> AuthResult<()> {
140        for hook in &self.hooks {
141            hook.before_delete_user(id).await?;
142        }
143        self.inner.delete_user(id).await?;
144        for hook in &self.hooks {
145            hook.after_delete_user(id).await?;
146        }
147        Ok(())
148    }
149}
150
151#[async_trait]
152impl<DB: DatabaseAdapter> SessionOps for HookedDatabaseAdapter<DB> {
153    type Session = DB::Session;
154
155    async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
156        for hook in &self.hooks {
157            hook.before_create_session(&mut session).await?;
158        }
159        let result = self.inner.create_session(session).await?;
160        for hook in &self.hooks {
161            hook.after_create_session(&result).await?;
162        }
163        Ok(result)
164    }
165
166    async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
167        self.inner.get_session(token).await
168    }
169
170    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
171        self.inner.get_user_sessions(user_id).await
172    }
173
174    async fn update_session_expiry(
175        &self,
176        token: &str,
177        expires_at: DateTime<Utc>,
178    ) -> AuthResult<()> {
179        self.inner.update_session_expiry(token, expires_at).await
180    }
181
182    async fn delete_session(&self, token: &str) -> AuthResult<()> {
183        for hook in &self.hooks {
184            hook.before_delete_session(token).await?;
185        }
186        self.inner.delete_session(token).await?;
187        for hook in &self.hooks {
188            hook.after_delete_session(token).await?;
189        }
190        Ok(())
191    }
192
193    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
194        self.inner.delete_user_sessions(user_id).await
195    }
196
197    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
198        self.inner.delete_expired_sessions().await
199    }
200
201    async fn update_session_active_organization(
202        &self,
203        token: &str,
204        organization_id: Option<&str>,
205    ) -> AuthResult<Self::Session> {
206        self.inner
207            .update_session_active_organization(token, organization_id)
208            .await
209    }
210}
211
212#[async_trait]
213impl<DB: DatabaseAdapter> AccountOps for HookedDatabaseAdapter<DB> {
214    type Account = DB::Account;
215
216    async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
217        self.inner.create_account(account).await
218    }
219
220    async fn get_account(
221        &self,
222        provider: &str,
223        provider_account_id: &str,
224    ) -> AuthResult<Option<Self::Account>> {
225        self.inner.get_account(provider, provider_account_id).await
226    }
227
228    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
229        self.inner.get_user_accounts(user_id).await
230    }
231
232    async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<Self::Account> {
233        self.inner.update_account(id, update).await
234    }
235
236    async fn delete_account(&self, id: &str) -> AuthResult<()> {
237        self.inner.delete_account(id).await
238    }
239}
240
241#[async_trait]
242impl<DB: DatabaseAdapter> VerificationOps for HookedDatabaseAdapter<DB> {
243    type Verification = DB::Verification;
244
245    async fn create_verification(
246        &self,
247        verification: CreateVerification,
248    ) -> AuthResult<Self::Verification> {
249        self.inner.create_verification(verification).await
250    }
251
252    async fn get_verification(
253        &self,
254        identifier: &str,
255        value: &str,
256    ) -> AuthResult<Option<Self::Verification>> {
257        self.inner.get_verification(identifier, value).await
258    }
259
260    async fn get_verification_by_value(
261        &self,
262        value: &str,
263    ) -> AuthResult<Option<Self::Verification>> {
264        self.inner.get_verification_by_value(value).await
265    }
266
267    async fn get_verification_by_identifier(
268        &self,
269        identifier: &str,
270    ) -> AuthResult<Option<Self::Verification>> {
271        self.inner.get_verification_by_identifier(identifier).await
272    }
273
274    async fn consume_verification(
275        &self,
276        identifier: &str,
277        value: &str,
278    ) -> AuthResult<Option<Self::Verification>> {
279        self.inner.consume_verification(identifier, value).await
280    }
281
282    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
283        self.inner.delete_verification(id).await
284    }
285
286    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
287        self.inner.delete_expired_verifications().await
288    }
289}
290
291#[async_trait]
292impl<DB: DatabaseAdapter> OrganizationOps for HookedDatabaseAdapter<DB> {
293    type Organization = DB::Organization;
294
295    async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
296        self.inner.create_organization(org).await
297    }
298
299    async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
300        self.inner.get_organization_by_id(id).await
301    }
302
303    async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
304        self.inner.get_organization_by_slug(slug).await
305    }
306
307    async fn update_organization(
308        &self,
309        id: &str,
310        update: UpdateOrganization,
311    ) -> AuthResult<Self::Organization> {
312        self.inner.update_organization(id, update).await
313    }
314
315    async fn delete_organization(&self, id: &str) -> AuthResult<()> {
316        self.inner.delete_organization(id).await
317    }
318
319    async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
320        self.inner.list_user_organizations(user_id).await
321    }
322}
323
324#[async_trait]
325impl<DB: DatabaseAdapter> MemberOps for HookedDatabaseAdapter<DB> {
326    type Member = DB::Member;
327
328    async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
329        self.inner.create_member(member).await
330    }
331
332    async fn get_member(
333        &self,
334        organization_id: &str,
335        user_id: &str,
336    ) -> AuthResult<Option<Self::Member>> {
337        self.inner.get_member(organization_id, user_id).await
338    }
339
340    async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
341        self.inner.get_member_by_id(id).await
342    }
343
344    async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
345        self.inner.update_member_role(member_id, role).await
346    }
347
348    async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
349        self.inner.delete_member(member_id).await
350    }
351
352    async fn list_organization_members(
353        &self,
354        organization_id: &str,
355    ) -> AuthResult<Vec<Self::Member>> {
356        self.inner.list_organization_members(organization_id).await
357    }
358
359    async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
360        self.inner.count_organization_members(organization_id).await
361    }
362
363    async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
364        self.inner.count_organization_owners(organization_id).await
365    }
366}
367
368#[async_trait]
369impl<DB: DatabaseAdapter> InvitationOps for HookedDatabaseAdapter<DB> {
370    type Invitation = DB::Invitation;
371
372    async fn create_invitation(
373        &self,
374        invitation: CreateInvitation,
375    ) -> AuthResult<Self::Invitation> {
376        self.inner.create_invitation(invitation).await
377    }
378
379    async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
380        self.inner.get_invitation_by_id(id).await
381    }
382
383    async fn get_pending_invitation(
384        &self,
385        organization_id: &str,
386        email: &str,
387    ) -> AuthResult<Option<Self::Invitation>> {
388        self.inner
389            .get_pending_invitation(organization_id, email)
390            .await
391    }
392
393    async fn update_invitation_status(
394        &self,
395        id: &str,
396        status: InvitationStatus,
397    ) -> AuthResult<Self::Invitation> {
398        self.inner.update_invitation_status(id, status).await
399    }
400
401    async fn list_organization_invitations(
402        &self,
403        organization_id: &str,
404    ) -> AuthResult<Vec<Self::Invitation>> {
405        self.inner
406            .list_organization_invitations(organization_id)
407            .await
408    }
409
410    async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
411        self.inner.list_user_invitations(email).await
412    }
413}
414
415#[async_trait]
416impl<DB: DatabaseAdapter> TwoFactorOps for HookedDatabaseAdapter<DB> {
417    type TwoFactor = DB::TwoFactor;
418
419    async fn create_two_factor(&self, two_factor: CreateTwoFactor) -> AuthResult<Self::TwoFactor> {
420        self.inner.create_two_factor(two_factor).await
421    }
422
423    async fn get_two_factor_by_user_id(
424        &self,
425        user_id: &str,
426    ) -> AuthResult<Option<Self::TwoFactor>> {
427        self.inner.get_two_factor_by_user_id(user_id).await
428    }
429
430    async fn update_two_factor_backup_codes(
431        &self,
432        user_id: &str,
433        backup_codes: &str,
434    ) -> AuthResult<Self::TwoFactor> {
435        self.inner
436            .update_two_factor_backup_codes(user_id, backup_codes)
437            .await
438    }
439
440    async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
441        self.inner.delete_two_factor(user_id).await
442    }
443}
444
445#[async_trait]
446impl<DB: DatabaseAdapter> ApiKeyOps for HookedDatabaseAdapter<DB> {
447    type ApiKey = DB::ApiKey;
448
449    async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<Self::ApiKey> {
450        self.inner.create_api_key(input).await
451    }
452
453    async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<Self::ApiKey>> {
454        self.inner.get_api_key_by_id(id).await
455    }
456
457    async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<Self::ApiKey>> {
458        self.inner.get_api_key_by_hash(hash).await
459    }
460
461    async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<Self::ApiKey>> {
462        self.inner.list_api_keys_by_user(user_id).await
463    }
464
465    async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<Self::ApiKey> {
466        self.inner.update_api_key(id, update).await
467    }
468
469    async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
470        self.inner.delete_api_key(id).await
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::adapters::MemoryDatabaseAdapter;
478    use crate::types::{CreateUser, UpdateUser, User};
479    use std::sync::atomic::{AtomicU32, Ordering};
480
481    struct CountingHook {
482        before_create_count: AtomicU32,
483        after_create_count: AtomicU32,
484        before_update_count: AtomicU32,
485        after_update_count: AtomicU32,
486        before_delete_count: AtomicU32,
487        after_delete_count: AtomicU32,
488    }
489
490    impl CountingHook {
491        fn new() -> Self {
492            Self {
493                before_create_count: AtomicU32::new(0),
494                after_create_count: AtomicU32::new(0),
495                before_update_count: AtomicU32::new(0),
496                after_update_count: AtomicU32::new(0),
497                before_delete_count: AtomicU32::new(0),
498                after_delete_count: AtomicU32::new(0),
499            }
500        }
501    }
502
503    #[async_trait]
504    impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
505        async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
506            self.before_create_count.fetch_add(1, Ordering::SeqCst);
507            Ok(())
508        }
509        async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
510            self.after_create_count.fetch_add(1, Ordering::SeqCst);
511            Ok(())
512        }
513        async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
514            self.before_update_count.fetch_add(1, Ordering::SeqCst);
515            Ok(())
516        }
517        async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
518            self.after_update_count.fetch_add(1, Ordering::SeqCst);
519            Ok(())
520        }
521        async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
522            self.before_delete_count.fetch_add(1, Ordering::SeqCst);
523            Ok(())
524        }
525        async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
526            self.after_delete_count.fetch_add(1, Ordering::SeqCst);
527            Ok(())
528        }
529    }
530
531    #[tokio::test]
532    async fn test_hooks_called_on_create_user() {
533        let hook = Arc::new(CountingHook::new());
534        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
535            .with_hook(hook.clone());
536
537        let create = CreateUser::new()
538            .with_email("test@example.com")
539            .with_name("Test");
540        db.create_user(create).await.unwrap();
541
542        assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
543        assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
544    }
545
546    #[tokio::test]
547    async fn test_hooks_called_on_update_user() {
548        let hook = Arc::new(CountingHook::new());
549        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
550            .with_hook(hook.clone());
551
552        let create = CreateUser::new()
553            .with_email("test@example.com")
554            .with_name("Test");
555        let user = db.create_user(create).await.unwrap();
556
557        let update = UpdateUser {
558            name: Some("Updated".to_string()),
559            email: None,
560            image: None,
561            email_verified: None,
562            username: None,
563            display_username: None,
564            role: None,
565            banned: None,
566            ban_reason: None,
567            ban_expires: None,
568            two_factor_enabled: None,
569            metadata: None,
570        };
571        db.update_user(&user.id, update).await.unwrap();
572
573        assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
574        assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
575    }
576
577    #[tokio::test]
578    async fn test_hooks_called_on_delete_user() {
579        let hook = Arc::new(CountingHook::new());
580        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
581            .with_hook(hook.clone());
582
583        let create = CreateUser::new()
584            .with_email("test@example.com")
585            .with_name("Test");
586        let user = db.create_user(create).await.unwrap();
587
588        db.delete_user(&user.id).await.unwrap();
589
590        assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
591        assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
592    }
593
594    #[tokio::test]
595    async fn test_before_hook_can_reject() {
596        struct RejectHook;
597
598        #[async_trait]
599        impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
600            async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
601                Err(crate::error::AuthError::forbidden("Hook rejected"))
602            }
603        }
604
605        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
606            .with_hook(Arc::new(RejectHook));
607
608        let create = CreateUser::new()
609            .with_email("test@example.com")
610            .with_name("Test");
611        let result = db.create_user(create).await;
612
613        assert!(result.is_err());
614        assert_eq!(result.unwrap_err().status_code(), 403);
615    }
616
617    #[tokio::test]
618    async fn test_multiple_hooks() {
619        let hook1 = Arc::new(CountingHook::new());
620        let hook2 = Arc::new(CountingHook::new());
621        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
622            .with_hook(hook1.clone())
623            .with_hook(hook2.clone());
624
625        let create = CreateUser::new()
626            .with_email("test@example.com")
627            .with_name("Test");
628        db.create_user(create).await.unwrap();
629
630        assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
631        assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
632        assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
633        assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
634    }
635
636    #[tokio::test]
637    async fn test_passthrough_operations() {
638        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
639
640        let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
641        assert!(result.is_none());
642    }
643}