auth_framework/storage/
memory.rs

1//! In-memory storage backend for auth-framework with DashMap for deadlock-free operations.
2//!
3//! This module provides a fast in-memory storage implementation
4//! suitable for development, testing, and single-instance deployments.
5//!
6//! Uses DashMap to provide:
7//! - Lock-free concurrent access
8//! - Deadlock-free operations
9//! - Better performance under high concurrency
10
11use crate::{
12    errors::Result,
13    storage::{AuthStorage, SessionData},
14    tokens::AuthToken,
15};
16use async_trait::async_trait;
17use std::{
18    collections::HashMap,
19    sync::{Arc, RwLock},
20    time::{Duration, Instant},
21};
22use tokio::time;
23
24/// In-memory storage backend with automatic cleanup
25#[derive(Clone)]
26pub struct InMemoryStorage {
27    tokens: Arc<RwLock<HashMap<String, TimestampedToken>>>,
28    access_tokens: Arc<RwLock<HashMap<String, String>>>, // access_token -> token_id
29    user_tokens: Arc<RwLock<HashMap<String, Vec<String>>>>, // user_id -> token_ids
30    sessions: Arc<RwLock<HashMap<String, TimestampedSession>>>,
31    kv_store: Arc<RwLock<HashMap<String, TimestampedValue>>>,
32    cleanup_interval: Duration,
33    default_ttl: Duration,
34}
35
36#[derive(Clone)]
37struct TimestampedToken {
38    token: AuthToken,
39    expires_at: Instant,
40}
41
42#[derive(Clone)]
43struct TimestampedSession {
44    session: SessionData,
45    expires_at: Instant,
46}
47
48#[derive(Clone)]
49struct TimestampedValue {
50    value: Vec<u8>,
51    expires_at: Instant,
52}
53
54impl InMemoryStorage {
55    /// Create a new in-memory storage backend
56    pub fn new() -> Self {
57        let storage = Self {
58            tokens: Arc::new(RwLock::new(HashMap::new())),
59            access_tokens: Arc::new(RwLock::new(HashMap::new())),
60            user_tokens: Arc::new(RwLock::new(HashMap::new())),
61            sessions: Arc::new(RwLock::new(HashMap::new())),
62            kv_store: Arc::new(RwLock::new(HashMap::new())),
63            cleanup_interval: Duration::from_secs(300), // 5 minutes
64            default_ttl: Duration::from_secs(3600),     // 1 hour
65        };
66
67        // Start background cleanup task
68        storage.start_cleanup_task();
69        storage
70    }
71
72    /// Create in-memory storage with custom configuration
73    pub fn with_config(cleanup_interval: Duration, default_ttl: Duration) -> Self {
74        let mut storage = Self::new();
75        storage.cleanup_interval = cleanup_interval;
76        storage.default_ttl = default_ttl;
77        storage
78    }
79
80    fn start_cleanup_task(&self) {
81        let tokens = self.tokens.clone();
82        let access_tokens = self.access_tokens.clone();
83        let user_tokens = self.user_tokens.clone();
84        let sessions = self.sessions.clone();
85        let kv_store = self.kv_store.clone();
86        let interval = self.cleanup_interval;
87
88        tokio::spawn(async move {
89            let mut cleanup_timer = time::interval(interval);
90
91            loop {
92                cleanup_timer.tick().await;
93                Self::cleanup_expired_data(
94                    &tokens,
95                    &access_tokens,
96                    &user_tokens,
97                    &sessions,
98                    &kv_store,
99                );
100            }
101        });
102    }
103
104    fn cleanup_expired_data(
105        tokens: &Arc<RwLock<HashMap<String, TimestampedToken>>>,
106        access_tokens: &Arc<RwLock<HashMap<String, String>>>,
107        user_tokens: &Arc<RwLock<HashMap<String, Vec<String>>>>,
108        sessions: &Arc<RwLock<HashMap<String, TimestampedSession>>>,
109        kv_store: &Arc<RwLock<HashMap<String, TimestampedValue>>>,
110    ) {
111        let now = Instant::now();
112
113        // Clean up expired tokens
114        {
115            let mut tokens_guard = tokens.write().unwrap();
116            let mut access_tokens_guard = access_tokens.write().unwrap();
117            let mut user_tokens_guard = user_tokens.write().unwrap();
118
119            let expired_tokens: Vec<String> = tokens_guard
120                .iter()
121                .filter(|(_, timestamped)| timestamped.expires_at <= now)
122                .map(|(id, _)| id.clone())
123                .collect();
124
125            for token_id in expired_tokens {
126                if let Some(timestamped) = tokens_guard.remove(&token_id) {
127                    // Remove access token lookup
128                    access_tokens_guard.remove(&timestamped.token.access_token);
129
130                    // Remove from user tokens
131                    if let Some(user_token_list) =
132                        user_tokens_guard.get_mut(&timestamped.token.user_id)
133                    {
134                        user_token_list.retain(|id| id != &token_id);
135                        if user_token_list.is_empty() {
136                            user_tokens_guard.remove(&timestamped.token.user_id);
137                        }
138                    }
139                }
140            }
141        }
142
143        // Clean up expired sessions
144        {
145            let mut sessions_guard = sessions.write().unwrap();
146            sessions_guard.retain(|_, timestamped| timestamped.expires_at > now);
147        }
148
149        // Clean up expired KV pairs
150        {
151            let mut kv_guard = kv_store.write().unwrap();
152            kv_guard.retain(|_, timestamped| timestamped.expires_at > now);
153        }
154    }
155
156    fn calculate_expiry(&self, token: &AuthToken) -> Instant {
157        let now = chrono::Utc::now();
158        if token.expires_at > now {
159            let duration = (token.expires_at - now).num_seconds() as u64;
160            Instant::now() + Duration::from_secs(duration)
161        } else {
162            Instant::now() + self.default_ttl
163        }
164    }
165}
166
167impl Default for InMemoryStorage {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173#[async_trait]
174impl AuthStorage for InMemoryStorage {
175    async fn store_token(&self, token: &AuthToken) -> Result<()> {
176        let expires_at = self.calculate_expiry(token);
177        let timestamped_token = TimestampedToken {
178            token: token.clone(),
179            expires_at,
180        };
181
182        {
183            let mut tokens = self.tokens.write().unwrap();
184            tokens.insert(token.token_id.clone(), timestamped_token);
185        }
186
187        {
188            let mut access_tokens = self.access_tokens.write().unwrap();
189            access_tokens.insert(token.access_token.clone(), token.token_id.clone());
190        }
191
192        {
193            let mut user_tokens = self.user_tokens.write().unwrap();
194            user_tokens
195                .entry(token.user_id.clone())
196                .or_default()
197                .push(token.token_id.clone());
198        }
199
200        Ok(())
201    }
202
203    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
204        let tokens = self.tokens.read().unwrap();
205        if let Some(timestamped) = tokens.get(token_id) {
206            if timestamped.expires_at > Instant::now() {
207                Ok(Some(timestamped.token.clone()))
208            } else {
209                Ok(None)
210            }
211        } else {
212            Ok(None)
213        }
214    }
215
216    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
217        let token_id_opt = {
218            let access_tokens = self.access_tokens.read().unwrap();
219            access_tokens.get(access_token).cloned()
220        };
221        if let Some(token_id) = token_id_opt {
222            self.get_token(&token_id).await
223        } else {
224            Ok(None)
225        }
226    }
227
228    async fn update_token(&self, token: &AuthToken) -> Result<()> {
229        // For in-memory storage, update is the same as store
230        self.store_token(token).await
231    }
232
233    async fn delete_token(&self, token_id: &str) -> Result<()> {
234        let removed_token = {
235            let mut tokens = self.tokens.write().unwrap();
236            tokens.remove(token_id)
237        };
238
239        if let Some(timestamped) = removed_token {
240            // Remove access token lookup
241            {
242                let mut access_tokens = self.access_tokens.write().unwrap();
243                access_tokens.remove(&timestamped.token.access_token);
244            }
245
246            // Remove from user tokens
247            {
248                let mut user_tokens = self.user_tokens.write().unwrap();
249                if let Some(user_token_list) = user_tokens.get_mut(&timestamped.token.user_id) {
250                    user_token_list.retain(|id| id != token_id);
251                    if user_token_list.is_empty() {
252                        user_tokens.remove(&timestamped.token.user_id);
253                    }
254                }
255            }
256        }
257
258        Ok(())
259    }
260
261    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
262        let user_tokens = self.user_tokens.read().unwrap();
263        let tokens = self.tokens.read().unwrap();
264        let now = Instant::now();
265
266        match user_tokens.get(user_id) {
267            Some(token_ids) => {
268                let mut result = Vec::new();
269                for token_id in token_ids {
270                    if let Some(timestamped) = tokens.get(token_id)
271                        && timestamped.expires_at > now
272                    {
273                        result.push(timestamped.token.clone());
274                    }
275                }
276                Ok(result)
277            }
278            None => Ok(Vec::new()),
279        }
280    }
281
282    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
283        let expires_at = Instant::now() + self.default_ttl;
284        let timestamped_session = TimestampedSession {
285            session: data.clone(),
286            expires_at,
287        };
288
289        let mut sessions = self.sessions.write().unwrap();
290        sessions.insert(session_id.to_string(), timestamped_session);
291        Ok(())
292    }
293
294    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
295        let sessions = self.sessions.read().unwrap();
296        if let Some(timestamped) = sessions.get(session_id) {
297            if timestamped.expires_at > Instant::now() {
298                Ok(Some(timestamped.session.clone()))
299            } else {
300                Ok(None)
301            }
302        } else {
303            Ok(None)
304        }
305    }
306
307    async fn delete_session(&self, session_id: &str) -> Result<()> {
308        let mut sessions = self.sessions.write().unwrap();
309        sessions.remove(session_id);
310        Ok(())
311    }
312
313    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
314        let sessions = self.sessions.read().unwrap();
315        let now = Instant::now();
316
317        let user_sessions: Vec<SessionData> = sessions
318            .values()
319            .filter_map(|timestamped| {
320                if timestamped.session.user_id == user_id && timestamped.expires_at > now {
321                    Some(timestamped.session.clone())
322                } else {
323                    None
324                }
325            })
326            .collect();
327
328        Ok(user_sessions)
329    }
330
331    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
332        let expires_at = Instant::now() + ttl.unwrap_or(self.default_ttl);
333        let timestamped_value = TimestampedValue {
334            value: value.to_vec(),
335            expires_at,
336        };
337
338        let mut kv_store = self.kv_store.write().unwrap();
339        kv_store.insert(key.to_string(), timestamped_value);
340        Ok(())
341    }
342
343    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
344        let kv_store = self.kv_store.read().unwrap();
345        if let Some(timestamped) = kv_store.get(key) {
346            if timestamped.expires_at > Instant::now() {
347                Ok(Some(timestamped.value.clone()))
348            } else {
349                Ok(None)
350            }
351        } else {
352            Ok(None)
353        }
354    }
355
356    async fn delete_kv(&self, key: &str) -> Result<()> {
357        let mut kv_store = self.kv_store.write().unwrap();
358        kv_store.remove(key);
359        Ok(())
360    }
361
362    async fn cleanup_expired(&self) -> Result<()> {
363        Self::cleanup_expired_data(
364            &self.tokens,
365            &self.access_tokens,
366            &self.user_tokens,
367            &self.sessions,
368            &self.kv_store,
369        );
370        Ok(())
371    }
372
373    async fn count_active_sessions(&self) -> Result<u64> {
374        let sessions = self.sessions.read().unwrap();
375        let now = Instant::now();
376
377        let active_count = sessions
378            .values()
379            .filter(|timestamped| timestamped.expires_at > now)
380            .count() as u64;
381
382        Ok(active_count)
383    }
384}
385
386/// Configuration for in-memory storage
387pub struct InMemoryConfig {
388    pub cleanup_interval: Duration,
389    pub default_ttl: Duration,
390}
391
392impl Default for InMemoryConfig {
393    fn default() -> Self {
394        Self {
395            cleanup_interval: Duration::from_secs(300), // 5 minutes
396            default_ttl: Duration::from_secs(3600),     // 1 hour
397        }
398    }
399}
400
401impl InMemoryConfig {
402    pub fn new() -> Self {
403        Self::default()
404    }
405
406    pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
407        self.cleanup_interval = interval;
408        self
409    }
410
411    pub fn with_default_ttl(mut self, ttl: Duration) -> Self {
412        self.default_ttl = ttl;
413        self
414    }
415
416    pub fn build(self) -> InMemoryStorage {
417        InMemoryStorage::with_config(self.cleanup_interval, self.default_ttl)
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::testing::helpers::create_test_token;
425
426    #[tokio::test]
427    async fn test_in_memory_token_operations() {
428        let storage = InMemoryStorage::new();
429        let token = create_test_token("test_user");
430
431        // Store token
432        storage.store_token(&token).await.unwrap();
433
434        // Retrieve by ID
435        let retrieved = storage.get_token(&token.token_id).await.unwrap();
436        assert!(retrieved.is_some());
437        assert_eq!(retrieved.unwrap().token_id, token.token_id);
438
439        // Retrieve by access token
440        let retrieved = storage
441            .get_token_by_access_token(&token.access_token)
442            .await
443            .unwrap();
444        assert!(retrieved.is_some());
445        assert_eq!(retrieved.unwrap().access_token, token.access_token);
446
447        // List user tokens
448        let user_tokens = storage.list_user_tokens(&token.user_id).await.unwrap();
449        assert_eq!(user_tokens.len(), 1);
450
451        // Delete token
452        storage.delete_token(&token.token_id).await.unwrap();
453        let retrieved = storage.get_token(&token.token_id).await.unwrap();
454        assert!(retrieved.is_none());
455    }
456
457    #[tokio::test]
458    async fn test_in_memory_expiration() {
459        let storage = InMemoryStorage::with_config(
460            Duration::from_millis(100), // Fast cleanup
461            Duration::from_millis(200), // Short TTL
462        );
463
464        let key = "test_key";
465        let value = b"test_value";
466
467        // Store with short TTL
468        storage
469            .store_kv(key, value, Some(Duration::from_millis(50)))
470            .await
471            .unwrap();
472
473        // Should be available immediately
474        let retrieved = storage.get_kv(key).await.unwrap();
475        assert!(retrieved.is_some());
476
477        // Wait for expiration
478        tokio::time::sleep(Duration::from_millis(100)).await;
479
480        // Should be expired
481        let retrieved = storage.get_kv(key).await.unwrap();
482        assert!(retrieved.is_none());
483    }
484
485    #[tokio::test]
486    async fn test_in_memory_session_operations() {
487        let storage = InMemoryStorage::new();
488
489        let session_id = "test_session";
490        let session_data = SessionData {
491            session_id: session_id.to_string(),
492            user_id: "test_user".to_string(),
493            created_at: chrono::Utc::now(),
494            expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
495            last_activity: chrono::Utc::now(),
496            ip_address: None,
497            user_agent: None,
498            data: [("key".to_string(), serde_json::json!("value"))]
499                .into_iter()
500                .collect(),
501        };
502
503        // Store session
504        storage
505            .store_session(session_id, &session_data)
506            .await
507            .unwrap();
508
509        // Retrieve session
510        let retrieved = storage.get_session(session_id).await.unwrap();
511        assert!(retrieved.is_some());
512        assert_eq!(retrieved.unwrap().user_id, session_data.user_id);
513
514        // Delete session
515        storage.delete_session(session_id).await.unwrap();
516        let retrieved = storage.get_session(session_id).await.unwrap();
517        assert!(retrieved.is_none());
518    }
519}
520
521