use std::ops::DerefMut;
use std::sync::Arc;
use deadpool_redis::Pool as RedisPool;
use tracing::{debug, info, warn};
use super::config::LockoutConfig;
use super::notification::{LockoutEvent, LockoutNotification, UnlockReason};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct LockoutStatus {
pub locked: bool,
pub attempt_count: u32,
pub max_attempts: u32,
pub lockout_remaining_secs: u64,
pub delay_ms: u64,
}
#[derive(Clone)]
pub struct LoginLockout {
config: LockoutConfig,
redis_pool: RedisPool,
notifications: Vec<Arc<dyn LockoutNotification>>,
}
impl LoginLockout {
pub fn new(config: LockoutConfig, redis_pool: RedisPool) -> Self {
Self {
config,
redis_pool,
notifications: Vec::new(),
}
}
pub fn with_notification(mut self, handler: Arc<dyn LockoutNotification>) -> Self {
self.notifications.push(handler);
self
}
#[cfg(feature = "audit")]
pub fn with_audit(self, audit_logger: crate::audit::AuditLogger) -> Self {
let handler = Arc::new(super::AuditLockoutNotification::new(audit_logger));
self.with_notification(handler)
}
pub async fn check(&self, identity: &str) -> Result<LockoutStatus> {
if !self.config.enabled {
return Ok(LockoutStatus {
locked: false,
attempt_count: 0,
max_attempts: self.config.max_attempts,
lockout_remaining_secs: 0,
delay_ms: 0,
});
}
let mut conn = self.get_connection().await?;
let locked_key = self.locked_key(identity);
let locked_ttl: i64 = redis::cmd("TTL")
.arg(&locked_key)
.query_async(conn.deref_mut())
.await
.unwrap_or(-2);
if locked_ttl > 0 {
let attempt_count = self.get_attempt_count(identity).await.unwrap_or(0);
return Ok(LockoutStatus {
locked: true,
attempt_count,
max_attempts: self.config.max_attempts,
lockout_remaining_secs: locked_ttl as u64,
delay_ms: 0,
});
}
let attempt_count = self.get_attempt_count(identity).await.unwrap_or(0);
let delay_ms = self.compute_delay(attempt_count);
Ok(LockoutStatus {
locked: false,
attempt_count,
max_attempts: self.config.max_attempts,
lockout_remaining_secs: 0,
delay_ms,
})
}
pub async fn record_failure(&self, identity: &str) -> Result<LockoutStatus> {
if !self.config.enabled {
return Ok(LockoutStatus {
locked: false,
attempt_count: 0,
max_attempts: self.config.max_attempts,
lockout_remaining_secs: 0,
delay_ms: 0,
});
}
let mut conn = self.get_connection().await?;
let attempts_key = self.attempts_key(identity);
let count: u32 = redis::cmd("INCR")
.arg(&attempts_key)
.query_async(conn.deref_mut())
.await?;
if count == 1 {
let _: () = redis::cmd("EXPIRE")
.arg(&attempts_key)
.arg(self.config.window_secs as i64)
.query_async(conn.deref_mut())
.await?;
}
debug!(
identity = identity,
attempt_count = count,
max_attempts = self.config.max_attempts,
"Login failure recorded"
);
self.notify(LockoutEvent::FailedAttempt {
identity: identity.to_string(),
attempt_count: count,
max_attempts: self.config.max_attempts,
});
if self.config.warning_threshold > 0
&& count == self.config.warning_threshold
&& count < self.config.max_attempts
{
let remaining = self.config.max_attempts - count;
self.notify(LockoutEvent::ApproachingThreshold {
identity: identity.to_string(),
attempt_count: count,
remaining_attempts: remaining,
});
}
if count >= self.config.max_attempts {
let locked_key = self.locked_key(identity);
let _: () = redis::cmd("SET")
.arg(&locked_key)
.arg(chrono::Utc::now().timestamp().to_string())
.arg("EX")
.arg(self.config.lockout_duration_secs as i64)
.query_async(conn.deref_mut())
.await?;
warn!(
identity = identity,
attempt_count = count,
lockout_duration_secs = self.config.lockout_duration_secs,
"Account locked due to repeated login failures"
);
self.notify(LockoutEvent::AccountLocked {
identity: identity.to_string(),
attempt_count: count,
lockout_duration_secs: self.config.lockout_duration_secs,
});
return Ok(LockoutStatus {
locked: true,
attempt_count: count,
max_attempts: self.config.max_attempts,
lockout_remaining_secs: self.config.lockout_duration_secs,
delay_ms: 0,
});
}
let delay_ms = self.compute_delay(count);
Ok(LockoutStatus {
locked: false,
attempt_count: count,
max_attempts: self.config.max_attempts,
lockout_remaining_secs: 0,
delay_ms,
})
}
pub async fn record_success(&self, identity: &str) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let mut conn = self.get_connection().await?;
let attempts_key = self.attempts_key(identity);
let locked_key = self.locked_key(identity);
let was_locked: bool = redis::cmd("EXISTS")
.arg(&locked_key)
.query_async(conn.deref_mut())
.await
.unwrap_or(false);
let _: () = redis::cmd("DEL")
.arg(&attempts_key)
.arg(&locked_key)
.query_async(conn.deref_mut())
.await?;
if was_locked {
info!(identity = identity, "Account unlocked via successful login");
self.notify(LockoutEvent::AccountUnlocked {
identity: identity.to_string(),
reason: UnlockReason::SuccessfulLogin,
});
}
Ok(())
}
pub async fn unlock(&self, identity: &str) -> Result<()> {
let mut conn = self.get_connection().await?;
let attempts_key = self.attempts_key(identity);
let locked_key = self.locked_key(identity);
let _: () = redis::cmd("DEL")
.arg(&attempts_key)
.arg(&locked_key)
.query_async(conn.deref_mut())
.await?;
info!(identity = identity, "Account manually unlocked (admin)");
self.notify(LockoutEvent::AccountUnlocked {
identity: identity.to_string(),
reason: UnlockReason::AdminAction,
});
Ok(())
}
fn compute_delay(&self, attempt_count: u32) -> u64 {
if !self.config.progressive_delay_enabled || attempt_count == 0 {
return 0;
}
let exponent = (attempt_count - 1) as f64;
let delay = self.config.base_delay_ms as f64 * self.config.delay_multiplier.powf(exponent);
if delay.is_finite() {
(delay as u64).min(self.config.max_delay_ms)
} else {
self.config.max_delay_ms
}
}
async fn get_connection(&self) -> Result<deadpool_redis::Connection> {
self.redis_pool.get().await.map_err(|e| {
let redis_err = redis::RedisError::from((
redis::ErrorKind::IoError,
"Failed to get Redis connection for lockout",
e.to_string(),
));
Error::Redis(Box::new(redis_err))
})
}
async fn get_attempt_count(&self, identity: &str) -> Result<u32> {
let mut conn = self.get_connection().await?;
let attempts_key = self.attempts_key(identity);
let count: Option<u32> = redis::cmd("GET")
.arg(&attempts_key)
.query_async(conn.deref_mut())
.await?;
Ok(count.unwrap_or(0))
}
fn attempts_key(&self, identity: &str) -> String {
format!("{}:attempts:{}", self.config.key_prefix, identity)
}
fn locked_key(&self, identity: &str) -> String {
format!("{}:locked:{}", self.config.key_prefix, identity)
}
fn notify(&self, event: LockoutEvent) {
for handler in &self.notifications {
let handler = Arc::clone(handler);
let event = event.clone();
tokio::spawn(async move {
handler.on_event(event).await;
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_lockout() -> LockoutConfig {
LockoutConfig::default()
}
#[test]
fn test_compute_delay_zero_attempts() {
let config = test_lockout();
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.compute_delay(0), 0);
}
#[test]
fn test_compute_delay_first_attempt() {
let config = test_lockout();
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.compute_delay(1), 1000);
}
#[test]
fn test_compute_delay_progressive() {
let config = test_lockout();
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.compute_delay(2), 2000);
assert_eq!(lockout.compute_delay(3), 4000);
assert_eq!(lockout.compute_delay(4), 8000);
assert_eq!(lockout.compute_delay(5), 16000);
}
#[test]
fn test_compute_delay_caps_at_max() {
let config = test_lockout();
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.compute_delay(6), 30000);
assert_eq!(lockout.compute_delay(100), 30000);
}
#[test]
fn test_compute_delay_disabled() {
let mut config = test_lockout();
config.progressive_delay_enabled = false;
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.compute_delay(1), 0);
assert_eq!(lockout.compute_delay(5), 0);
}
#[test]
fn test_compute_delay_multiplier_one() {
let mut config = test_lockout();
config.delay_multiplier = 1.0;
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.compute_delay(1), 1000);
assert_eq!(lockout.compute_delay(5), 1000);
assert_eq!(lockout.compute_delay(100), 1000);
}
#[test]
fn test_compute_delay_overflow_protection() {
let mut config = test_lockout();
config.delay_multiplier = 10.0;
config.max_delay_ms = 30000;
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.compute_delay(100), 30000);
}
#[test]
fn test_redis_key_format() {
let config = test_lockout();
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(
lockout.attempts_key("user@example.com"),
"lockout:attempts:user@example.com"
);
assert_eq!(
lockout.locked_key("user@example.com"),
"lockout:locked:user@example.com"
);
}
#[test]
fn test_redis_key_custom_prefix() {
let mut config = test_lockout();
config.key_prefix = "myapp".to_string();
let lockout = LoginLockout {
config,
redis_pool: create_dummy_pool(),
notifications: Vec::new(),
};
assert_eq!(lockout.attempts_key("alice"), "myapp:attempts:alice");
assert_eq!(lockout.locked_key("alice"), "myapp:locked:alice");
}
fn create_dummy_pool() -> RedisPool {
let cfg = deadpool_redis::Config::from_url("redis://localhost:6379");
cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))
.expect("Failed to create dummy pool")
}
}