use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use std::collections::HashMap;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::errors::AppError;
#[derive(Debug, Clone)]
pub struct LoginAttemptRecord {
pub id: Uuid,
pub user_id: Option<Uuid>,
pub email: String,
pub ip_address: Option<String>,
pub successful: bool,
pub attempted_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct LockoutStatus {
pub is_locked: bool,
pub failed_attempts: u32,
pub lockout_expires_at: Option<DateTime<Utc>>,
pub lockout_remaining_secs: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct LoginAttemptConfig {
pub max_attempts: u32,
pub window_minutes: u32,
pub lockout_minutes: u32,
}
impl Default for LoginAttemptConfig {
fn default() -> Self {
Self {
max_attempts: 5,
window_minutes: 15,
lockout_minutes: 30,
}
}
}
#[async_trait]
pub trait LoginAttemptRepository: Send + Sync {
async fn record_attempt(
&self,
user_id: Option<Uuid>,
email: &str,
ip_address: Option<&str>,
successful: bool,
) -> Result<(), AppError>;
async fn get_lockout_status(
&self,
email: &str,
config: &LoginAttemptConfig,
) -> Result<LockoutStatus, AppError>;
async fn record_failed_attempt_atomic(
&self,
user_id: Option<Uuid>,
email: &str,
ip_address: Option<&str>,
config: &LoginAttemptConfig,
) -> Result<LockoutStatus, AppError>;
async fn clear_failed_attempts(&self, email: &str) -> Result<(), AppError>;
async fn cleanup_old_records(&self, older_than: DateTime<Utc>) -> Result<u64, AppError>;
}
pub struct InMemoryLoginAttemptRepository {
attempts: RwLock<HashMap<String, Vec<LoginAttemptRecord>>>,
}
impl InMemoryLoginAttemptRepository {
pub fn new() -> Self {
Self {
attempts: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryLoginAttemptRepository {
fn default() -> Self {
Self::new()
}
}
const RETENTION_DAYS: i64 = 7;
const MAX_ATTEMPTS_PER_EMAIL: usize = 1000;
#[async_trait]
impl LoginAttemptRepository for InMemoryLoginAttemptRepository {
async fn record_attempt(
&self,
user_id: Option<Uuid>,
email: &str,
ip_address: Option<&str>,
successful: bool,
) -> Result<(), AppError> {
let now = Utc::now();
let record = LoginAttemptRecord {
id: Uuid::new_v4(),
user_id,
email: email.to_lowercase(),
ip_address: ip_address.map(|s| s.to_string()),
successful,
attempted_at: now,
};
let mut attempts = self.attempts.write().await;
let retention_cutoff = now - Duration::days(RETENTION_DAYS);
for user_attempts in attempts.values_mut() {
user_attempts.retain(|a| a.attempted_at >= retention_cutoff);
}
attempts.retain(|_, v| !v.is_empty());
let user_attempts = attempts
.entry(email.to_lowercase())
.or_insert_with(Vec::new);
user_attempts.push(record);
if user_attempts.len() > MAX_ATTEMPTS_PER_EMAIL {
user_attempts.drain(0..(user_attempts.len() - MAX_ATTEMPTS_PER_EMAIL));
}
Ok(())
}
async fn get_lockout_status(
&self,
email: &str,
config: &LoginAttemptConfig,
) -> Result<LockoutStatus, AppError> {
let attempts = self.attempts.read().await;
let email_lower = email.to_lowercase();
let now = Utc::now();
let window_start = now - Duration::minutes(config.window_minutes as i64);
let user_attempts = match attempts.get(&email_lower) {
Some(a) => a,
None => {
return Ok(LockoutStatus {
is_locked: false,
failed_attempts: 0,
lockout_expires_at: None,
lockout_remaining_secs: None,
})
}
};
let failed_in_window: Vec<_> = user_attempts
.iter()
.filter(|a| !a.successful && a.attempted_at > window_start)
.collect();
let failed_attempts = failed_in_window.len() as u32;
if failed_attempts >= config.max_attempts {
if let Some(last_failed) = failed_in_window.iter().map(|a| a.attempted_at).max() {
let lockout_expires_at =
last_failed + Duration::minutes(config.lockout_minutes as i64);
if lockout_expires_at > now {
let remaining = (lockout_expires_at - now).num_seconds();
return Ok(LockoutStatus {
is_locked: true,
failed_attempts,
lockout_expires_at: Some(lockout_expires_at),
lockout_remaining_secs: Some(remaining),
});
}
}
}
Ok(LockoutStatus {
is_locked: false,
failed_attempts,
lockout_expires_at: None,
lockout_remaining_secs: None,
})
}
async fn clear_failed_attempts(&self, email: &str) -> Result<(), AppError> {
let mut attempts = self.attempts.write().await;
if let Some(user_attempts) = attempts.get_mut(&email.to_lowercase()) {
user_attempts.retain(|a| a.successful);
}
Ok(())
}
async fn record_failed_attempt_atomic(
&self,
user_id: Option<Uuid>,
email: &str,
ip_address: Option<&str>,
config: &LoginAttemptConfig,
) -> Result<LockoutStatus, AppError> {
let now = Utc::now();
let email_lower = email.to_lowercase();
let window_start = now - Duration::minutes(config.window_minutes as i64);
let mut attempts = self.attempts.write().await;
let record = LoginAttemptRecord {
id: Uuid::new_v4(),
user_id,
email: email_lower.clone(),
ip_address: ip_address.map(|s| s.to_string()),
successful: false,
attempted_at: now,
};
let user_attempts = attempts.entry(email_lower).or_insert_with(Vec::new);
user_attempts.push(record);
if user_attempts.len() > MAX_ATTEMPTS_PER_EMAIL {
user_attempts.drain(0..(user_attempts.len() - MAX_ATTEMPTS_PER_EMAIL));
}
let failed_in_window: Vec<_> = user_attempts
.iter()
.filter(|a| !a.successful && a.attempted_at > window_start)
.collect();
let failed_attempts = failed_in_window.len() as u32;
if failed_attempts >= config.max_attempts {
if let Some(last_failed) = failed_in_window.iter().map(|a| a.attempted_at).max() {
let lockout_expires_at =
last_failed + Duration::minutes(config.lockout_minutes as i64);
if lockout_expires_at > now {
let remaining = (lockout_expires_at - now).num_seconds();
return Ok(LockoutStatus {
is_locked: true,
failed_attempts,
lockout_expires_at: Some(lockout_expires_at),
lockout_remaining_secs: Some(remaining),
});
}
}
}
Ok(LockoutStatus {
is_locked: false,
failed_attempts,
lockout_expires_at: None,
lockout_remaining_secs: None,
})
}
async fn cleanup_old_records(&self, older_than: DateTime<Utc>) -> Result<u64, AppError> {
let mut attempts = self.attempts.write().await;
let mut removed = 0u64;
for user_attempts in attempts.values_mut() {
let before = user_attempts.len();
user_attempts.retain(|a| a.attempted_at >= older_than);
removed += (before - user_attempts.len()) as u64;
}
attempts.retain(|_, v| !v.is_empty());
Ok(removed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_record_and_check_lockout() {
let repo = InMemoryLoginAttemptRepository::new();
let config = LoginAttemptConfig {
max_attempts: 3,
window_minutes: 15,
lockout_minutes: 30,
};
let status = repo
.get_lockout_status("test@example.com", &config)
.await
.unwrap();
assert!(!status.is_locked);
assert_eq!(status.failed_attempts, 0);
for _ in 0..3 {
repo.record_attempt(None, "test@example.com", None, false)
.await
.unwrap();
}
let status = repo
.get_lockout_status("test@example.com", &config)
.await
.unwrap();
assert!(status.is_locked);
assert_eq!(status.failed_attempts, 3);
assert!(status.lockout_remaining_secs.is_some());
}
#[tokio::test]
async fn test_clear_failed_attempts() {
let repo = InMemoryLoginAttemptRepository::new();
let config = LoginAttemptConfig::default();
for _ in 0..3 {
repo.record_attempt(None, "test@example.com", None, false)
.await
.unwrap();
}
let status = repo
.get_lockout_status("test@example.com", &config)
.await
.unwrap();
assert_eq!(status.failed_attempts, 3);
repo.clear_failed_attempts("test@example.com")
.await
.unwrap();
let status = repo
.get_lockout_status("test@example.com", &config)
.await
.unwrap();
assert_eq!(status.failed_attempts, 0);
}
#[tokio::test]
async fn test_successful_attempt_not_counted() {
let repo = InMemoryLoginAttemptRepository::new();
let config = LoginAttemptConfig::default();
for _ in 0..10 {
repo.record_attempt(None, "test@example.com", None, true)
.await
.unwrap();
}
let status = repo
.get_lockout_status("test@example.com", &config)
.await
.unwrap();
assert!(!status.is_locked);
assert_eq!(status.failed_attempts, 0);
}
#[tokio::test]
async fn test_case_insensitive_email() {
let repo = InMemoryLoginAttemptRepository::new();
let config = LoginAttemptConfig {
max_attempts: 2,
window_minutes: 15,
lockout_minutes: 30,
};
repo.record_attempt(None, "Test@Example.COM", None, false)
.await
.unwrap();
repo.record_attempt(None, "test@example.com", None, false)
.await
.unwrap();
let status = repo
.get_lockout_status("TEST@EXAMPLE.COM", &config)
.await
.unwrap();
assert!(status.is_locked);
assert_eq!(status.failed_attempts, 2);
}
}