use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::errors::AppError;
use crate::services::TotpService;
#[derive(Debug, Clone)]
pub struct TotpSecret {
pub id: Uuid,
pub user_id: Uuid,
pub secret: String,
pub enabled: bool,
pub created_at: DateTime<Utc>,
pub enabled_at: Option<DateTime<Utc>>,
pub last_used_time_step: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct RecoveryCode {
pub id: Uuid,
pub user_id: Uuid,
pub code_hash: String,
pub used: bool,
pub created_at: DateTime<Utc>,
pub used_at: Option<DateTime<Utc>>,
}
#[async_trait]
pub trait TotpRepository: Send + Sync {
async fn upsert_secret(&self, user_id: Uuid, secret: &str) -> Result<TotpSecret, AppError>;
async fn find_by_user(&self, user_id: Uuid) -> Result<Option<TotpSecret>, AppError>;
async fn enable_mfa(&self, user_id: Uuid) -> Result<(), AppError>;
async fn disable_mfa(&self, user_id: Uuid) -> Result<(), AppError>;
async fn has_mfa_enabled(&self, user_id: Uuid) -> Result<bool, AppError>;
async fn record_used_time_step_if_newer(
&self,
user_id: Uuid,
time_step: i64,
) -> Result<bool, AppError>;
async fn store_recovery_codes(
&self,
user_id: Uuid,
code_hashes: Vec<String>,
) -> Result<(), AppError>;
async fn get_recovery_codes(&self, user_id: Uuid) -> Result<Vec<RecoveryCode>, AppError>;
async fn use_recovery_code(&self, user_id: Uuid, code: &str) -> Result<bool, AppError>;
async fn delete_recovery_codes(&self, user_id: Uuid) -> Result<(), AppError>;
}
pub struct InMemoryTotpRepository {
secrets: RwLock<HashMap<Uuid, TotpSecret>>,
recovery_codes: RwLock<HashMap<Uuid, Vec<RecoveryCode>>>,
}
impl InMemoryTotpRepository {
pub fn new() -> Self {
Self {
secrets: RwLock::new(HashMap::new()),
recovery_codes: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryTotpRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TotpRepository for InMemoryTotpRepository {
async fn upsert_secret(&self, user_id: Uuid, secret: &str) -> Result<TotpSecret, AppError> {
let mut secrets = self.secrets.write().await;
let totp_secret = TotpSecret {
id: Uuid::new_v4(),
user_id,
secret: secret.to_string(),
enabled: false,
created_at: Utc::now(),
enabled_at: None,
last_used_time_step: None,
};
secrets.insert(user_id, totp_secret.clone());
Ok(totp_secret)
}
async fn find_by_user(&self, user_id: Uuid) -> Result<Option<TotpSecret>, AppError> {
let secrets = self.secrets.read().await;
Ok(secrets.get(&user_id).cloned())
}
async fn enable_mfa(&self, user_id: Uuid) -> Result<(), AppError> {
let mut secrets = self.secrets.write().await;
if let Some(secret) = secrets.get_mut(&user_id) {
secret.enabled = true;
secret.enabled_at = Some(Utc::now());
Ok(())
} else {
Err(AppError::NotFound("TOTP secret not found".into()))
}
}
async fn disable_mfa(&self, user_id: Uuid) -> Result<(), AppError> {
let mut secrets = self.secrets.write().await;
secrets.remove(&user_id);
let mut codes = self.recovery_codes.write().await;
codes.remove(&user_id);
Ok(())
}
async fn has_mfa_enabled(&self, user_id: Uuid) -> Result<bool, AppError> {
let secrets = self.secrets.read().await;
Ok(secrets.get(&user_id).map(|s| s.enabled).unwrap_or(false))
}
async fn record_used_time_step_if_newer(
&self,
user_id: Uuid,
time_step: i64,
) -> Result<bool, AppError> {
let mut secrets = self.secrets.write().await;
if let Some(secret) = secrets.get_mut(&user_id) {
let should_update = secret
.last_used_time_step
.map(|last| time_step > last)
.unwrap_or(true);
if should_update {
secret.last_used_time_step = Some(time_step);
}
Ok(should_update)
} else {
Err(AppError::NotFound("TOTP secret not found".into()))
}
}
async fn store_recovery_codes(
&self,
user_id: Uuid,
code_hashes: Vec<String>,
) -> Result<(), AppError> {
let mut codes = self.recovery_codes.write().await;
let now = Utc::now();
let recovery_codes: Vec<RecoveryCode> = code_hashes
.into_iter()
.map(|hash| RecoveryCode {
id: Uuid::new_v4(),
user_id,
code_hash: hash,
used: false,
created_at: now,
used_at: None,
})
.collect();
codes.insert(user_id, recovery_codes);
Ok(())
}
async fn get_recovery_codes(&self, user_id: Uuid) -> Result<Vec<RecoveryCode>, AppError> {
let codes = self.recovery_codes.read().await;
Ok(codes
.get(&user_id)
.map(|c| c.iter().filter(|code| !code.used).cloned().collect())
.unwrap_or_default())
}
async fn use_recovery_code(&self, user_id: Uuid, code: &str) -> Result<bool, AppError> {
let mut codes = self.recovery_codes.write().await;
if let Some(user_codes) = codes.get_mut(&user_id) {
let mut matched_idx: Option<usize> = None;
for (idx, stored_code) in user_codes.iter().enumerate() {
if !stored_code.used {
if TotpService::verify_recovery_code(code, &stored_code.code_hash) {
matched_idx = Some(idx);
}
}
}
if let Some(idx) = matched_idx {
user_codes[idx].used = true;
user_codes[idx].used_at = Some(Utc::now());
return Ok(true);
}
}
Ok(false)
}
async fn delete_recovery_codes(&self, user_id: Uuid) -> Result<(), AppError> {
let mut codes = self.recovery_codes.write().await;
codes.remove(&user_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_and_enable_mfa() {
let repo = InMemoryTotpRepository::new();
let user_id = Uuid::new_v4();
let secret = repo
.upsert_secret(user_id, "JBSWY3DPEHPK3PXP")
.await
.unwrap();
assert!(!secret.enabled);
assert!(!repo.has_mfa_enabled(user_id).await.unwrap());
repo.enable_mfa(user_id).await.unwrap();
assert!(repo.has_mfa_enabled(user_id).await.unwrap());
}
#[tokio::test]
async fn test_disable_mfa() {
let repo = InMemoryTotpRepository::new();
let user_id = Uuid::new_v4();
repo.upsert_secret(user_id, "JBSWY3DPEHPK3PXP")
.await
.unwrap();
repo.enable_mfa(user_id).await.unwrap();
assert!(repo.has_mfa_enabled(user_id).await.unwrap());
repo.disable_mfa(user_id).await.unwrap();
assert!(!repo.has_mfa_enabled(user_id).await.unwrap());
}
#[tokio::test]
async fn test_recovery_codes() {
let repo = InMemoryTotpRepository::new();
let user_id = Uuid::new_v4();
let code1 = "ABCD-1234";
let code2 = "EFGH-5678";
let code3 = "IJKL-9012";
let hashes = vec![
TotpService::hash_recovery_code(code1).unwrap(),
TotpService::hash_recovery_code(code2).unwrap(),
TotpService::hash_recovery_code(code3).unwrap(),
];
repo.store_recovery_codes(user_id, hashes).await.unwrap();
let codes = repo.get_recovery_codes(user_id).await.unwrap();
assert_eq!(codes.len(), 3);
assert!(repo.use_recovery_code(user_id, code1).await.unwrap());
assert!(!repo.use_recovery_code(user_id, code1).await.unwrap());
let codes = repo.get_recovery_codes(user_id).await.unwrap();
assert_eq!(codes.len(), 2);
assert!(!repo.use_recovery_code(user_id, "WRONG-CODE").await.unwrap());
}
#[tokio::test]
async fn test_recovery_code_case_insensitive() {
let repo = InMemoryTotpRepository::new();
let user_id = Uuid::new_v4();
let code = "ABCD-1234";
let hashes = vec![TotpService::hash_recovery_code(code).unwrap()];
repo.store_recovery_codes(user_id, hashes).await.unwrap();
assert!(repo.use_recovery_code(user_id, "abcd-1234").await.unwrap());
}
#[tokio::test]
async fn test_disable_mfa_removes_recovery_codes() {
let repo = InMemoryTotpRepository::new();
let user_id = Uuid::new_v4();
repo.upsert_secret(user_id, "JBSWY3DPEHPK3PXP")
.await
.unwrap();
let hashes = vec![TotpService::hash_recovery_code("ABCD-1234").unwrap()];
repo.store_recovery_codes(user_id, hashes).await.unwrap();
assert_eq!(repo.get_recovery_codes(user_id).await.unwrap().len(), 1);
repo.disable_mfa(user_id).await.unwrap();
assert_eq!(repo.get_recovery_codes(user_id).await.unwrap().len(), 0);
}
#[tokio::test]
async fn test_record_used_time_step_if_newer() {
let repo = InMemoryTotpRepository::new();
let user_id = Uuid::new_v4();
repo.upsert_secret(user_id, "JBSWY3DPEHPK3PXP")
.await
.unwrap();
assert!(repo
.record_used_time_step_if_newer(user_id, 100)
.await
.unwrap());
assert!(!repo
.record_used_time_step_if_newer(user_id, 100)
.await
.unwrap());
assert!(!repo
.record_used_time_step_if_newer(user_id, 99)
.await
.unwrap());
assert!(repo
.record_used_time_step_if_newer(user_id, 101)
.await
.unwrap());
}
#[tokio::test]
async fn test_record_used_time_step_if_newer_concurrent() {
let repo = InMemoryTotpRepository::new();
let user_id = Uuid::new_v4();
repo.upsert_secret(user_id, "JBSWY3DPEHPK3PXP")
.await
.unwrap();
let (first, second) = tokio::join!(
repo.record_used_time_step_if_newer(user_id, 200),
repo.record_used_time_step_if_newer(user_id, 200)
);
let first = first.unwrap();
let second = second.unwrap();
assert_ne!(first, second);
}
}