Skip to main content

authx_storage/memory/
mod.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4#[cfg(test)]
5mod tests;
6
7use async_trait::async_trait;
8use chrono::Utc;
9use uuid::Uuid;
10
11use authx_core::{
12    error::{AuthError, Result, StorageError},
13    models::{
14        ApiKey, AuditLog, CreateApiKey, CreateAuditLog, CreateCredential, CreateInvite, CreateOrg,
15        CreateSession, CreateUser, Credential, CredentialKind, Invite, Membership, OAuthAccount,
16        Organization, Role, Session, UpdateUser, UpsertOAuthAccount, User,
17    },
18};
19
20use crate::ports::{
21    ApiKeyRepository, AuditLogRepository, CredentialRepository, InviteRepository,
22    OAuthAccountRepository, OrgRepository, SessionRepository, UserRepository,
23};
24
25/// Acquire a read guard, recovering from a poisoned lock instead of panicking.
26macro_rules! rlock {
27    ($lock:expr, $label:literal) => {
28        match $lock.read() {
29            Ok(g) => g,
30            Err(e) => {
31                tracing::error!(concat!(
32                    "memory store read-lock poisoned (",
33                    $label,
34                    ") — recovering"
35                ));
36                e.into_inner()
37            }
38        }
39    };
40}
41
42/// Acquire a write guard, recovering from a poisoned lock instead of panicking.
43macro_rules! wlock {
44    ($lock:expr, $label:literal) => {
45        match $lock.write() {
46            Ok(g) => g,
47            Err(e) => {
48                tracing::error!(concat!(
49                    "memory store write-lock poisoned (",
50                    $label,
51                    ") — recovering"
52                ));
53                e.into_inner()
54            }
55        }
56    };
57}
58
59#[derive(Clone, Default)]
60pub struct MemoryStore {
61    users: Arc<RwLock<HashMap<Uuid, User>>>,
62    sessions: Arc<RwLock<HashMap<Uuid, Session>>>,
63    credentials: Arc<RwLock<Vec<Credential>>>,
64    audit_logs: Arc<RwLock<Vec<AuditLog>>>,
65    orgs: Arc<RwLock<HashMap<Uuid, Organization>>>,
66    roles: Arc<RwLock<HashMap<Uuid, Role>>>,
67    memberships: Arc<RwLock<Vec<Membership>>>,
68    api_keys: Arc<RwLock<Vec<ApiKey>>>,
69    oauth_accounts: Arc<RwLock<Vec<OAuthAccount>>>,
70    invites: Arc<RwLock<Vec<Invite>>>,
71}
72
73impl MemoryStore {
74    pub fn new() -> Self {
75        Self::default()
76    }
77}
78
79// ── UserRepository ────────────────────────────────────────────────────────────
80
81#[async_trait]
82impl UserRepository for MemoryStore {
83    async fn find_by_id(&self, id: Uuid) -> Result<Option<User>> {
84        Ok(rlock!(self.users, "users").get(&id).cloned())
85    }
86
87    async fn find_by_email(&self, email: &str) -> Result<Option<User>> {
88        Ok(rlock!(self.users, "users")
89            .values()
90            .find(|u| u.email == email)
91            .cloned())
92    }
93
94    async fn find_by_username(&self, username: &str) -> Result<Option<User>> {
95        Ok(rlock!(self.users, "users")
96            .values()
97            .find(|u| u.username.as_deref() == Some(username))
98            .cloned())
99    }
100
101    async fn list(&self, offset: u32, limit: u32) -> Result<Vec<User>> {
102        let users = rlock!(self.users, "users");
103        let mut sorted: Vec<User> = users.values().cloned().collect();
104        sorted.sort_by_key(|u| u.created_at);
105        Ok(sorted
106            .into_iter()
107            .skip(offset as usize)
108            .take(limit as usize)
109            .collect())
110    }
111
112    async fn create(&self, data: CreateUser) -> Result<User> {
113        let mut users = wlock!(self.users, "users");
114        if users.values().any(|u| u.email == data.email) {
115            return Err(AuthError::EmailTaken);
116        }
117        if let Some(ref uname) = data.username {
118            if users
119                .values()
120                .any(|u| u.username.as_deref() == Some(uname.as_str()))
121            {
122                return Err(AuthError::Storage(StorageError::Conflict(format!(
123                    "username '{}' already taken",
124                    uname
125                ))));
126            }
127        }
128        let user = User {
129            id: Uuid::new_v4(),
130            email: data.email,
131            email_verified: false,
132            username: data.username,
133            created_at: Utc::now(),
134            updated_at: Utc::now(),
135            metadata: data.metadata.unwrap_or(serde_json::Value::Null),
136        };
137        users.insert(user.id, user.clone());
138        Ok(user)
139    }
140
141    async fn update(&self, id: Uuid, data: UpdateUser) -> Result<User> {
142        let mut users = wlock!(self.users, "users");
143        let user = users.get_mut(&id).ok_or(AuthError::UserNotFound)?;
144        if let Some(email) = data.email {
145            user.email = email;
146        }
147        if let Some(verified) = data.email_verified {
148            user.email_verified = verified;
149        }
150        if let Some(uname) = data.username {
151            user.username = Some(uname);
152        }
153        if let Some(meta) = data.metadata {
154            user.metadata = meta;
155        }
156        user.updated_at = Utc::now();
157        Ok(user.clone())
158    }
159
160    async fn delete(&self, id: Uuid) -> Result<()> {
161        wlock!(self.users, "users")
162            .remove(&id)
163            .ok_or(AuthError::UserNotFound)?;
164        Ok(())
165    }
166}
167
168// ── SessionRepository ─────────────────────────────────────────────────────────
169
170#[async_trait]
171impl SessionRepository for MemoryStore {
172    async fn create(&self, data: CreateSession) -> Result<Session> {
173        let session = Session {
174            id: Uuid::new_v4(),
175            user_id: data.user_id,
176            token_hash: data.token_hash,
177            device_info: data.device_info,
178            ip_address: data.ip_address,
179            org_id: data.org_id,
180            expires_at: data.expires_at,
181            created_at: Utc::now(),
182        };
183        wlock!(self.sessions, "sessions").insert(session.id, session.clone());
184        Ok(session)
185    }
186
187    async fn find_by_token_hash(&self, hash: &str) -> Result<Option<Session>> {
188        Ok(rlock!(self.sessions, "sessions")
189            .values()
190            .find(|s| s.token_hash == hash && s.expires_at > Utc::now())
191            .cloned())
192    }
193
194    async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<Session>> {
195        Ok(rlock!(self.sessions, "sessions")
196            .values()
197            .filter(|s| s.user_id == user_id)
198            .cloned()
199            .collect())
200    }
201
202    async fn invalidate(&self, session_id: Uuid) -> Result<()> {
203        wlock!(self.sessions, "sessions")
204            .remove(&session_id)
205            .ok_or(AuthError::Storage(StorageError::NotFound))?;
206        Ok(())
207    }
208
209    async fn invalidate_all_for_user(&self, user_id: Uuid) -> Result<()> {
210        wlock!(self.sessions, "sessions").retain(|_, s| s.user_id != user_id);
211        Ok(())
212    }
213
214    async fn set_org(&self, session_id: Uuid, org_id: Option<Uuid>) -> Result<Session> {
215        let mut sessions = wlock!(self.sessions, "sessions");
216        let session = sessions
217            .get_mut(&session_id)
218            .ok_or(AuthError::Storage(StorageError::NotFound))?;
219        session.org_id = org_id;
220        Ok(session.clone())
221    }
222}
223
224// ── CredentialRepository ──────────────────────────────────────────────────────
225
226#[async_trait]
227impl CredentialRepository for MemoryStore {
228    async fn create(&self, data: CreateCredential) -> Result<Credential> {
229        let cred = Credential {
230            id: Uuid::new_v4(),
231            user_id: data.user_id,
232            kind: data.kind,
233            credential_hash: data.credential_hash,
234            metadata: data.metadata.unwrap_or(serde_json::Value::Null),
235        };
236        wlock!(self.credentials, "credentials").push(cred.clone());
237        Ok(cred)
238    }
239
240    async fn find_password_hash(&self, user_id: Uuid) -> Result<Option<String>> {
241        Ok(rlock!(self.credentials, "credentials")
242            .iter()
243            .find(|c| c.user_id == user_id && c.kind == CredentialKind::Password)
244            .map(|c| c.credential_hash.clone()))
245    }
246
247    async fn find_by_user_and_kind(
248        &self,
249        user_id: Uuid,
250        kind: CredentialKind,
251    ) -> Result<Option<Credential>> {
252        Ok(rlock!(self.credentials, "credentials")
253            .iter()
254            .find(|c| c.user_id == user_id && c.kind == kind)
255            .cloned())
256    }
257
258    async fn delete_by_user_and_kind(&self, user_id: Uuid, kind: CredentialKind) -> Result<()> {
259        let mut creds = wlock!(self.credentials, "credentials");
260        let before = creds.len();
261        creds.retain(|c| !(c.user_id == user_id && c.kind == kind));
262        if creds.len() == before {
263            return Err(AuthError::Storage(StorageError::NotFound));
264        }
265        Ok(())
266    }
267}
268
269// ── OrgRepository ─────────────────────────────────────────────────────────────
270
271#[async_trait]
272impl OrgRepository for MemoryStore {
273    async fn create(&self, data: CreateOrg) -> Result<Organization> {
274        let mut orgs = wlock!(self.orgs, "orgs");
275        if orgs.values().any(|o| o.slug == data.slug) {
276            return Err(AuthError::Storage(StorageError::Conflict(format!(
277                "slug '{}' already taken",
278                data.slug
279            ))));
280        }
281        let org = Organization {
282            id: Uuid::new_v4(),
283            name: data.name,
284            slug: data.slug,
285            metadata: data.metadata.unwrap_or(serde_json::Value::Null),
286            created_at: Utc::now(),
287        };
288        orgs.insert(org.id, org.clone());
289        Ok(org)
290    }
291
292    async fn find_by_id(&self, id: Uuid) -> Result<Option<Organization>> {
293        Ok(rlock!(self.orgs, "orgs").get(&id).cloned())
294    }
295
296    async fn find_by_slug(&self, slug: &str) -> Result<Option<Organization>> {
297        Ok(rlock!(self.orgs, "orgs")
298            .values()
299            .find(|o| o.slug == slug)
300            .cloned())
301    }
302
303    async fn add_member(&self, org_id: Uuid, user_id: Uuid, role_id: Uuid) -> Result<Membership> {
304        let role = rlock!(self.roles, "roles")
305            .get(&role_id)
306            .cloned()
307            .ok_or(AuthError::Storage(StorageError::NotFound))?;
308        let membership = Membership {
309            id: Uuid::new_v4(),
310            user_id,
311            org_id,
312            role,
313            created_at: Utc::now(),
314        };
315        wlock!(self.memberships, "memberships").push(membership.clone());
316        Ok(membership)
317    }
318
319    async fn remove_member(&self, org_id: Uuid, user_id: Uuid) -> Result<()> {
320        let mut memberships = wlock!(self.memberships, "memberships");
321        let before = memberships.len();
322        memberships.retain(|m| !(m.org_id == org_id && m.user_id == user_id));
323        if memberships.len() == before {
324            return Err(AuthError::Storage(StorageError::NotFound));
325        }
326        Ok(())
327    }
328
329    async fn get_members(&self, org_id: Uuid) -> Result<Vec<Membership>> {
330        Ok(rlock!(self.memberships, "memberships")
331            .iter()
332            .filter(|m| m.org_id == org_id)
333            .cloned()
334            .collect())
335    }
336
337    async fn find_roles(&self, org_id: Uuid) -> Result<Vec<Role>> {
338        Ok(rlock!(self.roles, "roles")
339            .values()
340            .filter(|r| r.org_id == org_id)
341            .cloned()
342            .collect())
343    }
344
345    async fn create_role(
346        &self,
347        org_id: Uuid,
348        name: String,
349        permissions: Vec<String>,
350    ) -> Result<Role> {
351        let role = Role {
352            id: Uuid::new_v4(),
353            org_id,
354            name,
355            permissions,
356        };
357        wlock!(self.roles, "roles").insert(role.id, role.clone());
358        Ok(role)
359    }
360
361    async fn update_member_role(
362        &self,
363        org_id: Uuid,
364        user_id: Uuid,
365        role_id: Uuid,
366    ) -> Result<Membership> {
367        let role = rlock!(self.roles, "roles")
368            .get(&role_id)
369            .cloned()
370            .ok_or(AuthError::Storage(StorageError::NotFound))?;
371
372        let mut memberships = wlock!(self.memberships, "memberships");
373        let m = memberships
374            .iter_mut()
375            .find(|m| m.org_id == org_id && m.user_id == user_id)
376            .ok_or(AuthError::Storage(StorageError::NotFound))?;
377        m.role = role;
378        Ok(m.clone())
379    }
380}
381
382// ── AuditLogRepository ────────────────────────────────────────────────────────
383
384#[async_trait]
385impl AuditLogRepository for MemoryStore {
386    async fn append(&self, entry: CreateAuditLog) -> Result<AuditLog> {
387        let log = AuditLog {
388            id: Uuid::new_v4(),
389            user_id: entry.user_id,
390            org_id: entry.org_id,
391            action: entry.action,
392            resource_type: entry.resource_type,
393            resource_id: entry.resource_id,
394            ip_address: entry.ip_address,
395            metadata: entry.metadata.unwrap_or(serde_json::Value::Null),
396            created_at: Utc::now(),
397        };
398        wlock!(self.audit_logs, "audit_logs").push(log.clone());
399        Ok(log)
400    }
401
402    async fn find_by_user(&self, user_id: Uuid, limit: u32) -> Result<Vec<AuditLog>> {
403        Ok(rlock!(self.audit_logs, "audit_logs")
404            .iter()
405            .filter(|l| l.user_id == Some(user_id))
406            .take(limit as usize)
407            .cloned()
408            .collect())
409    }
410
411    async fn find_by_org(&self, org_id: Uuid, limit: u32) -> Result<Vec<AuditLog>> {
412        Ok(rlock!(self.audit_logs, "audit_logs")
413            .iter()
414            .filter(|l| l.org_id == Some(org_id))
415            .take(limit as usize)
416            .cloned()
417            .collect())
418    }
419}
420
421// ── ApiKeyRepository ──────────────────────────────────────────────────────────
422
423#[async_trait]
424impl ApiKeyRepository for MemoryStore {
425    async fn create(&self, data: CreateApiKey) -> Result<ApiKey> {
426        let key = ApiKey {
427            id: Uuid::new_v4(),
428            user_id: data.user_id,
429            org_id: data.org_id,
430            key_hash: data.key_hash,
431            prefix: data.prefix,
432            name: data.name,
433            scopes: data.scopes,
434            expires_at: data.expires_at,
435            last_used_at: None,
436        };
437        wlock!(self.api_keys, "api_keys").push(key.clone());
438        Ok(key)
439    }
440
441    async fn find_by_hash(&self, key_hash: &str) -> Result<Option<ApiKey>> {
442        Ok(rlock!(self.api_keys, "api_keys")
443            .iter()
444            .find(|k| k.key_hash == key_hash)
445            .cloned())
446    }
447
448    async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<ApiKey>> {
449        Ok(rlock!(self.api_keys, "api_keys")
450            .iter()
451            .filter(|k| k.user_id == user_id)
452            .cloned()
453            .collect())
454    }
455
456    async fn revoke(&self, key_id: Uuid, user_id: Uuid) -> Result<()> {
457        let mut keys = wlock!(self.api_keys, "api_keys");
458        let before = keys.len();
459        keys.retain(|k| !(k.id == key_id && k.user_id == user_id));
460        if keys.len() == before {
461            return Err(AuthError::Storage(StorageError::NotFound));
462        }
463        Ok(())
464    }
465
466    async fn touch_last_used(&self, key_id: Uuid, at: chrono::DateTime<Utc>) -> Result<()> {
467        let mut keys = wlock!(self.api_keys, "api_keys");
468        if let Some(k) = keys.iter_mut().find(|k| k.id == key_id) {
469            k.last_used_at = Some(at);
470        }
471        Ok(())
472    }
473}
474
475// ── OAuthAccountRepository ────────────────────────────────────────────────────
476
477#[async_trait]
478impl OAuthAccountRepository for MemoryStore {
479    async fn upsert(&self, data: UpsertOAuthAccount) -> Result<OAuthAccount> {
480        let mut accounts = wlock!(self.oauth_accounts, "oauth_accounts");
481        if let Some(existing) = accounts
482            .iter_mut()
483            .find(|a| a.provider == data.provider && a.provider_user_id == data.provider_user_id)
484        {
485            existing.access_token_enc = data.access_token_enc;
486            existing.refresh_token_enc = data.refresh_token_enc;
487            existing.expires_at = data.expires_at;
488            return Ok(existing.clone());
489        }
490        let account = OAuthAccount {
491            id: Uuid::new_v4(),
492            user_id: data.user_id,
493            provider: data.provider,
494            provider_user_id: data.provider_user_id,
495            access_token_enc: data.access_token_enc,
496            refresh_token_enc: data.refresh_token_enc,
497            expires_at: data.expires_at,
498        };
499        accounts.push(account.clone());
500        Ok(account)
501    }
502
503    async fn find_by_provider(
504        &self,
505        provider: &str,
506        provider_user_id: &str,
507    ) -> Result<Option<OAuthAccount>> {
508        Ok(rlock!(self.oauth_accounts, "oauth_accounts")
509            .iter()
510            .find(|a| a.provider == provider && a.provider_user_id == provider_user_id)
511            .cloned())
512    }
513
514    async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<OAuthAccount>> {
515        Ok(rlock!(self.oauth_accounts, "oauth_accounts")
516            .iter()
517            .filter(|a| a.user_id == user_id)
518            .cloned()
519            .collect())
520    }
521
522    async fn delete(&self, id: Uuid) -> Result<()> {
523        let mut accounts = wlock!(self.oauth_accounts, "oauth_accounts");
524        let before = accounts.len();
525        accounts.retain(|a| a.id != id);
526        if accounts.len() == before {
527            return Err(AuthError::Storage(StorageError::NotFound));
528        }
529        Ok(())
530    }
531}
532
533// ── InviteRepository ──────────────────────────────────────────────────────────
534
535#[async_trait]
536impl InviteRepository for MemoryStore {
537    async fn create(&self, data: CreateInvite) -> Result<Invite> {
538        let invite = Invite {
539            id: Uuid::new_v4(),
540            org_id: data.org_id,
541            email: data.email,
542            role_id: data.role_id,
543            token_hash: data.token_hash,
544            expires_at: data.expires_at,
545            accepted_at: None,
546        };
547        wlock!(self.invites, "invites").push(invite.clone());
548        Ok(invite)
549    }
550
551    async fn find_by_token_hash(&self, hash: &str) -> Result<Option<Invite>> {
552        Ok(rlock!(self.invites, "invites")
553            .iter()
554            .find(|i| i.token_hash == hash)
555            .cloned())
556    }
557
558    async fn accept(&self, invite_id: Uuid) -> Result<Invite> {
559        let mut invites = wlock!(self.invites, "invites");
560        let invite = invites
561            .iter_mut()
562            .find(|i| i.id == invite_id)
563            .ok_or(AuthError::Storage(StorageError::NotFound))?;
564        invite.accepted_at = Some(Utc::now());
565        Ok(invite.clone())
566    }
567
568    async fn delete_expired(&self) -> Result<u64> {
569        let mut invites = wlock!(self.invites, "invites");
570        let before = invites.len();
571        let now = Utc::now();
572        invites.retain(|i| i.accepted_at.is_some() || i.expires_at > now);
573        Ok((before - invites.len()) as u64)
574    }
575}