use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::IpAddr;
use crate::error::{Error, Result, StorageError};
use async_trait::async_trait;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountLockoutConfig {
pub max_failed_attempts: u32,
pub lockout_duration: Duration,
pub progressive_delay: bool,
pub base_delay: Duration,
pub max_delay: Duration,
pub attempt_window: Duration,
pub track_ip: bool,
pub max_ip_attempts: u32,
}
impl Default for AccountLockoutConfig {
fn default() -> Self {
Self {
max_failed_attempts: 5,
lockout_duration: Duration::minutes(15),
progressive_delay: true,
base_delay: Duration::seconds(1),
max_delay: Duration::minutes(5),
attempt_window: Duration::hours(1),
track_ip: true,
max_ip_attempts: 10,
}
}
}
impl AccountLockoutConfig {
pub fn strict() -> Self {
Self {
max_failed_attempts: 3,
lockout_duration: Duration::minutes(30),
progressive_delay: true,
base_delay: Duration::seconds(2),
max_delay: Duration::minutes(10),
attempt_window: Duration::hours(24),
track_ip: true,
max_ip_attempts: 5,
}
}
pub fn relaxed() -> Self {
Self {
max_failed_attempts: 10,
lockout_duration: Duration::minutes(5),
progressive_delay: false,
base_delay: Duration::seconds(0),
max_delay: Duration::seconds(0),
attempt_window: Duration::minutes(30),
track_ip: false,
max_ip_attempts: 0,
}
}
pub fn validate(&self) -> Result<()> {
if self.max_failed_attempts == 0 {
return Err(Error::validation(
"max_failed_attempts must be greater than 0",
));
}
if self.lockout_duration.num_seconds() <= 0 {
return Err(Error::validation("lockout_duration must be greater than 0"));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoginAttempt {
pub timestamp: DateTime<Utc>,
pub success: bool,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl LoginAttempt {
pub fn failed() -> Self {
Self {
timestamp: Utc::now(),
success: false,
ip_address: None,
user_agent: None,
metadata: HashMap::new(),
}
}
pub fn success() -> Self {
Self {
timestamp: Utc::now(),
success: true,
ip_address: None,
user_agent: None,
metadata: HashMap::new(),
}
}
pub fn with_ip(mut self, ip: IpAddr) -> Self {
self.ip_address = Some(ip);
self
}
pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = Some(user_agent.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountLockStatus {
pub account_id: String,
pub is_locked: bool,
pub locked_at: Option<DateTime<Utc>>,
pub locked_until: Option<DateTime<Utc>>,
pub failed_attempts: u32,
pub last_attempt_at: Option<DateTime<Utc>>,
pub last_success_at: Option<DateTime<Utc>>,
pub current_delay: Duration,
pub lock_reason: Option<LockReason>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LockReason {
TooManyFailedAttempts,
SuspiciousActivity,
AdminAction,
IpBanned,
Other,
}
impl AccountLockStatus {
pub fn new(account_id: impl Into<String>) -> Self {
Self {
account_id: account_id.into(),
is_locked: false,
locked_at: None,
locked_until: None,
failed_attempts: 0,
last_attempt_at: None,
last_success_at: None,
current_delay: Duration::zero(),
lock_reason: None,
}
}
pub fn is_currently_locked(&self) -> bool {
if !self.is_locked {
return false;
}
if let Some(locked_until) = self.locked_until {
locked_until > Utc::now()
} else {
true
}
}
pub fn remaining_lockout_time(&self) -> Option<Duration> {
if !self.is_currently_locked() {
return None;
}
self.locked_until.map(|until| {
let remaining = until - Utc::now();
if remaining.num_seconds() > 0 {
remaining
} else {
Duration::zero()
}
})
}
pub fn next_attempt_allowed_at(&self) -> Option<DateTime<Utc>> {
if self.is_currently_locked() {
return self.locked_until;
}
if self.current_delay.num_seconds() > 0 {
self.last_attempt_at.map(|t| t + self.current_delay)
} else {
None
}
}
pub fn can_attempt_now(&self) -> bool {
if self.is_currently_locked() {
return false;
}
if let Some(next_allowed) = self.next_attempt_allowed_at() {
Utc::now() >= next_allowed
} else {
true
}
}
}
#[derive(Debug, Clone)]
pub enum LoginCheckResult {
Allowed,
Locked {
reason: LockReason,
remaining: Option<Duration>,
},
DelayRequired {
wait_time: Duration,
},
IpBanned {
ip: IpAddr,
},
}
#[derive(Debug)]
pub struct LoginAttemptTracker {
config: AccountLockoutConfig,
accounts: HashMap<String, AccountLockStatus>,
ip_attempts: HashMap<IpAddr, (u32, DateTime<Utc>)>,
}
impl LoginAttemptTracker {
pub fn new(config: AccountLockoutConfig) -> Result<Self> {
config.validate()?;
Ok(Self {
config,
accounts: HashMap::new(),
ip_attempts: HashMap::new(),
})
}
pub fn with_default_config() -> Self {
Self {
config: AccountLockoutConfig::default(),
accounts: HashMap::new(),
ip_attempts: HashMap::new(),
}
}
pub fn config(&self) -> &AccountLockoutConfig {
&self.config
}
pub fn check_login_allowed(
&mut self,
account_id: &str,
ip: Option<IpAddr>,
) -> LoginCheckResult {
self.cleanup_expired();
if let Some(ip_addr) = ip
&& self.is_ip_banned(ip_addr)
{
return LoginCheckResult::IpBanned { ip: ip_addr };
}
let status = self
.accounts
.entry(account_id.to_string())
.or_insert_with(|| AccountLockStatus::new(account_id));
if status.is_currently_locked() {
return LoginCheckResult::Locked {
reason: status
.lock_reason
.unwrap_or(LockReason::TooManyFailedAttempts),
remaining: status.remaining_lockout_time(),
};
}
if !status.can_attempt_now()
&& let Some(next_allowed) = status.next_attempt_allowed_at()
{
let wait_time = next_allowed - Utc::now();
if wait_time.num_seconds() > 0 {
return LoginCheckResult::DelayRequired {
wait_time: Duration::seconds(wait_time.num_seconds()),
};
}
}
LoginCheckResult::Allowed
}
pub fn record_attempt(&mut self, account_id: &str, attempt: &LoginAttempt) {
let now = Utc::now();
if self.config.track_ip
&& let Some(ip) = attempt.ip_address
{
if !attempt.success {
let entry = self.ip_attempts.entry(ip).or_insert((0, now));
entry.0 += 1;
entry.1 = now;
} else {
self.ip_attempts.remove(&ip);
}
}
let current_failed_attempts = self
.accounts
.get(account_id)
.map(|s| s.failed_attempts)
.unwrap_or(0);
let progressive_delay = self.config.progressive_delay;
let max_failed_attempts = self.config.max_failed_attempts;
let lockout_duration = self.config.lockout_duration;
let (new_failed_attempts, new_delay) = if attempt.success {
(0, Duration::zero())
} else {
let new_count = current_failed_attempts + 1;
let delay = if progressive_delay {
self.calculate_delay(new_count)
} else {
Duration::zero()
};
(new_count, delay)
};
let status = self
.accounts
.entry(account_id.to_string())
.or_insert_with(|| AccountLockStatus::new(account_id));
status.last_attempt_at = Some(now);
if attempt.success {
status.failed_attempts = 0;
status.is_locked = false;
status.locked_at = None;
status.locked_until = None;
status.last_success_at = Some(now);
status.current_delay = Duration::zero();
status.lock_reason = None;
} else {
status.failed_attempts = new_failed_attempts;
status.current_delay = new_delay;
if status.failed_attempts >= max_failed_attempts {
status.is_locked = true;
status.locked_at = Some(now);
status.locked_until = Some(now + lockout_duration);
status.lock_reason = Some(LockReason::TooManyFailedAttempts);
}
}
}
pub fn record_failed_attempt(&mut self, account_id: &str, ip: Option<IpAddr>) {
let mut attempt = LoginAttempt::failed();
if let Some(ip_addr) = ip {
attempt = attempt.with_ip(ip_addr);
}
self.record_attempt(account_id, &attempt);
}
pub fn record_successful_login(&mut self, account_id: &str, ip: Option<IpAddr>) {
let mut attempt = LoginAttempt::success();
if let Some(ip_addr) = ip {
attempt = attempt.with_ip(ip_addr);
}
self.record_attempt(account_id, &attempt);
}
pub fn get_account_status(&self, account_id: &str) -> Option<&AccountLockStatus> {
self.accounts.get(account_id)
}
pub fn lock_account(
&mut self,
account_id: &str,
reason: LockReason,
duration: Option<Duration>,
) {
let now = Utc::now();
let status = self
.accounts
.entry(account_id.to_string())
.or_insert_with(|| AccountLockStatus::new(account_id));
status.is_locked = true;
status.locked_at = Some(now);
status.locked_until = duration.map(|d| now + d);
status.lock_reason = Some(reason);
}
pub fn unlock_account(&mut self, account_id: &str) {
if let Some(status) = self.accounts.get_mut(account_id) {
status.is_locked = false;
status.locked_at = None;
status.locked_until = None;
status.failed_attempts = 0;
status.current_delay = Duration::zero();
status.lock_reason = None;
}
}
pub fn reset_failed_attempts(&mut self, account_id: &str) {
if let Some(status) = self.accounts.get_mut(account_id) {
status.failed_attempts = 0;
status.current_delay = Duration::zero();
}
}
pub fn ban_ip(&mut self, ip: IpAddr) {
self.ip_attempts
.insert(ip, (self.config.max_ip_attempts + 1, Utc::now()));
}
pub fn unban_ip(&mut self, ip: &IpAddr) {
self.ip_attempts.remove(ip);
}
fn is_ip_banned(&self, ip: IpAddr) -> bool {
if self.config.max_ip_attempts == 0 {
return false;
}
if let Some((count, _)) = self.ip_attempts.get(&ip) {
*count >= self.config.max_ip_attempts
} else {
false
}
}
fn calculate_delay(&self, failed_attempts: u32) -> Duration {
if failed_attempts == 0 {
return Duration::zero();
}
let multiplier = 2_i64.pow(failed_attempts.saturating_sub(1));
let delay_seconds = self.config.base_delay.num_seconds() * multiplier;
let delay = Duration::seconds(delay_seconds);
if delay > self.config.max_delay {
self.config.max_delay
} else {
delay
}
}
fn cleanup_expired(&mut self) {
let now = Utc::now();
let window = self.config.attempt_window;
self.accounts.retain(|_, status| {
if status.is_currently_locked() {
return true;
}
if let Some(last_attempt) = status.last_attempt_at {
now - last_attempt < window
} else {
false
}
});
self.ip_attempts
.retain(|_, (_, timestamp)| now - *timestamp < window);
}
pub fn get_locked_accounts(&self) -> Vec<&AccountLockStatus> {
self.accounts
.values()
.filter(|s| s.is_currently_locked())
.collect()
}
pub fn stats(&self) -> TrackerStats {
let locked_count = self
.accounts
.values()
.filter(|s| s.is_currently_locked())
.count();
let total_accounts = self.accounts.len();
let banned_ips = self
.ip_attempts
.iter()
.filter(|(_, (count, _))| *count >= self.config.max_ip_attempts)
.count();
TrackerStats {
total_tracked_accounts: total_accounts,
currently_locked_accounts: locked_count,
banned_ip_addresses: banned_ips,
total_tracked_ips: self.ip_attempts.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct TrackerStats {
pub total_tracked_accounts: usize,
pub currently_locked_accounts: usize,
pub banned_ip_addresses: usize,
pub total_tracked_ips: usize,
}
#[async_trait]
pub trait AccountLockStore: Send + Sync {
async fn save(&mut self, status: &AccountLockStatus) -> Result<()>;
async fn load(&self, account_id: &str) -> Result<Option<AccountLockStatus>>;
async fn delete(&mut self, account_id: &str) -> Result<()>;
async fn list_locked(&self) -> Result<Vec<AccountLockStatus>>;
}
#[derive(Debug, Default)]
pub struct InMemoryAccountLockStore {
data: HashMap<String, AccountLockStatus>,
}
impl InMemoryAccountLockStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl AccountLockStore for InMemoryAccountLockStore {
async fn save(&mut self, status: &AccountLockStatus) -> Result<()> {
self.data.insert(status.account_id.clone(), status.clone());
Ok(())
}
async fn load(&self, account_id: &str) -> Result<Option<AccountLockStatus>> {
Ok(self.data.get(account_id).cloned())
}
async fn delete(&mut self, account_id: &str) -> Result<()> {
self.data
.remove(account_id)
.ok_or_else(|| Error::Storage(StorageError::NotFound(account_id.to_string())))?;
Ok(())
}
async fn list_locked(&self) -> Result<Vec<AccountLockStatus>> {
Ok(self
.data
.values()
.filter(|s| s.is_currently_locked())
.cloned()
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn test_default_config() {
let config = AccountLockoutConfig::default();
assert_eq!(config.max_failed_attempts, 5);
assert!(config.validate().is_ok());
}
#[test]
fn test_strict_config() {
let config = AccountLockoutConfig::strict();
assert_eq!(config.max_failed_attempts, 3);
assert!(config.validate().is_ok());
}
#[test]
fn test_account_locking() {
let mut tracker = LoginAttemptTracker::with_default_config();
for _ in 0..5 {
tracker.record_failed_attempt("user1", None);
}
let result = tracker.check_login_allowed("user1", None);
assert!(matches!(result, LoginCheckResult::Locked { .. }));
}
#[test]
fn test_successful_login_resets() {
let mut tracker = LoginAttemptTracker::with_default_config();
for _ in 0..3 {
tracker.record_failed_attempt("user1", None);
}
tracker.record_successful_login("user1", None);
let status = tracker.get_account_status("user1").unwrap();
assert_eq!(status.failed_attempts, 0);
assert!(!status.is_locked);
}
#[test]
fn test_progressive_delay() {
let config = AccountLockoutConfig {
progressive_delay: true,
base_delay: Duration::seconds(1),
max_delay: Duration::minutes(5),
..Default::default()
};
let mut tracker = LoginAttemptTracker::new(config).unwrap();
tracker.record_failed_attempt("user1", None);
let status = tracker.get_account_status("user1").unwrap();
assert_eq!(status.current_delay.num_seconds(), 1);
tracker.record_failed_attempt("user1", None);
let status = tracker.get_account_status("user1").unwrap();
assert_eq!(status.current_delay.num_seconds(), 2);
tracker.record_failed_attempt("user1", None);
let status = tracker.get_account_status("user1").unwrap();
assert_eq!(status.current_delay.num_seconds(), 4);
}
#[test]
fn test_ip_tracking() {
let config = AccountLockoutConfig {
track_ip: true,
max_ip_attempts: 3,
..Default::default()
};
let mut tracker = LoginAttemptTracker::new(config).unwrap();
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
for _ in 0..3 {
tracker.record_failed_attempt("user1", Some(ip));
}
let result = tracker.check_login_allowed("user2", Some(ip));
assert!(matches!(result, LoginCheckResult::IpBanned { .. }));
}
#[test]
fn test_manual_lock_unlock() {
let mut tracker = LoginAttemptTracker::with_default_config();
tracker.lock_account("user1", LockReason::AdminAction, None);
let result = tracker.check_login_allowed("user1", None);
assert!(matches!(
result,
LoginCheckResult::Locked {
reason: LockReason::AdminAction,
..
}
));
tracker.unlock_account("user1");
let result = tracker.check_login_allowed("user1", None);
assert!(matches!(result, LoginCheckResult::Allowed));
}
#[test]
fn test_lock_status() {
let mut status = AccountLockStatus::new("user1");
assert!(!status.is_currently_locked());
status.is_locked = true;
status.locked_until = Some(Utc::now() + Duration::hours(1));
assert!(status.is_currently_locked());
assert!(status.remaining_lockout_time().is_some());
}
#[test]
fn test_tracker_stats() {
let mut tracker = LoginAttemptTracker::with_default_config();
tracker.lock_account("user1", LockReason::AdminAction, None);
let stats = tracker.stats();
assert_eq!(stats.total_tracked_accounts, 1);
assert_eq!(stats.currently_locked_accounts, 1);
}
#[tokio::test]
async fn test_in_memory_store() {
let mut store = InMemoryAccountLockStore::new();
let status = AccountLockStatus::new("user1");
store.save(&status).await.unwrap();
let loaded = store.load("user1").await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().account_id, "user1");
store.delete("user1").await.unwrap();
assert!(store.load("user1").await.unwrap().is_none());
}
#[test]
fn test_login_attempt_builder() {
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let attempt = LoginAttempt::failed()
.with_ip(ip)
.with_user_agent("Mozilla/5.0")
.with_metadata("reason", "invalid_password");
assert!(!attempt.success);
assert_eq!(attempt.ip_address, Some(ip));
assert_eq!(attempt.user_agent, Some("Mozilla/5.0".to_string()));
assert_eq!(
attempt.metadata.get("reason"),
Some(&"invalid_password".to_string())
);
}
}