use super::{
RateLimiter, SystemTimeSource, TimeSource,
store::{FixedWindowParams, FixedWindowStore},
};
use dashmap::DashMap;
use std::sync::{
Arc,
atomic::{AtomicU32, AtomicU64, Ordering::Relaxed},
};
use std::time::Duration;
#[derive(Debug)]
struct Entry {
count: AtomicU32,
window_start: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct InMemoryFixedWindowStore {
storage: Arc<DashMap<u64, Entry>>,
}
impl InMemoryFixedWindowStore {
pub fn new() -> Self {
Self {
storage: Arc::new(DashMap::new()),
}
}
}
impl Default for InMemoryFixedWindowStore {
fn default() -> Self {
Self::new()
}
}
impl FixedWindowStore for InMemoryFixedWindowStore {
#[inline]
fn check_and_count(&self, params: FixedWindowParams) -> bool {
let FixedWindowParams {
key,
window,
max_requests,
now,
grace_secs,
} = params;
if let Some(entry) = self.storage.get(&key) {
let prev_window = entry.window_start.load(Relaxed);
if now.saturating_sub(prev_window) > grace_secs {
drop(entry);
self.storage.remove(&key);
}
}
let entry = self.storage.entry(key).or_insert_with(|| Entry {
window_start: AtomicU64::new(window),
count: AtomicU32::new(0),
});
let prev_window = entry.window_start.load(Relaxed);
if prev_window != window {
entry.window_start.store(window, Relaxed);
entry.count.store(0, Relaxed);
}
let prev = entry.count.fetch_add(1, Relaxed);
prev < max_requests
}
}
#[derive(Debug)]
pub struct FixedWindowRateLimiter<
T: TimeSource = SystemTimeSource,
S: FixedWindowStore = InMemoryFixedWindowStore,
> {
store: S,
max_requests: u32,
window_size_secs: u64,
eviction_grace_secs: u64,
time_source: T,
}
impl<T: TimeSource, S: FixedWindowStore> RateLimiter for FixedWindowRateLimiter<T, S> {
#[inline]
fn check(&self, key: u64) -> bool {
let now = self.time_source.now_secs();
let window = self.current_window(now);
self.store.check_and_count(FixedWindowParams {
key,
window,
max_requests: self.max_requests,
now,
grace_secs: self.eviction_grace_secs,
})
}
}
impl FixedWindowRateLimiter {
#[inline]
pub fn new(max_requests: u32, window_size: Duration) -> Self {
Self::with_time_source(max_requests, window_size, SystemTimeSource)
}
}
impl<T: TimeSource> FixedWindowRateLimiter<T> {
#[inline]
pub fn with_time_source(max_requests: u32, window_size: Duration, time_source: T) -> Self {
Self::with_time_source_and_store(
max_requests,
window_size,
time_source,
InMemoryFixedWindowStore::new(),
)
}
}
impl<S: FixedWindowStore> FixedWindowRateLimiter<SystemTimeSource, S> {
#[inline]
pub fn with_store(max_requests: u32, window_size: Duration, store: S) -> Self {
Self::with_time_source_and_store(max_requests, window_size, SystemTimeSource, store)
}
}
impl<T: TimeSource, S: FixedWindowStore> FixedWindowRateLimiter<T, S> {
#[inline]
pub fn with_time_source_and_store(
max_requests: u32,
window_size: Duration,
time_source: T,
store: S,
) -> Self {
let window_size_secs = window_size.as_secs();
assert!(
window_size_secs > 0,
"window_size must be at least 1 second"
);
Self {
store,
max_requests,
window_size_secs,
eviction_grace_secs: window_size_secs.saturating_mul(2),
time_source,
}
}
#[inline]
pub fn set_eviction(&mut self, eviction: Duration) {
self.eviction_grace_secs = eviction.as_secs();
}
#[inline(always)]
pub fn max_requests(&self) -> u32 {
self.max_requests
}
#[inline(always)]
pub fn window_size_secs(&self) -> u64 {
self.window_size_secs
}
#[inline(always)]
pub fn eviction_grace_secs(&self) -> u64 {
self.eviction_grace_secs
}
#[inline]
fn current_window(&self, now: u64) -> u64 {
(now / self.window_size_secs) * self.window_size_secs
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils::MockTimeSource;
use super::*;
#[test]
fn fixed_window_allows_within_limit() {
let limiter = FixedWindowRateLimiter::new(3, Duration::from_secs(10));
let key = 42;
assert!(limiter.check(key));
assert!(limiter.check(key));
assert!(limiter.check(key));
assert!(!limiter.check(key)); }
#[test]
fn fixed_window_resets_after_window() {
let time = MockTimeSource::new(1000);
let limiter =
FixedWindowRateLimiter::with_time_source(2, Duration::from_secs(1), time.clone());
let key = 1;
assert!(limiter.check(key));
assert!(limiter.check(key));
assert!(!limiter.check(key));
time.advance(1);
assert!(limiter.check(key)); }
#[test]
fn fixed_window_isolated_per_key() {
let limiter = FixedWindowRateLimiter::new(1, Duration::from_secs(10));
assert!(limiter.check(1));
assert!(!limiter.check(1));
assert!(limiter.check(2)); }
#[test]
fn fixed_window_with_custom_store_allows_within_limit() {
use crate::rate_limiter::store::{FixedWindowParams, FixedWindowStore};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
struct CountingStore {
inner: InMemoryFixedWindowStore,
calls: Arc<AtomicU32>,
}
impl FixedWindowStore for CountingStore {
fn check_and_count(&self, params: FixedWindowParams) -> bool {
self.calls.fetch_add(1, Relaxed);
self.inner.check_and_count(params)
}
}
let calls = Arc::new(AtomicU32::new(0));
let store = CountingStore {
inner: InMemoryFixedWindowStore::new(),
calls: calls.clone(),
};
let limiter = FixedWindowRateLimiter::with_store(3, Duration::from_secs(10), store);
assert!(limiter.check(1));
assert!(limiter.check(1));
assert!(limiter.check(1));
assert!(!limiter.check(1));
assert_eq!(calls.load(Relaxed), 4);
}
#[test]
#[should_panic(expected = "window_size must be at least 1 second")]
fn fixed_window_panics_on_zero_window_size() {
let _ = FixedWindowRateLimiter::new(10, Duration::ZERO);
}
#[test]
fn fixed_window_is_thread_safe() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(FixedWindowRateLimiter::new(1000, Duration::from_secs(10)));
let key = 123;
let mut handles = vec![];
for _ in 0..8 {
let limiter = limiter.clone();
handles.push(thread::spawn(move || {
let mut allowed = 0;
for _ in 0..200 {
if limiter.check(key) {
allowed += 1;
}
}
allowed
}));
}
let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total <= 1000 + 8);
}
}