use dashmap::DashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RateLimitOutcome {
pub remaining: u32,
pub limit: u32,
pub reset_after: u64,
}
#[derive(Debug, Clone)]
pub struct RateLimitRejection {
pub retry_after: u64,
pub limit: u32,
}
pub trait RateLimitBackend: Send + Sync + 'static {
fn check(
&self,
key: &str,
limit: u32,
window: Duration,
) -> Pin<Box<dyn Future<Output = Result<RateLimitOutcome, RateLimitRejection>> + Send + '_>>;
}
#[derive(Debug, Clone)]
pub struct InMemoryBackend {
entries: Arc<DashMap<String, WindowEntry>>,
}
#[derive(Debug, Clone)]
struct WindowEntry {
count: u32,
window_start: Instant,
}
impl InMemoryBackend {
pub fn new() -> Self {
Self {
entries: Arc::new(DashMap::new()),
}
}
}
impl Default for InMemoryBackend {
fn default() -> Self {
Self::new()
}
}
impl RateLimitBackend for InMemoryBackend {
fn check(
&self,
key: &str,
limit: u32,
window: Duration,
) -> Pin<Box<dyn Future<Output = Result<RateLimitOutcome, RateLimitRejection>> + Send + '_>>
{
let now = Instant::now();
let mut entry = self.entries.entry(key.to_string()).or_insert(WindowEntry {
count: 0,
window_start: now,
});
if now.duration_since(entry.window_start) >= window {
entry.count = 0;
entry.window_start = now;
}
let reset_after = window
.checked_sub(now.duration_since(entry.window_start))
.unwrap_or(Duration::ZERO)
.as_secs()
.max(1);
if entry.count >= limit {
let rejection = RateLimitRejection {
retry_after: reset_after,
limit,
};
return Box::pin(async move { Err(rejection) });
}
entry.count += 1;
let outcome = RateLimitOutcome {
remaining: limit - entry.count,
limit,
reset_after,
};
Box::pin(async move { Ok(outcome) })
}
}