use std::sync::OnceLock;
use std::time::{Duration, Instant};
use umbral::ratelimit::{Rate, RateLimiter};
#[derive(Debug)]
pub struct Throttle {
inner: RateLimiter,
}
impl Throttle {
pub fn new(max: usize, window: Duration) -> Self {
Self {
inner: RateLimiter::new(Rate::new(max as u32, window)),
}
}
pub fn check(&self, key: &str) -> bool {
self.inner.check(key).allowed
}
pub fn clear(&self, key: &str) {
self.inner.clear(key);
}
pub fn check_at(&self, key: &str, now: Instant) -> bool {
self.inner.check_at(key, now).allowed
}
pub fn clear_at(&self, key: &str) {
self.inner.clear(key);
}
}
#[derive(Debug, Clone, Copy)]
pub struct ThrottleConfig {
pub login_max: usize,
pub login_window: Duration,
pub register_max: usize,
pub register_window: Duration,
pub email_action_max: usize,
pub email_action_window: Duration,
pub enabled: bool,
}
impl Default for ThrottleConfig {
fn default() -> Self {
Self {
login_max: 5,
login_window: Duration::from_secs(5 * 60),
register_max: 10,
register_window: Duration::from_secs(60 * 60),
email_action_max: 5,
email_action_window: Duration::from_secs(60 * 60),
enabled: true,
}
}
}
#[derive(Debug)]
pub struct AuthThrottle {
config: ThrottleConfig,
login: Throttle,
register: Throttle,
email_action: Throttle,
}
impl AuthThrottle {
pub fn from_config(config: ThrottleConfig) -> Self {
Self {
login: Throttle::new(config.login_max, config.login_window),
register: Throttle::new(config.register_max, config.register_window),
email_action: Throttle::new(config.email_action_max, config.email_action_window),
config,
}
}
}
static AUTH_THROTTLE: OnceLock<AuthThrottle> = OnceLock::new();
pub(crate) fn install(throttle: AuthThrottle) {
let _ = AUTH_THROTTLE.set(throttle);
}
fn active() -> &'static AuthThrottle {
if let Some(t) = AUTH_THROTTLE.get() {
return t;
}
static FALLBACK: OnceLock<AuthThrottle> = OnceLock::new();
FALLBACK.get_or_init(|| AuthThrottle::from_config(ThrottleConfig::default()))
}
fn login_key(ip: &str, username: &str) -> String {
format!("{ip}\0{username}")
}
pub fn login_throttle_check(ip: &str, username: &str) -> bool {
let t = active();
if !t.config.enabled {
return true;
}
t.login.check(&login_key(ip, username))
}
pub fn login_throttle_clear(ip: &str, username: &str) {
active().login.clear(&login_key(ip, username));
}
pub fn register_throttle_check(ip: &str) -> bool {
let t = active();
if !t.config.enabled {
return true;
}
t.register.check(ip)
}
fn email_action_key(ip: &str, email: &str) -> String {
format!("{ip}\0{email}")
}
pub fn email_action_throttle_check(ip: &str, email: &str) -> bool {
let t = active();
if !t.config.enabled {
return true;
}
t.email_action.check(&email_action_key(ip, email))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn third_attempt_in_window_is_denied() {
let t = Throttle::new(2, Duration::from_secs(60));
let now = Instant::now();
assert!(t.check_at("k", now));
assert!(t.check_at("k", now));
assert!(!t.check_at("k", now));
}
#[test]
fn different_keys_are_independent() {
let t = Throttle::new(1, Duration::from_secs(60));
let now = Instant::now();
assert!(t.check_at("a", now));
assert!(!t.check_at("a", now));
assert!(t.check_at("b", now));
}
#[test]
fn window_elapse_re_allows() {
let t = Throttle::new(1, Duration::from_secs(60));
let now = Instant::now();
assert!(t.check_at("k", now));
assert!(!t.check_at("k", now));
let later = now + Duration::from_secs(61);
assert!(t.check_at("k", later));
}
#[test]
fn clear_resets_a_key() {
let t = Throttle::new(1, Duration::from_secs(60));
let now = Instant::now();
assert!(t.check_at("k", now));
assert!(!t.check_at("k", now));
t.clear_at("k");
assert!(t.check_at("k", now));
}
#[test]
fn max_zero_denies_everything() {
let t = Throttle::new(0, Duration::from_secs(60));
assert!(!t.check_at("k", Instant::now()));
}
#[test]
fn disabled_config_gate_short_circuits() {
let cfg = ThrottleConfig {
login_max: 1,
enabled: false,
..ThrottleConfig::default()
};
let store = AuthThrottle::from_config(cfg);
let now = Instant::now();
assert!(store.login.check_at("k", now)); assert!(!store.login.check_at("k", now)); assert!(!cfg.enabled, "gate is open when enabled == false");
}
}