use std::time::Duration;
use super::rate::Rate;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Decision {
pub allowed: bool,
pub limit: u64,
pub remaining: u64,
pub retry_after: Option<Duration>,
}
#[async_trait::async_trait]
pub trait RateLimitStore: Send + Sync {
async fn hit(&self, key: &str, rate: &Rate) -> Decision;
}
#[derive(Debug)]
pub struct MemoryStore {
windows: dashmap::DashMap<String, Window>,
base: std::time::Instant,
last_sweep_ms: std::sync::atomic::AtomicU64,
}
#[derive(Debug, Clone, Copy)]
struct Window {
expires_at: std::time::Instant,
count: u64,
}
const SWEEP_INTERVAL_MS: u64 = 60_000;
impl Default for MemoryStore {
fn default() -> Self {
Self::new()
}
}
impl MemoryStore {
pub fn new() -> Self {
Self {
windows: dashmap::DashMap::new(),
base: std::time::Instant::now(),
last_sweep_ms: std::sync::atomic::AtomicU64::new(0),
}
}
fn evict_expired(&self, now: std::time::Instant) {
self.windows.retain(|_, w| w.expires_at > now);
}
fn maybe_sweep(&self, now: std::time::Instant) {
use std::sync::atomic::Ordering;
let now_ms = now.duration_since(self.base).as_millis() as u64;
let last = self.last_sweep_ms.load(Ordering::Relaxed);
if now_ms.saturating_sub(last) < SWEEP_INTERVAL_MS {
return;
}
if self
.last_sweep_ms
.compare_exchange(last, now_ms, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.evict_expired(now);
}
}
}
#[async_trait::async_trait]
impl RateLimitStore for MemoryStore {
async fn hit(&self, key: &str, rate: &Rate) -> Decision {
let now = std::time::Instant::now();
self.maybe_sweep(now);
let mut entry = self.windows.entry(key.to_string()).or_insert(Window {
expires_at: now + rate.window,
count: 0,
});
if now >= entry.expires_at {
entry.expires_at = now + rate.window;
entry.count = 0;
}
entry.count += 1;
let count = entry.count;
let expires_at = entry.expires_at;
drop(entry);
let allowed = count <= rate.limit;
Decision {
allowed,
limit: rate.limit,
remaining: rate.limit.saturating_sub(count),
retry_after: (!allowed).then(|| expires_at.saturating_duration_since(now)),
}
}
}
#[cfg(feature = "redis")]
pub struct RedisStore {
client: redis::Client,
conn: tokio::sync::OnceCell<redis::aio::MultiplexedConnection>,
}
#[cfg(feature = "redis")]
impl RedisStore {
pub fn open(url: &str) -> redis::RedisResult<Self> {
Ok(Self {
client: redis::Client::open(url)?,
conn: tokio::sync::OnceCell::new(),
})
}
async fn connection(&self) -> redis::RedisResult<redis::aio::MultiplexedConnection> {
self.conn
.get_or_try_init(|| self.client.get_multiplexed_async_connection())
.await
.cloned()
}
fn fail_open(rate: &Rate, err: redis::RedisError) -> Decision {
tracing::warn!("rate-limit store unavailable, allowing request: {err}");
Decision {
allowed: true,
limit: rate.limit,
remaining: rate.limit,
retry_after: None,
}
}
}
#[cfg(feature = "redis")]
#[async_trait::async_trait]
impl RateLimitStore for RedisStore {
async fn hit(&self, key: &str, rate: &Rate) -> Decision {
let window_ms = rate.window.as_millis().max(1) as u64;
let mut conn = match self.connection().await {
Ok(c) => c,
Err(e) => return Self::fail_open(rate, e),
};
let script = redis::Script::new(
r"local c = redis.call('INCR', KEYS[1])
if redis.call('PTTL', KEYS[1]) < 0 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
end
return {c, redis.call('PTTL', KEYS[1])}",
);
let (count, pttl): (i64, i64) =
match script.key(key).arg(window_ms).invoke_async(&mut conn).await {
Ok(v) => v,
Err(e) => return Self::fail_open(rate, e),
};
let count = count.max(0) as u64;
let allowed = count <= rate.limit;
Decision {
allowed,
limit: rate.limit,
remaining: rate.limit.saturating_sub(count),
retry_after: (!allowed).then(|| Duration::from_millis(pttl.max(0) as u64)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn memory_store_evicts_expired_entries() {
let store = MemoryStore::new();
let rate = Rate {
limit: 5,
window: Duration::from_millis(20),
};
store.hit("a", &rate).await;
store.hit("b", &rate).await;
assert_eq!(store.windows.len(), 2);
tokio::time::sleep(Duration::from_millis(30)).await;
store.evict_expired(std::time::Instant::now());
assert_eq!(store.windows.len(), 0);
}
#[tokio::test]
async fn memory_store_allows_then_blocks_within_window() {
let store = MemoryStore::new();
let rate = Rate {
limit: 2,
window: Duration::from_secs(60),
};
let d1 = store.hit("k", &rate).await;
assert!(d1.allowed && d1.remaining == 1);
let d2 = store.hit("k", &rate).await;
assert!(d2.allowed && d2.remaining == 0);
let d3 = store.hit("k", &rate).await;
assert!(!d3.allowed);
assert_eq!(d3.remaining, 0);
assert!(d3.retry_after.is_some());
}
#[tokio::test]
async fn memory_store_resets_after_window() {
let store = MemoryStore::new();
let rate = Rate {
limit: 1,
window: Duration::from_millis(50),
};
assert!(store.hit("k", &rate).await.allowed);
assert!(!store.hit("k", &rate).await.allowed);
tokio::time::sleep(Duration::from_millis(60)).await;
assert!(store.hit("k", &rate).await.allowed);
}
#[tokio::test]
async fn memory_store_isolates_keys() {
let store = MemoryStore::new();
let rate = Rate {
limit: 1,
window: Duration::from_secs(60),
};
assert!(store.hit("a", &rate).await.allowed);
assert!(store.hit("b", &rate).await.allowed);
assert!(!store.hit("a", &rate).await.allowed);
}
}