Skip to main content

better_auth_core/adapters/
memory.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use uuid::Uuid;
6
7use crate::error::{AuthError, AuthResult};
8use crate::types::{
9    Account, CreateAccount, CreateInvitation, CreateMember, CreateOrganization, CreateSession,
10    CreateUser, CreateVerification, Invitation, InvitationStatus, Member, Organization, Session,
11    UpdateOrganization, UpdateUser, User, Verification,
12};
13
14pub use super::memory_traits::{
15    MemoryAccount, MemoryInvitation, MemoryMember, MemoryOrganization, MemorySession, MemoryUser,
16    MemoryVerification,
17};
18
19use super::traits::{
20    AccountOps, InvitationOps, MemberOps, OrganizationOps, SessionOps, UserOps, VerificationOps,
21};
22
23/// In-memory database adapter for testing and development.
24///
25/// Generic over entity types — use default type parameters for the built-in
26/// types, or supply your own custom structs that implement the `Memory*`
27/// traits.
28///
29/// ```rust,ignore
30/// // Using built-in types (no turbofish needed):
31/// let adapter = MemoryDatabaseAdapter::new();
32///
33/// // Using custom types:
34/// let adapter = MemoryDatabaseAdapter::<MyUser, MySession, MyAccount,
35///     MyOrg, MyMember, MyInvitation, MyVerification>::new();
36/// ```
37pub struct MemoryDatabaseAdapter<
38    U = User,
39    S = Session,
40    A = Account,
41    O = Organization,
42    M = Member,
43    I = Invitation,
44    V = Verification,
45> {
46    users: Arc<Mutex<HashMap<String, U>>>,
47    sessions: Arc<Mutex<HashMap<String, S>>>,
48    accounts: Arc<Mutex<HashMap<String, A>>>,
49    verifications: Arc<Mutex<HashMap<String, V>>>,
50    email_index: Arc<Mutex<HashMap<String, String>>>,
51    username_index: Arc<Mutex<HashMap<String, String>>>,
52    organizations: Arc<Mutex<HashMap<String, O>>>,
53    members: Arc<Mutex<HashMap<String, M>>>,
54    invitations: Arc<Mutex<HashMap<String, I>>>,
55    slug_index: Arc<Mutex<HashMap<String, String>>>,
56}
57
58/// Constructor for the default (built-in) entity types.
59/// Use `Default::default()` for custom type parameterizations.
60impl MemoryDatabaseAdapter {
61    pub fn new() -> Self {
62        Self::default()
63    }
64}
65
66impl<U, S, A, O, M, I, V> Default for MemoryDatabaseAdapter<U, S, A, O, M, I, V> {
67    fn default() -> Self {
68        Self {
69            users: Arc::new(Mutex::new(HashMap::new())),
70            sessions: Arc::new(Mutex::new(HashMap::new())),
71            accounts: Arc::new(Mutex::new(HashMap::new())),
72            verifications: Arc::new(Mutex::new(HashMap::new())),
73            email_index: Arc::new(Mutex::new(HashMap::new())),
74            username_index: Arc::new(Mutex::new(HashMap::new())),
75            organizations: Arc::new(Mutex::new(HashMap::new())),
76            members: Arc::new(Mutex::new(HashMap::new())),
77            invitations: Arc::new(Mutex::new(HashMap::new())),
78            slug_index: Arc::new(Mutex::new(HashMap::new())),
79        }
80    }
81}
82
83// -- UserOps --
84
85#[async_trait]
86impl<U, S, A, O, M, I, V> UserOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
87where
88    U: MemoryUser,
89    S: MemorySession,
90    A: MemoryAccount,
91    O: MemoryOrganization,
92    M: MemoryMember,
93    I: MemoryInvitation,
94    V: MemoryVerification,
95{
96    type User = U;
97
98    async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
99        let mut users = self.users.lock().unwrap();
100        let mut email_index = self.email_index.lock().unwrap();
101        let mut username_index = self.username_index.lock().unwrap();
102
103        let id = create_user
104            .id
105            .clone()
106            .unwrap_or_else(|| Uuid::new_v4().to_string());
107
108        if let Some(email) = &create_user.email
109            && email_index.contains_key(email)
110        {
111            return Err(AuthError::config("Email already exists"));
112        }
113
114        if let Some(username) = &create_user.username
115            && username_index.contains_key(username)
116        {
117            return Err(AuthError::conflict(
118                "A user with this username already exists",
119            ));
120        }
121
122        let now = Utc::now();
123        let user = U::from_create(id.clone(), &create_user, now);
124
125        users.insert(id.clone(), user.clone());
126
127        if let Some(email) = &create_user.email {
128            email_index.insert(email.clone(), id.clone());
129        }
130        if let Some(username) = &create_user.username {
131            username_index.insert(username.clone(), id);
132        }
133
134        Ok(user)
135    }
136
137    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
138        let users = self.users.lock().unwrap();
139        Ok(users.get(id).cloned())
140    }
141
142    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
143        let email_index = self.email_index.lock().unwrap();
144        let users = self.users.lock().unwrap();
145
146        if let Some(user_id) = email_index.get(email) {
147            Ok(users.get(user_id).cloned())
148        } else {
149            Ok(None)
150        }
151    }
152
153    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
154        let username_index = self.username_index.lock().unwrap();
155        let users = self.users.lock().unwrap();
156
157        if let Some(user_id) = username_index.get(username) {
158            Ok(users.get(user_id).cloned())
159        } else {
160            Ok(None)
161        }
162    }
163
164    async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
165        let mut users = self.users.lock().unwrap();
166        let mut email_index = self.email_index.lock().unwrap();
167        let mut username_index = self.username_index.lock().unwrap();
168
169        let user = users.get_mut(id).ok_or(AuthError::UserNotFound)?;
170
171        // Update indices BEFORE mutation (read old values via trait getters)
172        if let Some(new_email) = &update.email {
173            if let Some(old_email) = user.email() {
174                email_index.remove(old_email);
175            }
176            email_index.insert(new_email.clone(), id.to_string());
177        }
178
179        if let Some(ref new_username) = update.username {
180            if let Some(old_username) = user.username() {
181                username_index.remove(old_username);
182            }
183            username_index.insert(new_username.clone(), id.to_string());
184        }
185
186        user.apply_update(&update);
187        Ok(user.clone())
188    }
189
190    async fn delete_user(&self, id: &str) -> AuthResult<()> {
191        let mut users = self.users.lock().unwrap();
192        let mut email_index = self.email_index.lock().unwrap();
193        let mut username_index = self.username_index.lock().unwrap();
194
195        if let Some(user) = users.remove(id) {
196            if let Some(email) = user.email() {
197                email_index.remove(email);
198            }
199            if let Some(username) = user.username() {
200                username_index.remove(username);
201            }
202        }
203
204        Ok(())
205    }
206}
207
208// -- SessionOps --
209
210#[async_trait]
211impl<U, S, A, O, M, I, V> SessionOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
212where
213    U: MemoryUser,
214    S: MemorySession,
215    A: MemoryAccount,
216    O: MemoryOrganization,
217    M: MemoryMember,
218    I: MemoryInvitation,
219    V: MemoryVerification,
220{
221    type Session = S;
222
223    async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
224        let mut sessions = self.sessions.lock().unwrap();
225
226        let id = Uuid::new_v4().to_string();
227        let token = format!("session_{}", Uuid::new_v4());
228        let now = Utc::now();
229        let session = S::from_create(id, token.clone(), &create_session, now);
230
231        sessions.insert(token, session.clone());
232        Ok(session)
233    }
234
235    async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
236        let sessions = self.sessions.lock().unwrap();
237        Ok(sessions.get(token).cloned())
238    }
239
240    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
241        let sessions = self.sessions.lock().unwrap();
242        Ok(sessions
243            .values()
244            .filter(|s| s.user_id() == user_id && s.active())
245            .cloned()
246            .collect())
247    }
248
249    async fn update_session_expiry(
250        &self,
251        token: &str,
252        expires_at: DateTime<Utc>,
253    ) -> AuthResult<()> {
254        let mut sessions = self.sessions.lock().unwrap();
255        if let Some(session) = sessions.get_mut(token) {
256            session.set_expires_at(expires_at);
257        }
258        Ok(())
259    }
260
261    async fn delete_session(&self, token: &str) -> AuthResult<()> {
262        let mut sessions = self.sessions.lock().unwrap();
263        sessions.remove(token);
264        Ok(())
265    }
266
267    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
268        let mut sessions = self.sessions.lock().unwrap();
269        sessions.retain(|_, s| s.user_id() != user_id);
270        Ok(())
271    }
272
273    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
274        let mut sessions = self.sessions.lock().unwrap();
275        let now = Utc::now();
276        let initial_count = sessions.len();
277        sessions.retain(|_, s| s.expires_at() > now && s.active());
278        Ok(initial_count - sessions.len())
279    }
280
281    async fn update_session_active_organization(
282        &self,
283        token: &str,
284        organization_id: Option<&str>,
285    ) -> AuthResult<S> {
286        let mut sessions = self.sessions.lock().unwrap();
287        let session = sessions.get_mut(token).ok_or(AuthError::SessionNotFound)?;
288        session.set_active_organization_id(organization_id.map(|s| s.to_string()));
289        session.set_updated_at(Utc::now());
290        Ok(session.clone())
291    }
292}
293
294// -- AccountOps --
295
296#[async_trait]
297impl<U, S, A, O, M, I, V> AccountOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
298where
299    U: MemoryUser,
300    S: MemorySession,
301    A: MemoryAccount,
302    O: MemoryOrganization,
303    M: MemoryMember,
304    I: MemoryInvitation,
305    V: MemoryVerification,
306{
307    type Account = A;
308
309    async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
310        let mut accounts = self.accounts.lock().unwrap();
311
312        let id = Uuid::new_v4().to_string();
313        let now = Utc::now();
314        let account = A::from_create(id.clone(), &create_account, now);
315
316        accounts.insert(id, account.clone());
317        Ok(account)
318    }
319
320    async fn get_account(
321        &self,
322        provider: &str,
323        provider_account_id: &str,
324    ) -> AuthResult<Option<A>> {
325        let accounts = self.accounts.lock().unwrap();
326        Ok(accounts
327            .values()
328            .find(|acc| acc.provider_id() == provider && acc.account_id() == provider_account_id)
329            .cloned())
330    }
331
332    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
333        let accounts = self.accounts.lock().unwrap();
334        Ok(accounts
335            .values()
336            .filter(|acc| acc.user_id() == user_id)
337            .cloned()
338            .collect())
339    }
340
341    async fn delete_account(&self, id: &str) -> AuthResult<()> {
342        let mut accounts = self.accounts.lock().unwrap();
343        accounts.remove(id);
344        Ok(())
345    }
346}
347
348// -- VerificationOps --
349
350#[async_trait]
351impl<U, S, A, O, M, I, V> VerificationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
352where
353    U: MemoryUser,
354    S: MemorySession,
355    A: MemoryAccount,
356    O: MemoryOrganization,
357    M: MemoryMember,
358    I: MemoryInvitation,
359    V: MemoryVerification,
360{
361    type Verification = V;
362
363    async fn create_verification(&self, create_verification: CreateVerification) -> AuthResult<V> {
364        let mut verifications = self.verifications.lock().unwrap();
365
366        let id = Uuid::new_v4().to_string();
367        let now = Utc::now();
368        let verification = V::from_create(id.clone(), &create_verification, now);
369
370        verifications.insert(id, verification.clone());
371        Ok(verification)
372    }
373
374    async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
375        let verifications = self.verifications.lock().unwrap();
376        let now = Utc::now();
377        Ok(verifications
378            .values()
379            .find(|v| v.identifier() == identifier && v.value() == value && v.expires_at() > now)
380            .cloned())
381    }
382
383    async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
384        let verifications = self.verifications.lock().unwrap();
385        let now = Utc::now();
386        Ok(verifications
387            .values()
388            .find(|v| v.value() == value && v.expires_at() > now)
389            .cloned())
390    }
391
392    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
393        let mut verifications = self.verifications.lock().unwrap();
394        verifications.remove(id);
395        Ok(())
396    }
397
398    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
399        let mut verifications = self.verifications.lock().unwrap();
400        let now = Utc::now();
401        let initial_count = verifications.len();
402        verifications.retain(|_, v| v.expires_at() > now);
403        Ok(initial_count - verifications.len())
404    }
405}
406
407// -- OrganizationOps --
408
409#[async_trait]
410impl<U, S, A, O, M, I, V> OrganizationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
411where
412    U: MemoryUser,
413    S: MemorySession,
414    A: MemoryAccount,
415    O: MemoryOrganization,
416    M: MemoryMember,
417    I: MemoryInvitation,
418    V: MemoryVerification,
419{
420    type Organization = O;
421
422    async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
423        let mut organizations = self.organizations.lock().unwrap();
424        let mut slug_index = self.slug_index.lock().unwrap();
425
426        if slug_index.contains_key(&create_org.slug) {
427            return Err(AuthError::conflict("Organization slug already exists"));
428        }
429
430        let id = create_org
431            .id
432            .clone()
433            .unwrap_or_else(|| Uuid::new_v4().to_string());
434        let now = Utc::now();
435        let organization = O::from_create(id.clone(), &create_org, now);
436
437        organizations.insert(id.clone(), organization.clone());
438        slug_index.insert(create_org.slug.clone(), id);
439
440        Ok(organization)
441    }
442
443    async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
444        let organizations = self.organizations.lock().unwrap();
445        Ok(organizations.get(id).cloned())
446    }
447
448    async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
449        let slug_index = self.slug_index.lock().unwrap();
450        let organizations = self.organizations.lock().unwrap();
451
452        if let Some(org_id) = slug_index.get(slug) {
453            Ok(organizations.get(org_id).cloned())
454        } else {
455            Ok(None)
456        }
457    }
458
459    async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
460        let mut organizations = self.organizations.lock().unwrap();
461        let mut slug_index = self.slug_index.lock().unwrap();
462
463        let org = organizations
464            .get_mut(id)
465            .ok_or_else(|| AuthError::not_found("Organization not found"))?;
466
467        // Update slug index BEFORE mutation
468        if let Some(new_slug) = &update.slug {
469            let current_slug = org.slug().to_string();
470            if *new_slug != current_slug {
471                if slug_index.contains_key(new_slug.as_str()) {
472                    return Err(AuthError::conflict("Organization slug already exists"));
473                }
474                slug_index.remove(&current_slug);
475                slug_index.insert(new_slug.clone(), id.to_string());
476            }
477        }
478
479        org.apply_update(&update);
480        Ok(org.clone())
481    }
482
483    async fn delete_organization(&self, id: &str) -> AuthResult<()> {
484        let mut organizations = self.organizations.lock().unwrap();
485        let mut slug_index = self.slug_index.lock().unwrap();
486        let mut members = self.members.lock().unwrap();
487        let mut invitations = self.invitations.lock().unwrap();
488
489        if let Some(org) = organizations.remove(id) {
490            slug_index.remove(org.slug());
491        }
492
493        members.retain(|_, m| m.organization_id() != id);
494        invitations.retain(|_, i| i.organization_id() != id);
495
496        Ok(())
497    }
498
499    async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
500        let members = self.members.lock().unwrap();
501        let organizations = self.organizations.lock().unwrap();
502
503        let org_ids: Vec<String> = members
504            .values()
505            .filter(|m| m.user_id() == user_id)
506            .map(|m| m.organization_id().to_string())
507            .collect();
508
509        let orgs = org_ids
510            .iter()
511            .filter_map(|id| organizations.get(id).cloned())
512            .collect();
513
514        Ok(orgs)
515    }
516}
517
518// -- MemberOps --
519
520#[async_trait]
521impl<U, S, A, O, M, I, V> MemberOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
522where
523    U: MemoryUser,
524    S: MemorySession,
525    A: MemoryAccount,
526    O: MemoryOrganization,
527    M: MemoryMember,
528    I: MemoryInvitation,
529    V: MemoryVerification,
530{
531    type Member = M;
532
533    async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
534        let mut members = self.members.lock().unwrap();
535
536        let exists = members.values().any(|m| {
537            m.organization_id() == create_member.organization_id
538                && m.user_id() == create_member.user_id
539        });
540
541        if exists {
542            return Err(AuthError::conflict(
543                "User is already a member of this organization",
544            ));
545        }
546
547        let id = Uuid::new_v4().to_string();
548        let now = Utc::now();
549        let member = M::from_create(id.clone(), &create_member, now);
550
551        members.insert(id, member.clone());
552        Ok(member)
553    }
554
555    async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
556        let members = self.members.lock().unwrap();
557        Ok(members
558            .values()
559            .find(|m| m.organization_id() == organization_id && m.user_id() == user_id)
560            .cloned())
561    }
562
563    async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
564        let members = self.members.lock().unwrap();
565        Ok(members.get(id).cloned())
566    }
567
568    async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
569        let mut members = self.members.lock().unwrap();
570        let member = members
571            .get_mut(member_id)
572            .ok_or_else(|| AuthError::not_found("Member not found"))?;
573        member.set_role(role.to_string());
574        Ok(member.clone())
575    }
576
577    async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
578        let mut members = self.members.lock().unwrap();
579        members.remove(member_id);
580        Ok(())
581    }
582
583    async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
584        let members = self.members.lock().unwrap();
585        Ok(members
586            .values()
587            .filter(|m| m.organization_id() == organization_id)
588            .cloned()
589            .collect())
590    }
591
592    async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
593        let members = self.members.lock().unwrap();
594        Ok(members
595            .values()
596            .filter(|m| m.organization_id() == organization_id)
597            .count())
598    }
599
600    async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
601        let members = self.members.lock().unwrap();
602        Ok(members
603            .values()
604            .filter(|m| m.organization_id() == organization_id && m.role() == "owner")
605            .count())
606    }
607}
608
609// -- InvitationOps --
610
611#[async_trait]
612impl<U, S, A, O, M, I, V> InvitationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V>
613where
614    U: MemoryUser,
615    S: MemorySession,
616    A: MemoryAccount,
617    O: MemoryOrganization,
618    M: MemoryMember,
619    I: MemoryInvitation,
620    V: MemoryVerification,
621{
622    type Invitation = I;
623
624    async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
625        let mut invitations = self.invitations.lock().unwrap();
626
627        let id = Uuid::new_v4().to_string();
628        let now = Utc::now();
629        let invitation = I::from_create(id.clone(), &create_inv, now);
630
631        invitations.insert(id, invitation.clone());
632        Ok(invitation)
633    }
634
635    async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
636        let invitations = self.invitations.lock().unwrap();
637        Ok(invitations.get(id).cloned())
638    }
639
640    async fn get_pending_invitation(
641        &self,
642        organization_id: &str,
643        email: &str,
644    ) -> AuthResult<Option<I>> {
645        let invitations = self.invitations.lock().unwrap();
646        Ok(invitations
647            .values()
648            .find(|i| {
649                i.organization_id() == organization_id
650                    && i.email().to_lowercase() == email.to_lowercase()
651                    && *i.status() == InvitationStatus::Pending
652            })
653            .cloned())
654    }
655
656    async fn update_invitation_status(&self, id: &str, status: InvitationStatus) -> AuthResult<I> {
657        let mut invitations = self.invitations.lock().unwrap();
658        let invitation = invitations
659            .get_mut(id)
660            .ok_or_else(|| AuthError::not_found("Invitation not found"))?;
661        invitation.set_status(status);
662        Ok(invitation.clone())
663    }
664
665    async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
666        let invitations = self.invitations.lock().unwrap();
667        Ok(invitations
668            .values()
669            .filter(|i| i.organization_id() == organization_id)
670            .cloned()
671            .collect())
672    }
673
674    async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
675        let invitations = self.invitations.lock().unwrap();
676        let now = Utc::now();
677        Ok(invitations
678            .values()
679            .filter(|i| {
680                i.email().to_lowercase() == email.to_lowercase()
681                    && *i.status() == InvitationStatus::Pending
682                    && i.expires_at() > now
683            })
684            .cloned()
685            .collect())
686    }
687}