Skip to main content

better_auth_core/adapters/
memory.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use uuid::Uuid;
6
7use crate::error::{AuthError, AuthResult};
8use crate::types::{
9    Account, ApiKey, CreateAccount, CreateApiKey, CreateInvitation, CreateMember,
10    CreateOrganization, CreatePasskey, CreateSession, CreateTwoFactor, CreateUser,
11    CreateVerification, Invitation, InvitationStatus, ListUsersParams, Member, Organization,
12    Passkey, Session, TwoFactor, UpdateAccount, UpdateApiKey, UpdateOrganization, UpdateUser, User,
13    Verification,
14};
15
16pub use super::memory_traits::{
17    MemoryAccount, MemoryApiKey, MemoryInvitation, MemoryMember, MemoryOrganization, MemoryPasskey,
18    MemorySession, MemoryTwoFactor, MemoryUser, MemoryVerification,
19};
20
21use super::traits::{
22    AccountOps, ApiKeyOps, InvitationOps, MemberOps, OrganizationOps, PasskeyOps, SessionOps,
23    TwoFactorOps, UserOps, VerificationOps,
24};
25
26/// In-memory database adapter for testing and development.
27///
28/// Generic over entity types — use default type parameters for the built-in
29/// types, or supply your own custom structs that implement the `Memory*`
30/// traits.
31///
32/// ```rust,ignore
33/// // Using built-in types (no turbofish needed):
34/// let adapter = MemoryDatabaseAdapter::new();
35///
36/// // Using custom types:
37/// let adapter = MemoryDatabaseAdapter::<MyUser, MySession, MyAccount,
38///     MyOrg, MyMember, MyInvitation, MyVerification>::new();
39/// ```
40pub struct MemoryDatabaseAdapter<
41    U = User,
42    S = Session,
43    A = Account,
44    O = Organization,
45    M = Member,
46    I = Invitation,
47    V = Verification,
48    P = Passkey,
49> {
50    users: Arc<Mutex<HashMap<String, U>>>,
51    sessions: Arc<Mutex<HashMap<String, S>>>,
52    accounts: Arc<Mutex<HashMap<String, A>>>,
53    verifications: Arc<Mutex<HashMap<String, V>>>,
54    email_index: Arc<Mutex<HashMap<String, String>>>,
55    username_index: Arc<Mutex<HashMap<String, String>>>,
56    organizations: Arc<Mutex<HashMap<String, O>>>,
57    members: Arc<Mutex<HashMap<String, M>>>,
58    invitations: Arc<Mutex<HashMap<String, I>>>,
59    slug_index: Arc<Mutex<HashMap<String, String>>>,
60    two_factors: Arc<Mutex<HashMap<String, TwoFactor>>>,
61    api_keys: Arc<Mutex<HashMap<String, ApiKey>>>,
62    passkeys: Arc<Mutex<HashMap<String, P>>>,
63    passkey_credential_index: Arc<Mutex<HashMap<String, String>>>,
64}
65
66/// Constructor for the default (built-in) entity types.
67/// Use `Default::default()` for custom type parameterizations.
68impl MemoryDatabaseAdapter {
69    pub fn new() -> Self {
70        Self::default()
71    }
72}
73
74impl<U, S, A, O, M, I, V, P> Default for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P> {
75    fn default() -> Self {
76        Self {
77            users: Arc::new(Mutex::new(HashMap::new())),
78            sessions: Arc::new(Mutex::new(HashMap::new())),
79            accounts: Arc::new(Mutex::new(HashMap::new())),
80            verifications: Arc::new(Mutex::new(HashMap::new())),
81            email_index: Arc::new(Mutex::new(HashMap::new())),
82            username_index: Arc::new(Mutex::new(HashMap::new())),
83            organizations: Arc::new(Mutex::new(HashMap::new())),
84            members: Arc::new(Mutex::new(HashMap::new())),
85            invitations: Arc::new(Mutex::new(HashMap::new())),
86            slug_index: Arc::new(Mutex::new(HashMap::new())),
87            two_factors: Arc::new(Mutex::new(HashMap::new())),
88            api_keys: Arc::new(Mutex::new(HashMap::new())),
89            passkeys: Arc::new(Mutex::new(HashMap::new())),
90            passkey_credential_index: Arc::new(Mutex::new(HashMap::new())),
91        }
92    }
93}
94
95// -- UserOps --
96
97#[async_trait]
98impl<U, S, A, O, M, I, V, P> UserOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
99where
100    U: MemoryUser,
101    S: MemorySession,
102    A: MemoryAccount,
103    O: MemoryOrganization,
104    M: MemoryMember,
105    I: MemoryInvitation,
106    V: MemoryVerification,
107    P: MemoryPasskey,
108{
109    type User = U;
110
111    async fn create_user(&self, create_user: CreateUser) -> AuthResult<U> {
112        let mut users = self.users.lock().unwrap();
113        let mut email_index = self.email_index.lock().unwrap();
114        let mut username_index = self.username_index.lock().unwrap();
115
116        let id = create_user
117            .id
118            .clone()
119            .unwrap_or_else(|| Uuid::new_v4().to_string());
120
121        if let Some(email) = &create_user.email
122            && email_index.contains_key(email)
123        {
124            return Err(AuthError::config("Email already exists"));
125        }
126
127        if let Some(username) = &create_user.username
128            && username_index.contains_key(username)
129        {
130            return Err(AuthError::conflict(
131                "A user with this username already exists",
132            ));
133        }
134
135        let now = Utc::now();
136        let user = U::from_create(id.clone(), &create_user, now);
137
138        users.insert(id.clone(), user.clone());
139
140        if let Some(email) = &create_user.email {
141            email_index.insert(email.clone(), id.clone());
142        }
143        if let Some(username) = &create_user.username {
144            username_index.insert(username.clone(), id);
145        }
146
147        Ok(user)
148    }
149
150    async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<U>> {
151        let users = self.users.lock().unwrap();
152        Ok(users.get(id).cloned())
153    }
154
155    async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<U>> {
156        let email_index = self.email_index.lock().unwrap();
157        let users = self.users.lock().unwrap();
158
159        if let Some(user_id) = email_index.get(email) {
160            Ok(users.get(user_id).cloned())
161        } else {
162            Ok(None)
163        }
164    }
165
166    async fn get_user_by_username(&self, username: &str) -> AuthResult<Option<U>> {
167        let username_index = self.username_index.lock().unwrap();
168        let users = self.users.lock().unwrap();
169
170        if let Some(user_id) = username_index.get(username) {
171            Ok(users.get(user_id).cloned())
172        } else {
173            Ok(None)
174        }
175    }
176
177    async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<U> {
178        let mut users = self.users.lock().unwrap();
179        let mut email_index = self.email_index.lock().unwrap();
180        let mut username_index = self.username_index.lock().unwrap();
181
182        let user = users.get_mut(id).ok_or(AuthError::UserNotFound)?;
183
184        // Update indices BEFORE mutation (read old values via trait getters)
185        if let Some(new_email) = &update.email {
186            if let Some(old_email) = user.email() {
187                email_index.remove(old_email);
188            }
189            email_index.insert(new_email.clone(), id.to_string());
190        }
191
192        if let Some(ref new_username) = update.username {
193            if let Some(old_username) = user.username() {
194                username_index.remove(old_username);
195            }
196            username_index.insert(new_username.clone(), id.to_string());
197        }
198
199        user.apply_update(&update);
200        Ok(user.clone())
201    }
202
203    async fn delete_user(&self, id: &str) -> AuthResult<()> {
204        let mut users = self.users.lock().unwrap();
205        let mut email_index = self.email_index.lock().unwrap();
206        let mut username_index = self.username_index.lock().unwrap();
207
208        if let Some(user) = users.remove(id) {
209            if let Some(email) = user.email() {
210                email_index.remove(email);
211            }
212            if let Some(username) = user.username() {
213                username_index.remove(username);
214            }
215        }
216
217        Ok(())
218    }
219
220    async fn list_users(&self, params: ListUsersParams) -> AuthResult<(Vec<U>, usize)> {
221        let users = self.users.lock().unwrap();
222        let mut result: Vec<U> = users.values().cloned().collect();
223
224        // Apply search filter
225        if let Some(search_value) = &params.search_value {
226            let field = params.search_field.as_deref().unwrap_or("email");
227            let op = params.search_operator.as_deref().unwrap_or("contains");
228            let sv = search_value.to_lowercase();
229            result.retain(|u| {
230                let field_val = match field {
231                    "name" => u.name().unwrap_or("").to_lowercase(),
232                    _ => u.email().unwrap_or("").to_lowercase(),
233                };
234                match op {
235                    "starts_with" => field_val.starts_with(&sv),
236                    "ends_with" => field_val.ends_with(&sv),
237                    _ => field_val.contains(&sv),
238                }
239            });
240        }
241
242        // Apply filter
243        if let Some(filter_value) = &params.filter_value {
244            let field = params.filter_field.as_deref().unwrap_or("email");
245            let op = params.filter_operator.as_deref().unwrap_or("eq");
246            let fv = filter_value.to_lowercase();
247            result.retain(|u| {
248                let field_val = match field {
249                    "name" => u.name().unwrap_or("").to_lowercase(),
250                    "role" => u.role().unwrap_or("").to_lowercase(),
251                    _ => u.email().unwrap_or("").to_lowercase(),
252                };
253                match op {
254                    "contains" => field_val.contains(&fv),
255                    "starts_with" => field_val.starts_with(&fv),
256                    "ends_with" => field_val.ends_with(&fv),
257                    "ne" => field_val != fv,
258                    _ => field_val == fv,
259                }
260            });
261        }
262
263        // Apply sort
264        if let Some(sort_by) = &params.sort_by {
265            let desc = params.sort_direction.as_deref() == Some("desc");
266            result.sort_by(|a, b| {
267                let av = match sort_by.as_str() {
268                    "name" => a.name().unwrap_or("").to_string(),
269                    "createdAt" => a.created_at().to_rfc3339(),
270                    _ => a.email().unwrap_or("").to_string(),
271                };
272                let bv = match sort_by.as_str() {
273                    "name" => b.name().unwrap_or("").to_string(),
274                    "createdAt" => b.created_at().to_rfc3339(),
275                    _ => b.email().unwrap_or("").to_string(),
276                };
277                if desc { bv.cmp(&av) } else { av.cmp(&bv) }
278            });
279        }
280
281        let total = result.len();
282        let offset = params.offset.unwrap_or(0);
283        let limit = params.limit.unwrap_or(100);
284        let paged: Vec<U> = result.into_iter().skip(offset).take(limit).collect();
285
286        Ok((paged, total))
287    }
288}
289
290// -- SessionOps --
291
292#[async_trait]
293impl<U, S, A, O, M, I, V, P> SessionOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
294where
295    U: MemoryUser,
296    S: MemorySession,
297    A: MemoryAccount,
298    O: MemoryOrganization,
299    M: MemoryMember,
300    I: MemoryInvitation,
301    V: MemoryVerification,
302    P: MemoryPasskey,
303{
304    type Session = S;
305
306    async fn create_session(&self, create_session: CreateSession) -> AuthResult<S> {
307        let mut sessions = self.sessions.lock().unwrap();
308
309        let id = Uuid::new_v4().to_string();
310        let token = format!("session_{}", Uuid::new_v4());
311        let now = Utc::now();
312        let session = S::from_create(id, token.clone(), &create_session, now);
313
314        sessions.insert(token, session.clone());
315        Ok(session)
316    }
317
318    async fn get_session(&self, token: &str) -> AuthResult<Option<S>> {
319        let sessions = self.sessions.lock().unwrap();
320        Ok(sessions.get(token).cloned())
321    }
322
323    async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<S>> {
324        let sessions = self.sessions.lock().unwrap();
325        Ok(sessions
326            .values()
327            .filter(|s| s.user_id() == user_id && s.active())
328            .cloned()
329            .collect())
330    }
331
332    async fn update_session_expiry(
333        &self,
334        token: &str,
335        expires_at: DateTime<Utc>,
336    ) -> AuthResult<()> {
337        let mut sessions = self.sessions.lock().unwrap();
338        if let Some(session) = sessions.get_mut(token) {
339            session.set_expires_at(expires_at);
340            session.set_updated_at(Utc::now());
341        }
342        Ok(())
343    }
344
345    async fn delete_session(&self, token: &str) -> AuthResult<()> {
346        let mut sessions = self.sessions.lock().unwrap();
347        sessions.remove(token);
348        Ok(())
349    }
350
351    async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
352        let mut sessions = self.sessions.lock().unwrap();
353        sessions.retain(|_, s| s.user_id() != user_id);
354        Ok(())
355    }
356
357    async fn delete_expired_sessions(&self) -> AuthResult<usize> {
358        let mut sessions = self.sessions.lock().unwrap();
359        let now = Utc::now();
360        let initial_count = sessions.len();
361        sessions.retain(|_, s| s.expires_at() > now && s.active());
362        Ok(initial_count - sessions.len())
363    }
364
365    async fn update_session_active_organization(
366        &self,
367        token: &str,
368        organization_id: Option<&str>,
369    ) -> AuthResult<S> {
370        let mut sessions = self.sessions.lock().unwrap();
371        let session = sessions.get_mut(token).ok_or(AuthError::SessionNotFound)?;
372        session.set_active_organization_id(organization_id.map(|s| s.to_string()));
373        session.set_updated_at(Utc::now());
374        Ok(session.clone())
375    }
376}
377
378// -- AccountOps --
379
380#[async_trait]
381impl<U, S, A, O, M, I, V, P> AccountOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
382where
383    U: MemoryUser,
384    S: MemorySession,
385    A: MemoryAccount,
386    O: MemoryOrganization,
387    M: MemoryMember,
388    I: MemoryInvitation,
389    V: MemoryVerification,
390    P: MemoryPasskey,
391{
392    type Account = A;
393
394    async fn create_account(&self, create_account: CreateAccount) -> AuthResult<A> {
395        let mut accounts = self.accounts.lock().unwrap();
396
397        let id = Uuid::new_v4().to_string();
398        let now = Utc::now();
399        let account = A::from_create(id.clone(), &create_account, now);
400
401        accounts.insert(id, account.clone());
402        Ok(account)
403    }
404
405    async fn get_account(
406        &self,
407        provider: &str,
408        provider_account_id: &str,
409    ) -> AuthResult<Option<A>> {
410        let accounts = self.accounts.lock().unwrap();
411        Ok(accounts
412            .values()
413            .find(|acc| acc.provider_id() == provider && acc.account_id() == provider_account_id)
414            .cloned())
415    }
416
417    async fn get_user_accounts(&self, user_id: &str) -> AuthResult<Vec<A>> {
418        let accounts = self.accounts.lock().unwrap();
419        Ok(accounts
420            .values()
421            .filter(|acc| acc.user_id() == user_id)
422            .cloned()
423            .collect())
424    }
425
426    async fn update_account(&self, id: &str, update: UpdateAccount) -> AuthResult<A> {
427        let mut accounts = self.accounts.lock().unwrap();
428        let account = accounts
429            .get_mut(id)
430            .ok_or_else(|| AuthError::not_found("Account not found"))?;
431        account.apply_update(&update);
432        Ok(account.clone())
433    }
434
435    async fn delete_account(&self, id: &str) -> AuthResult<()> {
436        let mut accounts = self.accounts.lock().unwrap();
437        accounts.remove(id);
438        Ok(())
439    }
440}
441
442// -- VerificationOps --
443
444#[async_trait]
445impl<U, S, A, O, M, I, V, P> VerificationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
446where
447    U: MemoryUser,
448    S: MemorySession,
449    A: MemoryAccount,
450    O: MemoryOrganization,
451    M: MemoryMember,
452    I: MemoryInvitation,
453    V: MemoryVerification,
454    P: MemoryPasskey,
455{
456    type Verification = V;
457
458    async fn create_verification(&self, create_verification: CreateVerification) -> AuthResult<V> {
459        let mut verifications = self.verifications.lock().unwrap();
460
461        let id = Uuid::new_v4().to_string();
462        let now = Utc::now();
463        let verification = V::from_create(id.clone(), &create_verification, now);
464
465        verifications.insert(id, verification.clone());
466        Ok(verification)
467    }
468
469    async fn get_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
470        let verifications = self.verifications.lock().unwrap();
471        let now = Utc::now();
472        Ok(verifications
473            .values()
474            .find(|v| v.identifier() == identifier && v.value() == value && v.expires_at() > now)
475            .cloned())
476    }
477
478    async fn get_verification_by_value(&self, value: &str) -> AuthResult<Option<V>> {
479        let verifications = self.verifications.lock().unwrap();
480        let now = Utc::now();
481        Ok(verifications
482            .values()
483            .find(|v| v.value() == value && v.expires_at() > now)
484            .cloned())
485    }
486
487    async fn get_verification_by_identifier(&self, identifier: &str) -> AuthResult<Option<V>> {
488        let verifications = self.verifications.lock().unwrap();
489        let now = Utc::now();
490        Ok(verifications
491            .values()
492            .find(|v| v.identifier() == identifier && v.expires_at() > now)
493            .cloned())
494    }
495
496    async fn consume_verification(&self, identifier: &str, value: &str) -> AuthResult<Option<V>> {
497        let mut verifications = self.verifications.lock().unwrap();
498        let now = Utc::now();
499
500        let matched_id = verifications
501            .iter()
502            .filter_map(|(id, verification)| {
503                if verification.identifier() == identifier
504                    && verification.value() == value
505                    && verification.expires_at() > now
506                {
507                    Some((id, verification.created_at()))
508                } else {
509                    None
510                }
511            })
512            .max_by_key(|(_, created_at)| *created_at)
513            .map(|(id, _)| id.clone());
514
515        if let Some(id) = matched_id {
516            Ok(verifications.remove(&id))
517        } else {
518            Ok(None)
519        }
520    }
521
522    async fn delete_verification(&self, id: &str) -> AuthResult<()> {
523        let mut verifications = self.verifications.lock().unwrap();
524        verifications.remove(id);
525        Ok(())
526    }
527
528    async fn delete_expired_verifications(&self) -> AuthResult<usize> {
529        let mut verifications = self.verifications.lock().unwrap();
530        let now = Utc::now();
531        let initial_count = verifications.len();
532        verifications.retain(|_, v| v.expires_at() > now);
533        Ok(initial_count - verifications.len())
534    }
535}
536
537// -- OrganizationOps --
538
539#[async_trait]
540impl<U, S, A, O, M, I, V, P> OrganizationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
541where
542    U: MemoryUser,
543    S: MemorySession,
544    A: MemoryAccount,
545    O: MemoryOrganization,
546    M: MemoryMember,
547    I: MemoryInvitation,
548    V: MemoryVerification,
549    P: MemoryPasskey,
550{
551    type Organization = O;
552
553    async fn create_organization(&self, create_org: CreateOrganization) -> AuthResult<O> {
554        let mut organizations = self.organizations.lock().unwrap();
555        let mut slug_index = self.slug_index.lock().unwrap();
556
557        if slug_index.contains_key(&create_org.slug) {
558            return Err(AuthError::conflict("Organization slug already exists"));
559        }
560
561        let id = create_org
562            .id
563            .clone()
564            .unwrap_or_else(|| Uuid::new_v4().to_string());
565        let now = Utc::now();
566        let organization = O::from_create(id.clone(), &create_org, now);
567
568        organizations.insert(id.clone(), organization.clone());
569        slug_index.insert(create_org.slug.clone(), id);
570
571        Ok(organization)
572    }
573
574    async fn get_organization_by_id(&self, id: &str) -> AuthResult<Option<O>> {
575        let organizations = self.organizations.lock().unwrap();
576        Ok(organizations.get(id).cloned())
577    }
578
579    async fn get_organization_by_slug(&self, slug: &str) -> AuthResult<Option<O>> {
580        let slug_index = self.slug_index.lock().unwrap();
581        let organizations = self.organizations.lock().unwrap();
582
583        if let Some(org_id) = slug_index.get(slug) {
584            Ok(organizations.get(org_id).cloned())
585        } else {
586            Ok(None)
587        }
588    }
589
590    async fn update_organization(&self, id: &str, update: UpdateOrganization) -> AuthResult<O> {
591        let mut organizations = self.organizations.lock().unwrap();
592        let mut slug_index = self.slug_index.lock().unwrap();
593
594        let org = organizations
595            .get_mut(id)
596            .ok_or_else(|| AuthError::not_found("Organization not found"))?;
597
598        // Update slug index BEFORE mutation
599        if let Some(new_slug) = &update.slug {
600            let current_slug = org.slug().to_string();
601            if *new_slug != current_slug {
602                if slug_index.contains_key(new_slug.as_str()) {
603                    return Err(AuthError::conflict("Organization slug already exists"));
604                }
605                slug_index.remove(&current_slug);
606                slug_index.insert(new_slug.clone(), id.to_string());
607            }
608        }
609
610        org.apply_update(&update);
611        Ok(org.clone())
612    }
613
614    async fn delete_organization(&self, id: &str) -> AuthResult<()> {
615        let mut organizations = self.organizations.lock().unwrap();
616        let mut slug_index = self.slug_index.lock().unwrap();
617        let mut members = self.members.lock().unwrap();
618        let mut invitations = self.invitations.lock().unwrap();
619
620        if let Some(org) = organizations.remove(id) {
621            slug_index.remove(org.slug());
622        }
623
624        members.retain(|_, m| m.organization_id() != id);
625        invitations.retain(|_, i| i.organization_id() != id);
626
627        Ok(())
628    }
629
630    async fn list_user_organizations(&self, user_id: &str) -> AuthResult<Vec<O>> {
631        let members = self.members.lock().unwrap();
632        let organizations = self.organizations.lock().unwrap();
633
634        let org_ids: Vec<String> = members
635            .values()
636            .filter(|m| m.user_id() == user_id)
637            .map(|m| m.organization_id().to_string())
638            .collect();
639
640        let orgs = org_ids
641            .iter()
642            .filter_map(|id| organizations.get(id).cloned())
643            .collect();
644
645        Ok(orgs)
646    }
647}
648
649// -- MemberOps --
650
651#[async_trait]
652impl<U, S, A, O, M, I, V, P> MemberOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
653where
654    U: MemoryUser,
655    S: MemorySession,
656    A: MemoryAccount,
657    O: MemoryOrganization,
658    M: MemoryMember,
659    I: MemoryInvitation,
660    V: MemoryVerification,
661    P: MemoryPasskey,
662{
663    type Member = M;
664
665    async fn create_member(&self, create_member: CreateMember) -> AuthResult<M> {
666        let mut members = self.members.lock().unwrap();
667
668        let exists = members.values().any(|m| {
669            m.organization_id() == create_member.organization_id
670                && m.user_id() == create_member.user_id
671        });
672
673        if exists {
674            return Err(AuthError::conflict(
675                "User is already a member of this organization",
676            ));
677        }
678
679        let id = Uuid::new_v4().to_string();
680        let now = Utc::now();
681        let member = M::from_create(id.clone(), &create_member, now);
682
683        members.insert(id, member.clone());
684        Ok(member)
685    }
686
687    async fn get_member(&self, organization_id: &str, user_id: &str) -> AuthResult<Option<M>> {
688        let members = self.members.lock().unwrap();
689        Ok(members
690            .values()
691            .find(|m| m.organization_id() == organization_id && m.user_id() == user_id)
692            .cloned())
693    }
694
695    async fn get_member_by_id(&self, id: &str) -> AuthResult<Option<M>> {
696        let members = self.members.lock().unwrap();
697        Ok(members.get(id).cloned())
698    }
699
700    async fn update_member_role(&self, member_id: &str, role: &str) -> AuthResult<M> {
701        let mut members = self.members.lock().unwrap();
702        let member = members
703            .get_mut(member_id)
704            .ok_or_else(|| AuthError::not_found("Member not found"))?;
705        member.set_role(role.to_string());
706        Ok(member.clone())
707    }
708
709    async fn delete_member(&self, member_id: &str) -> AuthResult<()> {
710        let mut members = self.members.lock().unwrap();
711        members.remove(member_id);
712        Ok(())
713    }
714
715    async fn list_organization_members(&self, organization_id: &str) -> AuthResult<Vec<M>> {
716        let members = self.members.lock().unwrap();
717        Ok(members
718            .values()
719            .filter(|m| m.organization_id() == organization_id)
720            .cloned()
721            .collect())
722    }
723
724    async fn count_organization_members(&self, organization_id: &str) -> AuthResult<usize> {
725        let members = self.members.lock().unwrap();
726        Ok(members
727            .values()
728            .filter(|m| m.organization_id() == organization_id)
729            .count())
730    }
731
732    async fn count_organization_owners(&self, organization_id: &str) -> AuthResult<usize> {
733        let members = self.members.lock().unwrap();
734        Ok(members
735            .values()
736            .filter(|m| m.organization_id() == organization_id && m.role() == "owner")
737            .count())
738    }
739}
740
741// -- InvitationOps --
742
743#[async_trait]
744impl<U, S, A, O, M, I, V, P> InvitationOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
745where
746    U: MemoryUser,
747    S: MemorySession,
748    A: MemoryAccount,
749    O: MemoryOrganization,
750    M: MemoryMember,
751    I: MemoryInvitation,
752    V: MemoryVerification,
753    P: MemoryPasskey,
754{
755    type Invitation = I;
756
757    async fn create_invitation(&self, create_inv: CreateInvitation) -> AuthResult<I> {
758        let mut invitations = self.invitations.lock().unwrap();
759
760        let id = Uuid::new_v4().to_string();
761        let now = Utc::now();
762        let invitation = I::from_create(id.clone(), &create_inv, now);
763
764        invitations.insert(id, invitation.clone());
765        Ok(invitation)
766    }
767
768    async fn get_invitation_by_id(&self, id: &str) -> AuthResult<Option<I>> {
769        let invitations = self.invitations.lock().unwrap();
770        Ok(invitations.get(id).cloned())
771    }
772
773    async fn get_pending_invitation(
774        &self,
775        organization_id: &str,
776        email: &str,
777    ) -> AuthResult<Option<I>> {
778        let invitations = self.invitations.lock().unwrap();
779        Ok(invitations
780            .values()
781            .find(|i| {
782                i.organization_id() == organization_id
783                    && i.email().to_lowercase() == email.to_lowercase()
784                    && *i.status() == InvitationStatus::Pending
785            })
786            .cloned())
787    }
788
789    async fn update_invitation_status(&self, id: &str, status: InvitationStatus) -> AuthResult<I> {
790        let mut invitations = self.invitations.lock().unwrap();
791        let invitation = invitations
792            .get_mut(id)
793            .ok_or_else(|| AuthError::not_found("Invitation not found"))?;
794        invitation.set_status(status);
795        Ok(invitation.clone())
796    }
797
798    async fn list_organization_invitations(&self, organization_id: &str) -> AuthResult<Vec<I>> {
799        let invitations = self.invitations.lock().unwrap();
800        Ok(invitations
801            .values()
802            .filter(|i| i.organization_id() == organization_id)
803            .cloned()
804            .collect())
805    }
806
807    async fn list_user_invitations(&self, email: &str) -> AuthResult<Vec<I>> {
808        let invitations = self.invitations.lock().unwrap();
809        let now = Utc::now();
810        Ok(invitations
811            .values()
812            .filter(|i| {
813                i.email().to_lowercase() == email.to_lowercase()
814                    && *i.status() == InvitationStatus::Pending
815                    && i.expires_at() > now
816            })
817            .cloned()
818            .collect())
819    }
820}
821
822// -- TwoFactorOps --
823
824#[async_trait]
825impl<U, S, A, O, M, I, V, P> TwoFactorOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
826where
827    U: MemoryUser,
828    S: MemorySession,
829    A: MemoryAccount,
830    O: MemoryOrganization,
831    M: MemoryMember,
832    I: MemoryInvitation,
833    V: MemoryVerification,
834    P: MemoryPasskey,
835{
836    type TwoFactor = TwoFactor;
837
838    async fn create_two_factor(&self, create: CreateTwoFactor) -> AuthResult<TwoFactor> {
839        let mut two_factors = self.two_factors.lock().unwrap();
840
841        // Check if user already has 2FA
842        if two_factors.values().any(|tf| tf.user_id == create.user_id) {
843            return Err(AuthError::conflict(
844                "Two-factor authentication already enabled for this user",
845            ));
846        }
847
848        let id = Uuid::new_v4().to_string();
849        let now = Utc::now();
850        let two_factor: TwoFactor = MemoryTwoFactor::from_create(id.clone(), &create, now);
851
852        two_factors.insert(id, two_factor.clone());
853        Ok(two_factor)
854    }
855
856    async fn get_two_factor_by_user_id(&self, user_id: &str) -> AuthResult<Option<TwoFactor>> {
857        let two_factors = self.two_factors.lock().unwrap();
858        Ok(two_factors
859            .values()
860            .find(|tf| tf.user_id == user_id)
861            .cloned())
862    }
863
864    async fn update_two_factor_backup_codes(
865        &self,
866        user_id: &str,
867        backup_codes: &str,
868    ) -> AuthResult<TwoFactor> {
869        let mut two_factors = self.two_factors.lock().unwrap();
870        let two_factor = two_factors
871            .values_mut()
872            .find(|tf| tf.user_id == user_id)
873            .ok_or_else(|| AuthError::not_found("Two-factor record not found"))?;
874        two_factor.set_backup_codes(backup_codes.to_string());
875        Ok(two_factor.clone())
876    }
877
878    async fn delete_two_factor(&self, user_id: &str) -> AuthResult<()> {
879        let mut two_factors = self.two_factors.lock().unwrap();
880        two_factors.retain(|_, tf| tf.user_id != user_id);
881        Ok(())
882    }
883}
884
885// -- ApiKeyOps --
886
887#[async_trait]
888impl<U, S, A, O, M, I, V, P> ApiKeyOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
889where
890    U: MemoryUser,
891    S: MemorySession,
892    A: MemoryAccount,
893    O: MemoryOrganization,
894    M: MemoryMember,
895    I: MemoryInvitation,
896    V: MemoryVerification,
897    P: MemoryPasskey,
898{
899    type ApiKey = ApiKey;
900
901    async fn create_api_key(&self, input: CreateApiKey) -> AuthResult<ApiKey> {
902        let mut api_keys = self.api_keys.lock().unwrap();
903
904        if api_keys.values().any(|k| k.key_hash == input.key_hash) {
905            return Err(AuthError::conflict("API key already exists"));
906        }
907
908        let id = Uuid::new_v4().to_string();
909        let now = Utc::now();
910        let api_key: ApiKey = MemoryApiKey::from_create(id.clone(), &input, now);
911
912        api_keys.insert(id, api_key.clone());
913        Ok(api_key)
914    }
915
916    async fn get_api_key_by_id(&self, id: &str) -> AuthResult<Option<ApiKey>> {
917        let api_keys = self.api_keys.lock().unwrap();
918        Ok(api_keys.get(id).cloned())
919    }
920
921    async fn get_api_key_by_hash(&self, hash: &str) -> AuthResult<Option<ApiKey>> {
922        let api_keys = self.api_keys.lock().unwrap();
923        Ok(api_keys.values().find(|k| k.key_hash == hash).cloned())
924    }
925
926    async fn list_api_keys_by_user(&self, user_id: &str) -> AuthResult<Vec<ApiKey>> {
927        let api_keys = self.api_keys.lock().unwrap();
928        let mut keys: Vec<ApiKey> = api_keys
929            .values()
930            .filter(|k| k.user_id == user_id)
931            .cloned()
932            .collect();
933        keys.sort_by(|a, b| b.created_at.cmp(&a.created_at));
934        Ok(keys)
935    }
936
937    async fn update_api_key(&self, id: &str, update: UpdateApiKey) -> AuthResult<ApiKey> {
938        let mut api_keys = self.api_keys.lock().unwrap();
939        let api_key = api_keys
940            .get_mut(id)
941            .ok_or_else(|| AuthError::not_found("API key not found"))?;
942        api_key.apply_update(&update);
943        Ok(api_key.clone())
944    }
945
946    async fn delete_api_key(&self, id: &str) -> AuthResult<()> {
947        let mut api_keys = self.api_keys.lock().unwrap();
948        api_keys.remove(id);
949        Ok(())
950    }
951
952    async fn delete_expired_api_keys(&self) -> AuthResult<usize> {
953        let mut api_keys = self.api_keys.lock().unwrap();
954        let now = Utc::now();
955        let initial_count = api_keys.len();
956        api_keys.retain(|_, k| {
957            if let Some(expires_at) = &k.expires_at
958                && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires_at)
959            {
960                return exp > now;
961            }
962            true // keep keys without expiration
963        });
964        Ok(initial_count - api_keys.len())
965    }
966}
967
968// -- PasskeyOps --
969
970#[async_trait]
971impl<U, S, A, O, M, I, V, P> PasskeyOps for MemoryDatabaseAdapter<U, S, A, O, M, I, V, P>
972where
973    U: MemoryUser,
974    S: MemorySession,
975    A: MemoryAccount,
976    O: MemoryOrganization,
977    M: MemoryMember,
978    I: MemoryInvitation,
979    V: MemoryVerification,
980    P: MemoryPasskey,
981{
982    type Passkey = P;
983
984    async fn create_passkey(&self, input: CreatePasskey) -> AuthResult<P> {
985        let mut credential_index = self.passkey_credential_index.lock().unwrap();
986        let mut passkeys = self.passkeys.lock().unwrap();
987
988        if credential_index.contains_key(&input.credential_id) {
989            return Err(AuthError::conflict(
990                "A passkey with this credential ID already exists",
991            ));
992        }
993
994        let id = Uuid::new_v4().to_string();
995        let now = Utc::now();
996        let passkey = P::from_create(id.clone(), &input, now);
997
998        credential_index.insert(input.credential_id.clone(), id.clone());
999        passkeys.insert(id, passkey.clone());
1000        Ok(passkey)
1001    }
1002
1003    async fn get_passkey_by_id(&self, id: &str) -> AuthResult<Option<P>> {
1004        let passkeys = self.passkeys.lock().unwrap();
1005        Ok(passkeys.get(id).cloned())
1006    }
1007
1008    async fn get_passkey_by_credential_id(&self, credential_id: &str) -> AuthResult<Option<P>> {
1009        let passkey_id = {
1010            let credential_index = self.passkey_credential_index.lock().unwrap();
1011            credential_index.get(credential_id).cloned()
1012        };
1013
1014        let passkeys = self.passkeys.lock().unwrap();
1015
1016        if let Some(id) = passkey_id {
1017            Ok(passkeys.get(&id).cloned())
1018        } else {
1019            Ok(None)
1020        }
1021    }
1022
1023    async fn list_passkeys_by_user(&self, user_id: &str) -> AuthResult<Vec<P>> {
1024        let passkeys = self.passkeys.lock().unwrap();
1025        let mut matched: Vec<P> = passkeys
1026            .values()
1027            .filter(|p| p.user_id() == user_id)
1028            .cloned()
1029            .collect();
1030        matched.sort_by_key(|p| std::cmp::Reverse(p.created_at()));
1031        Ok(matched)
1032    }
1033
1034    async fn update_passkey_counter(&self, id: &str, counter: u64) -> AuthResult<P> {
1035        let mut passkeys = self.passkeys.lock().unwrap();
1036        let passkey = passkeys
1037            .get_mut(id)
1038            .ok_or_else(|| AuthError::not_found("Passkey not found"))?;
1039        passkey.set_counter(counter);
1040        Ok(passkey.clone())
1041    }
1042
1043    async fn update_passkey_name(&self, id: &str, name: &str) -> AuthResult<P> {
1044        let mut passkeys = self.passkeys.lock().unwrap();
1045        let passkey = passkeys
1046            .get_mut(id)
1047            .ok_or_else(|| AuthError::not_found("Passkey not found"))?;
1048        passkey.set_name(name.to_string());
1049        Ok(passkey.clone())
1050    }
1051
1052    async fn delete_passkey(&self, id: &str) -> AuthResult<()> {
1053        let mut credential_index = self.passkey_credential_index.lock().unwrap();
1054        let mut passkeys = self.passkeys.lock().unwrap();
1055
1056        if let Some(passkey) = passkeys.remove(id) {
1057            credential_index.remove(passkey.credential_id());
1058        }
1059        Ok(())
1060    }
1061}