1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use uuid::Uuid;
6
7use crate::error::{AuthError, AuthResult};
8use crate::types::{
9 Account, CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
10 CreateUser, CreateVerification, Invitation, InvitationStatus, Member, Organization, Session,
11 UpdateOrganization, UpdateUser, User, Verification,
12};
13
14pub use super::memory_traits::{
15 MemoryAccount, MemoryInvitation, MemoryMember, MemoryOrganization, MemorySession, MemoryUser,
16 MemoryVerification,
17};
18
19use super::traits::{
20 AccountOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, UserOps, VerificationOps,
21};
22
23pub struct MemoryDatabaseAdapter<
38 U = User,
39 S = Session,
40 A = Account,
41 O = Organization,
42 M = Member,
43 I = Invitation,
44 V = Verification,
45> {
46 users: Arc<Mutex<HashMap<String, U>>>,
47 sessions: Arc<Mutex<HashMap<String, S>>>,
48 accounts: Arc<Mutex<HashMap<String, A>>>,
49 verifications: Arc<Mutex<HashMap<String, V>>>,
50 email_index: Arc<Mutex<HashMap<String, String>>>,
51 username_index: Arc<Mutex<HashMap<String, String>>>,
52 organizations: Arc<Mutex<HashMap<String, O>>>,
53 members: Arc<Mutex<HashMap<String, M>>>,
54 invitations: Arc<Mutex<HashMap<String, I>>>,
55 slug_index: Arc<Mutex<HashMap<String, String>>>,
56}
57
58impl MemoryDatabaseAdapter {
61 pub fn new() -> Self {
62 Self::default()
63 }
64}
65
66impl<U, S, A, O, M, I, V> Default for MemoryDatabaseAdapter<U, S, A, O, M, I, V> {
67 fn default() -> Self {
68 Self {
69 users: Arc::new(Mutex::new(HashMap::new())),
70 sessions: Arc::new(Mutex::new(HashMap::new())),
71 accounts: Arc::new(Mutex::new(HashMap::new())),
72 verifications: Arc::new(Mutex::new(HashMap::new())),
73 email_index: Arc::new(Mutex::new(HashMap::new())),
74 username_index: Arc::new(Mutex::new(HashMap::new())),
75 organizations: Arc::new(Mutex::new(HashMap::new())),
76 members: Arc::new(Mutex::new(HashMap::new())),
77 invitations: Arc::new(Mutex::new(HashMap::new())),
78 slug_index: Arc::new(Mutex::new(HashMap::new())),
79 }
80 }
81}
82
83#[async_trait]
86impl<U, S, A, O, M, I, V> UserOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
87where
88 U: MemoryUser,
89 S: MemorySession,
90 A: MemoryAccount,
91 O: MemoryOrganization,
92 M: MemoryMember,
93 I: MemoryInvitation,
94 V: MemoryVerification,
95{
96 type User = U;
97
98 async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
99 let mut users = self.users.lock().unwrap();
100 let mut email_index = self.email_index.lock().unwrap();
101 let mut username_index = self.username_index.lock().unwrap();
102
103 let id = create_user
104 .id
105 .clone()
106 .unwrap_or_else(|| Uuid::new_v4().to_string());
107
108 if let Some(email) = &create_user.email
109 && email_index.contains_key(email)
110 {
111 return Err(AuthError::config("Email already exists"));
112 }
113
114 if let Some(username) = &create_user.username
115 && username_index.contains_key(username)
116 {
117 return Err(AuthError::conflict(
118 "A user with this username already exists",
119 ));
120 }
121
122 let now = Utc::now();
123 let user = U::from_create(id.clone(), &create_user, now);
124
125 users.insert(id.clone(), user.clone());
126
127 if let Some(email) = &create_user.email {
128 email_index.insert(email.clone(), id.clone());
129 }
130 if let Some(username) = &create_user.username {
131 username_index.insert(username.clone(), id);
132 }
133
134 Ok(user)
135 }
136
137 async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
138 let users = self.users.lock().unwrap();
139 Ok(users.get(id).cloned())
140 }
141
142 async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
143 let email_index = self.email_index.lock().unwrap();
144 let users = self.users.lock().unwrap();
145
146 if let Some(user_id) = email_index.get(email) {
147 Ok(users.get(user_id).cloned())
148 } else {
149 Ok(None)
150 }
151 }
152
153 async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
154 let username_index = self.username_index.lock().unwrap();
155 let users = self.users.lock().unwrap();
156
157 if let Some(user_id) = username_index.get(username) {
158 Ok(users.get(user_id).cloned())
159 } else {
160 Ok(None)
161 }
162 }
163
164 async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
165 let mut users = self.users.lock().unwrap();
166 let mut email_index = self.email_index.lock().unwrap();
167 let mut username_index = self.username_index.lock().unwrap();
168
169 let user = users.get_mut(id).ok_or(AuthError::UserNotFound)?;
170
171 if let Some(new_email) = &update.email {
173 if let Some(old_email) = user.email() {
174 email_index.remove(old_email);
175 }
176 email_index.insert(new_email.clone(), id.to_string());
177 }
178
179 if let Some(ref new_username) = update.username {
180 if let Some(old_username) = user.username() {
181 username_index.remove(old_username);
182 }
183 username_index.insert(new_username.clone(), id.to_string());
184 }
185
186 user.apply_update(&update);
187 Ok(user.clone())
188 }
189
190 async fn delete_user(&self, id: &str) -> AuthResult<()> {
191 let mut users = self.users.lock().unwrap();
192 let mut email_index = self.email_index.lock().unwrap();
193 let mut username_index = self.username_index.lock().unwrap();
194
195 if let Some(user) = users.remove(id) {
196 if let Some(email) = user.email() {
197 email_index.remove(email);
198 }
199 if let Some(username) = user.username() {
200 username_index.remove(username);
201 }
202 }
203
204 Ok(())
205 }
206}
207
208#[async_trait]
211impl<U, S, A, O, M, I, V> SessionOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
212where
213 U: MemoryUser,
214 S: MemorySession,
215 A: MemoryAccount,
216 O: MemoryOrganization,
217 M: MemoryMember,
218 I: MemoryInvitation,
219 V: MemoryVerification,
220{
221 type Session = S;
222
223 async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
224 let mut sessions = self.sessions.lock().unwrap();
225
226 let id = Uuid::new_v4().to_string();
227 let token = format!("session_{}", Uuid::new_v4());
228 let now = Utc::now();
229 let session = S::from_create(id, token.clone(), &create_session, now);
230
231 sessions.insert(token, session.clone());
232 Ok(session)
233 }
234
235 async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
236 let sessions = self.sessions.lock().unwrap();
237 Ok(sessions.get(token).cloned())
238 }
239
240 async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
241 let sessions = self.sessions.lock().unwrap();
242 Ok(sessions
243 .values()
244 .filter(|s| s.user_id() == user_id && s.active())
245 .cloned()
246 .collect())
247 }
248
249 async fn update_session_expiry(
250 &self,
251 token: &str,
252 expires_at: DateTime<Utc>,
253 ) -> AuthResult<()> {
254 let mut sessions = self.sessions.lock().unwrap();
255 if let Some(session) = sessions.get_mut(token) {
256 session.set_expires_at(expires_at);
257 }
258 Ok(())
259 }
260
261 async fn delete_session(&self, token: &str) -> AuthResult<()> {
262 let mut sessions = self.sessions.lock().unwrap();
263 sessions.remove(token);
264 Ok(())
265 }
266
267 async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
268 let mut sessions = self.sessions.lock().unwrap();
269 sessions.retain(|_, s| s.user_id() != user_id);
270 Ok(())
271 }
272
273 async fn delete_expired_sessions(&self) -> AuthResult<usize> {
274 let mut sessions = self.sessions.lock().unwrap();
275 let now = Utc::now();
276 let initial_count = sessions.len();
277 sessions.retain(|_, s| s.expires_at() > now && s.active());
278 Ok(initial_count - sessions.len())
279 }
280
281 async fn update_session_active_organization(
282 &self,
283 token: &str,
284 organization_id: Option<&str>,
285 ) -> AuthResult<S> {
286 let mut sessions = self.sessions.lock().unwrap();
287 let session = sessions.get_mut(token).ok_or(AuthError::SessionNotFound)?;
288 session.set_active_organization_id(organization_id.map(|s| s.to_string()));
289 session.set_updated_at(Utc::now());
290 Ok(session.clone())
291 }
292}
293
294#[async_trait]
297impl<U, S, A, O, M, I, V> AccountOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
298where
299 U: MemoryUser,
300 S: MemorySession,
301 A: MemoryAccount,
302 O: MemoryOrganization,
303 M: MemoryMember,
304 I: MemoryInvitation,
305 V: MemoryVerification,
306{
307 type Account = A;
308
309 async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
310 let mut accounts = self.accounts.lock().unwrap();
311
312 let id = Uuid::new_v4().to_string();
313 let now = Utc::now();
314 let account = A::from_create(id.clone(), &create_account, now);
315
316 accounts.insert(id, account.clone());
317 Ok(account)
318 }
319
320 async fn get_account(
321 &self,
322 provider: &str,
323 provider_account_id: &str,
324 ) -> AuthResult<Option<A>> {
325 let accounts = self.accounts.lock().unwrap();
326 Ok(accounts
327 .values()
328 .find(|acc| acc.provider_id() == provider && acc.account_id() == provider_account_id)
329 .cloned())
330 }
331
332 async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
333 let accounts = self.accounts.lock().unwrap();
334 Ok(accounts
335 .values()
336 .filter(|acc| acc.user_id() == user_id)
337 .cloned()
338 .collect())
339 }
340
341 async fn delete_account(&self, id: &str) -> AuthResult<()> {
342 let mut accounts = self.accounts.lock().unwrap();
343 accounts.remove(id);
344 Ok(())
345 }
346}
347
348#[async_trait]
351impl<U, S, A, O, M, I, V> VerificationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
352where
353 U: MemoryUser,
354 S: MemorySession,
355 A: MemoryAccount,
356 O: MemoryOrganization,
357 M: MemoryMember,
358 I: MemoryInvitation,
359 V: MemoryVerification,
360{
361 type Verification = V;
362
363 async fn create_verification(&self, create_verification: CreateVerification) -> AuthResult<V> {
364 let mut verifications = self.verifications.lock().unwrap();
365
366 let id = Uuid::new_v4().to_string();
367 let now = Utc::now();
368 let verification = V::from_create(id.clone(), &create_verification, now);
369
370 verifications.insert(id, verification.clone());
371 Ok(verification)
372 }
373
374 async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
375 let verifications = self.verifications.lock().unwrap();
376 let now = Utc::now();
377 Ok(verifications
378 .values()
379 .find(|v| v.identifier() == identifier && v.value() == value && v.expires_at() > now)
380 .cloned())
381 }
382
383 async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
384 let verifications = self.verifications.lock().unwrap();
385 let now = Utc::now();
386 Ok(verifications
387 .values()
388 .find(|v| v.value() == value && v.expires_at() > now)
389 .cloned())
390 }
391
392 async fn delete_verification(&self, id: &str) -> AuthResult<()> {
393 let mut verifications = self.verifications.lock().unwrap();
394 verifications.remove(id);
395 Ok(())
396 }
397
398 async fn delete_expired_verifications(&self) -> AuthResult<usize> {
399 let mut verifications = self.verifications.lock().unwrap();
400 let now = Utc::now();
401 let initial_count = verifications.len();
402 verifications.retain(|_, v| v.expires_at() > now);
403 Ok(initial_count - verifications.len())
404 }
405}
406
407#[async_trait]
410impl<U, S, A, O, M, I, V> OrganizationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
411where
412 U: MemoryUser,
413 S: MemorySession,
414 A: MemoryAccount,
415 O: MemoryOrganization,
416 M: MemoryMember,
417 I: MemoryInvitation,
418 V: MemoryVerification,
419{
420 type Organization = O;
421
422 async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
423 let mut organizations = self.organizations.lock().unwrap();
424 let mut slug_index = self.slug_index.lock().unwrap();
425
426 if slug_index.contains_key(&create_org.slug) {
427 return Err(AuthError::conflict("Organization slug already exists"));
428 }
429
430 let id = create_org
431 .id
432 .clone()
433 .unwrap_or_else(|| Uuid::new_v4().to_string());
434 let now = Utc::now();
435 let organization = O::from_create(id.clone(), &create_org, now);
436
437 organizations.insert(id.clone(), organization.clone());
438 slug_index.insert(create_org.slug.clone(), id);
439
440 Ok(organization)
441 }
442
443 async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
444 let organizations = self.organizations.lock().unwrap();
445 Ok(organizations.get(id).cloned())
446 }
447
448 async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
449 let slug_index = self.slug_index.lock().unwrap();
450 let organizations = self.organizations.lock().unwrap();
451
452 if let Some(org_id) = slug_index.get(slug) {
453 Ok(organizations.get(org_id).cloned())
454 } else {
455 Ok(None)
456 }
457 }
458
459 async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
460 let mut organizations = self.organizations.lock().unwrap();
461 let mut slug_index = self.slug_index.lock().unwrap();
462
463 let org = organizations
464 .get_mut(id)
465 .ok_or_else(|| AuthError::not_found("Organization not found"))?;
466
467 if let Some(new_slug) = &update.slug {
469 let current_slug = org.slug().to_string();
470 if *new_slug != current_slug {
471 if slug_index.contains_key(new_slug.as_str()) {
472 return Err(AuthError::conflict("Organization slug already exists"));
473 }
474 slug_index.remove(¤t_slug);
475 slug_index.insert(new_slug.clone(), id.to_string());
476 }
477 }
478
479 org.apply_update(&update);
480 Ok(org.clone())
481 }
482
483 async fn delete_organization(&self, id: &str) -> AuthResult<()> {
484 let mut organizations = self.organizations.lock().unwrap();
485 let mut slug_index = self.slug_index.lock().unwrap();
486 let mut members = self.members.lock().unwrap();
487 let mut invitations = self.invitations.lock().unwrap();
488
489 if let Some(org) = organizations.remove(id) {
490 slug_index.remove(org.slug());
491 }
492
493 members.retain(|_, m| m.organization_id() != id);
494 invitations.retain(|_, i| i.organization_id() != id);
495
496 Ok(())
497 }
498
499 async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
500 let members = self.members.lock().unwrap();
501 let organizations = self.organizations.lock().unwrap();
502
503 let org_ids: Vec<String> = members
504 .values()
505 .filter(|m| m.user_id() == user_id)
506 .map(|m| m.organization_id().to_string())
507 .collect();
508
509 let orgs = org_ids
510 .iter()
511 .filter_map(|id| organizations.get(id).cloned())
512 .collect();
513
514 Ok(orgs)
515 }
516}
517
518#[async_trait]
521impl<U, S, A, O, M, I, V> MemberOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
522where
523 U: MemoryUser,
524 S: MemorySession,
525 A: MemoryAccount,
526 O: MemoryOrganization,
527 M: MemoryMember,
528 I: MemoryInvitation,
529 V: MemoryVerification,
530{
531 type Member = M;
532
533 async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
534 let mut members = self.members.lock().unwrap();
535
536 let exists = members.values().any(|m| {
537 m.organization_id() == create_member.organization_id
538 && m.user_id() == create_member.user_id
539 });
540
541 if exists {
542 return Err(AuthError::conflict(
543 "User is already a member of this organization",
544 ));
545 }
546
547 let id = Uuid::new_v4().to_string();
548 let now = Utc::now();
549 let member = M::from_create(id.clone(), &create_member, now);
550
551 members.insert(id, member.clone());
552 Ok(member)
553 }
554
555 async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
556 let members = self.members.lock().unwrap();
557 Ok(members
558 .values()
559 .find(|m| m.organization_id() == organization_id && m.user_id() == user_id)
560 .cloned())
561 }
562
563 async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
564 let members = self.members.lock().unwrap();
565 Ok(members.get(id).cloned())
566 }
567
568 async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
569 let mut members = self.members.lock().unwrap();
570 let member = members
571 .get_mut(member_id)
572 .ok_or_else(|| AuthError::not_found("Member not found"))?;
573 member.set_role(role.to_string());
574 Ok(member.clone())
575 }
576
577 async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
578 let mut members = self.members.lock().unwrap();
579 members.remove(member_id);
580 Ok(())
581 }
582
583 async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
584 let members = self.members.lock().unwrap();
585 Ok(members
586 .values()
587 .filter(|m| m.organization_id() == organization_id)
588 .cloned()
589 .collect())
590 }
591
592 async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
593 let members = self.members.lock().unwrap();
594 Ok(members
595 .values()
596 .filter(|m| m.organization_id() == organization_id)
597 .count())
598 }
599
600 async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
601 let members = self.members.lock().unwrap();
602 Ok(members
603 .values()
604 .filter(|m| m.organization_id() == organization_id && m.role() == "owner")
605 .count())
606 }
607}
608
609#[async_trait]
612impl<U, S, A, O, M, I, V> InvitationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
613where
614 U: MemoryUser,
615 S: MemorySession,
616 A: MemoryAccount,
617 O: MemoryOrganization,
618 M: MemoryMember,
619 I: MemoryInvitation,
620 V: MemoryVerification,
621{
622 type Invitation = I;
623
624 async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
625 let mut invitations = self.invitations.lock().unwrap();
626
627 let id = Uuid::new_v4().to_string();
628 let now = Utc::now();
629 let invitation = I::from_create(id.clone(), &create_inv, now);
630
631 invitations.insert(id, invitation.clone());
632 Ok(invitation)
633 }
634
635 async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
636 let invitations = self.invitations.lock().unwrap();
637 Ok(invitations.get(id).cloned())
638 }
639
640 async fn get_pending_invitation(
641 &self,
642 organization_id: &str,
643 email: &str,
644 ) -> AuthResult<Option<I>> {
645 let invitations = self.invitations.lock().unwrap();
646 Ok(invitations
647 .values()
648 .find(|i| {
649 i.organization_id() == organization_id
650 && i.email().to_lowercase() == email.to_lowercase()
651 && *i.status() == InvitationStatus::Pending
652 })
653 .cloned())
654 }
655
656 async fn update_invitation_status(&self, id: &str, status: InvitationStatus) -> AuthResult<I> {
657 let mut invitations = self.invitations.lock().unwrap();
658 let invitation = invitations
659 .get_mut(id)
660 .ok_or_else(|| AuthError::not_found("Invitation not found"))?;
661 invitation.set_status(status);
662 Ok(invitation.clone())
663 }
664
665 async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
666 let invitations = self.invitations.lock().unwrap();
667 Ok(invitations
668 .values()
669 .filter(|i| i.organization_id() == organization_id)
670 .cloned()
671 .collect())
672 }
673
674 async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
675 let invitations = self.invitations.lock().unwrap();
676 let now = Utc::now();
677 Ok(invitations
678 .values()
679 .filter(|i| {
680 i.email().to_lowercase() == email.to_lowercase()
681 && *i.status() == InvitationStatus::Pending
682 && i.expires_at() > now
683 })
684 .cloned()
685 .collect())
686 }
687}