1use crate::errors::Result;
4use crate::tokens::AuthToken;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Duration;
10
11#[cfg(feature = "redis-storage")]
12use crate::errors::StorageError;
13
14#[async_trait]
77pub trait AuthStorage: Send + Sync {
78 async fn store_tokens_bulk(&self, tokens: &[AuthToken]) -> Result<()> {
80 for token in tokens {
81 self.store_token(token).await?;
82 }
83 Ok(())
84 }
85
86 async fn delete_tokens_bulk(&self, token_ids: &[String]) -> Result<()> {
88 for token_id in token_ids {
89 self.delete_token(token_id).await?;
90 }
91 Ok(())
92 }
93
94 async fn store_sessions_bulk(&self, sessions: &[(String, SessionData)]) -> Result<()> {
96 for (session_id, data) in sessions {
97 self.store_session(session_id, data).await?;
98 }
99 Ok(())
100 }
101
102 async fn delete_sessions_bulk(&self, session_ids: &[String]) -> Result<()> {
104 for session_id in session_ids {
105 self.delete_session(session_id).await?;
106 }
107 Ok(())
108 }
109 async fn store_token(&self, token: &AuthToken) -> Result<()>;
111
112 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>>;
114
115 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>>;
117
118 async fn update_token(&self, token: &AuthToken) -> Result<()>;
120
121 async fn delete_token(&self, token_id: &str) -> Result<()>;
123
124 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>>;
126
127 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()>;
129
130 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>>;
132
133 async fn delete_session(&self, session_id: &str) -> Result<()>;
135
136 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>>;
138
139 async fn count_active_sessions(&self) -> Result<u64>;
141
142 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()>;
144
145 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>>;
147
148 async fn delete_kv(&self, key: &str) -> Result<()>;
150
151 async fn list_kv_keys(&self, _prefix: &str) -> Result<Vec<String>> {
157 tracing::warn!(
158 "list_kv_keys called on a storage backend that does not override it — returning empty"
159 );
160 Ok(Vec::new())
161 }
162
163 async fn cleanup_expired(&self) -> Result<()>;
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct SessionData {
194 pub session_id: String,
196
197 pub user_id: String,
199
200 pub created_at: chrono::DateTime<chrono::Utc>,
202
203 pub expires_at: chrono::DateTime<chrono::Utc>,
205
206 pub last_activity: chrono::DateTime<chrono::Utc>,
208
209 pub ip_address: Option<String>,
211
212 pub user_agent: Option<String>,
214
215 pub data: HashMap<String, serde_json::Value>,
217}
218
219#[derive(Debug, Clone)]
234pub struct MemoryStorage {
235 inner: crate::storage::dashmap_memory::DashMapMemoryStorage,
237 roles: Arc<RwLock<HashMap<String, crate::authorization::AbacRole>>>,
239 user_roles: Arc<RwLock<Vec<crate::authorization::UserRole>>>,
240}
241
242#[cfg(feature = "redis-storage")]
244#[derive(Debug, Clone)]
245pub struct RedisStorage {
246 client: redis::Client,
247 key_prefix: String,
248}
249
250impl Default for MemoryStorage {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256impl MemoryStorage {
257 pub fn new() -> Self {
259 Self {
260 inner: crate::storage::dashmap_memory::DashMapMemoryStorage::new(),
261 roles: Arc::new(RwLock::new(HashMap::new())),
262 user_roles: Arc::new(RwLock::new(Vec::new())),
263 }
264 }
265}
266#[async_trait::async_trait]
268impl crate::authorization::AuthorizationStorage for MemoryStorage {
269 async fn store_role(&self, role: &crate::authorization::AbacRole) -> crate::errors::Result<()> {
270 let mut roles = self.roles.write().map_err(|_| {
271 crate::errors::AuthError::internal(
272 "Lock poisoned — a prior thread panicked while holding this lock",
273 )
274 })?;
275 roles.insert(role.id.clone(), role.clone());
276 Ok(())
277 }
278
279 async fn get_role(
280 &self,
281 role_id: &str,
282 ) -> crate::errors::Result<Option<crate::authorization::AbacRole>> {
283 let roles = self.roles.read().map_err(|_| {
284 crate::errors::AuthError::internal(
285 "Lock poisoned — a prior thread panicked while holding this lock",
286 )
287 })?;
288 Ok(roles.get(role_id).cloned())
289 }
290
291 async fn update_role(
292 &self,
293 role: &crate::authorization::AbacRole,
294 ) -> crate::errors::Result<()> {
295 let mut roles = self.roles.write().map_err(|_| {
296 crate::errors::AuthError::internal(
297 "Lock poisoned — a prior thread panicked while holding this lock",
298 )
299 })?;
300 roles.insert(role.id.clone(), role.clone());
301 Ok(())
302 }
303
304 async fn delete_role(&self, role_id: &str) -> crate::errors::Result<()> {
305 let mut roles = self.roles.write().map_err(|_| {
306 crate::errors::AuthError::internal(
307 "Lock poisoned — a prior thread panicked while holding this lock",
308 )
309 })?;
310 roles.remove(role_id);
311 Ok(())
312 }
313
314 async fn list_roles(&self) -> crate::errors::Result<Vec<crate::authorization::AbacRole>> {
315 let roles = self.roles.read().map_err(|_| {
316 crate::errors::AuthError::internal(
317 "Lock poisoned — a prior thread panicked while holding this lock",
318 )
319 })?;
320 Ok(roles.values().cloned().collect())
321 }
322
323 async fn assign_role(
324 &self,
325 user_role: &crate::authorization::UserRole,
326 ) -> crate::errors::Result<()> {
327 let mut user_roles = self.user_roles.write().map_err(|_| {
328 crate::errors::AuthError::internal(
329 "Lock poisoned — a prior thread panicked while holding this lock",
330 )
331 })?;
332 user_roles.push(user_role.clone());
333 Ok(())
334 }
335
336 async fn remove_role(&self, user_id: &str, role_id: &str) -> crate::errors::Result<()> {
337 let mut user_roles = self.user_roles.write().map_err(|_| {
338 crate::errors::AuthError::internal(
339 "Lock poisoned — a prior thread panicked while holding this lock",
340 )
341 })?;
342 user_roles.retain(|ur| ur.user_id != user_id || ur.role_id != role_id);
343 Ok(())
344 }
345
346 async fn get_user_roles(
347 &self,
348 user_id: &str,
349 ) -> crate::errors::Result<Vec<crate::authorization::UserRole>> {
350 let user_roles = self.user_roles.read().map_err(|_| {
351 crate::errors::AuthError::internal(
352 "Lock poisoned — a prior thread panicked while holding this lock",
353 )
354 })?;
355 Ok(user_roles
356 .iter()
357 .filter(|ur| ur.user_id == user_id)
358 .cloned()
359 .collect())
360 }
361
362 async fn get_role_users(
363 &self,
364 role_id: &str,
365 ) -> crate::errors::Result<Vec<crate::authorization::UserRole>> {
366 let user_roles = self.user_roles.read().map_err(|_| {
367 crate::errors::AuthError::internal(
368 "Lock poisoned — a prior thread panicked while holding this lock",
369 )
370 })?;
371 Ok(user_roles
372 .iter()
373 .filter(|ur| ur.role_id == role_id)
374 .cloned()
375 .collect())
376 }
377}
378
379#[async_trait]
380impl AuthStorage for MemoryStorage {
381 async fn store_token(&self, token: &AuthToken) -> Result<()> {
382 self.inner.store_token(token).await
384 }
385
386 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
387 self.inner.get_token(token_id).await
389 }
390
391 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
392 self.inner.get_token_by_access_token(access_token).await
394 }
395
396 async fn update_token(&self, token: &AuthToken) -> Result<()> {
397 self.inner.update_token(token).await
399 }
400
401 async fn delete_token(&self, token_id: &str) -> Result<()> {
402 self.inner.delete_token(token_id).await
404 }
405
406 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
407 self.inner.list_user_tokens(user_id).await
409 }
410
411 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
412 self.inner.store_session(session_id, data).await
414 }
415
416 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
417 self.inner.get_session(session_id).await
419 }
420
421 async fn delete_session(&self, session_id: &str) -> Result<()> {
422 self.inner.delete_session(session_id).await
424 }
425
426 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
427 self.inner.list_user_sessions(user_id).await
429 }
430
431 async fn count_active_sessions(&self) -> Result<u64> {
432 self.inner.count_active_sessions().await
434 }
435
436 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
437 self.inner.store_kv(key, value, ttl).await
439 }
440
441 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
442 self.inner.get_kv(key).await
444 }
445
446 async fn delete_kv(&self, key: &str) -> Result<()> {
447 self.inner.delete_kv(key).await
449 }
450
451 async fn list_kv_keys(&self, prefix: &str) -> Result<Vec<String>> {
452 Ok(self.inner.list_kv_keys_by_prefix(prefix))
453 }
454
455 async fn cleanup_expired(&self) -> Result<()> {
456 self.inner.cleanup_expired().await
458 }
459}
460
461#[cfg(feature = "redis-storage")]
462impl RedisStorage {
463 pub fn new(redis_url: &str, key_prefix: impl Into<String>) -> Result<Self> {
465 let client = redis::Client::open(redis_url).map_err(|e| {
466 StorageError::connection_failed(format!("Redis connection failed: {e}"))
467 })?;
468
469 Ok(Self {
470 client,
471 key_prefix: key_prefix.into(),
472 })
473 }
474
475 async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection> {
477 self.client
478 .get_multiplexed_async_connection()
479 .await
480 .map_err(|e| {
481 StorageError::connection_failed(format!("Failed to get Redis connection: {e}"))
482 .into()
483 })
484 }
485
486 fn key(&self, suffix: &str) -> String {
488 format!("{}{}", self.key_prefix, suffix)
489 }
490}
491
492#[cfg(feature = "redis-storage")]
493#[async_trait]
494impl AuthStorage for RedisStorage {
495 async fn store_token(&self, token: &AuthToken) -> Result<()> {
496 let mut conn = self.get_connection().await?;
497 let token_json = serde_json::to_string(token)
498 .map_err(|e| StorageError::serialization(format!("Token serialization failed: {e}")))?;
499
500 let token_key = self.key(&format!("token:{}", token.token_id));
501 let access_token_key = self.key(&format!("access_token:{}", token.access_token));
502 let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
503
504 let ttl = token.time_until_expiry().as_secs().max(1);
506
507 let _: () = redis::cmd("SETEX")
509 .arg(&token_key)
510 .arg(ttl)
511 .arg(&token_json)
512 .query_async(&mut conn)
513 .await
514 .map_err(|e| StorageError::operation_failed(format!("Failed to store token: {e}")))?;
515
516 let _: () = redis::cmd("SETEX")
518 .arg(&access_token_key)
519 .arg(ttl)
520 .arg(&token.token_id)
521 .query_async(&mut conn)
522 .await
523 .map_err(|e| {
524 StorageError::operation_failed(format!("Failed to store access token mapping: {e}"))
525 })?;
526
527 let _: () = redis::cmd("SADD")
529 .arg(&user_tokens_key)
530 .arg(&token.token_id)
531 .query_async(&mut conn)
532 .await
533 .map_err(|e| {
534 StorageError::operation_failed(format!("Failed to add token to user set: {e}"))
535 })?;
536
537 Ok(())
538 }
539
540 async fn get_token(&self, token_id: &str) -> Result<Option<AuthToken>> {
541 let mut conn = self.get_connection().await?;
542 let token_key = self.key(&format!("token:{token_id}"));
543
544 let token_json: Option<String> = redis::cmd("GET")
545 .arg(&token_key)
546 .query_async(&mut conn)
547 .await
548 .map_err(|e| StorageError::operation_failed(format!("Failed to get token: {e}")))?;
549
550 if let Some(json) = token_json {
551 let token: AuthToken = serde_json::from_str(&json).map_err(|e| {
552 StorageError::serialization(format!("Token deserialization failed: {e}"))
553 })?;
554 Ok(Some(token))
555 } else {
556 Ok(None)
557 }
558 }
559
560 async fn get_token_by_access_token(&self, access_token: &str) -> Result<Option<AuthToken>> {
561 let mut conn = self.get_connection().await?;
562 let access_token_key = self.key(&format!("access_token:{access_token}"));
563
564 let token_id: Option<String> = redis::cmd("GET")
565 .arg(&access_token_key)
566 .query_async(&mut conn)
567 .await
568 .map_err(|e| {
569 StorageError::operation_failed(format!("Failed to get access token mapping: {e}"))
570 })?;
571
572 if let Some(token_id) = token_id {
573 self.get_token(&token_id).await
574 } else {
575 Ok(None)
576 }
577 }
578
579 async fn update_token(&self, token: &AuthToken) -> Result<()> {
580 self.store_token(token).await
582 }
583
584 async fn delete_token(&self, token_id: &str) -> Result<()> {
585 let mut conn = self.get_connection().await?;
586
587 if let Some(token) = self.get_token(token_id).await? {
589 let token_key = self.key(&format!("token:{token_id}"));
590 let access_token_key = self.key(&format!("access_token:{}", token.access_token));
591 let user_tokens_key = self.key(&format!("user_tokens:{}", token.user_id));
592
593 let _: () = redis::cmd("DEL")
595 .arg(&token_key)
596 .query_async(&mut conn)
597 .await
598 .map_err(|e| {
599 StorageError::operation_failed(format!("Failed to delete token: {e}"))
600 })?;
601
602 let _: () = redis::cmd("DEL")
604 .arg(&access_token_key)
605 .query_async(&mut conn)
606 .await
607 .map_err(|e| {
608 StorageError::operation_failed(format!(
609 "Failed to delete access token mapping: {e}"
610 ))
611 })?;
612
613 let _: () = redis::cmd("SREM")
615 .arg(&user_tokens_key)
616 .arg(token_id)
617 .query_async(&mut conn)
618 .await
619 .map_err(|e| {
620 StorageError::operation_failed(format!(
621 "Failed to remove token from user set: {e}"
622 ))
623 })?;
624 }
625
626 Ok(())
627 }
628
629 async fn list_user_tokens(&self, user_id: &str) -> Result<Vec<AuthToken>> {
630 let mut conn = self.get_connection().await?;
631 let user_tokens_key = self.key(&format!("user_tokens:{user_id}"));
632
633 let token_ids: Vec<String> = redis::cmd("SMEMBERS")
634 .arg(&user_tokens_key)
635 .query_async(&mut conn)
636 .await
637 .map_err(|e| {
638 StorageError::operation_failed(format!("Failed to get user tokens: {e}"))
639 })?;
640
641 let mut tokens = Vec::new();
642 for token_id in token_ids {
643 if let Some(token) = self.get_token(&token_id).await? {
644 tokens.push(token);
645 }
646 }
647
648 Ok(tokens)
649 }
650
651 async fn store_session(&self, session_id: &str, data: &SessionData) -> Result<()> {
652 let mut conn = self.get_connection().await?;
653 let session_key = self.key(&format!("session:{session_id}"));
654
655 let session_json = serde_json::to_string(data).map_err(|e| {
656 StorageError::serialization(format!("Session serialization failed: {e}"))
657 })?;
658
659 let ttl = (data.expires_at - chrono::Utc::now()).num_seconds().max(1);
660
661 let _: () = redis::cmd("SETEX")
662 .arg(&session_key)
663 .arg(ttl)
664 .arg(&session_json)
665 .query_async(&mut conn)
666 .await
667 .map_err(|e| StorageError::operation_failed(format!("Failed to store session: {e}")))?;
668
669 Ok(())
670 }
671
672 async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
673 let mut conn = self.get_connection().await?;
674 let session_key = self.key(&format!("session:{session_id}"));
675
676 let session_json: Option<String> = redis::cmd("GET")
677 .arg(&session_key)
678 .query_async(&mut conn)
679 .await
680 .map_err(|e| StorageError::operation_failed(format!("Failed to get session: {e}")))?;
681
682 if let Some(json) = session_json {
683 let session: SessionData = serde_json::from_str(&json).map_err(|e| {
684 StorageError::serialization(format!("Session deserialization failed: {e}"))
685 })?;
686 Ok(Some(session))
687 } else {
688 Ok(None)
689 }
690 }
691
692 async fn delete_session(&self, session_id: &str) -> Result<()> {
693 let mut conn = self.get_connection().await?;
694 let session_key = self.key(&format!("session:{session_id}"));
695
696 let _: () = redis::cmd("DEL")
697 .arg(&session_key)
698 .query_async(&mut conn)
699 .await
700 .map_err(|e| {
701 StorageError::operation_failed(format!("Failed to delete session: {e}"))
702 })?;
703
704 Ok(())
705 }
706
707 async fn list_user_sessions(&self, user_id: &str) -> Result<Vec<SessionData>> {
708 let mut conn = self.get_connection().await?;
709 let pattern = self.key("session:*");
710
711 let keys: Vec<String> = redis::cmd("KEYS")
713 .arg(&pattern)
714 .query_async(&mut conn)
715 .await
716 .map_err(|e| StorageError::operation_failed(format!("Failed to scan sessions: {e}")))?;
717
718 let mut user_sessions = Vec::new();
719
720 for key in keys {
722 if let Ok(session_json) = redis::cmd("GET")
723 .arg(&key)
724 .query_async::<Option<String>>(&mut conn)
725 .await
726 && let Some(session_json) = session_json
727 && let Ok(session) = serde_json::from_str::<SessionData>(&session_json)
728 && session.user_id == user_id
729 && !session.is_expired()
730 {
731 user_sessions.push(session);
732 }
733 }
734
735 Ok(user_sessions)
736 }
737
738 async fn store_kv(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<()> {
739 let mut conn = self.get_connection().await?;
740 let storage_key = self.key(&format!("kv:{key}"));
741
742 if let Some(ttl) = ttl {
743 let _: () = redis::cmd("SETEX")
744 .arg(&storage_key)
745 .arg(ttl.as_secs())
746 .arg(value)
747 .query_async(&mut conn)
748 .await
749 .map_err(|e| {
750 StorageError::operation_failed(format!("Failed to store KV with TTL: {e}"))
751 })?;
752 } else {
753 let _: () = redis::cmd("SET")
754 .arg(&storage_key)
755 .arg(value)
756 .query_async(&mut conn)
757 .await
758 .map_err(|e| StorageError::operation_failed(format!("Failed to store KV: {e}")))?;
759 }
760
761 Ok(())
762 }
763
764 async fn get_kv(&self, key: &str) -> Result<Option<Vec<u8>>> {
765 let mut conn = self.get_connection().await?;
766 let storage_key = self.key(&format!("kv:{key}"));
767
768 let value: Option<Vec<u8>> = redis::cmd("GET")
769 .arg(&storage_key)
770 .query_async(&mut conn)
771 .await
772 .map_err(|e| StorageError::operation_failed(format!("Failed to get KV: {e}")))?;
773
774 Ok(value)
775 }
776
777 async fn delete_kv(&self, key: &str) -> Result<()> {
778 let mut conn = self.get_connection().await?;
779 let storage_key = self.key(&format!("kv:{key}"));
780
781 let _: () = redis::cmd("DEL")
782 .arg(&storage_key)
783 .query_async(&mut conn)
784 .await
785 .map_err(|e| StorageError::operation_failed(format!("Failed to delete KV: {e}")))?;
786
787 Ok(())
788 }
789
790 async fn list_kv_keys(&self, prefix: &str) -> Result<Vec<String>> {
791 let mut conn = self.get_connection().await?;
792 let pattern = self.key(&format!("kv:{prefix}*"));
793 let keys: Vec<String> = redis::cmd("KEYS")
794 .arg(&pattern)
795 .query_async(&mut conn)
796 .await
797 .map_err(|e| StorageError::operation_failed(format!("Failed to list KV keys: {e}")))?;
798
799 Ok(keys
800 .into_iter()
801 .filter_map(|key| key.strip_prefix(&self.key_prefix).map(str::to_string))
802 .filter_map(|key| key.strip_prefix("kv:").map(str::to_string))
803 .collect())
804 }
805
806 async fn cleanup_expired(&self) -> Result<()> {
807 Ok(())
809 }
810
811 async fn count_active_sessions(&self) -> Result<u64> {
812 let mut conn = self.get_connection().await?;
813 let pattern = self.key("session:*");
814
815 let keys: Vec<String> = redis::cmd("KEYS")
817 .arg(&pattern)
818 .query_async(&mut conn)
819 .await
820 .map_err(|e| StorageError::operation_failed(format!("Failed to scan sessions: {e}")))?;
821
822 let mut active_count = 0u64;
824 for key in keys {
825 let ttl: i64 = redis::cmd("TTL")
826 .arg(&key)
827 .query_async(&mut conn)
828 .await
829 .map_err(|e| StorageError::operation_failed(format!("Failed to check TTL: {e}")))?;
830
831 if ttl > 0 || ttl == -1 {
835 active_count += 1;
836 }
837 }
838
839 Ok(active_count)
840 }
841}
842
843impl SessionData {
844 pub fn new(
857 session_id: impl Into<String>,
858 user_id: impl Into<String>,
859 expires_in: Duration,
860 ) -> Self {
861 let now = chrono::Utc::now();
862
863 Self {
864 session_id: session_id.into(),
865 user_id: user_id.into(),
866 created_at: now,
867 expires_at: now
868 + chrono::Duration::from_std(expires_in).unwrap_or(chrono::Duration::hours(1)),
869 last_activity: now,
870 ip_address: None,
871 user_agent: None,
872 data: HashMap::new(),
873 }
874 }
875
876 pub fn is_expired(&self) -> bool {
878 chrono::Utc::now() > self.expires_at
879 }
880
881 pub fn time_until_expiry(&self) -> std::time::Duration {
886 let remaining = self.expires_at - chrono::Utc::now();
887 remaining
888 .to_std()
889 .unwrap_or(std::time::Duration::ZERO)
890 }
891
892 pub fn is_active(&self) -> bool {
894 !self.is_expired()
895 }
896
897 pub fn update_activity(&mut self) {
899 self.last_activity = chrono::Utc::now();
900 }
901
902 pub fn with_metadata(mut self, ip_address: Option<String>, user_agent: Option<String>) -> Self {
904 self.ip_address = ip_address;
905 self.user_agent = user_agent;
906 self
907 }
908
909 pub fn ip_address(mut self, ip: impl Into<String>) -> Self {
911 self.ip_address = Some(ip.into());
912 self
913 }
914
915 pub fn user_agent(mut self, ua: impl Into<String>) -> Self {
917 self.user_agent = Some(ua.into());
918 self
919 }
920
921 pub fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
923 self.data.insert(key.into(), value);
924 self
925 }
926
927 pub fn set_data(&mut self, key: impl Into<String>, value: serde_json::Value) {
929 self.data.insert(key.into(), value);
930 }
931
932 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
934 self.data.get(key)
935 }
936}
937
938#[async_trait]
940impl crate::audit::AuditStorage for MemoryStorage {
941 async fn store_event(&self, event: &crate::audit::AuditEvent) -> Result<()> {
942 let json_data = serde_json::to_vec(event).map_err(|e| {
944 crate::errors::AuthError::internal(format!("Failed to serialize audit event: {}", e))
945 })?;
946
947 let key = format!("audit_event_{}", event.id);
948 self.store_kv(&key, &json_data, None).await
949 }
950
951 async fn query_events(
952 &self,
953 query: &crate::audit::AuditQuery,
954 ) -> Result<Vec<crate::audit::AuditEvent>> {
955 let all_keys = self.list_kv_keys("audit_event_").await?;
957 let mut events = Vec::new();
958
959 for key in all_keys {
960 if let Some(data) = self.get_kv(&key).await?
961 && let Ok(event) = serde_json::from_slice::<crate::audit::AuditEvent>(&data)
962 {
963 let mut include = true;
965
966 if let Some(ref time_range) = query.time_range
967 && (event.timestamp < time_range.start || event.timestamp > time_range.end)
968 {
969 include = false;
970 }
971
972 if let Some(ref event_types) = query.event_types
973 && !event_types.contains(&event.event_type)
974 {
975 include = false;
976 }
977
978 if let Some(ref user_id) = query.user_id
979 && event.user_id.as_ref() != Some(user_id)
980 {
981 include = false;
982 }
983
984 if include {
985 events.push(event);
986 }
987 }
988 }
989
990 events.sort_by(|a, b| match query.sort_order {
992 crate::audit::SortOrder::TimestampAsc => a.timestamp.cmp(&b.timestamp),
993 crate::audit::SortOrder::TimestampDesc => b.timestamp.cmp(&a.timestamp),
994 crate::audit::SortOrder::RiskLevelDesc => b.risk_level.cmp(&a.risk_level),
995 });
996
997 if let Some(limit) = query.limit {
998 events.truncate(limit as usize);
999 }
1000 Ok(events)
1001 }
1002
1003 async fn get_event(&self, event_id: &str) -> Result<Option<crate::audit::AuditEvent>> {
1004 let key = format!("audit_event_{}", event_id);
1005 if let Some(data) = self.get_kv(&key).await? {
1006 let event = serde_json::from_slice(&data).map_err(|e| {
1007 crate::errors::AuthError::internal(format!(
1008 "Failed to deserialize audit event: {}",
1009 e
1010 ))
1011 })?;
1012 Ok(Some(event))
1013 } else {
1014 Ok(None)
1015 }
1016 }
1017
1018 async fn count_events(&self, query: &crate::audit::AuditQuery) -> Result<u64> {
1019 let events = self.query_events(query).await?;
1020 Ok(events.len() as u64)
1021 }
1022
1023 async fn delete_old_events(&self, before: std::time::SystemTime) -> Result<u64> {
1024 let all_keys = self.list_kv_keys("audit_event_").await?;
1025 let mut deleted_count = 0;
1026
1027 for key in all_keys {
1028 if let Some(data) = self.get_kv(&key).await?
1029 && let Ok(event) = serde_json::from_slice::<crate::audit::AuditEvent>(&data)
1030 && event.timestamp < before
1031 {
1032 self.delete_kv(&key).await?;
1033 deleted_count += 1;
1034 }
1035 }
1036
1037 Ok(deleted_count)
1038 }
1039
1040 async fn get_statistics(
1041 &self,
1042 query: &crate::audit::StatsQuery,
1043 ) -> Result<crate::audit::AuditStatistics> {
1044 use std::collections::HashMap;
1045
1046 let audit_query = crate::audit::AuditQuery::builder()
1048 .time_range(query.time_range.start, query.time_range.end)
1049 .sort_order(crate::audit::SortOrder::TimestampAsc)
1050 .build();
1051 let events = self.query_events(&audit_query).await?;
1052 let total_events = events.len() as u64;
1053
1054 let mut event_type_counts: HashMap<String, u64> = HashMap::new();
1055 let mut risk_level_counts: HashMap<String, u64> = HashMap::new();
1056 let mut outcome_counts: HashMap<String, u64> = HashMap::new();
1057
1058 for event in &events {
1059 *event_type_counts
1060 .entry(format!("{:?}", event.event_type))
1061 .or_insert(0) += 1;
1062 *risk_level_counts
1063 .entry(format!("{:?}", event.risk_level))
1064 .or_insert(0) += 1;
1065 *outcome_counts
1066 .entry(format!("{:?}", event.outcome))
1067 .or_insert(0) += 1;
1068 }
1069
1070 Ok(crate::audit::AuditStatistics {
1071 total_events,
1072 event_type_counts,
1073 risk_level_counts,
1074 outcome_counts,
1075 time_series: Vec::new(),
1076 top_users: Vec::new(),
1077 top_ips: Vec::new(),
1078 })
1079 }
1080}
1081
1082impl MemoryStorage {
1083 async fn list_kv_keys(&self, _prefix: &str) -> Result<Vec<String>> {
1085 Ok(self.inner.list_kv_keys_by_prefix(_prefix))
1086 }
1087}
1088
1089#[async_trait]
1091impl crate::audit::AuditStorage for Arc<MemoryStorage> {
1092 async fn store_event(&self, event: &crate::audit::AuditEvent) -> Result<()> {
1093 self.as_ref().store_event(event).await
1094 }
1095
1096 async fn query_events(
1097 &self,
1098 query: &crate::audit::AuditQuery,
1099 ) -> Result<Vec<crate::audit::AuditEvent>> {
1100 self.as_ref().query_events(query).await
1101 }
1102
1103 async fn get_event(&self, event_id: &str) -> Result<Option<crate::audit::AuditEvent>> {
1104 self.as_ref().get_event(event_id).await
1105 }
1106
1107 async fn count_events(&self, query: &crate::audit::AuditQuery) -> Result<u64> {
1108 self.as_ref().count_events(query).await
1109 }
1110
1111 async fn delete_old_events(&self, before: std::time::SystemTime) -> Result<u64> {
1112 self.as_ref().delete_old_events(before).await
1113 }
1114
1115 async fn get_statistics(
1116 &self,
1117 query: &crate::audit::StatsQuery,
1118 ) -> Result<crate::audit::AuditStatistics> {
1119 self.as_ref().get_statistics(query).await
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use super::*;
1126 use crate::tokens::AuthToken;
1127
1128 #[tokio::test]
1131 async fn test_memory_storage() {
1132 let storage = MemoryStorage::new();
1133
1134 let token = AuthToken::new("user123", "token123", Duration::from_secs(3600), "test");
1135
1136 storage.store_token(&token).await.unwrap();
1137
1138 let retrieved = storage.get_token(&token.token_id).await.unwrap().unwrap();
1139 assert_eq!(retrieved.user_id, "user123");
1140
1141 let retrieved = storage
1142 .get_token_by_access_token(&token.access_token)
1143 .await
1144 .unwrap()
1145 .unwrap();
1146 assert_eq!(retrieved.token_id, token.token_id);
1147
1148 let user_tokens = storage.list_user_tokens("user123").await.unwrap();
1149 assert_eq!(user_tokens.len(), 1);
1150
1151 storage.delete_token(&token.token_id).await.unwrap();
1152 let retrieved = storage.get_token(&token.token_id).await.unwrap();
1153 assert!(retrieved.is_none());
1154 }
1155
1156 #[tokio::test]
1157 async fn test_token_update() {
1158 let storage = MemoryStorage::new();
1159 let mut token =
1160 AuthToken::new("user1", "access_original", Duration::from_secs(3600), "pw");
1161 storage.store_token(&token).await.unwrap();
1162
1163 token.auth_method = "mfa".to_string();
1164 storage.update_token(&token).await.unwrap();
1165
1166 let retrieved = storage.get_token(&token.token_id).await.unwrap().unwrap();
1167 assert_eq!(retrieved.auth_method, "mfa");
1168 }
1169
1170 #[tokio::test]
1171 async fn test_token_get_nonexistent() {
1172 let storage = MemoryStorage::new();
1173 let result = storage.get_token("nonexistent").await.unwrap();
1174 assert!(result.is_none());
1175
1176 let result = storage
1177 .get_token_by_access_token("nonexistent")
1178 .await
1179 .unwrap();
1180 assert!(result.is_none());
1181 }
1182
1183 #[tokio::test]
1184 async fn test_list_user_tokens_multiple() {
1185 let storage = MemoryStorage::new();
1186 let t1 = AuthToken::new("user_a", "at1", Duration::from_secs(3600), "pw");
1187 let t2 = AuthToken::new("user_a", "at2", Duration::from_secs(3600), "pw");
1188 let t3 = AuthToken::new("user_b", "at3", Duration::from_secs(3600), "pw");
1189
1190 storage.store_token(&t1).await.unwrap();
1191 storage.store_token(&t2).await.unwrap();
1192 storage.store_token(&t3).await.unwrap();
1193
1194 let user_a_tokens = storage.list_user_tokens("user_a").await.unwrap();
1195 assert_eq!(user_a_tokens.len(), 2);
1196
1197 let user_b_tokens = storage.list_user_tokens("user_b").await.unwrap();
1198 assert_eq!(user_b_tokens.len(), 1);
1199
1200 let empty = storage.list_user_tokens("nobody").await.unwrap();
1201 assert!(empty.is_empty());
1202 }
1203
1204 #[tokio::test]
1205 async fn test_store_tokens_bulk() {
1206 let storage = MemoryStorage::new();
1207 let tokens = vec![
1208 AuthToken::new("u1", "a1", Duration::from_secs(3600), "pw"),
1209 AuthToken::new("u2", "a2", Duration::from_secs(3600), "pw"),
1210 AuthToken::new("u3", "a3", Duration::from_secs(3600), "pw"),
1211 ];
1212 let ids: Vec<String> = tokens.iter().map(|t| t.token_id.clone()).collect();
1213
1214 storage.store_tokens_bulk(&tokens).await.unwrap();
1215
1216 for id in &ids {
1217 assert!(storage.get_token(id).await.unwrap().is_some());
1218 }
1219
1220 storage.delete_tokens_bulk(&ids).await.unwrap();
1221
1222 for id in &ids {
1223 assert!(storage.get_token(id).await.unwrap().is_none());
1224 }
1225 }
1226
1227 #[tokio::test]
1230 async fn test_session_storage() {
1231 let storage = MemoryStorage::new();
1232
1233 let session = SessionData::new("session123", "user123", Duration::from_secs(3600))
1234 .with_metadata(
1235 Some("192.168.1.1".to_string()),
1236 Some("Test Agent".to_string()),
1237 );
1238
1239 storage
1240 .store_session(&session.session_id, &session)
1241 .await
1242 .unwrap();
1243
1244 let retrieved = storage
1245 .get_session(&session.session_id)
1246 .await
1247 .unwrap()
1248 .unwrap();
1249 assert_eq!(retrieved.user_id, "user123");
1250 assert_eq!(retrieved.ip_address, Some("192.168.1.1".to_string()));
1251
1252 storage.delete_session(&session.session_id).await.unwrap();
1253 let retrieved = storage.get_session(&session.session_id).await.unwrap();
1254 assert!(retrieved.is_none());
1255 }
1256
1257 #[tokio::test]
1258 async fn test_session_get_nonexistent() {
1259 let storage = MemoryStorage::new();
1260 let result = storage.get_session("no_such_session").await.unwrap();
1261 assert!(result.is_none());
1262 }
1263
1264 #[tokio::test]
1265 async fn test_list_user_sessions() {
1266 let storage = MemoryStorage::new();
1267
1268 let s1 = SessionData::new("s1", "alice", Duration::from_secs(3600));
1269 let s2 = SessionData::new("s2", "alice", Duration::from_secs(3600));
1270 let s3 = SessionData::new("s3", "bob", Duration::from_secs(3600));
1271
1272 storage.store_session("s1", &s1).await.unwrap();
1273 storage.store_session("s2", &s2).await.unwrap();
1274 storage.store_session("s3", &s3).await.unwrap();
1275
1276 let alice_sessions = storage.list_user_sessions("alice").await.unwrap();
1277 assert_eq!(alice_sessions.len(), 2);
1278
1279 let bob_sessions = storage.list_user_sessions("bob").await.unwrap();
1280 assert_eq!(bob_sessions.len(), 1);
1281 }
1282
1283 #[tokio::test]
1284 async fn test_count_active_sessions() {
1285 let storage = MemoryStorage::new();
1286
1287 let s1 = SessionData::new("cs1", "u1", Duration::from_secs(3600));
1288 let s2 = SessionData::new("cs2", "u2", Duration::from_secs(3600));
1289 storage.store_session("cs1", &s1).await.unwrap();
1290 storage.store_session("cs2", &s2).await.unwrap();
1291
1292 let count = storage.count_active_sessions().await.unwrap();
1293 assert!(count >= 2);
1294 }
1295
1296 #[tokio::test]
1297 async fn test_store_sessions_bulk() {
1298 let storage = MemoryStorage::new();
1299 let sessions = vec![
1300 ("bs1".to_string(), SessionData::new("bs1", "u1", Duration::from_secs(3600))),
1301 ("bs2".to_string(), SessionData::new("bs2", "u2", Duration::from_secs(3600))),
1302 ];
1303
1304 storage.store_sessions_bulk(&sessions).await.unwrap();
1305 assert!(storage.get_session("bs1").await.unwrap().is_some());
1306 assert!(storage.get_session("bs2").await.unwrap().is_some());
1307
1308 let ids = vec!["bs1".to_string(), "bs2".to_string()];
1309 storage.delete_sessions_bulk(&ids).await.unwrap();
1310 assert!(storage.get_session("bs1").await.unwrap().is_none());
1311 assert!(storage.get_session("bs2").await.unwrap().is_none());
1312 }
1313
1314 #[tokio::test]
1315 async fn test_session_custom_data() {
1316 let storage = MemoryStorage::new();
1317 let mut session = SessionData::new("sd1", "u1", Duration::from_secs(3600));
1318 session.set_data("theme", serde_json::json!("dark"));
1319 session.set_data("lang", serde_json::json!("en"));
1320
1321 storage.store_session("sd1", &session).await.unwrap();
1322
1323 let retrieved = storage.get_session("sd1").await.unwrap().unwrap();
1324 assert_eq!(
1325 retrieved.get_data("theme"),
1326 Some(&serde_json::json!("dark"))
1327 );
1328 assert_eq!(retrieved.get_data("lang"), Some(&serde_json::json!("en")));
1329 assert_eq!(retrieved.get_data("missing"), None);
1330 }
1331
1332 #[tokio::test]
1335 async fn test_kv_storage() {
1336 let storage = MemoryStorage::new();
1337
1338 let key = "test_key";
1339 let value = b"test_value";
1340
1341 storage
1342 .store_kv(key, value, Some(Duration::from_secs(3600)))
1343 .await
1344 .unwrap();
1345
1346 let retrieved = storage.get_kv(key).await.unwrap().unwrap();
1347 assert_eq!(retrieved, value);
1348
1349 storage.delete_kv(key).await.unwrap();
1350 let retrieved = storage.get_kv(key).await.unwrap();
1351 assert!(retrieved.is_none());
1352 }
1353
1354 #[tokio::test]
1355 async fn test_kv_no_ttl() {
1356 let storage = MemoryStorage::new();
1357 storage
1358 .store_kv("persistent", b"forever", None)
1359 .await
1360 .unwrap();
1361
1362 let retrieved = storage.get_kv("persistent").await.unwrap().unwrap();
1363 assert_eq!(retrieved, b"forever");
1364 }
1365
1366 #[tokio::test]
1367 async fn test_kv_overwrite() {
1368 let storage = MemoryStorage::new();
1369 storage
1370 .store_kv("k1", b"v1", Some(Duration::from_secs(3600)))
1371 .await
1372 .unwrap();
1373 storage
1374 .store_kv("k1", b"v2", Some(Duration::from_secs(3600)))
1375 .await
1376 .unwrap();
1377
1378 let retrieved = storage.get_kv("k1").await.unwrap().unwrap();
1379 assert_eq!(retrieved, b"v2");
1380 }
1381
1382 #[tokio::test]
1383 async fn test_kv_get_nonexistent() {
1384 let storage = MemoryStorage::new();
1385 let result = storage.get_kv("no_such_key").await.unwrap();
1386 assert!(result.is_none());
1387 }
1388
1389 #[tokio::test]
1390 async fn test_kv_list_keys_prefix() {
1391 let storage = MemoryStorage::new();
1392 storage
1393 .store_kv("user:1:name", b"alice", None)
1394 .await
1395 .unwrap();
1396 storage
1397 .store_kv("user:1:email", b"alice@x.com", None)
1398 .await
1399 .unwrap();
1400 storage
1401 .store_kv("user:2:name", b"bob", None)
1402 .await
1403 .unwrap();
1404 storage
1405 .store_kv("session:abc", b"data", None)
1406 .await
1407 .unwrap();
1408
1409 let user1_keys = storage.list_kv_keys("user:1:").await.unwrap();
1410 assert_eq!(user1_keys.len(), 2);
1411
1412 let all_user_keys = storage.list_kv_keys("user:").await.unwrap();
1413 assert_eq!(all_user_keys.len(), 3);
1414
1415 let session_keys = storage.list_kv_keys("session:").await.unwrap();
1416 assert_eq!(session_keys.len(), 1);
1417
1418 let empty = storage.list_kv_keys("nonexistent:").await.unwrap();
1419 assert!(empty.is_empty());
1420 }
1421
1422 #[tokio::test]
1423 async fn test_kv_binary_data() {
1424 let storage = MemoryStorage::new();
1425 let binary = vec![0u8, 1, 2, 255, 254, 253, 0, 128];
1426 storage
1427 .store_kv("binary_key", &binary, None)
1428 .await
1429 .unwrap();
1430
1431 let retrieved = storage.get_kv("binary_key").await.unwrap().unwrap();
1432 assert_eq!(retrieved, binary);
1433 }
1434
1435 #[tokio::test]
1438 async fn test_cleanup_expired() {
1439 let storage = MemoryStorage::new();
1440
1441 storage
1447 .store_kv("expire_me", b"val", Some(Duration::from_secs(0)))
1448 .await
1449 .unwrap();
1450 storage
1451 .store_kv("keep_me", b"val", Some(Duration::from_secs(3600)))
1452 .await
1453 .unwrap();
1454
1455 tokio::time::sleep(Duration::from_millis(10)).await;
1457
1458 storage.cleanup_expired().await.unwrap();
1459
1460 assert!(storage.get_kv("expire_me").await.unwrap().is_none());
1462 assert!(storage.get_kv("keep_me").await.unwrap().is_some());
1463 }
1464
1465 #[tokio::test]
1468 async fn test_session_is_expired() {
1469 let mut session = SessionData::new("exp1", "u1", Duration::from_secs(3600));
1470 assert!(!session.is_expired());
1471
1472 session.expires_at = chrono::Utc::now() - chrono::Duration::seconds(1);
1473 assert!(session.is_expired());
1474 }
1475
1476 #[tokio::test]
1477 async fn test_session_update_activity() {
1478 let mut session = SessionData::new("act1", "u1", Duration::from_secs(3600));
1479 let first_activity = session.last_activity;
1480 session.update_activity();
1482 assert!(session.last_activity >= first_activity);
1483 }
1484
1485 #[tokio::test]
1488 async fn test_token_store_twice_overwrites_primary() {
1489 let storage = MemoryStorage::new();
1490 let token = AuthToken::new("u1", "at1", Duration::from_secs(3600), "pw");
1491 let token_id = token.token_id.clone();
1492
1493 storage.store_token(&token).await.unwrap();
1494 storage.store_token(&token).await.unwrap();
1495
1496 let got = storage.get_token(&token_id).await.unwrap();
1498 assert!(got.is_some());
1499 }
1500
1501 #[tokio::test]
1504 async fn test_delete_nonexistent_token() {
1505 let storage = MemoryStorage::new();
1506 let result = storage.delete_token("nope").await;
1508 assert!(result.is_ok());
1509 }
1510
1511 #[tokio::test]
1512 async fn test_delete_nonexistent_session() {
1513 let storage = MemoryStorage::new();
1514 let result = storage.delete_session("nope").await;
1515 assert!(result.is_ok());
1516 }
1517
1518 #[tokio::test]
1519 async fn test_delete_nonexistent_kv() {
1520 let storage = MemoryStorage::new();
1521 let result = storage.delete_kv("nope").await;
1522 assert!(result.is_ok());
1523 }
1524}