use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::warn;
const MAX_ATTEMPTS: u32 = 5;
const LOCKOUT_DURATION: Duration = Duration::from_secs(300); const BASE_DELAY_MS: u64 = 1000; const DEFAULT_RETENTION: Duration = Duration::from_secs(3600);
#[derive(Debug)]
struct ClientState {
failed_attempts: u32,
last_attempt: Instant,
locked_until: Option<Instant>,
}
impl ClientState {
fn new() -> Self {
Self {
failed_attempts: 0,
last_attempt: Instant::now(),
locked_until: None,
}
}
}
pub struct RateLimiter {
clients: Mutex<HashMap<String, ClientState>>,
retention: Duration,
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(DEFAULT_RETENTION)
}
}
impl RateLimiter {
pub fn new(retention: Duration) -> Self {
Self {
clients: Mutex::new(HashMap::new()),
retention,
}
}
fn evict_stale(&self, clients: &mut HashMap<String, ClientState>) {
let now = Instant::now();
let retention = self.retention;
clients.retain(|_, state| {
let within_retention = now.duration_since(state.last_attempt) <= retention;
let still_locked = state
.locked_until
.is_some_and(|locked_until| now < locked_until);
within_retention || still_locked
});
}
pub async fn check_rate_limit(&self, addr: &SocketAddr) -> Result<(), String> {
let key = addr.ip().to_string();
let mut clients = self.clients.lock().await;
self.evict_stale(&mut clients);
let state = clients.entry(key.clone()).or_insert_with(ClientState::new);
if let Some(locked_until) = state.locked_until {
if Instant::now() < locked_until {
let remaining = locked_until.duration_since(Instant::now());
warn!(
"Rate limit: client {} is locked out for {} more seconds",
key,
remaining.as_secs()
);
return Err(format!(
"Too many failed attempts. Locked out for {} seconds.",
remaining.as_secs()
));
}
state.failed_attempts = 0;
state.locked_until = None;
}
if state.failed_attempts > 0 {
let delay_ms = BASE_DELAY_MS * 2u64.pow(state.failed_attempts.saturating_sub(1));
let backoff_deadline = state.last_attempt + Duration::from_millis(delay_ms);
let now = Instant::now();
if now < backoff_deadline {
let remaining = backoff_deadline.duration_since(now);
warn!(
"Rate limit: client {} is in backoff for {} more seconds",
key,
remaining.as_secs()
);
return Err(format!(
"Rate limited. Try again in {} seconds.",
remaining.as_secs() + 1
));
}
}
Ok(())
}
pub async fn record_failure(&self, addr: &SocketAddr) {
let key = addr.ip().to_string();
let mut clients = self.clients.lock().await;
self.evict_stale(&mut clients);
let state = clients.entry(key.clone()).or_insert_with(ClientState::new);
state.failed_attempts += 1;
state.last_attempt = Instant::now();
if state.failed_attempts >= MAX_ATTEMPTS {
state.locked_until = Some(Instant::now() + LOCKOUT_DURATION);
warn!(
"Rate limit: client {} locked out for {} seconds after {} failed attempts",
key,
LOCKOUT_DURATION.as_secs(),
state.failed_attempts
);
} else {
let delay_ms = BASE_DELAY_MS * 2u64.pow(state.failed_attempts.saturating_sub(1));
warn!(
"Rate limit: client {} failed attempt #{}, applying {}ms delay",
key, state.failed_attempts, delay_ms
);
drop(clients);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
pub async fn record_success(&self, addr: &SocketAddr) {
let key = addr.ip().to_string();
let mut clients = self.clients.lock().await;
self.evict_stale(&mut clients);
clients.remove(&key);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn test_addr(last_octet: u8) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, last_octet)), 50051)
}
#[tokio::test]
async fn test_first_attempt_allowed() {
let limiter = RateLimiter::new(DEFAULT_RETENTION);
assert!(limiter.check_rate_limit(&test_addr(1)).await.is_ok());
}
#[tokio::test]
async fn test_lockout_after_max_attempts() {
let limiter = RateLimiter::new(DEFAULT_RETENTION);
let addr = test_addr(2);
for _ in 0..MAX_ATTEMPTS {
limiter.record_failure(&addr).await;
}
assert!(limiter.check_rate_limit(&addr).await.is_err());
}
#[tokio::test]
async fn test_success_resets_state() {
let limiter = RateLimiter::new(DEFAULT_RETENTION);
let addr = test_addr(3);
limiter.record_failure(&addr).await;
limiter.record_failure(&addr).await;
limiter.record_success(&addr).await;
assert!(limiter.check_rate_limit(&addr).await.is_ok());
}
#[tokio::test]
async fn test_different_ips_independent() {
let limiter = RateLimiter::new(DEFAULT_RETENTION);
let addr1 = test_addr(4);
let addr2 = test_addr(5);
for _ in 0..MAX_ATTEMPTS {
limiter.record_failure(&addr1).await;
}
assert!(limiter.check_rate_limit(&addr1).await.is_err());
assert!(limiter.check_rate_limit(&addr2).await.is_ok());
}
#[tokio::test]
async fn test_stale_entries_evicted() {
let limiter = RateLimiter::new(Duration::ZERO);
let addr = test_addr(6);
limiter.record_failure(&addr).await;
assert!(limiter.check_rate_limit(&addr).await.is_ok());
let clients = limiter.clients.lock().await;
let state = clients.get(&addr.ip().to_string());
assert!(state.is_none() || state.unwrap().failed_attempts == 0);
}
}