use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
use std::time::{SystemTime, UNIX_EPOCH};
use crate::auth::guards::Guard;
use crate::web::{Error, RequestContext};
pub struct RateLimit {
state: AtomicU64, window_secs: u32,
max_per_window: u32,
}
impl RateLimit {
pub const fn new(max_per_window: u32, window_secs: u32) -> Self {
Self {
state: AtomicU64::new(0),
window_secs,
max_per_window,
}
}
fn now_secs() -> u32 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as u32)
.unwrap_or(0)
}
}
impl RateLimit {
fn admit_at(&self, now: u32) -> bool {
loop {
let cur = self.state.load(Relaxed);
let (start, count) = ((cur >> 32) as u32, cur as u32);
let (new_start, new_count) = if now.saturating_sub(start) >= self.window_secs {
(now, 1)
} else {
(start, count.saturating_add(1))
};
if new_count > self.max_per_window {
return false;
}
let next = ((new_start as u64) << 32) | (new_count as u64);
if self
.state
.compare_exchange(cur, next, Relaxed, Relaxed)
.is_ok()
{
return true;
}
}
}
}
impl Guard for RateLimit {
fn check(&self, _ctx: &RequestContext) -> Result<(), Error> {
if self.admit_at(Self::now_secs()) {
Ok(())
} else {
metrics::counter!("rate_limited_total").increment(1);
Err(Error::TooManyRequests)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn enforces_limit_and_resets_on_window_roll() {
let rl = RateLimit::new(2, 60);
assert!(rl.admit_at(1_000));
assert!(rl.admit_at(1_000));
assert!(!rl.admit_at(1_030), "third hit in the window is rejected");
assert!(rl.admit_at(1_060));
assert!(rl.admit_at(1_061));
assert!(!rl.admit_at(1_062));
}
}