Skip to main content

better_auth_core/
hooks.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3
4use crate::adapters::DatabaseAdapter;
5use crate::error::AuthResult;
6use crate::types::{
7    Account, CreateAccount, CreateSession, CreateUser, CreateVerification, Session, UpdateUser,
8    User, Verification,
9};
10use chrono::{DateTime, Utc};
11
12/// Database lifecycle hooks for intercepting operations.
13///
14/// All methods have default no-op implementations. Override only the hooks
15/// you need. Returning `Err` from a `before_*` hook aborts the operation.
16#[async_trait]
17pub trait DatabaseHooks: Send + Sync {
18    // --- User hooks ---
19
20    /// Called before a user is created. Can modify the `CreateUser` or reject the operation.
21    async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
22        let _ = user;
23        Ok(())
24    }
25
26    /// Called after a user is created.
27    async fn after_create_user(&self, user: &User) -> AuthResult<()> {
28        let _ = user;
29        Ok(())
30    }
31
32    /// Called before a user is updated. Can modify the `UpdateUser` or reject the operation.
33    async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
34        let _ = (id, update);
35        Ok(())
36    }
37
38    /// Called after a user is updated.
39    async fn after_update_user(&self, user: &User) -> AuthResult<()> {
40        let _ = user;
41        Ok(())
42    }
43
44    /// Called before a user is deleted.
45    async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
46        let _ = id;
47        Ok(())
48    }
49
50    /// Called after a user is deleted.
51    async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52        let _ = id;
53        Ok(())
54    }
55
56    // --- Session hooks ---
57
58    /// Called before a session is created. Can modify the `CreateSession` or reject.
59    async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
60        let _ = session;
61        Ok(())
62    }
63
64    /// Called after a session is created.
65    async fn after_create_session(&self, session: &Session) -> AuthResult<()> {
66        let _ = session;
67        Ok(())
68    }
69
70    /// Called before a session is deleted.
71    async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
72        let _ = token;
73        Ok(())
74    }
75
76    /// Called after a session is deleted.
77    async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
78        let _ = token;
79        Ok(())
80    }
81}
82
83/// A database adapter wrapper that calls hooks around the inner adapter's operations.
84pub struct HookedDatabaseAdapter {
85    inner: Arc<dyn DatabaseAdapter>,
86    hooks: Vec<Arc<dyn DatabaseHooks>>,
87}
88
89impl HookedDatabaseAdapter {
90    pub fn new(inner: Arc<dyn DatabaseAdapter>) -> Self {
91        Self {
92            inner,
93            hooks: Vec::new(),
94        }
95    }
96
97    pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks>) -> Self {
98        self.hooks.push(hook);
99        self
100    }
101
102    pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks>) {
103        self.hooks.push(hook);
104    }
105}
106
107#[async_trait]
108impl DatabaseAdapter for HookedDatabaseAdapter {
109    // --- User operations ---
110
111    async fn create_user(&self, mut user: CreateUser) -> AuthResult<User> {
112        for hook in &self.hooks {
113            hook.before_create_user(&mut user).await?;
114        }
115        let result = self.inner.create_user(user).await?;
116        for hook in &self.hooks {
117            hook.after_create_user(&result).await?;
118        }
119        Ok(result)
120    }
121
122    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<User>> {
123        self.inner.get_user_by_id(id).await
124    }
125
126    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<User>> {
127        self.inner.get_user_by_email(email).await
128    }
129
130    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<User>> {
131        self.inner.get_user_by_username(username).await
132    }
133
134    async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<User> {
135        for hook in &self.hooks {
136            hook.before_update_user(id, &mut update).await?;
137        }
138        let result = self.inner.update_user(id, update).await?;
139        for hook in &self.hooks {
140            hook.after_update_user(&result).await?;
141        }
142        Ok(result)
143    }
144
145    async fn delete_user(&self, id: &str) -> AuthResult<()> {
146        for hook in &self.hooks {
147            hook.before_delete_user(id).await?;
148        }
149        self.inner.delete_user(id).await?;
150        for hook in &self.hooks {
151            hook.after_delete_user(id).await?;
152        }
153        Ok(())
154    }
155
156    // --- Session operations ---
157
158    async fn create_session(&self, mut session: CreateSession) -> AuthResult<Session> {
159        for hook in &self.hooks {
160            hook.before_create_session(&mut session).await?;
161        }
162        let result = self.inner.create_session(session).await?;
163        for hook in &self.hooks {
164            hook.after_create_session(&result).await?;
165        }
166        Ok(result)
167    }
168
169    async fn get_session(&self, token: &str) -> AuthResult<Option<Session>> {
170        self.inner.get_session(token).await
171    }
172
173    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Session>> {
174        self.inner.get_user_sessions(user_id).await
175    }
176
177    async fn update_session_expiry(
178        &self,
179        token: &str,
180        expires_at: DateTime<Utc>,
181    ) -> AuthResult<()> {
182        self.inner.update_session_expiry(token, expires_at).await
183    }
184
185    async fn delete_session(&self, token: &str) -> AuthResult<()> {
186        for hook in &self.hooks {
187            hook.before_delete_session(token).await?;
188        }
189        self.inner.delete_session(token).await?;
190        for hook in &self.hooks {
191            hook.after_delete_session(token).await?;
192        }
193        Ok(())
194    }
195
196    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
197        self.inner.delete_user_sessions(user_id).await
198    }
199
200    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
201        self.inner.delete_expired_sessions().await
202    }
203
204    // --- Account operations (pass-through, no hooks) ---
205
206    async fn create_account(&self, account: CreateAccount) -> AuthResult<Account> {
207        self.inner.create_account(account).await
208    }
209
210    async fn get_account(
211        &self,
212        provider: &str,
213        provider_account_id: &str,
214    ) -> AuthResult<Option<Account>> {
215        self.inner.get_account(provider, provider_account_id).await
216    }
217
218    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Account>> {
219        self.inner.get_user_accounts(user_id).await
220    }
221
222    async fn delete_account(&self, id: &str) -> AuthResult<()> {
223        self.inner.delete_account(id).await
224    }
225
226    // --- Verification operations (pass-through, no hooks) ---
227
228    async fn create_verification(
229        &self,
230        verification: CreateVerification,
231    ) -> AuthResult<Verification> {
232        self.inner.create_verification(verification).await
233    }
234
235    async fn get_verification(
236        &self,
237        identifier: &str,
238        value: &str,
239    ) -> AuthResult<Option<Verification>> {
240        self.inner.get_verification(identifier, value).await
241    }
242
243    async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<Verification>> {
244        self.inner.get_verification_by_value(value).await
245    }
246
247    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
248        self.inner.delete_verification(id).await
249    }
250
251    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
252        self.inner.delete_expired_verifications().await
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::adapters::MemoryDatabaseAdapter;
260    use std::sync::atomic::{AtomicU32, Ordering};
261
262    struct CountingHook {
263        before_create_count: AtomicU32,
264        after_create_count: AtomicU32,
265        before_update_count: AtomicU32,
266        after_update_count: AtomicU32,
267        before_delete_count: AtomicU32,
268        after_delete_count: AtomicU32,
269    }
270
271    impl CountingHook {
272        fn new() -> Self {
273            Self {
274                before_create_count: AtomicU32::new(0),
275                after_create_count: AtomicU32::new(0),
276                before_update_count: AtomicU32::new(0),
277                after_update_count: AtomicU32::new(0),
278                before_delete_count: AtomicU32::new(0),
279                after_delete_count: AtomicU32::new(0),
280            }
281        }
282    }
283
284    #[async_trait]
285    impl DatabaseHooks for CountingHook {
286        async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
287            self.before_create_count.fetch_add(1, Ordering::SeqCst);
288            Ok(())
289        }
290        async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
291            self.after_create_count.fetch_add(1, Ordering::SeqCst);
292            Ok(())
293        }
294        async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
295            self.before_update_count.fetch_add(1, Ordering::SeqCst);
296            Ok(())
297        }
298        async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
299            self.after_update_count.fetch_add(1, Ordering::SeqCst);
300            Ok(())
301        }
302        async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
303            self.before_delete_count.fetch_add(1, Ordering::SeqCst);
304            Ok(())
305        }
306        async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
307            self.after_delete_count.fetch_add(1, Ordering::SeqCst);
308            Ok(())
309        }
310    }
311
312    #[tokio::test]
313    async fn test_hooks_called_on_create_user() {
314        let hook = Arc::new(CountingHook::new());
315        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
316            .with_hook(hook.clone());
317
318        let create = CreateUser::new()
319            .with_email("test@example.com")
320            .with_name("Test");
321        db.create_user(create).await.unwrap();
322
323        assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
324        assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
325    }
326
327    #[tokio::test]
328    async fn test_hooks_called_on_update_user() {
329        let hook = Arc::new(CountingHook::new());
330        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
331            .with_hook(hook.clone());
332
333        let create = CreateUser::new()
334            .with_email("test@example.com")
335            .with_name("Test");
336        let user = db.create_user(create).await.unwrap();
337
338        let update = UpdateUser {
339            name: Some("Updated".to_string()),
340            email: None,
341            image: None,
342            email_verified: None,
343            username: None,
344            display_username: None,
345            role: None,
346            banned: None,
347            ban_reason: None,
348            ban_expires: None,
349            two_factor_enabled: None,
350            metadata: None,
351        };
352        db.update_user(&user.id, update).await.unwrap();
353
354        assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
355        assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
356    }
357
358    #[tokio::test]
359    async fn test_hooks_called_on_delete_user() {
360        let hook = Arc::new(CountingHook::new());
361        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
362            .with_hook(hook.clone());
363
364        let create = CreateUser::new()
365            .with_email("test@example.com")
366            .with_name("Test");
367        let user = db.create_user(create).await.unwrap();
368
369        db.delete_user(&user.id).await.unwrap();
370
371        assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
372        assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
373    }
374
375    #[tokio::test]
376    async fn test_before_hook_can_reject() {
377        struct RejectHook;
378
379        #[async_trait]
380        impl DatabaseHooks for RejectHook {
381            async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
382                Err(crate::error::AuthError::forbidden("Hook rejected"))
383            }
384        }
385
386        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
387            .with_hook(Arc::new(RejectHook));
388
389        let create = CreateUser::new()
390            .with_email("test@example.com")
391            .with_name("Test");
392        let result = db.create_user(create).await;
393
394        assert!(result.is_err());
395        assert_eq!(result.unwrap_err().status_code(), 403);
396    }
397
398    #[tokio::test]
399    async fn test_multiple_hooks() {
400        let hook1 = Arc::new(CountingHook::new());
401        let hook2 = Arc::new(CountingHook::new());
402        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
403            .with_hook(hook1.clone())
404            .with_hook(hook2.clone());
405
406        let create = CreateUser::new()
407            .with_email("test@example.com")
408            .with_name("Test");
409        db.create_user(create).await.unwrap();
410
411        assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
412        assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
413        assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
414        assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
415    }
416
417    #[tokio::test]
418    async fn test_passthrough_operations() {
419        let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
420
421        // get_user_by_email should work without hooks
422        let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
423        assert!(result.is_none());
424    }
425}