mockforge_core/security/
mfa_tracking.rs1use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum MfaMethod {
16 Totp,
18 Sms,
20 Email,
22 HardwareKey,
24 Push,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct MfaStatus {
31 pub user_id: Uuid,
33 pub enabled: bool,
35 pub methods: Vec<MfaMethod>,
37 pub enabled_at: Option<DateTime<Utc>>,
39 pub last_verification: Option<DateTime<Utc>>,
41 pub backup_codes_remaining: u32,
43}
44
45#[async_trait::async_trait]
47pub trait MfaStorage: Send + Sync {
48 async fn get_mfa_status(&self, user_id: Uuid) -> Result<Option<MfaStatus>, crate::Error>;
50
51 async fn set_mfa_status(&self, status: MfaStatus) -> Result<(), crate::Error>;
53
54 async fn get_users_with_mfa(&self) -> Result<Vec<Uuid>, crate::Error>;
56
57 async fn get_privileged_users_without_mfa(&self, privileged_user_ids: &[Uuid]) -> Result<Vec<Uuid>, crate::Error>;
59}
60
61pub struct InMemoryMfaStorage {
63 mfa_statuses: Arc<RwLock<HashMap<Uuid, MfaStatus>>>,
64}
65
66impl InMemoryMfaStorage {
67 pub fn new() -> Self {
69 Self {
70 mfa_statuses: Arc::new(RwLock::new(HashMap::new())),
71 }
72 }
73}
74
75impl Default for InMemoryMfaStorage {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81#[async_trait::async_trait]
82impl MfaStorage for InMemoryMfaStorage {
83 async fn get_mfa_status(&self, user_id: Uuid) -> Result<Option<MfaStatus>, crate::Error> {
84 let statuses = self.mfa_statuses.read().await;
85 Ok(statuses.get(&user_id).cloned())
86 }
87
88 async fn set_mfa_status(&self, status: MfaStatus) -> Result<(), crate::Error> {
89 let mut statuses = self.mfa_statuses.write().await;
90 statuses.insert(status.user_id, status);
91 Ok(())
92 }
93
94 async fn get_users_with_mfa(&self) -> Result<Vec<Uuid>, crate::Error> {
95 let statuses = self.mfa_statuses.read().await;
96 Ok(statuses
97 .iter()
98 .filter(|(_, status)| status.enabled)
99 .map(|(user_id, _)| *user_id)
100 .collect())
101 }
102
103 async fn get_privileged_users_without_mfa(&self, privileged_user_ids: &[Uuid]) -> Result<Vec<Uuid>, crate::Error> {
104 let statuses = self.mfa_statuses.read().await;
105 Ok(privileged_user_ids
106 .iter()
107 .filter(|user_id| {
108 statuses
109 .get(user_id)
110 .map(|s| !s.enabled)
111 .unwrap_or(true) })
113 .copied()
114 .collect())
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[tokio::test]
123 async fn test_mfa_storage() {
124 let storage = InMemoryMfaStorage::new();
125 let user_id = Uuid::new_v4();
126 let status = MfaStatus {
127 user_id,
128 enabled: true,
129 methods: vec![MfaMethod::Totp],
130 enabled_at: Some(Utc::now()),
131 last_verification: Some(Utc::now()),
132 backup_codes_remaining: 5,
133 };
134
135 storage.set_mfa_status(status).await.unwrap();
136 let retrieved = storage.get_mfa_status(user_id).await.unwrap();
137 assert!(retrieved.is_some());
138 assert!(retrieved.unwrap().enabled);
139 }
140}