auth_framework/
storage.rs

1//! Storage backends for authentication data.
2
3use crate::errors::{Result, StorageError};
4use crate::tokens::AuthToken;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Duration;
10
11/// Trait for authentication data storage.
12#[async_trait]
13pub trait AuthStorage: Send + Sync {
14    /// Store a token.
15    async fn store_token(&self, token: &AuthToken) -> Result<()>;
16    
17    /// Retrieve a token by ID.
18    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>>;
19    
20    /// Retrieve a token by access token string.
21    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>>;
22    
23    /// Update a token.
24    async fn update_token(&self, token: &AuthToken) -> Result<()>;
25    
26    /// Delete a token.
27    async fn delete_token(&self, token_id: &str) -> Result<()>;
28    
29    /// List all tokens for a user.
30    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>>;
31    
32    /// Store session data.
33    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()>;
34    
35    /// Retrieve session data.
36    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>>;
37    
38    /// Delete session data.
39    async fn delete_session(&self, session_id: &str) -> Result<()>;
40    
41    /// Store arbitrary key-value data with expiration.
42    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()>;
43    
44    /// Retrieve arbitrary key-value data.
45    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>>;
46    
47    /// Delete arbitrary key-value data.
48    async fn delete_kv(&self, key: &str) -> Result<()>;
49    
50    /// Clean up expired data.
51    async fn cleanup_expired(&self) -> Result<()>;
52}
53
54/// Session data stored in the backend.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SessionData {
57    /// Session ID
58    pub session_id: String,
59    
60    /// User ID associated with this session
61    pub user_id: String,
62    
63    /// When the session was created
64    pub created_at: chrono::DateTime<chrono::Utc>,
65    
66    /// When the session expires
67    pub expires_at: chrono::DateTime<chrono::Utc>,
68    
69    /// Last activity timestamp
70    pub last_activity: chrono::DateTime<chrono::Utc>,
71    
72    /// IP address of the session
73    pub ip_address: Option<String>,
74    
75    /// User agent
76    pub user_agent: Option<String>,
77    
78    /// Custom session data
79    pub data: HashMap<String, serde_json::Value>,
80}
81
82/// KV store value type: (data, optional_expiry)
83type KvValue = (Vec<u8>, Option<chrono::DateTime<chrono::Utc>>);
84
85/// In-memory storage implementation (for development/testing).
86#[derive(Debug, Clone)]
87pub struct MemoryStorage {
88    tokens: Arc<RwLock<HashMap<String, AuthToken>>>,
89    sessions: Arc<RwLock<HashMap<String, SessionData>>>,
90    kv_store: Arc<RwLock<HashMap<String, KvValue>>>,
91}
92
93/// Redis storage implementation.
94#[cfg(feature = "redis-storage")]
95#[derive(Debug, Clone)]
96pub struct RedisStorage {
97    client: redis::Client,
98    key_prefix: String,
99}
100
101impl Default for MemoryStorage {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl MemoryStorage {
108    /// Create a new in-memory storage.
109    pub fn new() -> Self {
110        Self {
111            tokens: Arc::new(RwLock::new(HashMap::new())),
112            sessions: Arc::new(RwLock::new(HashMap::new())),
113            kv_store: Arc::new(RwLock::new(HashMap::new())),
114        }
115    }
116}
117
118#[async_trait]
119impl AuthStorage for MemoryStorage {
120    async fn store_token(&self, token: &AuthToken) -> Result<()> {
121        let mut tokens = self.tokens.write().unwrap();
122        tokens.insert(token.token_id.clone(), token.clone());
123        Ok(())
124    }
125
126    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
127        let tokens = self.tokens.read().unwrap();
128        Ok(tokens.get(token_id).cloned())
129    }
130
131    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
132        let tokens = self.tokens.read().unwrap();
133        Ok(tokens.values()
134            .find(|token| token.access_token == access_token)
135            .cloned())
136    }
137
138    async fn update_token(&self, token: &AuthToken) -> Result<()> {
139        let mut tokens = self.tokens.write().unwrap();
140        tokens.insert(token.token_id.clone(), token.clone());
141        Ok(())
142    }
143
144    async fn delete_token(&self, token_id: &str) -> Result<()> {
145        let mut tokens = self.tokens.write().unwrap();
146        tokens.remove(token_id);
147        Ok(())
148    }
149
150    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
151        let tokens = self.tokens.read().unwrap();
152        Ok(tokens.values()
153            .filter(|token| token.user_id == user_id)
154            .cloned()
155            .collect())
156    }
157
158    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
159        let mut sessions = self.sessions.write().unwrap();
160        sessions.insert(session_id.to_string(), data.clone());
161        Ok(())
162    }
163
164    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
165        let sessions = self.sessions.read().unwrap();
166        let session = sessions.get(session_id).cloned();
167        
168        // Check if session is expired
169        if let Some(ref session) = session {
170            if chrono::Utc::now() > session.expires_at {
171                return Ok(None);
172            }
173        }
174        
175        Ok(session)
176    }
177
178    async fn delete_session(&self, session_id: &str) -> Result<()> {
179        let mut sessions = self.sessions.write().unwrap();
180        sessions.remove(session_id);
181        Ok(())
182    }
183
184    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
185        let mut kv_store = self.kv_store.write().unwrap();
186        let expires_at = ttl.map(|ttl| {
187            chrono::Utc::now() + chrono::Duration::from_std(ttl).unwrap()
188        });
189        kv_store.insert(key.to_string(), (value.to_vec(), expires_at));
190        Ok(())
191    }
192
193    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
194        let kv_store = self.kv_store.read().unwrap();
195        if let Some((value, expires_at)) = kv_store.get(key) {
196            // Check if expired
197            if let Some(expires_at) = expires_at {
198                if chrono::Utc::now() > *expires_at {
199                    return Ok(None);
200                }
201            }
202            Ok(Some(value.clone()))
203        } else {
204            Ok(None)
205        }
206    }
207
208    async fn delete_kv(&self, key: &str) -> Result<()> {
209        let mut kv_store = self.kv_store.write().unwrap();
210        kv_store.remove(key);
211        Ok(())
212    }
213
214    async fn cleanup_expired(&self) -> Result<()> {
215        let now = chrono::Utc::now();
216        
217        // Clean up expired tokens
218        {
219            let mut tokens = self.tokens.write().unwrap();
220            tokens.retain(|_, token| !token.is_expired());
221        }
222        
223        // Clean up expired sessions
224        {
225            let mut sessions = self.sessions.write().unwrap();
226            sessions.retain(|_, session| now <= session.expires_at);
227        }
228        
229        // Clean up expired KV pairs
230        {
231            let mut kv_store = self.kv_store.write().unwrap();
232            kv_store.retain(|_, (_, expires_at)| {
233                expires_at.is_none_or(|exp| now <= exp)
234            });
235        }
236        
237        Ok(())
238    }
239}
240
241#[cfg(feature = "redis-storage")]
242impl RedisStorage {
243    /// Create a new Redis storage.
244    pub fn new(redis_url: &str, key_prefix: impl Into<String>) -> Result<Self> {
245        let client = redis::Client::open(redis_url)
246            .map_err(|e| StorageError::connection_failed(format!("Redis connection failed: {e}")))?;
247        
248        Ok(Self {
249            client,
250            key_prefix: key_prefix.into(),
251        })
252    }
253
254    /// Get a Redis connection.
255    async fn get_connection(&self) -> Result<redis::aio::Connection> {
256        self.client.get_async_connection().await
257            .map_err(|e| StorageError::connection_failed(format!("Failed to get Redis connection: {e}")).into())
258    }
259
260    /// Generate a key with the configured prefix.
261    fn key(&self, suffix: &str) -> String {
262        format!("{}{}", self.key_prefix, suffix)
263    }
264}
265
266#[cfg(feature = "redis-storage")]
267#[async_trait]
268impl AuthStorage for RedisStorage {
269    async fn store_token(&self, token: &AuthToken) -> Result<()> {
270        let mut conn = self.get_connection().await?;
271        let token_json = serde_json::to_string(token)
272            .map_err(|e| StorageError::serialization(format!("Token serialization failed: {e}")))?;
273        
274        let token_key = self.key(&format!("token:{}", token.token_id));
275        let access_token_key = self.key(&format!("access_token:{}", token.access_token));
276        let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
277        
278        // Calculate TTL
279        let ttl = token.time_until_expiry().as_secs().max(1);
280        
281        // Store token data
282        redis::cmd("SETEX")
283            .arg(&token_key)
284            .arg(ttl)
285            .arg(&token_json)
286            .query_async::<_, ()>(&mut conn)
287            .await
288            .map_err(|e| StorageError::operation_failed(format!("Failed to store token: {e}")))?;
289        
290        // Store access token mapping
291        redis::cmd("SETEX")
292            .arg(&access_token_key)
293            .arg(ttl)
294            .arg(&token.token_id)
295            .query_async::<_, ()>(&mut conn)
296            .await
297            .map_err(|e| StorageError::operation_failed(format!("Failed to store access token mapping: {e}")))?;
298        
299        // Add to user tokens set
300        redis::cmd("SADD")
301            .arg(&user_tokens_key)
302            .arg(&token.token_id)
303            .query_async::<_, ()>(&mut conn)
304            .await
305            .map_err(|e| StorageError::operation_failed(format!("Failed to add token to user set: {e}")))?;
306        
307        Ok(())
308    }
309
310    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
311        let mut conn = self.get_connection().await?;
312        let token_key = self.key(&format!("token:{token_id}"));
313        
314        let token_json: Option<String> = redis::cmd("GET")
315            .arg(&token_key)
316            .query_async(&mut conn)
317            .await
318            .map_err(|e| StorageError::operation_failed(format!("Failed to get token: {e}")))?;
319        
320        if let Some(json) = token_json {
321            let token: AuthToken = serde_json::from_str(&json)
322                .map_err(|e| StorageError::serialization(format!("Token deserialization failed: {e}")))?;
323            Ok(Some(token))
324        } else {
325            Ok(None)
326        }
327    }
328
329    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
330        let mut conn = self.get_connection().await?;
331        let access_token_key = self.key(&format!("access_token:{access_token}"));
332        
333        let token_id: Option<String> = redis::cmd("GET")
334            .arg(&access_token_key)
335            .query_async(&mut conn)
336            .await
337            .map_err(|e| StorageError::operation_failed(format!("Failed to get access token mapping: {e}")))?;
338        
339        if let Some(token_id) = token_id {
340            self.get_token(&token_id).await
341        } else {
342            Ok(None)
343        }
344    }
345
346    async fn update_token(&self, token: &AuthToken) -> Result<()> {
347        // Same as store_token for Redis
348        self.store_token(token).await
349    }
350
351    async fn delete_token(&self, token_id: &str) -> Result<()> {
352        let mut conn = self.get_connection().await?;
353        
354        // Get token first to get access token and user ID
355        if let Some(token) = self.get_token(token_id).await? {
356            let token_key = self.key(&format!("token:{token_id}"));
357            let access_token_key = self.key(&format!("access_token:{}", token.access_token));
358            let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
359            
360            // Delete token data
361            redis::cmd("DEL")
362                .arg(&token_key)
363                .query_async::<_, ()>(&mut conn)
364                .await
365                .map_err(|e| StorageError::operation_failed(format!("Failed to delete token: {e}")))?;
366            
367            // Delete access token mapping
368            redis::cmd("DEL")
369                .arg(&access_token_key)
370                .query_async::<_, ()>(&mut conn)
371                .await
372                .map_err(|e| StorageError::operation_failed(format!("Failed to delete access token mapping: {e}")))?;
373            
374            // Remove from user tokens set
375            redis::cmd("SREM")
376                .arg(&user_tokens_key)
377                .arg(token_id)
378                .query_async::<_, ()>(&mut conn)
379                .await
380                .map_err(|e| StorageError::operation_failed(format!("Failed to remove token from user set: {e}")))?;
381        }
382        
383        Ok(())
384    }
385
386    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
387        let mut conn = self.get_connection().await?;
388        let user_tokens_key = self.key(&format!("user_tokens:{user_id}"));
389        
390        let token_ids: Vec<String> = redis::cmd("SMEMBERS")
391            .arg(&user_tokens_key)
392            .query_async(&mut conn)
393            .await
394            .map_err(|e| StorageError::operation_failed(format!("Failed to get user tokens: {e}")))?;
395        
396        let mut tokens = Vec::new();
397        for token_id in token_ids {
398            if let Some(token) = self.get_token(&token_id).await? {
399                tokens.push(token);
400            }
401        }
402        
403        Ok(tokens)
404    }
405
406    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
407        let mut conn = self.get_connection().await?;
408        let session_key = self.key(&format!("session:{session_id}"));
409        
410        let session_json = serde_json::to_string(data)
411            .map_err(|e| StorageError::serialization(format!("Session serialization failed: {e}")))?;
412        
413        let ttl = (data.expires_at - chrono::Utc::now()).num_seconds().max(1);
414        
415        redis::cmd("SETEX")
416            .arg(&session_key)
417            .arg(ttl)
418            .arg(&session_json)
419            .query_async::<_, ()>(&mut conn)
420            .await
421            .map_err(|e| StorageError::operation_failed(format!("Failed to store session: {e}")))?;
422        
423        Ok(())
424    }
425
426    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
427        let mut conn = self.get_connection().await?;
428        let session_key = self.key(&format!("session:{session_id}"));
429        
430        let session_json: Option<String> = redis::cmd("GET")
431            .arg(&session_key)
432            .query_async(&mut conn)
433            .await
434            .map_err(|e| StorageError::operation_failed(format!("Failed to get session: {e}")))?;
435        
436        if let Some(json) = session_json {
437            let session: SessionData = serde_json::from_str(&json)
438                .map_err(|e| StorageError::serialization(format!("Session deserialization failed: {e}")))?;
439            Ok(Some(session))
440        } else {
441            Ok(None)
442        }
443    }
444
445    async fn delete_session(&self, session_id: &str) -> Result<()> {
446        let mut conn = self.get_connection().await?;
447        let session_key = self.key(&format!("session:{session_id}"));
448        
449        redis::cmd("DEL")
450            .arg(&session_key)
451            .query_async::<_, ()>(&mut conn)
452            .await
453            .map_err(|e| StorageError::operation_failed(format!("Failed to delete session: {e}")))?;
454        
455        Ok(())
456    }
457
458    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
459        let mut conn = self.get_connection().await?;
460        let storage_key = self.key(&format!("kv:{key}"));
461        
462        if let Some(ttl) = ttl {
463            redis::cmd("SETEX")
464                .arg(&storage_key)
465                .arg(ttl.as_secs())
466                .arg(value)
467                .query_async::<_, ()>(&mut conn)
468                .await
469                .map_err(|e| StorageError::operation_failed(format!("Failed to store KV with TTL: {e}")))?;
470        } else {
471            redis::cmd("SET")
472                .arg(&storage_key)
473                .arg(value)
474                .query_async::<_, ()>(&mut conn)
475                .await
476                .map_err(|e| StorageError::operation_failed(format!("Failed to store KV: {e}")))?;
477        }
478        
479        Ok(())
480    }
481
482    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
483        let mut conn = self.get_connection().await?;
484        let storage_key = self.key(&format!("kv:{key}"));
485        
486        let value: Option<Vec<u8>> = redis::cmd("GET")
487            .arg(&storage_key)
488            .query_async(&mut conn)
489            .await
490            .map_err(|e| StorageError::operation_failed(format!("Failed to get KV: {e}")))?;
491        
492        Ok(value)
493    }
494
495    async fn delete_kv(&self, key: &str) -> Result<()> {
496        let mut conn = self.get_connection().await?;
497        let storage_key = self.key(&format!("kv:{key}"));
498        
499        redis::cmd("DEL")
500            .arg(&storage_key)
501            .query_async::<_, ()>(&mut conn)
502            .await
503            .map_err(|e| StorageError::operation_failed(format!("Failed to delete KV: {e}")))?;
504        
505        Ok(())
506    }
507
508    async fn cleanup_expired(&self) -> Result<()> {
509        // Redis handles expiration automatically, so this is a no-op
510        Ok(())
511    }
512}
513
514impl SessionData {
515    /// Create a new session.
516    pub fn new(
517        session_id: impl Into<String>,
518        user_id: impl Into<String>,
519        expires_in: Duration,
520    ) -> Self {
521        let now = chrono::Utc::now();
522        
523        Self {
524            session_id: session_id.into(),
525            user_id: user_id.into(),
526            created_at: now,
527            expires_at: now + chrono::Duration::from_std(expires_in).unwrap(),
528            last_activity: now,
529            ip_address: None,
530            user_agent: None,
531            data: HashMap::new(),
532        }
533    }
534
535    /// Check if the session has expired.
536    pub fn is_expired(&self) -> bool {
537        chrono::Utc::now() > self.expires_at
538    }
539
540    /// Update the last activity timestamp.
541    pub fn update_activity(&mut self) {
542        self.last_activity = chrono::Utc::now();
543    }
544
545    /// Set session metadata.
546    pub fn with_metadata(
547        mut self,
548        ip_address: Option<String>,
549        user_agent: Option<String>,
550    ) -> Self {
551        self.ip_address = ip_address;
552        self.user_agent = user_agent;
553        self
554    }
555
556    /// Add custom data to the session.
557    pub fn set_data(&mut self, key: impl Into<String>, value: serde_json::Value) {
558        self.data.insert(key.into(), value);
559    }
560
561    /// Get custom data from the session.
562    pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
563        self.data.get(key)
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use crate::tokens::AuthToken;
571
572    #[tokio::test]
573    async fn test_memory_storage() {
574        let storage = MemoryStorage::new();
575        
576        // Create a test token
577        let token = AuthToken::new(
578            "user123",
579            "token123",
580            Duration::from_secs(3600),
581            "test",
582        );
583
584        // Store token
585        storage.store_token(&token).await.unwrap();
586
587        // Retrieve token
588        let retrieved = storage.get_token(&token.token_id).await.unwrap().unwrap();
589        assert_eq!(retrieved.user_id, "user123");
590
591        // Retrieve by access token
592        let retrieved = storage.get_token_by_access_token(&token.access_token).await.unwrap().unwrap();
593        assert_eq!(retrieved.token_id, token.token_id);
594
595        // List user tokens
596        let user_tokens = storage.list_user_tokens("user123").await.unwrap();
597        assert_eq!(user_tokens.len(), 1);
598
599        // Delete token
600        storage.delete_token(&token.token_id).await.unwrap();
601        let retrieved = storage.get_token(&token.token_id).await.unwrap();
602        assert!(retrieved.is_none());
603    }
604
605    #[tokio::test]
606    async fn test_session_storage() {
607        let storage = MemoryStorage::new();
608        
609        let session = SessionData::new(
610            "session123",
611            "user123",
612            Duration::from_secs(3600),
613        ).with_metadata(
614            Some("192.168.1.1".to_string()),
615            Some("Test Agent".to_string()),
616        );
617
618        // Store session
619        storage.store_session(&session.session_id, &session).await.unwrap();
620
621        // Retrieve session
622        let retrieved = storage.get_session(&session.session_id).await.unwrap().unwrap();
623        assert_eq!(retrieved.user_id, "user123");
624        assert_eq!(retrieved.ip_address, Some("192.168.1.1".to_string()));
625
626        // Delete session
627        storage.delete_session(&session.session_id).await.unwrap();
628        let retrieved = storage.get_session(&session.session_id).await.unwrap();
629        assert!(retrieved.is_none());
630    }
631
632    #[tokio::test]
633    async fn test_kv_storage() {
634        let storage = MemoryStorage::new();
635        
636        let key = "test_key";
637        let value = b"test_value";
638        
639        // Store KV
640        storage.store_kv(key, value, Some(Duration::from_secs(3600))).await.unwrap();
641
642        // Retrieve KV
643        let retrieved = storage.get_kv(key).await.unwrap().unwrap();
644        assert_eq!(retrieved, value);
645
646        // Delete KV
647        storage.delete_kv(key).await.unwrap();
648        let retrieved = storage.get_kv(key).await.unwrap();
649        assert!(retrieved.is_none());
650    }
651}