Skip to main content

better_auth_core/
hooks.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::error::AuthResult;
7use crate::types::{
8    CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession, CreateUser,
9    CreateVerification, InvitationStatus, UpdateOrganization, UpdateUser,
10};
11
12/// Database lifecycle hooks for intercepting operations.
13///
14/// All methods have default no-op implementations. Override only the hooks
15/// you need. Returning `Err` from a `before_*` hook aborts the operation.
16///
17/// The `DB` type parameter determines the concrete entity types used in
18/// `after_*` hooks (e.g., `after_create_user` receives `&DB::User`).
19#[async_trait]
20pub trait DatabaseHooks<DB: DatabaseAdapter>: Send + Sync {
21    // --- User hooks ---
22
23    async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
24        let _ = user;
25        Ok(())
26    }
27
28    async fn after_create_user(&self, user: &DB::User) -> AuthResult<()> {
29        let _ = user;
30        Ok(())
31    }
32
33    async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
34        let _ = (id, update);
35        Ok(())
36    }
37
38    async fn after_update_user(&self, user: &DB::User) -> AuthResult<()> {
39        let _ = user;
40        Ok(())
41    }
42
43    async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
44        let _ = id;
45        Ok(())
46    }
47
48    async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
49        let _ = id;
50        Ok(())
51    }
52
53    // --- Session hooks ---
54
55    async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
56        let _ = session;
57        Ok(())
58    }
59
60    async fn after_create_session(&self, session: &DB::Session) -> AuthResult<()> {
61        let _ = session;
62        Ok(())
63    }
64
65    async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
66        let _ = token;
67        Ok(())
68    }
69
70    async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
71        let _ = token;
72        Ok(())
73    }
74}
75
76/// A database adapter wrapper that calls hooks around the inner adapter's operations.
77pub struct HookedDatabaseAdapter<DB: DatabaseAdapter> {
78    inner: Arc<DB>,
79    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
80}
81
82impl<DB: DatabaseAdapter> HookedDatabaseAdapter<DB> {
83    pub fn new(inner: Arc<DB>) -> Self {
84        Self {
85            inner,
86            hooks: Vec::new(),
87        }
88    }
89
90    pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks<DB>>) -> Self {
91        self.hooks.push(hook);
92        self
93    }
94
95    pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks<DB>>) {
96        self.hooks.push(hook);
97    }
98}
99
100#[async_trait]
101impl<DB: DatabaseAdapter> DatabaseAdapter for HookedDatabaseAdapter<DB> {
102    type User = DB::User;
103    type Session = DB::Session;
104    type Account = DB::Account;
105    type Organization = DB::Organization;
106    type Member = DB::Member;
107    type Invitation = DB::Invitation;
108    type Verification = DB::Verification;
109
110    // --- User operations ---
111
112    async fn create_user(&self, mut user: CreateUser) -> AuthResult<Self::User> {
113        for hook in &self.hooks {
114            hook.before_create_user(&mut user).await?;
115        }
116        let result = self.inner.create_user(user).await?;
117        for hook in &self.hooks {
118            hook.after_create_user(&result).await?;
119        }
120        Ok(result)
121    }
122
123    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<Self::User>> {
124        self.inner.get_user_by_id(id).await
125    }
126
127    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<Self::User>> {
128        self.inner.get_user_by_email(email).await
129    }
130
131    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<Self::User>> {
132        self.inner.get_user_by_username(username).await
133    }
134
135    async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<Self::User> {
136        for hook in &self.hooks {
137            hook.before_update_user(id, &mut update).await?;
138        }
139        let result = self.inner.update_user(id, update).await?;
140        for hook in &self.hooks {
141            hook.after_update_user(&result).await?;
142        }
143        Ok(result)
144    }
145
146    async fn delete_user(&self, id: &str) -> AuthResult<()> {
147        for hook in &self.hooks {
148            hook.before_delete_user(id).await?;
149        }
150        self.inner.delete_user(id).await?;
151        for hook in &self.hooks {
152            hook.after_delete_user(id).await?;
153        }
154        Ok(())
155    }
156
157    // --- Session operations ---
158
159    async fn create_session(&self, mut session: CreateSession) -> AuthResult<Self::Session> {
160        for hook in &self.hooks {
161            hook.before_create_session(&mut session).await?;
162        }
163        let result = self.inner.create_session(session).await?;
164        for hook in &self.hooks {
165            hook.after_create_session(&result).await?;
166        }
167        Ok(result)
168    }
169
170    async fn get_session(&self, token: &str) -> AuthResult<Option<Self::Session>> {
171        self.inner.get_session(token).await
172    }
173
174    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Self::Session>> {
175        self.inner.get_user_sessions(user_id).await
176    }
177
178    async fn update_session_expiry(
179        &self,
180        token: &str,
181        expires_at: DateTime<Utc>,
182    ) -> AuthResult<()> {
183        self.inner.update_session_expiry(token, expires_at).await
184    }
185
186    async fn delete_session(&self, token: &str) -> AuthResult<()> {
187        for hook in &self.hooks {
188            hook.before_delete_session(token).await?;
189        }
190        self.inner.delete_session(token).await?;
191        for hook in &self.hooks {
192            hook.after_delete_session(token).await?;
193        }
194        Ok(())
195    }
196
197    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
198        self.inner.delete_user_sessions(user_id).await
199    }
200
201    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
202        self.inner.delete_expired_sessions().await
203    }
204
205    // --- Account operations (pass-through) ---
206
207    async fn create_account(&self, account: CreateAccount) -> AuthResult<Self::Account> {
208        self.inner.create_account(account).await
209    }
210
211    async fn get_account(
212        &self,
213        provider: &str,
214        provider_account_id: &str,
215    ) -> AuthResult<Option<Self::Account>> {
216        self.inner.get_account(provider, provider_account_id).await
217    }
218
219    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Self::Account>> {
220        self.inner.get_user_accounts(user_id).await
221    }
222
223    async fn delete_account(&self, id: &str) -> AuthResult<()> {
224        self.inner.delete_account(id).await
225    }
226
227    // --- Verification operations (pass-through) ---
228
229    async fn create_verification(
230        &self,
231        verification: CreateVerification,
232    ) -> AuthResult<Self::Verification> {
233        self.inner.create_verification(verification).await
234    }
235
236    async fn get_verification(
237        &self,
238        identifier: &str,
239        value: &str,
240    ) -> AuthResult<Option<Self::Verification>> {
241        self.inner.get_verification(identifier, value).await
242    }
243
244    async fn get_verification_by_value(
245        &self,
246        value: &str,
247    ) -> AuthResult<Option<Self::Verification>> {
248        self.inner.get_verification_by_value(value).await
249    }
250
251    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
252        self.inner.delete_verification(id).await
253    }
254
255    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
256        self.inner.delete_expired_verifications().await
257    }
258
259    // --- Organization operations (pass-through) ---
260
261    async fn create_organization(&self, org: CreateOrganization) -> AuthResult<Self::Organization> {
262        self.inner.create_organization(org).await
263    }
264
265    async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<Self::Organization>> {
266        self.inner.get_organization_by_id(id).await
267    }
268
269    async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<Self::Organization>> {
270        self.inner.get_organization_by_slug(slug).await
271    }
272
273    async fn update_organization(
274        &self,
275        id: &str,
276        update: UpdateOrganization,
277    ) -> AuthResult<Self::Organization> {
278        self.inner.update_organization(id, update).await
279    }
280
281    async fn delete_organization(&self, id: &str) -> AuthResult<()> {
282        self.inner.delete_organization(id).await
283    }
284
285    async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<Self::Organization>> {
286        self.inner.list_user_organizations(user_id).await
287    }
288
289    // --- Member operations (pass-through) ---
290
291    async fn create_member(&self, member: CreateMember) -> AuthResult<Self::Member> {
292        self.inner.create_member(member).await
293    }
294
295    async fn get_member(
296        &self,
297        organization_id: &str,
298        user_id: &str,
299    ) -> AuthResult<Option<Self::Member>> {
300        self.inner.get_member(organization_id, user_id).await
301    }
302
303    async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<Self::Member>> {
304        self.inner.get_member_by_id(id).await
305    }
306
307    async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<Self::Member> {
308        self.inner.update_member_role(member_id, role).await
309    }
310
311    async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
312        self.inner.delete_member(member_id).await
313    }
314
315    async fn list_organization_members(
316        &self,
317        organization_id: &str,
318    ) -> AuthResult<Vec<Self::Member>> {
319        self.inner.list_organization_members(organization_id).await
320    }
321
322    async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
323        self.inner.count_organization_members(organization_id).await
324    }
325
326    async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
327        self.inner.count_organization_owners(organization_id).await
328    }
329
330    // --- Invitation operations (pass-through) ---
331
332    async fn create_invitation(
333        &self,
334        invitation: CreateInvitation,
335    ) -> AuthResult<Self::Invitation> {
336        self.inner.create_invitation(invitation).await
337    }
338
339    async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<Self::Invitation>> {
340        self.inner.get_invitation_by_id(id).await
341    }
342
343    async fn get_pending_invitation(
344        &self,
345        organization_id: &str,
346        email: &str,
347    ) -> AuthResult<Option<Self::Invitation>> {
348        self.inner
349            .get_pending_invitation(organization_id, email)
350            .await
351    }
352
353    async fn update_invitation_status(
354        &self,
355        id: &str,
356        status: InvitationStatus,
357    ) -> AuthResult<Self::Invitation> {
358        self.inner.update_invitation_status(id, status).await
359    }
360
361    async fn list_organization_invitations(
362        &self,
363        organization_id: &str,
364    ) -> AuthResult<Vec<Self::Invitation>> {
365        self.inner
366            .list_organization_invitations(organization_id)
367            .await
368    }
369
370    async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<Self::Invitation>> {
371        self.inner.list_user_invitations(email).await
372    }
373
374    // --- Session organization support ---
375
376    async fn update_session_active_organization(
377        &self,
378        token: &str,
379        organization_id: Option<&str>,
380    ) -> AuthResult<Self::Session> {
381        self.inner
382            .update_session_active_organization(token, organization_id)
383            .await
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::adapters::MemoryDatabaseAdapter;
391    use crate::types::{CreateUser, UpdateUser, User};
392    use std::sync::atomic::{AtomicU32, Ordering};
393
394    struct CountingHook {
395        before_create_count: AtomicU32,
396        after_create_count: AtomicU32,
397        before_update_count: AtomicU32,
398        after_update_count: AtomicU32,
399        before_delete_count: AtomicU32,
400        after_delete_count: AtomicU32,
401    }
402
403    impl CountingHook {
404        fn new() -> Self {
405            Self {
406                before_create_count: AtomicU32::new(0),
407                after_create_count: AtomicU32::new(0),
408                before_update_count: AtomicU32::new(0),
409                after_update_count: AtomicU32::new(0),
410                before_delete_count: AtomicU32::new(0),
411                after_delete_count: AtomicU32::new(0),
412            }
413        }
414    }
415
416    #[async_trait]
417    impl DatabaseHooks<MemoryDatabaseAdapter> for CountingHook {
418        async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
419            self.before_create_count.fetch_add(1, Ordering::SeqCst);
420            Ok(())
421        }
422        async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
423            self.after_create_count.fetch_add(1, Ordering::SeqCst);
424            Ok(())
425        }
426        async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
427            self.before_update_count.fetch_add(1, Ordering::SeqCst);
428            Ok(())
429        }
430        async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
431            self.after_update_count.fetch_add(1, Ordering::SeqCst);
432            Ok(())
433        }
434        async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
435            self.before_delete_count.fetch_add(1, Ordering::SeqCst);
436            Ok(())
437        }
438        async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
439            self.after_delete_count.fetch_add(1, Ordering::SeqCst);
440            Ok(())
441        }
442    }
443
444    #[tokio::test]
445    async fn test_hooks_called_on_create_user() {
446        let hook = Arc::new(CountingHook::new());
447        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
448            .with_hook(hook.clone());
449
450        let create = CreateUser::new()
451            .with_email("test@example.com")
452            .with_name("Test");
453        db.create_user(create).await.unwrap();
454
455        assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
456        assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
457    }
458
459    #[tokio::test]
460    async fn test_hooks_called_on_update_user() {
461        let hook = Arc::new(CountingHook::new());
462        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
463            .with_hook(hook.clone());
464
465        let create = CreateUser::new()
466            .with_email("test@example.com")
467            .with_name("Test");
468        let user = db.create_user(create).await.unwrap();
469
470        let update = UpdateUser {
471            name: Some("Updated".to_string()),
472            email: None,
473            image: None,
474            email_verified: None,
475            username: None,
476            display_username: None,
477            role: None,
478            banned: None,
479            ban_reason: None,
480            ban_expires: None,
481            two_factor_enabled: None,
482            metadata: None,
483        };
484        db.update_user(&user.id, update).await.unwrap();
485
486        assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
487        assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
488    }
489
490    #[tokio::test]
491    async fn test_hooks_called_on_delete_user() {
492        let hook = Arc::new(CountingHook::new());
493        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
494            .with_hook(hook.clone());
495
496        let create = CreateUser::new()
497            .with_email("test@example.com")
498            .with_name("Test");
499        let user = db.create_user(create).await.unwrap();
500
501        db.delete_user(&user.id).await.unwrap();
502
503        assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
504        assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
505    }
506
507    #[tokio::test]
508    async fn test_before_hook_can_reject() {
509        struct RejectHook;
510
511        #[async_trait]
512        impl DatabaseHooks<MemoryDatabaseAdapter> for RejectHook {
513            async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
514                Err(crate::error::AuthError::forbidden("Hook rejected"))
515            }
516        }
517
518        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
519            .with_hook(Arc::new(RejectHook));
520
521        let create = CreateUser::new()
522            .with_email("test@example.com")
523            .with_name("Test");
524        let result = db.create_user(create).await;
525
526        assert!(result.is_err());
527        assert_eq!(result.unwrap_err().status_code(), 403);
528    }
529
530    #[tokio::test]
531    async fn test_multiple_hooks() {
532        let hook1 = Arc::new(CountingHook::new());
533        let hook2 = Arc::new(CountingHook::new());
534        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
535            .with_hook(hook1.clone())
536            .with_hook(hook2.clone());
537
538        let create = CreateUser::new()
539            .with_email("test@example.com")
540            .with_name("Test");
541        db.create_user(create).await.unwrap();
542
543        assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
544        assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
545        assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
546        assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
547    }
548
549    #[tokio::test]
550    async fn test_passthrough_operations() {
551        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
552
553        let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
554        assert!(result.is_none());
555    }
556}