use dashmap::DashMap;
use std::{
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use super::types::{RateLimitConfig, RateLimitEntry, RateLimitResult};
const DEFAULT_MAX_ENTRIES: usize = 100_000;
const CAPACITY_WARNING_THRESHOLD: f64 = 0.90;
const EVICTION_WARNING_INTERVAL: u64 = 10;
const MAX_EVICT_PER_CYCLE: usize = 2_000;
const EVICT_SAMPLE_MULTIPLIER: usize = 10;
#[derive(Debug, Clone)]
pub struct RateLimitStore {
entries: Arc<DashMap<String, RateLimitEntry>>,
cleanup_interval: Duration,
last_cleanup: Arc<Mutex<Instant>>,
max_entries: usize,
eviction_count: Arc<std::sync::atomic::AtomicU64>,
}
impl RateLimitStore {
fn parse_replicas_hint(value: Option<&str>) -> u32 {
value
.and_then(|v| v.parse::<u32>().ok())
.filter(|v| *v > 0)
.unwrap_or(1)
}
pub(crate) fn replicas_hint() -> u32 {
Self::parse_replicas_hint(std::env::var("REPLICAS").ok().as_deref())
}
pub fn new() -> Self {
Self::with_max_entries(DEFAULT_MAX_ENTRIES)
}
pub fn with_max_entries(max_entries: usize) -> Self {
tracing::info!(
max_entries = max_entries,
"Rate limiter initialized with in-memory store (not shared across instances)"
);
let replicas_hint = Self::replicas_hint();
let is_multi_instance = replicas_hint > 1
|| std::env::var("KUBERNETES_SERVICE_HOST").is_ok()
|| std::env::var("DYNO").is_ok() || std::env::var("FLY_APP_NAME").is_ok() || std::env::var("RENDER").is_ok();
if is_multi_instance {
tracing::warn!(
replicas = replicas_hint,
"Multi-instance environment detected! Rate limiter uses in-memory store. \
Each instance has independent counters, effectively multiplying allowed \
requests by instance count. Consider Redis backend or adjust limits."
);
}
Self {
entries: Arc::new(DashMap::new()),
cleanup_interval: Duration::from_secs(300), last_cleanup: Arc::new(Mutex::new(Instant::now())),
max_entries,
eviction_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
pub async fn check_and_record(&self, key: &str, config: &RateLimitConfig) -> RateLimitResult {
let now = Instant::now();
let window = Duration::from_secs(config.window_secs);
self.maybe_cleanup(now).await;
let entries_len = self.entries.len();
let warning_threshold = (self.max_entries as f64 * CAPACITY_WARNING_THRESHOLD) as usize;
if entries_len >= warning_threshold && entries_len < self.max_entries {
if !self.entries.contains_key(key) {
tracing::warn!(
current_entries = entries_len,
max_entries = self.max_entries,
threshold_pct = CAPACITY_WARNING_THRESHOLD * 100.0,
"Rate limit store approaching capacity - consider scaling or investigating traffic"
);
}
}
if entries_len >= self.max_entries && !self.entries.contains_key(key) {
self.evict_oldest(now);
}
let mut entry = self.entries.entry(key.to_string()).or_default();
entry.last_access = now;
let elapsed = now.duration_since(entry.window_start);
if elapsed >= window {
if elapsed >= window * 2 {
entry.prev_count = 0;
entry.curr_count = 0;
} else {
entry.prev_count = entry.curr_count;
entry.curr_count = 0;
}
let windows_passed = elapsed.as_secs() / config.window_secs;
entry.window_start += Duration::from_secs(windows_passed * config.window_secs);
}
let elapsed_in_window = now.duration_since(entry.window_start);
let elapsed_ratio = elapsed_in_window.as_secs_f64() / config.window_secs as f64;
let elapsed_ratio = elapsed_ratio.clamp(0.0, 1.0);
let estimated_count = (entry.prev_count as f64 * (1.0 - elapsed_ratio)
+ entry.curr_count as f64)
.ceil() as u32;
let time_until_reset = window.saturating_sub(elapsed_in_window);
let reset_secs = time_until_reset.as_secs();
if estimated_count >= config.limit {
RateLimitResult {
allowed: false,
limit: config.limit,
remaining: 0,
reset_secs,
}
} else {
entry.curr_count += 1;
RateLimitResult {
allowed: true,
limit: config.limit,
remaining: config.limit.saturating_sub(estimated_count + 1),
reset_secs,
}
}
}
async fn maybe_cleanup(&self, now: Instant) {
let should_cleanup = {
let last = self.last_cleanup.lock().await;
now.duration_since(*last) > self.cleanup_interval
};
if should_cleanup {
let mut last = self.last_cleanup.lock().await;
*last = now;
drop(last);
let stale_threshold = Duration::from_secs(600); self.entries
.retain(|_, entry| now.duration_since(entry.last_access) < stale_threshold);
}
}
fn evict_oldest(&self, now: Instant) {
let entries_len = self.entries.len();
let evict_count = std::cmp::max(1, self.max_entries / 5)
.min(entries_len)
.min(MAX_EVICT_PER_CYCLE);
let prev_count = self
.eviction_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let total_evictions = prev_count + 1;
if total_evictions % EVICTION_WARNING_INTERVAL == 1 || total_evictions == 1 {
tracing::warn!(
current_entries = entries_len,
max_entries = self.max_entries,
evicting = evict_count,
total_evictions = total_evictions,
"Rate limit store at capacity, evicting oldest entries (sustained pressure)"
);
} else {
tracing::debug!(
current_entries = entries_len,
evicting = evict_count,
total_evictions = total_evictions,
"Rate limit store eviction"
);
}
let sample_size = (evict_count.saturating_mul(EVICT_SAMPLE_MULTIPLIER)).min(entries_len);
let mut by_age: Vec<_> = self
.entries
.iter()
.take(sample_size)
.map(|entry| {
(
entry.key().clone(),
now.duration_since(entry.value().last_access),
)
})
.collect();
if by_age.len() <= evict_count {
self.entries.clear();
return;
}
let (oldest, nth, _) = by_age.select_nth_unstable_by(evict_count - 1, |a, b| b.1.cmp(&a.1));
for (key, _) in oldest.iter().chain(std::iter::once(&*nth)) {
self.entries.remove(key.as_str());
}
}
}
impl Default for RateLimitStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_replicas_hint_defaults_to_one() {
assert_eq!(RateLimitStore::parse_replicas_hint(None), 1);
assert_eq!(RateLimitStore::parse_replicas_hint(Some("")), 1);
assert_eq!(RateLimitStore::parse_replicas_hint(Some("not-a-number")), 1);
assert_eq!(RateLimitStore::parse_replicas_hint(Some("0")), 1);
}
#[test]
fn test_parse_replicas_hint_parses_positive_int() {
assert_eq!(RateLimitStore::parse_replicas_hint(Some("1")), 1);
assert_eq!(RateLimitStore::parse_replicas_hint(Some("2")), 2);
}
#[tokio::test]
async fn test_rate_limit_allows_within_limit() {
let store = RateLimitStore::new();
let config = RateLimitConfig {
limit: 5,
window_secs: 60,
};
for i in 0..5 {
let result = store.check_and_record("test-key", &config).await;
assert!(result.allowed, "Request {} should be allowed", i);
assert_eq!(result.remaining, 4 - i);
}
}
#[tokio::test]
async fn test_rate_limit_blocks_over_limit() {
let store = RateLimitStore::new();
let config = RateLimitConfig {
limit: 3,
window_secs: 60,
};
for _ in 0..3 {
let result = store.check_and_record("test-key", &config).await;
assert!(result.allowed);
}
let result = store.check_and_record("test-key", &config).await;
assert!(!result.allowed);
assert_eq!(result.remaining, 0);
}
#[tokio::test]
async fn test_rate_limit_separate_keys() {
let store = RateLimitStore::new();
let config = RateLimitConfig {
limit: 2,
window_secs: 60,
};
for _ in 0..2 {
store.check_and_record("key1", &config).await;
}
let result = store.check_and_record("key1", &config).await;
assert!(!result.allowed);
let result = store.check_and_record("key2", &config).await;
assert!(result.allowed);
assert_eq!(result.remaining, 1);
}
#[tokio::test]
async fn test_rate_limit_result_fields() {
let store = RateLimitStore::new();
let config = RateLimitConfig {
limit: 10,
window_secs: 60,
};
let result = store.check_and_record("test-key", &config).await;
assert!(result.allowed);
assert_eq!(result.limit, 10);
assert_eq!(result.remaining, 9);
assert!(result.reset_secs <= 60);
}
#[tokio::test]
async fn test_rate_limit_blocked_has_zero_remaining() {
let store = RateLimitStore::new();
let config = RateLimitConfig {
limit: 1,
window_secs: 60,
};
let result = store.check_and_record("test-key", &config).await;
assert!(result.allowed);
assert_eq!(result.remaining, 0);
let result = store.check_and_record("test-key", &config).await;
assert!(!result.allowed);
assert_eq!(result.remaining, 0);
assert_eq!(result.limit, 1);
}
#[tokio::test]
async fn test_rate_limit_store_clone() {
let store1 = RateLimitStore::new();
let store2 = store1.clone();
let config = RateLimitConfig {
limit: 2,
window_secs: 60,
};
store1.check_and_record("shared-key", &config).await;
let result = store2.check_and_record("shared-key", &config).await;
assert!(result.allowed);
assert_eq!(result.remaining, 0);
}
#[tokio::test]
async fn test_rate_limit_multiple_configs() {
let store = RateLimitStore::new();
let strict_config = RateLimitConfig {
limit: 2,
window_secs: 60,
};
let relaxed_config = RateLimitConfig {
limit: 100,
window_secs: 60,
};
store.check_and_record("key", &strict_config).await;
store.check_and_record("key", &strict_config).await;
let result = store.check_and_record("key", &strict_config).await;
assert!(!result.allowed);
let result = store.check_and_record("key", &relaxed_config).await;
assert!(result.allowed);
assert_eq!(result.remaining, 97);
}
#[test]
fn test_rate_limit_store_default() {
let store = RateLimitStore::default();
assert!(std::sync::Arc::strong_count(&store.entries) >= 1);
}
#[tokio::test]
async fn test_rate_limit_concurrent_access() {
let store = RateLimitStore::new();
let config = RateLimitConfig {
limit: 100,
window_secs: 60,
};
let mut handles = vec![];
for i in 0..10 {
let store = store.clone();
let config = config.clone();
handles.push(tokio::spawn(async move {
store
.check_and_record(&format!("concurrent-{}", i % 3), &config)
.await
}));
}
for handle in handles {
let result = handle.await.unwrap();
assert!(result.allowed);
}
}
#[tokio::test]
async fn test_rate_limit_max_entries_eviction() {
let store = RateLimitStore::with_max_entries(10);
let config = RateLimitConfig {
limit: 100,
window_secs: 60,
};
for i in 0..10 {
store.check_and_record(&format!("key-{}", i), &config).await;
}
assert_eq!(store.entries.len(), 10);
store.check_and_record("key-new", &config).await;
assert!(store.entries.len() <= 10, "Should not exceed max entries");
assert!(
store.entries.contains_key("key-new"),
"New key should be present"
);
}
#[tokio::test]
async fn test_rate_limit_eviction_removes_oldest() {
let store = RateLimitStore::with_max_entries(5);
let config = RateLimitConfig {
limit: 100,
window_secs: 60,
};
for i in 0..5 {
store.check_and_record(&format!("key-{}", i), &config).await;
}
store.check_and_record("key-4", &config).await;
store.check_and_record("key-new", &config).await;
assert!(
store.entries.contains_key("key-4"),
"Recently accessed key should remain"
);
assert!(
store.entries.contains_key("key-new"),
"New key should be present"
);
}
#[tokio::test]
async fn test_rate_limit_eviction_stress_under_attack() {
let store = RateLimitStore::with_max_entries(100);
let config = RateLimitConfig {
limit: 10,
window_secs: 60,
};
for i in 0..500 {
store
.check_and_record(&format!("attacker-ip-{}", i), &config)
.await;
}
let entries_len = store.entries.len();
assert!(
entries_len <= 100,
"Store should not exceed max_entries under attack. Got: {}",
entries_len
);
let result = store.check_and_record("legitimate-user", &config).await;
assert!(
result.allowed,
"Legitimate user should still be allowed after eviction"
);
assert_eq!(result.remaining, 9);
assert!(
store.entries.len() <= 100,
"Store should remain bounded after legitimate request"
);
}
#[tokio::test]
async fn test_rate_limit_early_warning_at_ninety_percent() {
let store = RateLimitStore::with_max_entries(10);
let config = RateLimitConfig {
limit: 100,
window_secs: 60,
};
for i in 0..9 {
store.check_and_record(&format!("key-{}", i), &config).await;
}
assert_eq!(store.entries.len(), 9);
store.check_and_record("key-9", &config).await;
assert_eq!(store.entries.len(), 10);
store.check_and_record("key-10", &config).await;
assert!(store.entries.len() < 10, "Should have evicted entries");
assert!(
store.entries.contains_key("key-10"),
"New key should be present"
);
}
#[tokio::test]
async fn test_rate_limit_eviction_evicts_twenty_percent() {
let store = RateLimitStore::with_max_entries(100);
let config = RateLimitConfig {
limit: 100,
window_secs: 60,
};
for i in 0..100 {
store.check_and_record(&format!("key-{}", i), &config).await;
}
assert_eq!(store.entries.len(), 100);
store.check_and_record("trigger-eviction", &config).await;
let entries_len = store.entries.len();
assert!(
(75..=85).contains(&entries_len),
"After eviction, should have ~80% of max entries. Got: {}",
entries_len
);
assert!(
store.entries.contains_key("trigger-eviction"),
"Newly added key should be present"
);
}
}