auth_framework/storage/
core.rs

1//! Storage backends for authentication data.
2
3use crate::errors::Result;
4use crate::tokens::AuthToken;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Duration;
10
11#[cfg(feature = "redis-storage")]
12use crate::errors::StorageError;
13
14/// Trait for authentication data storage.
15#[async_trait]
16pub trait AuthStorage: Send + Sync {
17    /// Bulk store tokens.
18    async fn store_tokens_bulk(&self, tokens: &[AuthToken]) -> Result<()> {
19        for token in tokens {
20            self.store_token(token).await?;
21        }
22        Ok(())
23    }
24
25    /// Bulk delete tokens by ID.
26    async fn delete_tokens_bulk(&self, token_ids: &[String]) -> Result<()> {
27        for token_id in token_ids {
28            self.delete_token(token_id).await?;
29        }
30        Ok(())
31    }
32
33    /// Bulk store sessions.
34    async fn store_sessions_bulk(&self, sessions: &[(String, SessionData)]) -> Result<()> {
35        for (session_id, data) in sessions {
36            self.store_session(session_id, data).await?;
37        }
38        Ok(())
39    }
40
41    /// Bulk delete sessions by ID.
42    async fn delete_sessions_bulk(&self, session_ids: &[String]) -> Result<()> {
43        for session_id in session_ids {
44            self.delete_session(session_id).await?;
45        }
46        Ok(())
47    }
48    /// Store a token.
49    async fn store_token(&self, token: &AuthToken) -> Result<()>;
50
51    /// Retrieve a token by ID.
52    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>>;
53
54    /// Retrieve a token by access token string.
55    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>>;
56
57    /// Update a token.
58    async fn update_token(&self, token: &AuthToken) -> Result<()>;
59
60    /// Delete a token.
61    async fn delete_token(&self, token_id: &str) -> Result<()>;
62
63    /// List all tokens for a user.
64    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>>;
65
66    /// Store session data.
67    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()>;
68
69    /// Retrieve session data.
70    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>>;
71
72    /// Delete session data.
73    async fn delete_session(&self, session_id: &str) -> Result<()>;
74
75    /// List all sessions for a user.
76    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>>;
77
78    /// Count currently active sessions (non-expired)
79    async fn count_active_sessions(&self) -> Result<u64>;
80
81    /// Store arbitrary key-value data with expiration.
82    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()>;
83
84    /// Retrieve arbitrary key-value data.
85    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>>;
86
87    /// Delete arbitrary key-value data.
88    async fn delete_kv(&self, key: &str) -> Result<()>;
89
90    /// Clean up expired data.
91    async fn cleanup_expired(&self) -> Result<()>;
92}
93
94/// Session data stored in the backend.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct SessionData {
97    /// Session ID
98    pub session_id: String,
99
100    /// User ID associated with this session
101    pub user_id: String,
102
103    /// When the session was created
104    pub created_at: chrono::DateTime<chrono::Utc>,
105
106    /// When the session expires
107    pub expires_at: chrono::DateTime<chrono::Utc>,
108
109    /// Last activity timestamp
110    pub last_activity: chrono::DateTime<chrono::Utc>,
111
112    /// IP address of the session
113    pub ip_address: Option<String>,
114
115    /// User agent
116    pub user_agent: Option<String>,
117
118    /// Custom session data
119    pub data: HashMap<String, serde_json::Value>,
120}
121
122/// In-memory storage implementation (for development/testing).
123/// SECURITY UPDATE: Now uses DashMap for deadlock-free concurrent operations
124#[derive(Debug, Clone)]
125pub struct MemoryStorage {
126    // Primary storage using DashMap for deadlock-free operations
127    inner: crate::storage::dashmap_memory::DashMapMemoryStorage,
128    // RBAC storage still uses RwLock for compatibility (lower concurrency requirements)
129    roles: Arc<RwLock<HashMap<String, crate::authorization::Role>>>,
130    user_roles: Arc<RwLock<Vec<crate::authorization::UserRole>>>,
131}
132
133/// Redis storage implementation.
134#[cfg(feature = "redis-storage")]
135#[derive(Debug, Clone)]
136pub struct RedisStorage {
137    client: redis::Client,
138    key_prefix: String,
139}
140
141impl Default for MemoryStorage {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl MemoryStorage {
148    /// Create a new in-memory storage.
149    pub fn new() -> Self {
150        Self {
151            inner: crate::storage::dashmap_memory::DashMapMemoryStorage::new(),
152            roles: Arc::new(RwLock::new(HashMap::new())),
153            user_roles: Arc::new(RwLock::new(Vec::new())),
154        }
155    }
156}
157// In-memory AuthorizationStorage implementation for RBAC examples
158#[async_trait::async_trait]
159impl crate::authorization::AuthorizationStorage for MemoryStorage {
160    async fn store_role(&self, role: &crate::authorization::Role) -> crate::errors::Result<()> {
161        let mut roles = self.roles.write().unwrap();
162        roles.insert(role.id.clone(), role.clone());
163        Ok(())
164    }
165
166    async fn get_role(
167        &self,
168        role_id: &str,
169    ) -> crate::errors::Result<Option<crate::authorization::Role>> {
170        let roles = self.roles.read().unwrap();
171        Ok(roles.get(role_id).cloned())
172    }
173
174    async fn update_role(&self, role: &crate::authorization::Role) -> crate::errors::Result<()> {
175        let mut roles = self.roles.write().unwrap();
176        roles.insert(role.id.clone(), role.clone());
177        Ok(())
178    }
179
180    async fn delete_role(&self, role_id: &str) -> crate::errors::Result<()> {
181        let mut roles = self.roles.write().unwrap();
182        roles.remove(role_id);
183        Ok(())
184    }
185
186    async fn list_roles(&self) -> crate::errors::Result<Vec<crate::authorization::Role>> {
187        let roles = self.roles.read().unwrap();
188        Ok(roles.values().cloned().collect())
189    }
190
191    async fn assign_role(
192        &self,
193        user_role: &crate::authorization::UserRole,
194    ) -> crate::errors::Result<()> {
195        let mut user_roles = self.user_roles.write().unwrap();
196        user_roles.push(user_role.clone());
197        Ok(())
198    }
199
200    async fn remove_role(&self, user_id: &str, role_id: &str) -> crate::errors::Result<()> {
201        let mut user_roles = self.user_roles.write().unwrap();
202        user_roles.retain(|ur| ur.user_id != user_id || ur.role_id != role_id);
203        Ok(())
204    }
205
206    async fn get_user_roles(
207        &self,
208        user_id: &str,
209    ) -> crate::errors::Result<Vec<crate::authorization::UserRole>> {
210        let user_roles = self.user_roles.read().unwrap();
211        Ok(user_roles
212            .iter()
213            .filter(|ur| ur.user_id == user_id)
214            .cloned()
215            .collect())
216    }
217
218    async fn get_role_users(
219        &self,
220        role_id: &str,
221    ) -> crate::errors::Result<Vec<crate::authorization::UserRole>> {
222        let user_roles = self.user_roles.read().unwrap();
223        Ok(user_roles
224            .iter()
225            .filter(|ur| ur.role_id == role_id)
226            .cloned()
227            .collect())
228    }
229}
230
231#[async_trait]
232impl AuthStorage for MemoryStorage {
233    async fn store_token(&self, token: &AuthToken) -> Result<()> {
234        // Delegate to DashMap implementation for deadlock-free operations
235        self.inner.store_token(token).await
236    }
237
238    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
239        // Delegate to DashMap implementation for deadlock-free operations
240        self.inner.get_token(token_id).await
241    }
242
243    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
244        // Delegate to DashMap implementation for deadlock-free operations
245        self.inner.get_token_by_access_token(access_token).await
246    }
247
248    async fn update_token(&self, token: &AuthToken) -> Result<()> {
249        // Delegate to DashMap implementation for deadlock-free operations
250        self.inner.update_token(token).await
251    }
252
253    async fn delete_token(&self, token_id: &str) -> Result<()> {
254        // Delegate to DashMap implementation for deadlock-free operations
255        self.inner.delete_token(token_id).await
256    }
257
258    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
259        // Delegate to DashMap implementation for deadlock-free operations
260        self.inner.list_user_tokens(user_id).await
261    }
262
263    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
264        // Delegate to DashMap implementation for deadlock-free operations
265        self.inner.store_session(session_id, data).await
266    }
267
268    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
269        // Delegate to DashMap implementation for deadlock-free operations
270        self.inner.get_session(session_id).await
271    }
272
273    async fn delete_session(&self, session_id: &str) -> Result<()> {
274        // Delegate to DashMap implementation for deadlock-free operations
275        self.inner.delete_session(session_id).await
276    }
277
278    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
279        // Delegate to DashMap implementation for deadlock-free operations
280        self.inner.list_user_sessions(user_id).await
281    }
282
283    async fn count_active_sessions(&self) -> Result<u64> {
284        // Delegate to DashMap implementation for deadlock-free operations
285        self.inner.count_active_sessions().await
286    }
287
288    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
289        // Delegate to DashMap implementation for deadlock-free operations
290        self.inner.store_kv(key, value, ttl).await
291    }
292
293    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
294        // Delegate to DashMap implementation for deadlock-free operations
295        self.inner.get_kv(key).await
296    }
297
298    async fn delete_kv(&self, key: &str) -> Result<()> {
299        // Delegate to DashMap implementation for deadlock-free operations
300        self.inner.delete_kv(key).await
301    }
302
303    async fn cleanup_expired(&self) -> Result<()> {
304        // Delegate to DashMap implementation for deadlock-free operations
305        self.inner.cleanup_expired().await
306    }
307}
308
309#[cfg(feature = "redis-storage")]
310impl RedisStorage {
311    /// Create a new Redis storage.
312    pub fn new(redis_url: &str, key_prefix: impl Into<String>) -> Result<Self> {
313        let client = redis::Client::open(redis_url).map_err(|e| {
314            StorageError::connection_failed(format!("Redis connection failed: {e}"))
315        })?;
316
317        Ok(Self {
318            client,
319            key_prefix: key_prefix.into(),
320        })
321    }
322
323    /// Get a Redis connection.
324    async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection> {
325        self.client
326            .get_multiplexed_tokio_connection()
327            .await
328            .map_err(|e| {
329                StorageError::connection_failed(format!("Failed to get Redis connection: {e}"))
330                    .into()
331            })
332    }
333
334    /// Generate a key with the configured prefix.
335    fn key(&self, suffix: &str) -> String {
336        format!("{}{}", self.key_prefix, suffix)
337    }
338}
339
340#[cfg(feature = "redis-storage")]
341#[async_trait]
342impl AuthStorage for RedisStorage {
343    async fn store_token(&self, token: &AuthToken) -> Result<()> {
344        let mut conn = self.get_connection().await?;
345        let token_json = serde_json::to_string(token)
346            .map_err(|e| StorageError::serialization(format!("Token serialization failed: {e}")))?;
347
348        let token_key = self.key(&format!("token:{}", token.token_id));
349        let access_token_key = self.key(&format!("access_token:{}", token.access_token));
350        let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
351
352        // Calculate TTL
353        let ttl = token.time_until_expiry().as_secs().max(1);
354
355        // Store token data
356        let _: () = redis::cmd("SETEX")
357            .arg(&token_key)
358            .arg(ttl)
359            .arg(&token_json)
360            .query_async(&mut conn)
361            .await
362            .map_err(|e| StorageError::operation_failed(format!("Failed to store token: {e}")))?;
363
364        // Store access token mapping
365        let _: () = redis::cmd("SETEX")
366            .arg(&access_token_key)
367            .arg(ttl)
368            .arg(&token.token_id)
369            .query_async(&mut conn)
370            .await
371            .map_err(|e| {
372                StorageError::operation_failed(format!("Failed to store access token mapping: {e}"))
373            })?;
374
375        // Add to user tokens set
376        let _: () = redis::cmd("SADD")
377            .arg(&user_tokens_key)
378            .arg(&token.token_id)
379            .query_async(&mut conn)
380            .await
381            .map_err(|e| {
382                StorageError::operation_failed(format!("Failed to add token to user set: {e}"))
383            })?;
384
385        Ok(())
386    }
387
388    async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
389        let mut conn = self.get_connection().await?;
390        let token_key = self.key(&format!("token:{token_id}"));
391
392        let token_json: Option<String> = redis::cmd("GET")
393            .arg(&token_key)
394            .query_async(&mut conn)
395            .await
396            .map_err(|e| StorageError::operation_failed(format!("Failed to get token: {e}")))?;
397
398        if let Some(json) = token_json {
399            let token: AuthToken = serde_json::from_str(&json).map_err(|e| {
400                StorageError::serialization(format!("Token deserialization failed: {e}"))
401            })?;
402            Ok(Some(token))
403        } else {
404            Ok(None)
405        }
406    }
407
408    async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
409        let mut conn = self.get_connection().await?;
410        let access_token_key = self.key(&format!("access_token:{access_token}"));
411
412        let token_id: Option<String> = redis::cmd("GET")
413            .arg(&access_token_key)
414            .query_async(&mut conn)
415            .await
416            .map_err(|e| {
417                StorageError::operation_failed(format!("Failed to get access token mapping: {e}"))
418            })?;
419
420        if let Some(token_id) = token_id {
421            self.get_token(&token_id).await
422        } else {
423            Ok(None)
424        }
425    }
426
427    async fn update_token(&self, token: &AuthToken) -> Result<()> {
428        // Same as store_token for Redis
429        self.store_token(token).await
430    }
431
432    async fn delete_token(&self, token_id: &str) -> Result<()> {
433        let mut conn = self.get_connection().await?;
434
435        // Get token first to get access token and user ID
436        if let Some(token) = self.get_token(token_id).await? {
437            let token_key = self.key(&format!("token:{token_id}"));
438            let access_token_key = self.key(&format!("access_token:{}", token.access_token));
439            let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
440
441            // Delete token data
442            let _: () = redis::cmd("DEL")
443                .arg(&token_key)
444                .query_async(&mut conn)
445                .await
446                .map_err(|e| {
447                    StorageError::operation_failed(format!("Failed to delete token: {e}"))
448                })?;
449
450            // Delete access token mapping
451            let _: () = redis::cmd("DEL")
452                .arg(&access_token_key)
453                .query_async(&mut conn)
454                .await
455                .map_err(|e| {
456                    StorageError::operation_failed(format!(
457                        "Failed to delete access token mapping: {e}"
458                    ))
459                })?;
460
461            // Remove from user tokens set
462            let _: () = redis::cmd("SREM")
463                .arg(&user_tokens_key)
464                .arg(token_id)
465                .query_async(&mut conn)
466                .await
467                .map_err(|e| {
468                    StorageError::operation_failed(format!(
469                        "Failed to remove token from user set: {e}"
470                    ))
471                })?;
472        }
473
474        Ok(())
475    }
476
477    async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
478        let mut conn = self.get_connection().await?;
479        let user_tokens_key = self.key(&format!("user_tokens:{user_id}"));
480
481        let token_ids: Vec<String> = redis::cmd("SMEMBERS")
482            .arg(&user_tokens_key)
483            .query_async(&mut conn)
484            .await
485            .map_err(|e| {
486                StorageError::operation_failed(format!("Failed to get user tokens: {e}"))
487            })?;
488
489        let mut tokens = Vec::new();
490        for token_id in token_ids {
491            if let Some(token) = self.get_token(&token_id).await? {
492                tokens.push(token);
493            }
494        }
495
496        Ok(tokens)
497    }
498
499    async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
500        let mut conn = self.get_connection().await?;
501        let session_key = self.key(&format!("session:{session_id}"));
502
503        let session_json = serde_json::to_string(data).map_err(|e| {
504            StorageError::serialization(format!("Session serialization failed: {e}"))
505        })?;
506
507        let ttl = (data.expires_at - chrono::Utc::now()).num_seconds().max(1);
508
509        let _: () = redis::cmd("SETEX")
510            .arg(&session_key)
511            .arg(ttl)
512            .arg(&session_json)
513            .query_async(&mut conn)
514            .await
515            .map_err(|e| StorageError::operation_failed(format!("Failed to store session: {e}")))?;
516
517        Ok(())
518    }
519
520    async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
521        let mut conn = self.get_connection().await?;
522        let session_key = self.key(&format!("session:{session_id}"));
523
524        let session_json: Option<String> = redis::cmd("GET")
525            .arg(&session_key)
526            .query_async(&mut conn)
527            .await
528            .map_err(|e| StorageError::operation_failed(format!("Failed to get session: {e}")))?;
529
530        if let Some(json) = session_json {
531            let session: SessionData = serde_json::from_str(&json).map_err(|e| {
532                StorageError::serialization(format!("Session deserialization failed: {e}"))
533            })?;
534            Ok(Some(session))
535        } else {
536            Ok(None)
537        }
538    }
539
540    async fn delete_session(&self, session_id: &str) -> Result<()> {
541        let mut conn = self.get_connection().await?;
542        let session_key = self.key(&format!("session:{session_id}"));
543
544        let _: () = redis::cmd("DEL")
545            .arg(&session_key)
546            .query_async(&mut conn)
547            .await
548            .map_err(|e| {
549                StorageError::operation_failed(format!("Failed to delete session: {e}"))
550            })?;
551
552        Ok(())
553    }
554
555    async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
556        let mut conn = self.get_connection().await?;
557        let pattern = self.key("session:*");
558
559        // Use SCAN to find all session keys
560        let keys: Vec<String> = redis::cmd("KEYS")
561            .arg(&pattern)
562            .query_async(&mut conn)
563            .await
564            .map_err(|e| StorageError::operation_failed(format!("Failed to scan sessions: {e}")))?;
565
566        let mut user_sessions = Vec::new();
567
568        // Check each session to see if it belongs to the user
569        for key in keys {
570            if let Ok(session_json) = redis::cmd("GET")
571                .arg(&key)
572                .query_async::<Option<String>>(&mut conn)
573                .await
574                && let Some(session_json) = session_json
575                && let Ok(session) = serde_json::from_str::<SessionData>(&session_json)
576                && session.user_id == user_id
577                && !session.is_expired()
578            {
579                user_sessions.push(session);
580            }
581        }
582
583        Ok(user_sessions)
584    }
585
586    async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
587        let mut conn = self.get_connection().await?;
588        let storage_key = self.key(&format!("kv:{key}"));
589
590        if let Some(ttl) = ttl {
591            let _: () = redis::cmd("SETEX")
592                .arg(&storage_key)
593                .arg(ttl.as_secs())
594                .arg(value)
595                .query_async(&mut conn)
596                .await
597                .map_err(|e| {
598                    StorageError::operation_failed(format!("Failed to store KV with TTL: {e}"))
599                })?;
600        } else {
601            let _: () = redis::cmd("SET")
602                .arg(&storage_key)
603                .arg(value)
604                .query_async(&mut conn)
605                .await
606                .map_err(|e| StorageError::operation_failed(format!("Failed to store KV: {e}")))?;
607        }
608
609        Ok(())
610    }
611
612    async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
613        let mut conn = self.get_connection().await?;
614        let storage_key = self.key(&format!("kv:{key}"));
615
616        let value: Option<Vec<u8>> = redis::cmd("GET")
617            .arg(&storage_key)
618            .query_async(&mut conn)
619            .await
620            .map_err(|e| StorageError::operation_failed(format!("Failed to get KV: {e}")))?;
621
622        Ok(value)
623    }
624
625    async fn delete_kv(&self, key: &str) -> Result<()> {
626        let mut conn = self.get_connection().await?;
627        let storage_key = self.key(&format!("kv:{key}"));
628
629        let _: () = redis::cmd("DEL")
630            .arg(&storage_key)
631            .query_async(&mut conn)
632            .await
633            .map_err(|e| StorageError::operation_failed(format!("Failed to delete KV: {e}")))?;
634
635        Ok(())
636    }
637
638    async fn cleanup_expired(&self) -> Result<()> {
639        // Redis handles expiration automatically, so this is a no-op
640        Ok(())
641    }
642
643    async fn count_active_sessions(&self) -> Result<u64> {
644        let mut conn = self.get_connection().await?;
645        let pattern = self.key("session:*");
646
647        // Use KEYS to find all session keys (consider SCAN for production with many keys)
648        let keys: Vec<String> = redis::cmd("KEYS")
649            .arg(&pattern)
650            .query_async(&mut conn)
651            .await
652            .map_err(|e| StorageError::operation_failed(format!("Failed to scan sessions: {e}")))?;
653
654        // Count only non-expired sessions by checking TTL
655        let mut active_count = 0u64;
656        for key in keys {
657            let ttl: i64 = redis::cmd("TTL")
658                .arg(&key)
659                .query_async(&mut conn)
660                .await
661                .map_err(|e| StorageError::operation_failed(format!("Failed to check TTL: {e}")))?;
662
663            // TTL > 0 means key has expiration and is still active
664            // TTL = -1 means key has no expiration (active)
665            // TTL = -2 means key doesn't exist (expired)
666            if ttl > 0 || ttl == -1 {
667                active_count += 1;
668            }
669        }
670
671        Ok(active_count)
672    }
673}
674
675impl SessionData {
676    /// Create a new session.
677    pub fn new(
678        session_id: impl Into<String>,
679        user_id: impl Into<String>,
680        expires_in: Duration,
681    ) -> Self {
682        let now = chrono::Utc::now();
683
684        Self {
685            session_id: session_id.into(),
686            user_id: user_id.into(),
687            created_at: now,
688            expires_at: now + chrono::Duration::from_std(expires_in).unwrap(),
689            last_activity: now,
690            ip_address: None,
691            user_agent: None,
692            data: HashMap::new(),
693        }
694    }
695
696    /// Check if the session has expired.
697    pub fn is_expired(&self) -> bool {
698        chrono::Utc::now() > self.expires_at
699    }
700
701    /// Update the last activity timestamp.
702    pub fn update_activity(&mut self) {
703        self.last_activity = chrono::Utc::now();
704    }
705
706    /// Set session metadata.
707    pub fn with_metadata(mut self, ip_address: Option<String>, user_agent: Option<String>) -> Self {
708        self.ip_address = ip_address;
709        self.user_agent = user_agent;
710        self
711    }
712
713    /// Add custom data to the session.
714    pub fn set_data(&mut self, key: impl Into<String>, value: serde_json::Value) {
715        self.data.insert(key.into(), value);
716    }
717
718    /// Get custom data from the session.
719    pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
720        self.data.get(key)
721    }
722}
723
724/// Implementation of AuditStorage for MemoryStorage
725#[async_trait]
726impl crate::audit::AuditStorage for MemoryStorage {
727    async fn store_event(&self, event: &crate::audit::AuditEvent) -> Result<()> {
728        // Store audit event as JSON in KV storage
729        let json_data = serde_json::to_vec(event).map_err(|e| {
730            crate::errors::AuthError::internal(format!("Failed to serialize audit event: {}", e))
731        })?;
732
733        let key = format!("audit_event_{}", event.id);
734        self.store_kv(&key, &json_data, None).await
735    }
736
737    async fn query_events(
738        &self,
739        query: &crate::audit::AuditQuery,
740    ) -> Result<Vec<crate::audit::AuditEvent>> {
741        // Simple implementation - in production, this would be more efficient
742        let all_keys = self.list_kv_keys("audit_event_").await?;
743        let mut events = Vec::new();
744
745        for key in all_keys {
746            if let Some(data) = self.get_kv(&key).await?
747                && let Ok(event) = serde_json::from_slice::<crate::audit::AuditEvent>(&data)
748            {
749                // Apply filters
750                let mut include = true;
751
752                if let Some(ref time_range) = query.time_range
753                    && (event.timestamp < time_range.start || event.timestamp > time_range.end)
754                {
755                    include = false;
756                }
757
758                if let Some(ref event_types) = query.event_types
759                    && !event_types.contains(&event.event_type)
760                {
761                    include = false;
762                }
763
764                if let Some(ref user_id) = query.user_id
765                    && event.user_id.as_ref() != Some(user_id)
766                {
767                    include = false;
768                }
769
770                if include {
771                    events.push(event);
772                }
773            }
774        }
775
776        // Sort and limit
777        events.sort_by(|a, b| match query.sort_order {
778            crate::audit::SortOrder::TimestampAsc => a.timestamp.cmp(&b.timestamp),
779            crate::audit::SortOrder::TimestampDesc => b.timestamp.cmp(&a.timestamp),
780            crate::audit::SortOrder::RiskLevelDesc => b.risk_level.cmp(&a.risk_level),
781        });
782
783        if let Some(limit) = query.limit {
784            events.truncate(limit as usize);
785        }
786        Ok(events)
787    }
788
789    async fn get_event(&self, event_id: &str) -> Result<Option<crate::audit::AuditEvent>> {
790        let key = format!("audit_event_{}", event_id);
791        if let Some(data) = self.get_kv(&key).await? {
792            let event = serde_json::from_slice(&data).map_err(|e| {
793                crate::errors::AuthError::internal(format!(
794                    "Failed to deserialize audit event: {}",
795                    e
796                ))
797            })?;
798            Ok(Some(event))
799        } else {
800            Ok(None)
801        }
802    }
803
804    async fn count_events(&self, query: &crate::audit::AuditQuery) -> Result<u64> {
805        let events = self.query_events(query).await?;
806        Ok(events.len() as u64)
807    }
808
809    async fn delete_old_events(&self, before: std::time::SystemTime) -> Result<u64> {
810        let all_keys = self.list_kv_keys("audit_event_").await?;
811        let mut deleted_count = 0;
812
813        for key in all_keys {
814            if let Some(data) = self.get_kv(&key).await?
815                && let Ok(event) = serde_json::from_slice::<crate::audit::AuditEvent>(&data)
816                && event.timestamp < before
817            {
818                self.delete_kv(&key).await?;
819                deleted_count += 1;
820            }
821        }
822
823        Ok(deleted_count)
824    }
825
826    async fn get_statistics(
827        &self,
828        _query: &crate::audit::StatsQuery,
829    ) -> Result<crate::audit::AuditStatistics> {
830        // For now, return basic statistics
831        // PRODUCTION: Full audit statistics available with integrated audit storage
832
833        let total_events = 0; // Placeholder
834        let event_type_counts = std::collections::HashMap::new();
835        let risk_level_counts = std::collections::HashMap::new();
836        let outcome_counts = std::collections::HashMap::new();
837        let time_series = Vec::new();
838        let top_users = Vec::new();
839        let top_ips = Vec::new();
840
841        Ok(crate::audit::AuditStatistics {
842            total_events,
843            event_type_counts,
844            risk_level_counts,
845            outcome_counts,
846            time_series,
847            top_users,
848            top_ips,
849        })
850    }
851}
852
853impl MemoryStorage {
854    /// Helper method to list KV keys with a prefix
855    async fn list_kv_keys(&self, prefix: &str) -> Result<Vec<String>> {
856        // Simple implementation for memory storage
857        // In a real implementation, this would scan the internal key-value store
858        // For now, return empty as we don't have direct access to internal storage
859        let _prefix = prefix; // Acknowledge parameter
860        Ok(Vec::new())
861    }
862}
863
864/// Implementation of AuditStorage for Arc<MemoryStorage>
865#[async_trait]
866impl crate::audit::AuditStorage for Arc<MemoryStorage> {
867    async fn store_event(&self, event: &crate::audit::AuditEvent) -> Result<()> {
868        self.as_ref().store_event(event).await
869    }
870
871    async fn query_events(
872        &self,
873        query: &crate::audit::AuditQuery,
874    ) -> Result<Vec<crate::audit::AuditEvent>> {
875        self.as_ref().query_events(query).await
876    }
877
878    async fn get_event(&self, event_id: &str) -> Result<Option<crate::audit::AuditEvent>> {
879        self.as_ref().get_event(event_id).await
880    }
881
882    async fn count_events(&self, query: &crate::audit::AuditQuery) -> Result<u64> {
883        self.as_ref().count_events(query).await
884    }
885
886    async fn delete_old_events(&self, before: std::time::SystemTime) -> Result<u64> {
887        self.as_ref().delete_old_events(before).await
888    }
889
890    async fn get_statistics(
891        &self,
892        query: &crate::audit::StatsQuery,
893    ) -> Result<crate::audit::AuditStatistics> {
894        self.as_ref().get_statistics(query).await
895    }
896}
897
898#[cfg(test)]
899mod tests {
900    use super::*;
901    use crate::tokens::AuthToken;
902
903    #[tokio::test]
904    async fn test_memory_storage() {
905        let storage = MemoryStorage::new();
906
907        // Create a test token
908        let token = AuthToken::new("user123", "token123", Duration::from_secs(3600), "test");
909
910        // Store token
911        storage.store_token(&token).await.unwrap();
912
913        // Retrieve token
914        let retrieved = storage.get_token(&token.token_id).await.unwrap().unwrap();
915        assert_eq!(retrieved.user_id, "user123");
916
917        // Retrieve by access token
918        let retrieved = storage
919            .get_token_by_access_token(&token.access_token)
920            .await
921            .unwrap()
922            .unwrap();
923        assert_eq!(retrieved.token_id, token.token_id);
924
925        // List user tokens
926        let user_tokens = storage.list_user_tokens("user123").await.unwrap();
927        assert_eq!(user_tokens.len(), 1);
928
929        // Delete token
930        storage.delete_token(&token.token_id).await.unwrap();
931        let retrieved = storage.get_token(&token.token_id).await.unwrap();
932        assert!(retrieved.is_none());
933    }
934
935    #[tokio::test]
936    async fn test_session_storage() {
937        let storage = MemoryStorage::new();
938
939        let session = SessionData::new("session123", "user123", Duration::from_secs(3600))
940            .with_metadata(
941                Some("192.168.1.1".to_string()),
942                Some("Test Agent".to_string()),
943            );
944
945        // Store session
946        storage
947            .store_session(&session.session_id, &session)
948            .await
949            .unwrap();
950
951        // Retrieve session
952        let retrieved = storage
953            .get_session(&session.session_id)
954            .await
955            .unwrap()
956            .unwrap();
957        assert_eq!(retrieved.user_id, "user123");
958        assert_eq!(retrieved.ip_address, Some("192.168.1.1".to_string()));
959
960        // Delete session
961        storage.delete_session(&session.session_id).await.unwrap();
962        let retrieved = storage.get_session(&session.session_id).await.unwrap();
963        assert!(retrieved.is_none());
964    }
965
966    #[tokio::test]
967    async fn test_kv_storage() {
968        let storage = MemoryStorage::new();
969
970        let key = "test_key";
971        let value = b"test_value";
972
973        // Store KV
974        storage
975            .store_kv(key, value, Some(Duration::from_secs(3600)))
976            .await
977            .unwrap();
978
979        // Retrieve KV
980        let retrieved = storage.get_kv(key).await.unwrap().unwrap();
981        assert_eq!(retrieved, value);
982
983        // Delete KV
984        storage.delete_kv(key).await.unwrap();
985        let retrieved = storage.get_kv(key).await.unwrap();
986        assert!(retrieved.is_none());
987    }
988}