use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::RwLock;
use uuid::Uuid;
const MAX_FAILED_ATTEMPTS: u32 = 5;
const LOCKOUT_DURATION: Duration = Duration::from_secs(15 * 60);
const CLEANUP_INTERVAL: Duration = Duration::from_secs(300);
const MAX_ENTRIES: usize = 50_000;
#[derive(Debug, Clone)]
struct MfaAttemptEntry {
failed_count: u32,
locked_until: Option<Instant>,
last_attempt: Instant,
}
impl MfaAttemptEntry {
fn new() -> Self {
Self {
failed_count: 0,
locked_until: None,
last_attempt: Instant::now(),
}
}
fn is_locked_out(&self) -> bool {
match self.locked_until {
Some(until) => Instant::now() < until,
None => false,
}
}
fn lockout_remaining(&self) -> Option<Duration> {
self.locked_until
.filter(|until| Instant::now() < *until)
.map(|until| until.duration_since(Instant::now()))
}
}
#[derive(Clone)]
pub struct MfaAttemptService {
entries: Arc<RwLock<HashMap<Uuid, MfaAttemptEntry>>>,
last_cleanup: Arc<RwLock<Instant>>,
}
impl MfaAttemptService {
pub fn new() -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
last_cleanup: Arc::new(RwLock::new(Instant::now())),
}
}
pub async fn check_allowed(&self, user_id: Uuid) -> Result<(), Duration> {
self.maybe_cleanup().await;
let entries = self.entries.read().await;
if let Some(entry) = entries.get(&user_id) {
if let Some(remaining) = entry.lockout_remaining() {
return Err(remaining);
}
}
Ok(())
}
pub async fn record_failed(&self, user_id: Uuid) -> Result<u32, Duration> {
self.maybe_cleanup().await;
let mut entries = self.entries.write().await;
if entries.len() >= MAX_ENTRIES && !entries.contains_key(&user_id) {
self.evict_oldest(&mut entries);
}
let entry = entries.entry(user_id).or_insert_with(MfaAttemptEntry::new);
entry.last_attempt = Instant::now();
if !entry.is_locked_out() && entry.locked_until.is_some() {
entry.failed_count = 0;
entry.locked_until = None;
}
entry.failed_count += 1;
tracing::debug!(
user_id = %user_id,
failed_count = entry.failed_count,
max_attempts = MAX_FAILED_ATTEMPTS,
"SEC-04: MFA attempt failed"
);
if entry.failed_count >= MAX_FAILED_ATTEMPTS {
entry.locked_until = Some(Instant::now() + LOCKOUT_DURATION);
tracing::warn!(
user_id = %user_id,
lockout_minutes = LOCKOUT_DURATION.as_secs() / 60,
"SEC-04: User locked out due to too many failed MFA attempts"
);
Err(LOCKOUT_DURATION)
} else {
Ok(MAX_FAILED_ATTEMPTS - entry.failed_count)
}
}
pub async fn record_success(&self, user_id: Uuid) {
let mut entries = self.entries.write().await;
entries.remove(&user_id);
tracing::debug!(
user_id = %user_id,
"SEC-04: MFA verification successful, cleared attempt tracking"
);
}
async fn maybe_cleanup(&self) {
let should_cleanup = {
let last = self.last_cleanup.read().await;
Instant::now().duration_since(*last) > CLEANUP_INTERVAL
};
if should_cleanup {
let mut last = self.last_cleanup.write().await;
*last = Instant::now();
drop(last);
let mut entries = self.entries.write().await;
let stale_threshold = Duration::from_secs(30 * 60); let now = Instant::now();
entries.retain(|_, entry| {
entry.is_locked_out() || now.duration_since(entry.last_attempt) < stale_threshold
});
}
}
fn evict_oldest(&self, entries: &mut HashMap<Uuid, MfaAttemptEntry>) {
let evict_count = std::cmp::max(1, MAX_ENTRIES / 5);
let now = Instant::now();
let mut by_age: Vec<_> = entries
.iter()
.map(|(k, v)| (*k, now.duration_since(v.last_attempt)))
.collect();
if by_age.len() <= evict_count {
entries.clear();
return;
}
by_age.sort_by(|a, b| b.1.cmp(&a.1));
for (key, _) in by_age.into_iter().take(evict_count) {
entries.remove(&key);
}
}
}
impl Default for MfaAttemptService {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initially_allowed() {
let service = MfaAttemptService::new();
let user_id = Uuid::new_v4();
assert!(service.check_allowed(user_id).await.is_ok());
}
#[tokio::test]
async fn test_failed_attempts_under_limit() {
let service = MfaAttemptService::new();
let user_id = Uuid::new_v4();
for i in 0..4 {
let result = service.record_failed(user_id).await;
assert!(
result.is_ok(),
"Attempt {} should not trigger lockout",
i + 1
);
assert_eq!(result.unwrap(), MAX_FAILED_ATTEMPTS - (i as u32 + 1));
}
assert!(service.check_allowed(user_id).await.is_ok());
}
#[tokio::test]
async fn test_lockout_after_max_attempts() {
let service = MfaAttemptService::new();
let user_id = Uuid::new_v4();
for i in 0..5 {
let _ = service.record_failed(user_id).await;
if i < 4 {
assert!(service.check_allowed(user_id).await.is_ok());
}
}
let result = service.check_allowed(user_id).await;
assert!(result.is_err());
let remaining = result.unwrap_err();
assert!(remaining.as_secs() > 0);
}
#[tokio::test]
async fn test_success_clears_attempts() {
let service = MfaAttemptService::new();
let user_id = Uuid::new_v4();
service.record_failed(user_id).await.ok();
service.record_failed(user_id).await.ok();
service.record_success(user_id).await;
let result = service.record_failed(user_id).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), MAX_FAILED_ATTEMPTS - 1);
}
#[tokio::test]
async fn test_different_users_independent() {
let service = MfaAttemptService::new();
let user1 = Uuid::new_v4();
let user2 = Uuid::new_v4();
for _ in 0..5 {
let _ = service.record_failed(user1).await;
}
assert!(service.check_allowed(user1).await.is_err());
assert!(service.check_allowed(user2).await.is_ok());
}
#[tokio::test]
async fn test_lockout_returns_duration() {
let service = MfaAttemptService::new();
let user_id = Uuid::new_v4();
for _ in 0..5 {
let _ = service.record_failed(user_id).await;
}
let result = service.check_allowed(user_id).await;
assert!(result.is_err());
let remaining = result.unwrap_err();
assert!(remaining.as_secs() > 14 * 60); assert!(remaining.as_secs() <= 15 * 60); }
}