use crate::error::Result;
use async_trait::async_trait;
use std::time::{Duration, SystemTime};
const DEFAULT_MAX_ATTEMPTS: u32 = 5;
const DEFAULT_LOCKOUT_DURATION: Duration = Duration::from_secs(15 * 60);
const MAX_IP_LENGTH: usize = 45;
fn truncate_ip(ip: &str) -> &str {
if ip.len() <= MAX_IP_LENGTH {
ip
} else {
&ip[..MAX_IP_LENGTH]
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LockoutPolicy {
pub max_attempts: u32,
pub lockout_duration: Duration,
pub progressive_delays: Vec<u64>,
pub send_notifications: bool,
pub track_by_ip: bool,
}
impl Default for LockoutPolicy {
fn default() -> Self {
Self {
max_attempts: DEFAULT_MAX_ATTEMPTS,
lockout_duration: DEFAULT_LOCKOUT_DURATION,
progressive_delays: vec![],
send_notifications: false,
track_by_ip: false,
}
}
}
impl LockoutPolicy {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn strict() -> Self {
Self {
max_attempts: 3,
lockout_duration: Duration::from_secs(30 * 60),
progressive_delays: vec![0, 30, 60],
send_notifications: true,
track_by_ip: true,
}
}
#[must_use]
pub fn lenient() -> Self {
Self {
max_attempts: 10,
lockout_duration: Duration::from_secs(5 * 60),
progressive_delays: vec![],
send_notifications: false,
track_by_ip: false,
}
}
#[must_use]
pub fn max_attempts(mut self, max: u32) -> Self {
self.max_attempts = max;
self
}
#[must_use]
pub fn lockout_duration(mut self, duration: Duration) -> Self {
self.lockout_duration = duration;
self
}
#[must_use]
pub fn progressive_delays(mut self, delays: Vec<u64>) -> Self {
self.progressive_delays = delays;
self
}
#[must_use]
pub fn with_notifications(mut self) -> Self {
self.send_notifications = true;
self
}
#[must_use]
pub fn track_by_ip(mut self, track: bool) -> Self {
self.track_by_ip = track;
self
}
#[must_use]
pub fn get_delay_for_attempt(&self, attempt: u32) -> Option<u64> {
if attempt >= self.max_attempts {
return None; }
let index = attempt.saturating_sub(1) as usize;
if index < self.progressive_delays.len() {
Some(self.progressive_delays[index])
} else if !self.progressive_delays.is_empty() {
Some(*self.progressive_delays.last().unwrap())
} else {
Some(0) }
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LockoutStatus {
pub failed_attempts: u32,
pub is_locked: bool,
pub locked_until: Option<SystemTime>,
pub delay_seconds: Option<u64>,
pub delay_until: Option<SystemTime>,
pub last_attempt_at: Option<SystemTime>,
}
impl LockoutStatus {
#[must_use]
pub fn can_attempt_now(&self) -> bool {
if self.is_locked {
if let Some(until) = self.locked_until {
return SystemTime::now() >= until;
}
return false;
}
if let Some(until) = self.delay_until {
return SystemTime::now() >= until;
}
true
}
#[must_use]
pub fn remaining_wait_seconds(&self) -> u64 {
let now = SystemTime::now();
if self.is_locked {
if let Some(until) = self.locked_until {
if let Ok(duration) = until.duration_since(now) {
return duration.as_secs();
}
}
}
if let Some(until) = self.delay_until {
if let Ok(duration) = until.duration_since(now) {
return duration.as_secs();
}
}
0
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FailedAttemptResult {
pub status: LockoutStatus,
pub just_locked: bool,
pub notification_sent: bool,
}
impl FailedAttemptResult {
#[must_use]
pub fn can_retry_now(&self) -> bool {
self.status.can_attempt_now()
}
#[must_use]
pub fn wait_seconds(&self) -> u64 {
self.status.remaining_wait_seconds()
}
}
#[async_trait]
pub trait LockoutStore: Send + Sync {
async fn get_failed_attempts(&self, user_id: &str) -> Result<u32>;
async fn get_lockout_status(&self, user_id: &str) -> Result<Option<LockoutStatus>>;
async fn increment_failed_attempts(&self, user_id: &str) -> Result<u32>;
async fn set_lockout(&self, user_id: &str, until: SystemTime) -> Result<()>;
async fn set_delay(&self, user_id: &str, until: SystemTime) -> Result<()>;
async fn clear_lockout(&self, user_id: &str) -> Result<()>;
async fn get_user_email(&self, user_id: &str) -> Result<Option<String>> {
let _ = user_id;
Ok(None)
}
async fn send_lockout_notification(
&self,
user_id: &str,
email: &str,
locked_until: SystemTime,
) -> Result<()> {
let _ = (user_id, email, locked_until);
Ok(())
}
async fn send_unlock_notification(&self, user_id: &str, email: &str) -> Result<()> {
let _ = (user_id, email);
Ok(())
}
async fn get_failed_attempts_by_ip(&self, ip: &str) -> Result<u32> {
let _ = ip;
Ok(0)
}
async fn increment_failed_attempts_by_ip(&self, ip: &str) -> Result<u32> {
let _ = ip;
Ok(0)
}
async fn set_lockout_by_ip(&self, ip: &str, until: SystemTime) -> Result<()> {
let _ = (ip, until);
Ok(())
}
async fn is_ip_locked(&self, ip: &str) -> Result<Option<SystemTime>> {
let _ = ip;
Ok(None)
}
async fn clear_ip_lockout(&self, ip: &str) -> Result<()> {
let _ = ip;
Ok(())
}
}
pub struct LockoutManager<S: LockoutStore> {
store: S,
policy: LockoutPolicy,
}
impl<S: LockoutStore> LockoutManager<S> {
#[must_use]
pub fn new(store: S, policy: LockoutPolicy) -> Self {
Self { store, policy }
}
#[must_use]
pub fn with_defaults(store: S) -> Self {
Self::new(store, LockoutPolicy::default())
}
pub async fn check_can_attempt(
&self,
user_id: &str,
ip: Option<&str>,
) -> Result<Option<LockoutStatus>> {
if let Some(status) = self.store.get_lockout_status(user_id).await? {
if !status.can_attempt_now() {
tracing::debug!(
target: "auth.lockout.blocked",
user_id = %user_id,
is_locked = status.is_locked,
remaining_seconds = status.remaining_wait_seconds(),
"Login attempt blocked by lockout"
);
return Ok(Some(status));
}
}
if self.policy.track_by_ip {
if let Some(ip) = ip.map(truncate_ip) {
if let Some(locked_until) = self.store.is_ip_locked(ip).await? {
if SystemTime::now() < locked_until {
tracing::debug!(
target: "auth.lockout.ip_blocked",
ip = %ip,
"Login attempt blocked by IP lockout"
);
return Ok(Some(LockoutStatus {
failed_attempts: 0,
is_locked: true,
locked_until: Some(locked_until),
delay_seconds: None,
delay_until: None,
last_attempt_at: None,
}));
}
}
}
}
Ok(None)
}
pub async fn record_failed_attempt(
&self,
user_id: &str,
ip: Option<&str>,
) -> Result<FailedAttemptResult> {
let new_count = self.store.increment_failed_attempts(user_id).await?;
let mut just_locked = false;
let mut notification_sent = false;
if self.policy.track_by_ip {
if let Some(ip) = ip.map(truncate_ip) {
let ip_count = self.store.increment_failed_attempts_by_ip(ip).await?;
if ip_count >= self.policy.max_attempts {
let until = SystemTime::now() + self.policy.lockout_duration;
self.store.set_lockout_by_ip(ip, until).await?;
tracing::warn!(
target: "auth.lockout.ip_locked",
ip = %ip,
attempts = ip_count,
duration_secs = self.policy.lockout_duration.as_secs(),
"IP address locked out"
);
}
}
}
let now = SystemTime::now();
if new_count >= self.policy.max_attempts {
let until = now + self.policy.lockout_duration;
self.store.set_lockout(user_id, until).await?;
just_locked = true;
tracing::warn!(
target: "auth.lockout.account_locked",
user_id = %user_id,
attempts = new_count,
duration_secs = self.policy.lockout_duration.as_secs(),
"Account locked due to failed attempts"
);
if self.policy.send_notifications {
if let Ok(Some(email)) = self.store.get_user_email(user_id).await {
if self
.store
.send_lockout_notification(user_id, &email, until)
.await
.is_ok()
{
notification_sent = true;
tracing::info!(
target: "auth.lockout.notification_sent",
user_id = %user_id,
email = %email,
"Lockout notification email sent"
);
}
}
}
return Ok(FailedAttemptResult {
status: LockoutStatus {
failed_attempts: new_count,
is_locked: true,
locked_until: Some(until),
delay_seconds: None,
delay_until: None,
last_attempt_at: Some(now),
},
just_locked,
notification_sent,
});
}
let delay_seconds = self.policy.get_delay_for_attempt(new_count);
let delay_until = delay_seconds.map(|secs| {
if secs > 0 {
now + Duration::from_secs(secs)
} else {
now
}
});
if let Some(secs) = delay_seconds {
if secs > 0 {
if let Some(until) = delay_until {
self.store.set_delay(user_id, until).await?;
}
tracing::info!(
target: "auth.lockout.delay_applied",
user_id = %user_id,
attempts = new_count,
delay_seconds = secs,
"Progressive delay applied"
);
}
}
Ok(FailedAttemptResult {
status: LockoutStatus {
failed_attempts: new_count,
is_locked: false,
locked_until: None,
delay_seconds,
delay_until,
last_attempt_at: Some(now),
},
just_locked,
notification_sent,
})
}
pub async fn record_successful_login(&self, user_id: &str, ip: Option<&str>) -> Result<()> {
self.store.clear_lockout(user_id).await?;
if self.policy.track_by_ip {
if let Some(ip) = ip.map(truncate_ip) {
self.store.clear_ip_lockout(ip).await?;
}
}
tracing::debug!(
target: "auth.lockout.cleared",
user_id = %user_id,
"Lockout state cleared on successful login"
);
Ok(())
}
pub async fn admin_unlock(&self, user_id: &str, admin_id: &str) -> Result<bool> {
let had_lockout = self.store.get_lockout_status(user_id).await?.is_some();
self.store.clear_lockout(user_id).await?;
tracing::warn!(
target: "auth.lockout.admin_unlock",
user_id = %user_id,
admin_id = %admin_id,
had_lockout = had_lockout,
"Account unlocked by admin"
);
if self.policy.send_notifications && had_lockout {
if let Ok(Some(email)) = self.store.get_user_email(user_id).await {
let _ = self.store.send_unlock_notification(user_id, &email).await;
}
}
Ok(had_lockout)
}
pub async fn get_status(&self, user_id: &str) -> Result<Option<LockoutStatus>> {
self.store.get_lockout_status(user_id).await
}
#[must_use]
pub fn policy(&self) -> &LockoutPolicy {
&self.policy
}
#[must_use]
pub fn store(&self) -> &S {
&self.store
}
}
#[cfg(any(test, feature = "test-auth-bypass"))]
pub mod test {
use super::*;
use std::collections::HashMap;
use std::sync::RwLock;
#[derive(Default)]
struct UserLockoutState {
failed_attempts: u32,
locked_until: Option<SystemTime>,
delay_until: Option<SystemTime>,
last_attempt_at: Option<SystemTime>,
}
#[derive(Default)]
pub struct InMemoryLockoutStore {
users: RwLock<HashMap<String, UserLockoutState>>,
ips: RwLock<HashMap<String, (u32, Option<SystemTime>)>>,
emails: RwLock<HashMap<String, String>>,
notifications: RwLock<Vec<(String, String, String)>>, }
impl InMemoryLockoutStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_email(&self, user_id: &str, email: &str) {
self.emails
.write()
.unwrap()
.insert(user_id.to_string(), email.to_string());
}
pub fn get_notifications(&self) -> Vec<(String, String, String)> {
self.notifications.read().unwrap().clone()
}
}
#[async_trait]
impl LockoutStore for InMemoryLockoutStore {
async fn get_failed_attempts(&self, user_id: &str) -> Result<u32> {
Ok(self
.users
.read()
.unwrap()
.get(user_id)
.map(|s| s.failed_attempts)
.unwrap_or(0))
}
async fn get_lockout_status(&self, user_id: &str) -> Result<Option<LockoutStatus>> {
let users = self.users.read().unwrap();
let state = match users.get(user_id) {
Some(s) => s,
None => return Ok(None),
};
if state.failed_attempts == 0
&& state.locked_until.is_none()
&& state.delay_until.is_none()
{
return Ok(None);
}
let is_locked = state
.locked_until
.map(|until| SystemTime::now() < until)
.unwrap_or(false);
Ok(Some(LockoutStatus {
failed_attempts: state.failed_attempts,
is_locked,
locked_until: state.locked_until,
delay_seconds: None,
delay_until: state.delay_until,
last_attempt_at: state.last_attempt_at,
}))
}
async fn increment_failed_attempts(&self, user_id: &str) -> Result<u32> {
let mut users = self.users.write().unwrap();
let state = users.entry(user_id.to_string()).or_default();
state.failed_attempts += 1;
state.last_attempt_at = Some(SystemTime::now());
Ok(state.failed_attempts)
}
async fn set_lockout(&self, user_id: &str, until: SystemTime) -> Result<()> {
let mut users = self.users.write().unwrap();
let state = users.entry(user_id.to_string()).or_default();
state.locked_until = Some(until);
Ok(())
}
async fn set_delay(&self, user_id: &str, until: SystemTime) -> Result<()> {
let mut users = self.users.write().unwrap();
let state = users.entry(user_id.to_string()).or_default();
state.delay_until = Some(until);
Ok(())
}
async fn clear_lockout(&self, user_id: &str) -> Result<()> {
let mut users = self.users.write().unwrap();
users.remove(user_id);
Ok(())
}
async fn get_user_email(&self, user_id: &str) -> Result<Option<String>> {
Ok(self.emails.read().unwrap().get(user_id).cloned())
}
async fn send_lockout_notification(
&self,
user_id: &str,
email: &str,
_locked_until: SystemTime,
) -> Result<()> {
self.notifications.write().unwrap().push((
user_id.to_string(),
email.to_string(),
"locked".to_string(),
));
Ok(())
}
async fn send_unlock_notification(&self, user_id: &str, email: &str) -> Result<()> {
self.notifications.write().unwrap().push((
user_id.to_string(),
email.to_string(),
"unlocked".to_string(),
));
Ok(())
}
async fn get_failed_attempts_by_ip(&self, ip: &str) -> Result<u32> {
Ok(self
.ips
.read()
.unwrap()
.get(ip)
.map(|(count, _)| *count)
.unwrap_or(0))
}
async fn increment_failed_attempts_by_ip(&self, ip: &str) -> Result<u32> {
let mut ips = self.ips.write().unwrap();
let entry = ips.entry(ip.to_string()).or_insert((0, None));
entry.0 += 1;
Ok(entry.0)
}
async fn set_lockout_by_ip(&self, ip: &str, until: SystemTime) -> Result<()> {
let mut ips = self.ips.write().unwrap();
let entry = ips.entry(ip.to_string()).or_insert((0, None));
entry.1 = Some(until);
Ok(())
}
async fn is_ip_locked(&self, ip: &str) -> Result<Option<SystemTime>> {
Ok(self
.ips
.read()
.unwrap()
.get(ip)
.and_then(|(_, until)| *until))
}
async fn clear_ip_lockout(&self, ip: &str) -> Result<()> {
self.ips.write().unwrap().remove(ip);
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use test::InMemoryLockoutStore;
#[test]
fn test_policy_defaults() {
let policy = LockoutPolicy::new();
assert_eq!(policy.max_attempts, 5);
assert_eq!(policy.lockout_duration, Duration::from_secs(15 * 60));
assert!(policy.progressive_delays.is_empty());
assert!(!policy.send_notifications);
assert!(!policy.track_by_ip);
}
#[test]
fn test_policy_strict() {
let policy = LockoutPolicy::strict();
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.lockout_duration, Duration::from_secs(30 * 60));
assert!(!policy.progressive_delays.is_empty());
assert!(policy.send_notifications);
assert!(policy.track_by_ip);
}
#[test]
fn test_policy_builder() {
let policy = LockoutPolicy::new()
.max_attempts(10)
.lockout_duration(Duration::from_secs(60))
.progressive_delays(vec![0, 30, 60])
.with_notifications()
.track_by_ip(true);
assert_eq!(policy.max_attempts, 10);
assert_eq!(policy.lockout_duration, Duration::from_secs(60));
assert_eq!(policy.progressive_delays, vec![0, 30, 60]);
assert!(policy.send_notifications);
assert!(policy.track_by_ip);
}
#[test]
fn test_get_delay_for_attempt() {
let policy = LockoutPolicy::new()
.max_attempts(5)
.progressive_delays(vec![0, 0, 30, 60]);
assert_eq!(policy.get_delay_for_attempt(1), Some(0));
assert_eq!(policy.get_delay_for_attempt(2), Some(0));
assert_eq!(policy.get_delay_for_attempt(3), Some(30));
assert_eq!(policy.get_delay_for_attempt(4), Some(60));
assert_eq!(policy.get_delay_for_attempt(5), None); }
#[test]
fn test_get_delay_extends_last() {
let policy = LockoutPolicy::new()
.max_attempts(10)
.progressive_delays(vec![0, 30]);
assert_eq!(policy.get_delay_for_attempt(3), Some(30));
assert_eq!(policy.get_delay_for_attempt(9), Some(30));
}
#[tokio::test]
async fn test_record_failed_attempts() {
let store = InMemoryLockoutStore::new();
let policy = LockoutPolicy::new().max_attempts(3);
let manager = LockoutManager::new(store, policy);
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert_eq!(result.status.failed_attempts, 1);
assert!(!result.status.is_locked);
assert!(!result.just_locked);
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert_eq!(result.status.failed_attempts, 2);
assert!(!result.status.is_locked);
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert_eq!(result.status.failed_attempts, 3);
assert!(result.status.is_locked);
assert!(result.just_locked);
assert!(result.status.locked_until.is_some());
}
#[tokio::test]
async fn test_check_can_attempt_when_locked() {
let store = InMemoryLockoutStore::new();
let policy = LockoutPolicy::new().max_attempts(1);
let manager = LockoutManager::new(store, policy);
manager.record_failed_attempt("user-1", None).await.unwrap();
let status = manager.check_can_attempt("user-1", None).await.unwrap();
assert!(status.is_some());
assert!(status.unwrap().is_locked);
}
#[tokio::test]
async fn test_successful_login_clears_lockout() {
let store = InMemoryLockoutStore::new();
let policy = LockoutPolicy::new().max_attempts(3);
let manager = LockoutManager::new(store, policy);
manager.record_failed_attempt("user-1", None).await.unwrap();
manager.record_failed_attempt("user-1", None).await.unwrap();
manager
.record_successful_login("user-1", None)
.await
.unwrap();
let status = manager.check_can_attempt("user-1", None).await.unwrap();
assert!(status.is_none());
}
#[tokio::test]
async fn test_admin_unlock() {
let store = InMemoryLockoutStore::new();
let policy = LockoutPolicy::new().max_attempts(1);
let manager = LockoutManager::new(store, policy);
manager.record_failed_attempt("user-1", None).await.unwrap();
let status = manager.check_can_attempt("user-1", None).await.unwrap();
assert!(status.is_some());
let had_lockout = manager.admin_unlock("user-1", "admin-1").await.unwrap();
assert!(had_lockout);
let status = manager.check_can_attempt("user-1", None).await.unwrap();
assert!(status.is_none());
}
#[tokio::test]
async fn test_progressive_delays() {
let store = InMemoryLockoutStore::new();
let policy = LockoutPolicy::new()
.max_attempts(5)
.progressive_delays(vec![0, 0, 30, 60]);
let manager = LockoutManager::new(store, policy);
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert_eq!(result.status.delay_seconds, Some(0));
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert_eq!(result.status.delay_seconds, Some(0));
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert_eq!(result.status.delay_seconds, Some(30));
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert_eq!(result.status.delay_seconds, Some(60));
}
#[tokio::test]
async fn test_notifications() {
let store = InMemoryLockoutStore::new();
store.set_email("user-1", "user@example.com");
let policy = LockoutPolicy::new().max_attempts(1).with_notifications();
let manager = LockoutManager::new(store, policy);
let result = manager.record_failed_attempt("user-1", None).await.unwrap();
assert!(result.notification_sent);
let notifications = manager.store.get_notifications();
assert_eq!(notifications.len(), 1);
assert_eq!(notifications[0].0, "user-1");
assert_eq!(notifications[0].1, "user@example.com");
assert_eq!(notifications[0].2, "locked");
}
#[tokio::test]
async fn test_ip_tracking() {
let store = InMemoryLockoutStore::new();
let policy = LockoutPolicy::new().max_attempts(2).track_by_ip(true);
let manager = LockoutManager::new(store, policy);
manager
.record_failed_attempt("user-1", Some("1.2.3.4"))
.await
.unwrap();
manager
.record_failed_attempt("user-2", Some("1.2.3.4"))
.await
.unwrap();
let status = manager
.check_can_attempt("user-3", Some("1.2.3.4"))
.await
.unwrap();
assert!(status.is_some());
assert!(status.unwrap().is_locked);
let status = manager
.check_can_attempt("user-3", Some("5.6.7.8"))
.await
.unwrap();
assert!(status.is_none());
}
#[test]
fn test_lockout_status_can_attempt() {
let now = SystemTime::now();
let status = LockoutStatus {
failed_attempts: 5,
is_locked: true,
locked_until: Some(now + Duration::from_secs(60)),
delay_seconds: None,
delay_until: None,
last_attempt_at: Some(now),
};
assert!(!status.can_attempt_now());
let status = LockoutStatus {
failed_attempts: 3,
is_locked: false,
locked_until: None,
delay_seconds: Some(30),
delay_until: Some(now + Duration::from_secs(30)),
last_attempt_at: Some(now),
};
assert!(!status.can_attempt_now());
let status = LockoutStatus {
failed_attempts: 2,
is_locked: false,
locked_until: None,
delay_seconds: None,
delay_until: None,
last_attempt_at: Some(now),
};
assert!(status.can_attempt_now());
}
}