use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AuthBucket {
Login,
Send,
Verify,
}
impl AuthBucket {
fn caps(&self) -> (u32, u32) {
match self {
Self::Login => (5, 30),
Self::Send => (3, 10),
Self::Verify => (30, 100),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitDecision {
Allow,
Deny { retry_after_secs: u64 },
}
#[derive(Debug, Clone, Copy)]
struct Counter {
window_start: u64,
count: u32,
}
pub struct AuthRateLimiter {
per_ip: Mutex<HashMap<(AuthBucket, String), Counter>>,
per_account: Mutex<HashMap<(AuthBucket, String), Counter>>,
}
impl Default for AuthRateLimiter {
fn default() -> Self {
Self {
per_ip: Mutex::new(HashMap::new()),
per_account: Mutex::new(HashMap::new()),
}
}
}
impl AuthRateLimiter {
pub fn new() -> Self {
Self::default()
}
pub fn shared() -> &'static AuthRateLimiter {
static CELL: std::sync::OnceLock<AuthRateLimiter> = std::sync::OnceLock::new();
CELL.get_or_init(AuthRateLimiter::default)
}
pub fn check(
&self,
bucket: AuthBucket,
ip: &str,
account_key: Option<&str>,
) -> RateLimitDecision {
let (ip_cap, acct_cap) = bucket.caps();
let now = now_secs();
if let Some(retry) = bump(&self.per_ip, (bucket, ip.to_string()), 60, ip_cap, now) {
return RateLimitDecision::Deny {
retry_after_secs: retry,
};
}
if let Some(key) = account_key {
if let Some(retry) = bump(
&self.per_account,
(bucket, key.to_ascii_lowercase()),
3600,
acct_cap,
now,
) {
return RateLimitDecision::Deny {
retry_after_secs: retry,
};
}
}
RateLimitDecision::Allow
}
}
fn bump(
map: &Mutex<HashMap<(AuthBucket, String), Counter>>,
key: (AuthBucket, String),
window_secs: u64,
cap: u32,
now: u64,
) -> Option<u64> {
let mut g = map.lock().unwrap();
let entry = g.entry(key).or_insert(Counter {
window_start: now,
count: 0,
});
if now >= entry.window_start + window_secs {
entry.window_start = now;
entry.count = 0;
}
if entry.count >= cap {
return Some(entry.window_start + window_secs - now);
}
entry.count += 1;
None
}
fn now_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_within_cap() {
let rl = AuthRateLimiter::new();
for _ in 0..5 {
assert_eq!(
rl.check(AuthBucket::Login, "1.2.3.4", Some("a@b.com")),
RateLimitDecision::Allow
);
}
}
#[test]
fn denies_after_per_ip_cap() {
let rl = AuthRateLimiter::new();
let bucket = AuthBucket::Login;
let (ip_cap, _) = bucket.caps();
for _ in 0..ip_cap {
assert_eq!(rl.check(bucket, "1.2.3.4", None), RateLimitDecision::Allow);
}
match rl.check(bucket, "1.2.3.4", None) {
RateLimitDecision::Deny { retry_after_secs } => assert!(retry_after_secs <= 60),
_ => panic!("expected Deny"),
}
}
#[test]
fn per_account_cap_independent_of_ip() {
let rl = AuthRateLimiter::new();
let bucket = AuthBucket::Send;
let (_, acct_cap) = bucket.caps();
for i in 0..acct_cap {
let ip = format!("10.0.0.{i}");
assert_eq!(rl.check(bucket, &ip, Some("victim@x.com")), RateLimitDecision::Allow);
}
let result = rl.check(bucket, "10.0.0.99", Some("victim@x.com"));
assert!(matches!(result, RateLimitDecision::Deny { .. }));
}
#[test]
fn account_key_lowercased() {
let rl = AuthRateLimiter::new();
let bucket = AuthBucket::Send;
let (_, acct_cap) = bucket.caps();
for i in 0..acct_cap {
let ip = format!("10.0.0.{i}");
let _ = rl.check(bucket, &ip, Some("a@b.com"));
}
let result = rl.check(bucket, "172.16.0.1", Some("A@B.COM"));
assert!(matches!(result, RateLimitDecision::Deny { .. }));
}
}