use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct AuthRateLimitConfig {
pub max_attempts: u32,
pub window: Duration,
pub ban_duration: Duration,
pub whitelist: HashSet<IpAddr>,
pub max_tracked_ips: usize,
}
impl Default for AuthRateLimitConfig {
fn default() -> Self {
Self {
max_attempts: 5,
window: Duration::from_secs(300), ban_duration: Duration::from_secs(300), whitelist: HashSet::new(),
max_tracked_ips: 10000, }
}
}
impl AuthRateLimitConfig {
pub fn new(max_attempts: u32, window_secs: u64, ban_duration_secs: u64) -> Self {
Self {
max_attempts,
window: Duration::from_secs(window_secs),
ban_duration: Duration::from_secs(ban_duration_secs),
whitelist: HashSet::new(),
max_tracked_ips: 10000,
}
}
pub fn add_whitelist(&mut self, ip: IpAddr) {
self.whitelist.insert(ip);
}
pub fn with_whitelist(mut self, whitelist: Vec<IpAddr>) -> Self {
self.whitelist = whitelist.into_iter().collect();
self
}
pub fn with_max_tracked_ips(mut self, max: usize) -> Self {
self.max_tracked_ips = max;
self
}
}
#[derive(Debug)]
struct FailureRecord {
count: u32,
first_failure: Instant,
last_failure: Instant,
}
#[derive(Debug)]
pub struct AuthRateLimiter {
failures: Arc<RwLock<HashMap<IpAddr, FailureRecord>>>,
bans: Arc<RwLock<HashMap<IpAddr, Instant>>>,
config: AuthRateLimitConfig,
}
impl AuthRateLimiter {
pub fn new(config: AuthRateLimitConfig) -> Self {
Self {
failures: Arc::new(RwLock::new(HashMap::new())),
bans: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn try_is_banned(&self, ip: &IpAddr) -> Option<bool> {
if self.config.whitelist.contains(ip) {
return Some(false);
}
let bans = self.bans.try_read().ok()?;
if let Some(expiry) = bans.get(ip)
&& Instant::now() < *expiry
{
return Some(true);
}
Some(false)
}
pub async fn is_banned(&self, ip: &IpAddr) -> bool {
if self.config.whitelist.contains(ip) {
return false;
}
let bans = self.bans.read().await;
if let Some(expiry) = bans.get(ip)
&& Instant::now() < *expiry
{
return true;
}
false
}
pub async fn record_failure(&self, ip: IpAddr) -> bool {
if self.config.whitelist.contains(&ip) {
return false;
}
let should_ban;
{
let mut failures = self.failures.write().await;
let now = Instant::now();
if failures.len() >= self.config.max_tracked_ips && !failures.contains_key(&ip) {
if let Some(oldest_ip) = failures
.iter()
.min_by_key(|(_, record)| record.last_failure)
.map(|(ip, _)| *ip)
{
failures.remove(&oldest_ip);
tracing::debug!(
removed_ip = %oldest_ip,
capacity = self.config.max_tracked_ips,
"Removed oldest failure record due to capacity limit"
);
}
}
let record = failures.entry(ip).or_insert_with(|| FailureRecord {
count: 0,
first_failure: now,
last_failure: now,
});
if now.duration_since(record.first_failure) > self.config.window {
record.count = 1;
record.first_failure = now;
} else {
record.count += 1;
}
record.last_failure = now;
should_ban = record.count >= self.config.max_attempts;
if should_ban {
failures.remove(&ip);
}
}
if should_ban {
self.ban_internal(ip).await;
return true;
}
false
}
pub async fn record_success(&self, ip: &IpAddr) {
let mut failures = self.failures.write().await;
failures.remove(ip);
}
pub async fn ban(&self, ip: IpAddr) {
{
let mut failures = self.failures.write().await;
failures.remove(&ip);
}
self.ban_internal(ip).await;
}
async fn ban_internal(&self, ip: IpAddr) {
tracing::warn!(
ip = %ip,
duration_secs = self.config.ban_duration.as_secs(),
"Banning IP due to too many failed auth attempts"
);
let mut bans = self.bans.write().await;
let expiry = Instant::now() + self.config.ban_duration;
bans.insert(ip, expiry);
}
pub async fn unban(&self, ip: &IpAddr) {
let mut bans = self.bans.write().await;
if bans.remove(ip).is_some() {
tracing::info!(ip = %ip, "Manually unbanned IP");
}
}
pub async fn remaining_attempts(&self, ip: &IpAddr) -> u32 {
let failures = self.failures.read().await;
if let Some(record) = failures.get(ip) {
let now = Instant::now();
if now.duration_since(record.first_failure) > self.config.window {
return self.config.max_attempts;
}
self.config.max_attempts.saturating_sub(record.count)
} else {
self.config.max_attempts
}
}
pub async fn cleanup(&self) {
let now = Instant::now();
{
let mut bans = self.bans.write().await;
let before = bans.len();
bans.retain(|_, expiry| now < *expiry);
let after = bans.len();
if before > after {
tracing::debug!(
removed = before - after,
remaining = after,
"Cleaned up expired bans"
);
}
}
{
let mut failures = self.failures.write().await;
let before = failures.len();
failures
.retain(|_, record| now.duration_since(record.last_failure) < self.config.window);
let after = failures.len();
if before > after {
tracing::debug!(
removed = before - after,
remaining = after,
"Cleaned up expired failure records"
);
}
}
}
pub async fn get_bans(&self) -> Vec<(IpAddr, Duration)> {
let now = Instant::now();
let bans = self.bans.read().await;
bans.iter()
.filter_map(|(ip, expiry)| {
if now < *expiry {
Some((*ip, *expiry - now))
} else {
None
}
})
.collect()
}
pub async fn banned_count(&self) -> usize {
let now = Instant::now();
let bans = self.bans.read().await;
bans.values().filter(|expiry| now < **expiry).count()
}
pub async fn tracked_count(&self) -> usize {
self.failures.read().await.len()
}
pub fn config(&self) -> &AuthRateLimitConfig {
&self.config
}
pub fn is_whitelisted(&self, ip: &IpAddr) -> bool {
self.config.whitelist.contains(ip)
}
}
impl Clone for AuthRateLimiter {
fn clone(&self) -> Self {
Self {
failures: Arc::clone(&self.failures),
bans: Arc::clone(&self.bans),
config: self.config.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn test_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100))
}
fn test_ip2() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101))
}
fn localhost() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
}
#[tokio::test]
async fn test_failure_counting() {
let config = AuthRateLimitConfig::new(5, 300, 300);
let limiter = AuthRateLimiter::new(config);
let ip = test_ip();
for i in 1..5 {
let banned = limiter.record_failure(ip).await;
assert!(!banned, "Should not be banned after {i} failures");
assert_eq!(
limiter.remaining_attempts(&ip).await,
5 - i,
"Should have {} remaining attempts",
5 - i
);
}
}
#[tokio::test]
async fn test_ban_after_max_attempts() {
let config = AuthRateLimitConfig::new(3, 300, 300);
let limiter = AuthRateLimiter::new(config);
let ip = test_ip();
assert!(!limiter.record_failure(ip).await);
assert!(!limiter.record_failure(ip).await);
assert!(!limiter.is_banned(&ip).await);
assert!(limiter.record_failure(ip).await);
assert!(limiter.is_banned(&ip).await);
}
#[tokio::test]
async fn test_ban_expiration() {
let config = AuthRateLimitConfig::new(2, 300, 0); let limiter = AuthRateLimiter::new(config);
let ip = test_ip();
limiter.record_failure(ip).await;
assert!(limiter.record_failure(ip).await);
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(!limiter.is_banned(&ip).await);
}
#[tokio::test]
async fn test_whitelist_ips() {
let config = AuthRateLimitConfig::new(1, 300, 300).with_whitelist(vec![localhost()]);
let limiter = AuthRateLimiter::new(config);
let whitelisted = localhost();
let not_whitelisted = test_ip();
assert!(!limiter.record_failure(whitelisted).await);
assert!(!limiter.is_banned(&whitelisted).await);
assert!(limiter.record_failure(not_whitelisted).await);
assert!(limiter.is_banned(¬_whitelisted).await);
assert!(limiter.is_whitelisted(&whitelisted));
assert!(!limiter.is_whitelisted(¬_whitelisted));
}
#[tokio::test]
async fn test_success_resets_failures() {
let config = AuthRateLimitConfig::new(3, 300, 300);
let limiter = AuthRateLimiter::new(config);
let ip = test_ip();
limiter.record_failure(ip).await;
limiter.record_failure(ip).await;
assert_eq!(limiter.remaining_attempts(&ip).await, 1);
limiter.record_success(&ip).await;
assert_eq!(limiter.remaining_attempts(&ip).await, 3);
limiter.record_failure(ip).await;
limiter.record_failure(ip).await;
assert!(!limiter.is_banned(&ip).await);
limiter.record_failure(ip).await;
assert!(limiter.is_banned(&ip).await);
}
#[tokio::test]
async fn test_window_expiration() {
let config = AuthRateLimitConfig {
max_attempts: 3,
window: Duration::from_millis(50),
ban_duration: Duration::from_secs(300),
whitelist: HashSet::new(),
max_tracked_ips: 10000,
};
let limiter = AuthRateLimiter::new(config);
let ip = test_ip();
limiter.record_failure(ip).await;
limiter.record_failure(ip).await;
assert_eq!(limiter.remaining_attempts(&ip).await, 1);
tokio::time::sleep(Duration::from_millis(60)).await;
assert_eq!(limiter.remaining_attempts(&ip).await, 3);
assert!(!limiter.record_failure(ip).await);
}
#[tokio::test]
async fn test_cleanup() {
let config = AuthRateLimitConfig {
max_attempts: 2,
window: Duration::from_millis(10),
ban_duration: Duration::from_millis(10),
whitelist: HashSet::new(),
max_tracked_ips: 10000,
};
let limiter = AuthRateLimiter::new(config);
let ip1 = test_ip();
let ip2 = test_ip2();
limiter.record_failure(ip1).await;
limiter.record_failure(ip2).await;
limiter.record_failure(ip2).await;
assert_eq!(limiter.tracked_count().await, 1); assert_eq!(limiter.banned_count().await, 1);
tokio::time::sleep(Duration::from_millis(20)).await;
limiter.cleanup().await;
assert_eq!(limiter.tracked_count().await, 0);
assert_eq!(limiter.banned_count().await, 0);
}
#[tokio::test]
async fn test_manual_ban_unban() {
let config = AuthRateLimitConfig::new(5, 300, 300);
let limiter = AuthRateLimiter::new(config);
let ip = test_ip();
limiter.ban(ip).await;
assert!(limiter.is_banned(&ip).await);
limiter.unban(&ip).await;
assert!(!limiter.is_banned(&ip).await);
}
#[tokio::test]
async fn test_get_bans() {
let config = AuthRateLimitConfig::new(1, 300, 300);
let limiter = AuthRateLimiter::new(config);
let ip1 = test_ip();
let ip2 = test_ip2();
limiter.record_failure(ip1).await;
limiter.record_failure(ip2).await;
let bans = limiter.get_bans().await;
assert_eq!(bans.len(), 2);
let ips: Vec<IpAddr> = bans.iter().map(|(ip, _)| *ip).collect();
assert!(ips.contains(&ip1));
assert!(ips.contains(&ip2));
for (_, duration) in &bans {
assert!(duration.as_secs() > 0);
}
}
#[tokio::test]
async fn test_clone_shares_state() {
let config = AuthRateLimitConfig::new(3, 300, 300);
let limiter1 = AuthRateLimiter::new(config);
let limiter2 = limiter1.clone();
let ip = test_ip();
limiter1.record_failure(ip).await;
limiter1.record_failure(ip).await;
assert_eq!(limiter2.remaining_attempts(&ip).await, 1);
limiter2.record_failure(ip).await;
assert!(limiter1.is_banned(&ip).await);
}
#[tokio::test]
async fn test_per_ip_isolation() {
let config = AuthRateLimitConfig::new(2, 300, 300);
let limiter = AuthRateLimiter::new(config);
let ip1 = test_ip();
let ip2 = test_ip2();
limiter.record_failure(ip1).await;
assert_eq!(limiter.remaining_attempts(&ip2).await, 2);
assert!(!limiter.is_banned(&ip2).await);
limiter.record_failure(ip1).await;
assert!(limiter.is_banned(&ip1).await);
assert!(!limiter.is_banned(&ip2).await);
}
#[tokio::test]
async fn test_config_accessors() {
let config = AuthRateLimitConfig::new(10, 600, 1800);
let limiter = AuthRateLimiter::new(config);
assert_eq!(limiter.config().max_attempts, 10);
assert_eq!(limiter.config().window.as_secs(), 600);
assert_eq!(limiter.config().ban_duration.as_secs(), 1800);
}
#[tokio::test]
async fn test_capacity_limit() {
let config = AuthRateLimitConfig::new(5, 300, 300).with_max_tracked_ips(3);
let limiter = AuthRateLimiter::new(config);
let ip1: IpAddr = "192.168.1.1".parse().unwrap();
let ip2: IpAddr = "192.168.1.2".parse().unwrap();
let ip3: IpAddr = "192.168.1.3".parse().unwrap();
let ip4: IpAddr = "192.168.1.4".parse().unwrap();
limiter.record_failure(ip1).await;
limiter.record_failure(ip2).await;
limiter.record_failure(ip3).await;
assert_eq!(limiter.tracked_count().await, 3);
limiter.record_failure(ip4).await;
assert_eq!(limiter.tracked_count().await, 3);
assert_eq!(limiter.remaining_attempts(&ip4).await, 4);
}
}