use super::{
MICROS_PER_SEC, RateLimiter, SystemTimeSource, TimeSource,
store::{GcraParams, GcraStore},
};
use dashmap::DashMap;
use std::{
sync::{
Arc,
atomic::{AtomicU64, Ordering::*},
},
time::Duration,
};
const DEFAULT_EVICTION: u64 = 60 * MICROS_PER_SEC;
#[derive(Debug)]
struct Entry {
tat_us: AtomicU64,
last_seen_us: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct InMemoryGcraStore {
storage: Arc<DashMap<u64, Entry>>,
}
impl InMemoryGcraStore {
pub fn new() -> Self {
Self {
storage: Arc::new(DashMap::new()),
}
}
}
impl Default for InMemoryGcraStore {
fn default() -> Self {
Self::new()
}
}
impl GcraStore for InMemoryGcraStore {
#[inline]
fn check_and_advance(&self, params: GcraParams) -> bool {
let GcraParams {
key,
now_us,
emission_interval_us,
burst_allowance_us,
eviction_grace_us,
} = params;
if let Some(entry) = self.storage.get(&key) {
let last_seen = entry.last_seen_us.load(Acquire);
if now_us.saturating_sub(last_seen) > eviction_grace_us {
drop(entry);
self.storage.remove(&key);
}
}
let entry = self.storage.entry(key).or_insert_with(|| Entry {
tat_us: AtomicU64::new(now_us),
last_seen_us: AtomicU64::new(now_us),
});
entry.last_seen_us.store(now_us, Release);
let mut current_tat = entry.tat_us.load(Relaxed);
loop {
let limit = current_tat.saturating_sub(burst_allowance_us);
if now_us < limit {
return false;
}
let base = now_us.max(current_tat);
let next_tat = base.saturating_add(emission_interval_us);
match entry
.tat_us
.compare_exchange(current_tat, next_tat, AcqRel, Relaxed)
{
Ok(_) => return true,
Err(next) => current_tat = next,
}
}
}
}
#[derive(Debug)]
pub struct GcraRateLimiter<T: TimeSource = SystemTimeSource, S: GcraStore = InMemoryGcraStore> {
store: S,
emission_interval_us: u64,
burst_allowance_us: u64,
burst: u32,
eviction_grace_us: u64,
time_source: T,
}
impl<T: TimeSource, S: GcraStore> RateLimiter for GcraRateLimiter<T, S> {
#[inline]
fn check(&self, key: u64) -> bool {
self.store.check_and_advance(GcraParams {
key,
now_us: self.time_source.now_micros(),
emission_interval_us: self.emission_interval_us,
burst_allowance_us: self.burst_allowance_us,
eviction_grace_us: self.eviction_grace_us,
})
}
}
impl GcraRateLimiter {
#[inline]
pub fn new(rate_per_second: f64, burst: u32) -> Self {
Self::with_time_source(rate_per_second, burst, SystemTimeSource)
}
}
impl<T: TimeSource> GcraRateLimiter<T> {
#[inline]
pub fn with_time_source(rate_per_second: f64, burst: u32, time_source: T) -> Self {
Self::with_time_source_and_store(
rate_per_second,
burst,
time_source,
InMemoryGcraStore::new(),
)
}
}
impl<S: GcraStore> GcraRateLimiter<SystemTimeSource, S> {
#[inline]
pub fn with_store(rate_per_second: f64, burst: u32, store: S) -> Self {
Self::with_time_source_and_store(rate_per_second, burst, SystemTimeSource, store)
}
}
impl<T: TimeSource, S: GcraStore> GcraRateLimiter<T, S> {
#[inline]
pub fn with_time_source_and_store(
rate_per_second: f64,
burst: u32,
time_source: T,
store: S,
) -> Self {
assert!(
rate_per_second.is_finite(),
"rate_per_second must be finite"
);
assert!(rate_per_second > 0.0, "rate_per_second must be > 0");
assert!(burst >= 1, "burst must be >= 1");
let tau_f = MICROS_PER_SEC as f64 / rate_per_second;
let emission_interval_us = tau_f.ceil() as u64;
let burst_allowance_us = emission_interval_us.saturating_mul((burst - 1) as u64);
Self {
store,
emission_interval_us,
burst_allowance_us,
burst,
eviction_grace_us: DEFAULT_EVICTION,
time_source,
}
}
#[inline]
pub fn set_eviction(&mut self, eviction: Duration) {
self.eviction_grace_us = eviction.as_micros().try_into().unwrap_or(u64::MAX);
}
#[inline(always)]
pub fn rate_per_second(&self) -> f64 {
(MICROS_PER_SEC / self.emission_interval_us) as f64
}
#[inline(always)]
pub fn burst(&self) -> u32 {
self.burst
}
#[inline(always)]
pub fn eviction_grace_secs(&self) -> u64 {
self.eviction_grace_us / MICROS_PER_SEC
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils::MockTimeSource;
use super::*;
#[test]
fn gcra_allows_burst_then_limits() {
let time = MockTimeSource::new(0);
let limiter = GcraRateLimiter::with_time_source(1.0, 3, time.clone());
let key = 10;
assert!(limiter.check(key));
assert!(limiter.check(key));
assert!(limiter.check(key));
assert!(!limiter.check(key));
}
#[test]
fn gcra_refills_over_time() {
let time = MockTimeSource::new(100);
let limiter = GcraRateLimiter::with_time_source(1.0, 1, time.clone());
let key = 5;
assert!(limiter.check(key));
assert!(!limiter.check(key));
time.advance(1);
assert!(limiter.check(key));
}
#[test]
fn gcra_isolated_per_key() {
let limiter = GcraRateLimiter::new(1.0, 1);
assert!(limiter.check(1));
assert!(!limiter.check(1));
assert!(limiter.check(2));
}
#[test]
fn gcra_with_custom_store_delegates_to_store() {
use crate::rate_limiter::store::{GcraParams, GcraStore};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
struct CountingStore {
inner: InMemoryGcraStore,
calls: Arc<AtomicU32>,
}
impl GcraStore for CountingStore {
fn check_and_advance(&self, params: GcraParams) -> bool {
self.calls.fetch_add(1, Relaxed);
self.inner.check_and_advance(params)
}
}
let calls = Arc::new(AtomicU32::new(0));
let store = CountingStore {
inner: InMemoryGcraStore::new(),
calls: calls.clone(),
};
let limiter = GcraRateLimiter::with_store(1.0, 3, store);
assert!(limiter.check(10));
assert_eq!(calls.load(Relaxed), 1);
}
#[test]
#[should_panic(expected = "rate_per_second must be finite")]
fn panics_when_rate_is_nan() {
let _ = GcraRateLimiter::with_time_source(f64::NAN, 1, SystemTimeSource);
}
#[test]
#[should_panic(expected = "rate_per_second must be finite")]
fn panics_when_rate_is_infinite() {
let _ = GcraRateLimiter::with_time_source(f64::INFINITY, 1, SystemTimeSource);
}
#[test]
#[should_panic(expected = "rate_per_second must be > 0")]
fn panics_when_rate_is_zero() {
let _ = GcraRateLimiter::with_time_source(0.0, 1, SystemTimeSource);
}
#[test]
#[should_panic(expected = "rate_per_second must be > 0")]
fn panics_when_rate_is_negative() {
let _ = GcraRateLimiter::with_time_source(-1.0, 1, SystemTimeSource);
}
#[test]
#[should_panic(expected = "burst must be >= 1")]
fn panics_when_burst_is_zero() {
let _ = GcraRateLimiter::with_time_source(1.0, 0, SystemTimeSource);
}
#[test]
#[should_panic(expected = "rate_per_second must be > 0")]
fn gcra_with_store_panics_on_zero_rate() {
let _ = GcraRateLimiter::with_store(0.0, 1, InMemoryGcraStore::new());
}
#[test]
#[should_panic(expected = "burst must be >= 1")]
fn gcra_with_store_panics_on_zero_burst() {
let _ = GcraRateLimiter::with_store(1.0, 0, InMemoryGcraStore::new());
}
}