1use async_trait::async_trait;
2use std::sync::Arc;
3
4use crate::adapters::DatabaseAdapter;
5use crate::error::AuthResult;
6use crate::types::{
7 Account, CreateAccount, CreateSession, CreateUser, CreateVerification, Session, UpdateUser,
8 User, Verification,
9};
10use chrono::{DateTime, Utc};
11
12#[async_trait]
17pub trait DatabaseHooks: Send + Sync {
18 async fn before_create_user(&self, user: &mut CreateUser) -> AuthResult<()> {
22 let _ = user;
23 Ok(())
24 }
25
26 async fn after_create_user(&self, user: &User) -> AuthResult<()> {
28 let _ = user;
29 Ok(())
30 }
31
32 async fn before_update_user(&self, id: &str, update: &mut UpdateUser) -> AuthResult<()> {
34 let _ = (id, update);
35 Ok(())
36 }
37
38 async fn after_update_user(&self, user: &User) -> AuthResult<()> {
40 let _ = user;
41 Ok(())
42 }
43
44 async fn before_delete_user(&self, id: &str) -> AuthResult<()> {
46 let _ = id;
47 Ok(())
48 }
49
50 async fn after_delete_user(&self, id: &str) -> AuthResult<()> {
52 let _ = id;
53 Ok(())
54 }
55
56 async fn before_create_session(&self, session: &mut CreateSession) -> AuthResult<()> {
60 let _ = session;
61 Ok(())
62 }
63
64 async fn after_create_session(&self, session: &Session) -> AuthResult<()> {
66 let _ = session;
67 Ok(())
68 }
69
70 async fn before_delete_session(&self, token: &str) -> AuthResult<()> {
72 let _ = token;
73 Ok(())
74 }
75
76 async fn after_delete_session(&self, token: &str) -> AuthResult<()> {
78 let _ = token;
79 Ok(())
80 }
81}
82
83pub struct HookedDatabaseAdapter {
85 inner: Arc<dyn DatabaseAdapter>,
86 hooks: Vec<Arc<dyn DatabaseHooks>>,
87}
88
89impl HookedDatabaseAdapter {
90 pub fn new(inner: Arc<dyn DatabaseAdapter>) -> Self {
91 Self {
92 inner,
93 hooks: Vec::new(),
94 }
95 }
96
97 pub fn with_hook(mut self, hook: Arc<dyn DatabaseHooks>) -> Self {
98 self.hooks.push(hook);
99 self
100 }
101
102 pub fn add_hook(&mut self, hook: Arc<dyn DatabaseHooks>) {
103 self.hooks.push(hook);
104 }
105}
106
107#[async_trait]
108impl DatabaseAdapter for HookedDatabaseAdapter {
109 async fn create_user(&self, mut user: CreateUser) -> AuthResult<User> {
112 for hook in &self.hooks {
113 hook.before_create_user(&mut user).await?;
114 }
115 let result = self.inner.create_user(user).await?;
116 for hook in &self.hooks {
117 hook.after_create_user(&result).await?;
118 }
119 Ok(result)
120 }
121
122 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<User>> {
123 self.inner.get_user_by_id(id).await
124 }
125
126 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<User>> {
127 self.inner.get_user_by_email(email).await
128 }
129
130 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<User>> {
131 self.inner.get_user_by_username(username).await
132 }
133
134 async fn update_user(&self, id: &str, mut update: UpdateUser) -> AuthResult<User> {
135 for hook in &self.hooks {
136 hook.before_update_user(id, &mut update).await?;
137 }
138 let result = self.inner.update_user(id, update).await?;
139 for hook in &self.hooks {
140 hook.after_update_user(&result).await?;
141 }
142 Ok(result)
143 }
144
145 async fn delete_user(&self, id: &str) -> AuthResult<()> {
146 for hook in &self.hooks {
147 hook.before_delete_user(id).await?;
148 }
149 self.inner.delete_user(id).await?;
150 for hook in &self.hooks {
151 hook.after_delete_user(id).await?;
152 }
153 Ok(())
154 }
155
156 async fn create_session(&self, mut session: CreateSession) -> AuthResult<Session> {
159 for hook in &self.hooks {
160 hook.before_create_session(&mut session).await?;
161 }
162 let result = self.inner.create_session(session).await?;
163 for hook in &self.hooks {
164 hook.after_create_session(&result).await?;
165 }
166 Ok(result)
167 }
168
169 async fn get_session(&self, token: &str) -> AuthResult<Option<Session>> {
170 self.inner.get_session(token).await
171 }
172
173 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Session>> {
174 self.inner.get_user_sessions(user_id).await
175 }
176
177 async fn update_session_expiry(
178 &self,
179 token: &str,
180 expires_at: DateTime<Utc>,
181 ) -> AuthResult<()> {
182 self.inner.update_session_expiry(token, expires_at).await
183 }
184
185 async fn delete_session(&self, token: &str) -> AuthResult<()> {
186 for hook in &self.hooks {
187 hook.before_delete_session(token).await?;
188 }
189 self.inner.delete_session(token).await?;
190 for hook in &self.hooks {
191 hook.after_delete_session(token).await?;
192 }
193 Ok(())
194 }
195
196 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
197 self.inner.delete_user_sessions(user_id).await
198 }
199
200 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
201 self.inner.delete_expired_sessions().await
202 }
203
204 async fn create_account(&self, account: CreateAccount) -> AuthResult<Account> {
207 self.inner.create_account(account).await
208 }
209
210 async fn get_account(
211 &self,
212 provider: &str,
213 provider_account_id: &str,
214 ) -> AuthResult<Option<Account>> {
215 self.inner.get_account(provider, provider_account_id).await
216 }
217
218 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<Account>> {
219 self.inner.get_user_accounts(user_id).await
220 }
221
222 async fn delete_account(&self, id: &str) -> AuthResult<()> {
223 self.inner.delete_account(id).await
224 }
225
226 async fn create_verification(
229 &self,
230 verification: CreateVerification,
231 ) -> AuthResult<Verification> {
232 self.inner.create_verification(verification).await
233 }
234
235 async fn get_verification(
236 &self,
237 identifier: &str,
238 value: &str,
239 ) -> AuthResult<Option<Verification>> {
240 self.inner.get_verification(identifier, value).await
241 }
242
243 async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<Verification>> {
244 self.inner.get_verification_by_value(value).await
245 }
246
247 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
248 self.inner.delete_verification(id).await
249 }
250
251 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
252 self.inner.delete_expired_verifications().await
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::adapters::MemoryDatabaseAdapter;
260 use std::sync::atomic::{AtomicU32, Ordering};
261
262 struct CountingHook {
263 before_create_count: AtomicU32,
264 after_create_count: AtomicU32,
265 before_update_count: AtomicU32,
266 after_update_count: AtomicU32,
267 before_delete_count: AtomicU32,
268 after_delete_count: AtomicU32,
269 }
270
271 impl CountingHook {
272 fn new() -> Self {
273 Self {
274 before_create_count: AtomicU32::new(0),
275 after_create_count: AtomicU32::new(0),
276 before_update_count: AtomicU32::new(0),
277 after_update_count: AtomicU32::new(0),
278 before_delete_count: AtomicU32::new(0),
279 after_delete_count: AtomicU32::new(0),
280 }
281 }
282 }
283
284 #[async_trait]
285 impl DatabaseHooks for CountingHook {
286 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
287 self.before_create_count.fetch_add(1, Ordering::SeqCst);
288 Ok(())
289 }
290 async fn after_create_user(&self, _user: &User) -> AuthResult<()> {
291 self.after_create_count.fetch_add(1, Ordering::SeqCst);
292 Ok(())
293 }
294 async fn before_update_user(&self, _id: &str, _update: &mut UpdateUser) -> AuthResult<()> {
295 self.before_update_count.fetch_add(1, Ordering::SeqCst);
296 Ok(())
297 }
298 async fn after_update_user(&self, _user: &User) -> AuthResult<()> {
299 self.after_update_count.fetch_add(1, Ordering::SeqCst);
300 Ok(())
301 }
302 async fn before_delete_user(&self, _id: &str) -> AuthResult<()> {
303 self.before_delete_count.fetch_add(1, Ordering::SeqCst);
304 Ok(())
305 }
306 async fn after_delete_user(&self, _id: &str) -> AuthResult<()> {
307 self.after_delete_count.fetch_add(1, Ordering::SeqCst);
308 Ok(())
309 }
310 }
311
312 #[tokio::test]
313 async fn test_hooks_called_on_create_user() {
314 let hook = Arc::new(CountingHook::new());
315 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
316 .with_hook(hook.clone());
317
318 let create = CreateUser::new()
319 .with_email("test@example.com")
320 .with_name("Test");
321 db.create_user(create).await.unwrap();
322
323 assert_eq!(hook.before_create_count.load(Ordering::SeqCst), 1);
324 assert_eq!(hook.after_create_count.load(Ordering::SeqCst), 1);
325 }
326
327 #[tokio::test]
328 async fn test_hooks_called_on_update_user() {
329 let hook = Arc::new(CountingHook::new());
330 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
331 .with_hook(hook.clone());
332
333 let create = CreateUser::new()
334 .with_email("test@example.com")
335 .with_name("Test");
336 let user = db.create_user(create).await.unwrap();
337
338 let update = UpdateUser {
339 name: Some("Updated".to_string()),
340 email: None,
341 image: None,
342 email_verified: None,
343 username: None,
344 display_username: None,
345 role: None,
346 banned: None,
347 ban_reason: None,
348 ban_expires: None,
349 two_factor_enabled: None,
350 metadata: None,
351 };
352 db.update_user(&user.id, update).await.unwrap();
353
354 assert_eq!(hook.before_update_count.load(Ordering::SeqCst), 1);
355 assert_eq!(hook.after_update_count.load(Ordering::SeqCst), 1);
356 }
357
358 #[tokio::test]
359 async fn test_hooks_called_on_delete_user() {
360 let hook = Arc::new(CountingHook::new());
361 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
362 .with_hook(hook.clone());
363
364 let create = CreateUser::new()
365 .with_email("test@example.com")
366 .with_name("Test");
367 let user = db.create_user(create).await.unwrap();
368
369 db.delete_user(&user.id).await.unwrap();
370
371 assert_eq!(hook.before_delete_count.load(Ordering::SeqCst), 1);
372 assert_eq!(hook.after_delete_count.load(Ordering::SeqCst), 1);
373 }
374
375 #[tokio::test]
376 async fn test_before_hook_can_reject() {
377 struct RejectHook;
378
379 #[async_trait]
380 impl DatabaseHooks for RejectHook {
381 async fn before_create_user(&self, _user: &mut CreateUser) -> AuthResult<()> {
382 Err(crate::error::AuthError::forbidden("Hook rejected"))
383 }
384 }
385
386 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
387 .with_hook(Arc::new(RejectHook));
388
389 let create = CreateUser::new()
390 .with_email("test@example.com")
391 .with_name("Test");
392 let result = db.create_user(create).await;
393
394 assert!(result.is_err());
395 assert_eq!(result.unwrap_err().status_code(), 403);
396 }
397
398 #[tokio::test]
399 async fn test_multiple_hooks() {
400 let hook1 = Arc::new(CountingHook::new());
401 let hook2 = Arc::new(CountingHook::new());
402 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()))
403 .with_hook(hook1.clone())
404 .with_hook(hook2.clone());
405
406 let create = CreateUser::new()
407 .with_email("test@example.com")
408 .with_name("Test");
409 db.create_user(create).await.unwrap();
410
411 assert_eq!(hook1.before_create_count.load(Ordering::SeqCst), 1);
412 assert_eq!(hook2.before_create_count.load(Ordering::SeqCst), 1);
413 assert_eq!(hook1.after_create_count.load(Ordering::SeqCst), 1);
414 assert_eq!(hook2.after_create_count.load(Ordering::SeqCst), 1);
415 }
416
417 #[tokio::test]
418 async fn test_passthrough_operations() {
419 let db = HookedDatabaseAdapter::new(Arc::new(MemoryDatabaseAdapter::new()));
420
421 let result = db.get_user_by_email("nonexistent@test.com").await.unwrap();
423 assert!(result.is_none());
424 }
425}