use crate::error::{Result, TidewayError};
use governor::{
Quota, RateLimiter, clock::DefaultClock, middleware::NoOpMiddleware,
state::keyed::DashMapStateStore,
};
use std::{
num::NonZeroU32,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
const SHRINK_INTERVAL: u64 = 1000;
#[derive(Clone, Debug)]
pub struct LoginRateLimitConfig {
pub max_attempts: u32,
pub window_seconds: u64,
}
impl Default for LoginRateLimitConfig {
fn default() -> Self {
Self {
max_attempts: 5,
window_seconds: 900, }
}
}
impl LoginRateLimitConfig {
pub fn new(max_attempts: u32, window_seconds: u64) -> Self {
Self {
max_attempts,
window_seconds,
}
}
pub fn strict() -> Self {
Self {
max_attempts: 3,
window_seconds: 1800, }
}
pub fn lenient() -> Self {
Self {
max_attempts: 10,
window_seconds: 900, }
}
}
type KeyedLimiter = RateLimiter<String, DashMapStateStore<String>, DefaultClock, NoOpMiddleware>;
#[derive(Clone)]
pub struct LoginRateLimiter {
limiter: Arc<KeyedLimiter>,
config: LoginRateLimitConfig,
request_count: Arc<AtomicU64>,
}
impl LoginRateLimiter {
pub fn new(config: LoginRateLimitConfig) -> Self {
let max_attempts = NonZeroU32::new(config.max_attempts.max(1)).unwrap_or(NonZeroU32::MIN);
let quota = Quota::with_period(Duration::from_secs(config.window_seconds.max(1)))
.unwrap_or_else(|| Quota::per_second(max_attempts))
.allow_burst(max_attempts);
Self {
limiter: Arc::new(RateLimiter::keyed(quota)),
config,
request_count: Arc::new(AtomicU64::new(0)),
}
}
pub fn check(&self, ip: &str) -> std::result::Result<(), u64> {
let count = self.request_count.fetch_add(1, Ordering::Relaxed);
if count % SHRINK_INTERVAL == 0 && count > 0 {
self.limiter.retain_recent();
}
match self.limiter.check_key(&ip.to_string()) {
Ok(_) => Ok(()),
Err(not_until) => {
let wait =
not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
Err(wait.as_secs().max(1))
}
}
}
pub fn config(&self) -> &LoginRateLimitConfig {
&self.config
}
}
pub trait OptionalRateLimiter: Send + Sync + Clone {
fn check_rate_limit(&self, ip: Option<&str>) -> Result<()>;
}
impl OptionalRateLimiter for () {
fn check_rate_limit(&self, _ip: Option<&str>) -> Result<()> {
Ok(())
}
}
#[derive(Clone)]
pub struct WithRateLimiter(pub LoginRateLimiter);
impl OptionalRateLimiter for WithRateLimiter {
fn check_rate_limit(&self, ip: Option<&str>) -> Result<()> {
let Some(ip) = ip else {
return Ok(());
};
match self.0.check(ip) {
Ok(()) => Ok(()),
Err(retry_after) => {
tracing::warn!(
target: "auth.login.rate_limited",
ip = %ip,
retry_after_secs = retry_after,
max_attempts = self.0.config.max_attempts,
window_secs = self.0.config.window_seconds,
"Login rate limited"
);
Err(TidewayError::TooManyRequests(format!(
"Too many login attempts. Please try again in {} seconds.",
retry_after
)))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_allows_requests_under_limit() {
let config = LoginRateLimitConfig::new(5, 60);
let limiter = LoginRateLimiter::new(config);
for i in 0..5 {
let result = limiter.check("192.168.1.1");
assert!(result.is_ok(), "Request {} should be allowed", i + 1);
}
}
#[test]
fn test_rate_limit_blocks_requests_over_limit() {
let config = LoginRateLimitConfig::new(5, 60);
let limiter = LoginRateLimiter::new(config);
for _ in 0..5 {
limiter.check("192.168.1.1").unwrap();
}
let result = limiter.check("192.168.1.1");
assert!(result.is_err(), "6th request should be blocked");
}
#[test]
fn test_rate_limit_per_ip_isolation() {
let config = LoginRateLimitConfig::new(5, 60);
let limiter = LoginRateLimiter::new(config);
for _ in 0..5 {
limiter.check("192.168.1.1").unwrap();
}
let result = limiter.check("192.168.1.2");
assert!(result.is_ok(), "Different IP should have separate quota");
}
#[test]
fn test_rate_limit_returns_retry_after() {
let config = LoginRateLimitConfig::new(1, 60);
let limiter = LoginRateLimiter::new(config);
limiter.check("192.168.1.1").unwrap();
let result = limiter.check("192.168.1.1");
assert!(result.is_err());
if let Err(retry_after) = result {
assert!(retry_after > 0, "Should return positive retry_after");
assert!(retry_after <= 60, "retry_after should be within window");
}
}
#[test]
fn test_optional_rate_limiter_noop() {
let noop: () = ();
assert!(noop.check_rate_limit(Some("192.168.1.1")).is_ok());
assert!(noop.check_rate_limit(None).is_ok());
}
#[test]
fn test_optional_rate_limiter_with_limiter() {
let config = LoginRateLimitConfig::new(2, 60);
let limiter = WithRateLimiter(LoginRateLimiter::new(config));
assert!(limiter.check_rate_limit(Some("192.168.1.1")).is_ok());
assert!(limiter.check_rate_limit(Some("192.168.1.1")).is_ok());
assert!(limiter.check_rate_limit(Some("192.168.1.1")).is_err());
assert!(limiter.check_rate_limit(None).is_ok());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let config = LoginRateLimitConfig::new(100, 60);
let limiter = LoginRateLimiter::new(config);
let mut handles = vec![];
for i in 0..10 {
let limiter = limiter.clone();
handles.push(thread::spawn(move || {
for j in 0..50 {
let ip = format!("192.168.{}.{}", i, j % 256);
let _ = limiter.check(&ip);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let result = limiter.check("10.0.0.1");
assert!(result.is_ok());
}
#[test]
fn test_config_presets() {
let default = LoginRateLimitConfig::default();
assert_eq!(default.max_attempts, 5);
assert_eq!(default.window_seconds, 900);
let strict = LoginRateLimitConfig::strict();
assert_eq!(strict.max_attempts, 3);
assert_eq!(strict.window_seconds, 1800);
let lenient = LoginRateLimitConfig::lenient();
assert_eq!(lenient.max_attempts, 10);
assert_eq!(lenient.window_seconds, 900);
}
}