1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4#[cfg(test)]
5mod tests;
6
7use async_trait::async_trait;
8use chrono::Utc;
9use uuid::Uuid;
10
11use authx_core::{
12 error::{AuthError, Result, StorageError},
13 models::{
14 ApiKey, AuditLog, CreateApiKey, CreateAuditLog, CreateCredential, CreateInvite, CreateOrg,
15 CreateSession, CreateUser, Credential, CredentialKind, Invite, Membership, OAuthAccount,
16 Organization, Role, Session, UpdateUser, UpsertOAuthAccount, User,
17 },
18};
19
20use crate::ports::{
21 ApiKeyRepository, AuditLogRepository, CredentialRepository, InviteRepository,
22 OAuthAccountRepository, OrgRepository, SessionRepository, UserRepository,
23};
24
25macro_rules! rlock {
27 ($lock:expr, $label:literal) => {
28 match $lock.read() {
29 Ok(g) => g,
30 Err(e) => {
31 tracing::error!(concat!(
32 "memory store read-lock poisoned (",
33 $label,
34 ") — recovering"
35 ));
36 e.into_inner()
37 }
38 }
39 };
40}
41
42macro_rules! wlock {
44 ($lock:expr, $label:literal) => {
45 match $lock.write() {
46 Ok(g) => g,
47 Err(e) => {
48 tracing::error!(concat!(
49 "memory store write-lock poisoned (",
50 $label,
51 ") — recovering"
52 ));
53 e.into_inner()
54 }
55 }
56 };
57}
58
59#[derive(Clone, Default)]
60pub struct MemoryStore {
61 users: Arc<RwLock<HashMap<Uuid, User>>>,
62 sessions: Arc<RwLock<HashMap<Uuid, Session>>>,
63 credentials: Arc<RwLock<Vec<Credential>>>,
64 audit_logs: Arc<RwLock<Vec<AuditLog>>>,
65 orgs: Arc<RwLock<HashMap<Uuid, Organization>>>,
66 roles: Arc<RwLock<HashMap<Uuid, Role>>>,
67 memberships: Arc<RwLock<Vec<Membership>>>,
68 api_keys: Arc<RwLock<Vec<ApiKey>>>,
69 oauth_accounts: Arc<RwLock<Vec<OAuthAccount>>>,
70 invites: Arc<RwLock<Vec<Invite>>>,
71}
72
73impl MemoryStore {
74 pub fn new() -> Self {
75 Self::default()
76 }
77}
78
79#[async_trait]
82impl UserRepository for MemoryStore {
83 async fn find_by_id(&self, id: Uuid) -> Result<Option<User>> {
84 Ok(rlock!(self.users, "users").get(&id).cloned())
85 }
86
87 async fn find_by_email(&self, email: &str) -> Result<Option<User>> {
88 Ok(rlock!(self.users, "users")
89 .values()
90 .find(|u| u.email == email)
91 .cloned())
92 }
93
94 async fn find_by_username(&self, username: &str) -> Result<Option<User>> {
95 Ok(rlock!(self.users, "users")
96 .values()
97 .find(|u| u.username.as_deref() == Some(username))
98 .cloned())
99 }
100
101 async fn list(&self, offset: u32, limit: u32) -> Result<Vec<User>> {
102 let users = rlock!(self.users, "users");
103 let mut sorted: Vec<User> = users.values().cloned().collect();
104 sorted.sort_by_key(|u| u.created_at);
105 Ok(sorted
106 .into_iter()
107 .skip(offset as usize)
108 .take(limit as usize)
109 .collect())
110 }
111
112 async fn create(&self, data: CreateUser) -> Result<User> {
113 let mut users = wlock!(self.users, "users");
114 if users.values().any(|u| u.email == data.email) {
115 return Err(AuthError::EmailTaken);
116 }
117 if let Some(ref uname) = data.username {
118 if users
119 .values()
120 .any(|u| u.username.as_deref() == Some(uname.as_str()))
121 {
122 return Err(AuthError::Storage(StorageError::Conflict(format!(
123 "username '{}' already taken",
124 uname
125 ))));
126 }
127 }
128 let user = User {
129 id: Uuid::new_v4(),
130 email: data.email,
131 email_verified: false,
132 username: data.username,
133 created_at: Utc::now(),
134 updated_at: Utc::now(),
135 metadata: data.metadata.unwrap_or(serde_json::Value::Null),
136 };
137 users.insert(user.id, user.clone());
138 Ok(user)
139 }
140
141 async fn update(&self, id: Uuid, data: UpdateUser) -> Result<User> {
142 let mut users = wlock!(self.users, "users");
143 let user = users.get_mut(&id).ok_or(AuthError::UserNotFound)?;
144 if let Some(email) = data.email {
145 user.email = email;
146 }
147 if let Some(verified) = data.email_verified {
148 user.email_verified = verified;
149 }
150 if let Some(uname) = data.username {
151 user.username = Some(uname);
152 }
153 if let Some(meta) = data.metadata {
154 user.metadata = meta;
155 }
156 user.updated_at = Utc::now();
157 Ok(user.clone())
158 }
159
160 async fn delete(&self, id: Uuid) -> Result<()> {
161 wlock!(self.users, "users")
162 .remove(&id)
163 .ok_or(AuthError::UserNotFound)?;
164 Ok(())
165 }
166}
167
168#[async_trait]
171impl SessionRepository for MemoryStore {
172 async fn create(&self, data: CreateSession) -> Result<Session> {
173 let session = Session {
174 id: Uuid::new_v4(),
175 user_id: data.user_id,
176 token_hash: data.token_hash,
177 device_info: data.device_info,
178 ip_address: data.ip_address,
179 org_id: data.org_id,
180 expires_at: data.expires_at,
181 created_at: Utc::now(),
182 };
183 wlock!(self.sessions, "sessions").insert(session.id, session.clone());
184 Ok(session)
185 }
186
187 async fn find_by_token_hash(&self, hash: &str) -> Result<Option<Session>> {
188 Ok(rlock!(self.sessions, "sessions")
189 .values()
190 .find(|s| s.token_hash == hash && s.expires_at > Utc::now())
191 .cloned())
192 }
193
194 async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<Session>> {
195 Ok(rlock!(self.sessions, "sessions")
196 .values()
197 .filter(|s| s.user_id == user_id)
198 .cloned()
199 .collect())
200 }
201
202 async fn invalidate(&self, session_id: Uuid) -> Result<()> {
203 wlock!(self.sessions, "sessions")
204 .remove(&session_id)
205 .ok_or(AuthError::Storage(StorageError::NotFound))?;
206 Ok(())
207 }
208
209 async fn invalidate_all_for_user(&self, user_id: Uuid) -> Result<()> {
210 wlock!(self.sessions, "sessions").retain(|_, s| s.user_id != user_id);
211 Ok(())
212 }
213
214 async fn set_org(&self, session_id: Uuid, org_id: Option<Uuid>) -> Result<Session> {
215 let mut sessions = wlock!(self.sessions, "sessions");
216 let session = sessions
217 .get_mut(&session_id)
218 .ok_or(AuthError::Storage(StorageError::NotFound))?;
219 session.org_id = org_id;
220 Ok(session.clone())
221 }
222}
223
224#[async_trait]
227impl CredentialRepository for MemoryStore {
228 async fn create(&self, data: CreateCredential) -> Result<Credential> {
229 let cred = Credential {
230 id: Uuid::new_v4(),
231 user_id: data.user_id,
232 kind: data.kind,
233 credential_hash: data.credential_hash,
234 metadata: data.metadata.unwrap_or(serde_json::Value::Null),
235 };
236 wlock!(self.credentials, "credentials").push(cred.clone());
237 Ok(cred)
238 }
239
240 async fn find_password_hash(&self, user_id: Uuid) -> Result<Option<String>> {
241 Ok(rlock!(self.credentials, "credentials")
242 .iter()
243 .find(|c| c.user_id == user_id && c.kind == CredentialKind::Password)
244 .map(|c| c.credential_hash.clone()))
245 }
246
247 async fn find_by_user_and_kind(
248 &self,
249 user_id: Uuid,
250 kind: CredentialKind,
251 ) -> Result<Option<Credential>> {
252 Ok(rlock!(self.credentials, "credentials")
253 .iter()
254 .find(|c| c.user_id == user_id && c.kind == kind)
255 .cloned())
256 }
257
258 async fn delete_by_user_and_kind(&self, user_id: Uuid, kind: CredentialKind) -> Result<()> {
259 let mut creds = wlock!(self.credentials, "credentials");
260 let before = creds.len();
261 creds.retain(|c| !(c.user_id == user_id && c.kind == kind));
262 if creds.len() == before {
263 return Err(AuthError::Storage(StorageError::NotFound));
264 }
265 Ok(())
266 }
267}
268
269#[async_trait]
272impl OrgRepository for MemoryStore {
273 async fn create(&self, data: CreateOrg) -> Result<Organization> {
274 let mut orgs = wlock!(self.orgs, "orgs");
275 if orgs.values().any(|o| o.slug == data.slug) {
276 return Err(AuthError::Storage(StorageError::Conflict(format!(
277 "slug '{}' already taken",
278 data.slug
279 ))));
280 }
281 let org = Organization {
282 id: Uuid::new_v4(),
283 name: data.name,
284 slug: data.slug,
285 metadata: data.metadata.unwrap_or(serde_json::Value::Null),
286 created_at: Utc::now(),
287 };
288 orgs.insert(org.id, org.clone());
289 Ok(org)
290 }
291
292 async fn find_by_id(&self, id: Uuid) -> Result<Option<Organization>> {
293 Ok(rlock!(self.orgs, "orgs").get(&id).cloned())
294 }
295
296 async fn find_by_slug(&self, slug: &str) -> Result<Option<Organization>> {
297 Ok(rlock!(self.orgs, "orgs")
298 .values()
299 .find(|o| o.slug == slug)
300 .cloned())
301 }
302
303 async fn add_member(&self, org_id: Uuid, user_id: Uuid, role_id: Uuid) -> Result<Membership> {
304 let role = rlock!(self.roles, "roles")
305 .get(&role_id)
306 .cloned()
307 .ok_or(AuthError::Storage(StorageError::NotFound))?;
308 let membership = Membership {
309 id: Uuid::new_v4(),
310 user_id,
311 org_id,
312 role,
313 created_at: Utc::now(),
314 };
315 wlock!(self.memberships, "memberships").push(membership.clone());
316 Ok(membership)
317 }
318
319 async fn remove_member(&self, org_id: Uuid, user_id: Uuid) -> Result<()> {
320 let mut memberships = wlock!(self.memberships, "memberships");
321 let before = memberships.len();
322 memberships.retain(|m| !(m.org_id == org_id && m.user_id == user_id));
323 if memberships.len() == before {
324 return Err(AuthError::Storage(StorageError::NotFound));
325 }
326 Ok(())
327 }
328
329 async fn get_members(&self, org_id: Uuid) -> Result<Vec<Membership>> {
330 Ok(rlock!(self.memberships, "memberships")
331 .iter()
332 .filter(|m| m.org_id == org_id)
333 .cloned()
334 .collect())
335 }
336
337 async fn find_roles(&self, org_id: Uuid) -> Result<Vec<Role>> {
338 Ok(rlock!(self.roles, "roles")
339 .values()
340 .filter(|r| r.org_id == org_id)
341 .cloned()
342 .collect())
343 }
344
345 async fn create_role(
346 &self,
347 org_id: Uuid,
348 name: String,
349 permissions: Vec<String>,
350 ) -> Result<Role> {
351 let role = Role {
352 id: Uuid::new_v4(),
353 org_id,
354 name,
355 permissions,
356 };
357 wlock!(self.roles, "roles").insert(role.id, role.clone());
358 Ok(role)
359 }
360
361 async fn update_member_role(
362 &self,
363 org_id: Uuid,
364 user_id: Uuid,
365 role_id: Uuid,
366 ) -> Result<Membership> {
367 let role = rlock!(self.roles, "roles")
368 .get(&role_id)
369 .cloned()
370 .ok_or(AuthError::Storage(StorageError::NotFound))?;
371
372 let mut memberships = wlock!(self.memberships, "memberships");
373 let m = memberships
374 .iter_mut()
375 .find(|m| m.org_id == org_id && m.user_id == user_id)
376 .ok_or(AuthError::Storage(StorageError::NotFound))?;
377 m.role = role;
378 Ok(m.clone())
379 }
380}
381
382#[async_trait]
385impl AuditLogRepository for MemoryStore {
386 async fn append(&self, entry: CreateAuditLog) -> Result<AuditLog> {
387 let log = AuditLog {
388 id: Uuid::new_v4(),
389 user_id: entry.user_id,
390 org_id: entry.org_id,
391 action: entry.action,
392 resource_type: entry.resource_type,
393 resource_id: entry.resource_id,
394 ip_address: entry.ip_address,
395 metadata: entry.metadata.unwrap_or(serde_json::Value::Null),
396 created_at: Utc::now(),
397 };
398 wlock!(self.audit_logs, "audit_logs").push(log.clone());
399 Ok(log)
400 }
401
402 async fn find_by_user(&self, user_id: Uuid, limit: u32) -> Result<Vec<AuditLog>> {
403 Ok(rlock!(self.audit_logs, "audit_logs")
404 .iter()
405 .filter(|l| l.user_id == Some(user_id))
406 .take(limit as usize)
407 .cloned()
408 .collect())
409 }
410
411 async fn find_by_org(&self, org_id: Uuid, limit: u32) -> Result<Vec<AuditLog>> {
412 Ok(rlock!(self.audit_logs, "audit_logs")
413 .iter()
414 .filter(|l| l.org_id == Some(org_id))
415 .take(limit as usize)
416 .cloned()
417 .collect())
418 }
419}
420
421#[async_trait]
424impl ApiKeyRepository for MemoryStore {
425 async fn create(&self, data: CreateApiKey) -> Result<ApiKey> {
426 let key = ApiKey {
427 id: Uuid::new_v4(),
428 user_id: data.user_id,
429 org_id: data.org_id,
430 key_hash: data.key_hash,
431 prefix: data.prefix,
432 name: data.name,
433 scopes: data.scopes,
434 expires_at: data.expires_at,
435 last_used_at: None,
436 };
437 wlock!(self.api_keys, "api_keys").push(key.clone());
438 Ok(key)
439 }
440
441 async fn find_by_hash(&self, key_hash: &str) -> Result<Option<ApiKey>> {
442 Ok(rlock!(self.api_keys, "api_keys")
443 .iter()
444 .find(|k| k.key_hash == key_hash)
445 .cloned())
446 }
447
448 async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<ApiKey>> {
449 Ok(rlock!(self.api_keys, "api_keys")
450 .iter()
451 .filter(|k| k.user_id == user_id)
452 .cloned()
453 .collect())
454 }
455
456 async fn revoke(&self, key_id: Uuid, user_id: Uuid) -> Result<()> {
457 let mut keys = wlock!(self.api_keys, "api_keys");
458 let before = keys.len();
459 keys.retain(|k| !(k.id == key_id && k.user_id == user_id));
460 if keys.len() == before {
461 return Err(AuthError::Storage(StorageError::NotFound));
462 }
463 Ok(())
464 }
465
466 async fn touch_last_used(&self, key_id: Uuid, at: chrono::DateTime<Utc>) -> Result<()> {
467 let mut keys = wlock!(self.api_keys, "api_keys");
468 if let Some(k) = keys.iter_mut().find(|k| k.id == key_id) {
469 k.last_used_at = Some(at);
470 }
471 Ok(())
472 }
473}
474
475#[async_trait]
478impl OAuthAccountRepository for MemoryStore {
479 async fn upsert(&self, data: UpsertOAuthAccount) -> Result<OAuthAccount> {
480 let mut accounts = wlock!(self.oauth_accounts, "oauth_accounts");
481 if let Some(existing) = accounts
482 .iter_mut()
483 .find(|a| a.provider == data.provider && a.provider_user_id == data.provider_user_id)
484 {
485 existing.access_token_enc = data.access_token_enc;
486 existing.refresh_token_enc = data.refresh_token_enc;
487 existing.expires_at = data.expires_at;
488 return Ok(existing.clone());
489 }
490 let account = OAuthAccount {
491 id: Uuid::new_v4(),
492 user_id: data.user_id,
493 provider: data.provider,
494 provider_user_id: data.provider_user_id,
495 access_token_enc: data.access_token_enc,
496 refresh_token_enc: data.refresh_token_enc,
497 expires_at: data.expires_at,
498 };
499 accounts.push(account.clone());
500 Ok(account)
501 }
502
503 async fn find_by_provider(
504 &self,
505 provider: &str,
506 provider_user_id: &str,
507 ) -> Result<Option<OAuthAccount>> {
508 Ok(rlock!(self.oauth_accounts, "oauth_accounts")
509 .iter()
510 .find(|a| a.provider == provider && a.provider_user_id == provider_user_id)
511 .cloned())
512 }
513
514 async fn find_by_user(&self, user_id: Uuid) -> Result<Vec<OAuthAccount>> {
515 Ok(rlock!(self.oauth_accounts, "oauth_accounts")
516 .iter()
517 .filter(|a| a.user_id == user_id)
518 .cloned()
519 .collect())
520 }
521
522 async fn delete(&self, id: Uuid) -> Result<()> {
523 let mut accounts = wlock!(self.oauth_accounts, "oauth_accounts");
524 let before = accounts.len();
525 accounts.retain(|a| a.id != id);
526 if accounts.len() == before {
527 return Err(AuthError::Storage(StorageError::NotFound));
528 }
529 Ok(())
530 }
531}
532
533#[async_trait]
536impl InviteRepository for MemoryStore {
537 async fn create(&self, data: CreateInvite) -> Result<Invite> {
538 let invite = Invite {
539 id: Uuid::new_v4(),
540 org_id: data.org_id,
541 email: data.email,
542 role_id: data.role_id,
543 token_hash: data.token_hash,
544 expires_at: data.expires_at,
545 accepted_at: None,
546 };
547 wlock!(self.invites, "invites").push(invite.clone());
548 Ok(invite)
549 }
550
551 async fn find_by_token_hash(&self, hash: &str) -> Result<Option<Invite>> {
552 Ok(rlock!(self.invites, "invites")
553 .iter()
554 .find(|i| i.token_hash == hash)
555 .cloned())
556 }
557
558 async fn accept(&self, invite_id: Uuid) -> Result<Invite> {
559 let mut invites = wlock!(self.invites, "invites");
560 let invite = invites
561 .iter_mut()
562 .find(|i| i.id == invite_id)
563 .ok_or(AuthError::Storage(StorageError::NotFound))?;
564 invite.accepted_at = Some(Utc::now());
565 Ok(invite.clone())
566 }
567
568 async fn delete_expired(&self) -> Result<u64> {
569 let mut invites = wlock!(self.invites, "invites");
570 let before = invites.len();
571 let now = Utc::now();
572 invites.retain(|i| i.accepted_at.is_some() || i.expires_at > now);
573 Ok((before - invites.len()) as u64)
574 }
575}