pub mod clock;
mod gcra;
mod nanos;
pub mod quota;
use std::{
fmt::Debug,
hash::Hash,
num::NonZeroU64,
sync::atomic::{AtomicU64, Ordering},
time::Duration,
};
use dashmap::DashMap;
use futures_util::StreamExt;
use self::{
clock::{Clock, FakeRelativeClock, MonotonicClock},
gcra::{Gcra, NotUntil},
nanos::Nanos,
quota::Quota,
};
#[derive(Debug, Default)]
pub struct InMemoryState(AtomicU64);
impl InMemoryState {
pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
where
F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
{
let mut prev = self.0.load(Ordering::Acquire);
let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
while let Ok((result, new_data)) = decision {
match self.0.compare_exchange_weak(
prev,
new_data.into(),
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return Ok(result),
Err(e) => prev = e, }
decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
}
decision.map(|(result, _)| result)
}
}
pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
pub trait StateStore {
type Key;
fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
where
F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
}
impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
type Key = K;
fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
where
F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
{
if let Some(v) = self.get(key) {
return v.measure_and_replace_one(f);
}
let entry = self.entry(key.clone()).or_default();
(*entry).measure_and_replace_one(f)
}
}
pub struct RateLimiter<K, C>
where
C: Clock,
{
default_gcra: Option<Gcra>,
state: DashMapStateStore<K>,
gcra: DashMap<K, Gcra>,
clock: C,
start: C::Instant,
}
impl<K, C> Debug for RateLimiter<K, C>
where
K: Debug,
C: Clock,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(RateLimiter)).finish()
}
}
impl<K> RateLimiter<K, MonotonicClock>
where
K: Eq + Hash,
{
#[must_use]
pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
let clock = MonotonicClock {};
let start = MonotonicClock::now(&clock);
let gcra: DashMap<_, _> = keyed_quotas
.into_iter()
.map(|(k, q)| (k, Gcra::new(q)))
.collect();
Self {
default_gcra: base_quota.map(Gcra::new),
state: DashMapStateStore::new(),
gcra,
clock,
start,
}
}
}
impl<K> RateLimiter<K, FakeRelativeClock>
where
K: Hash + Eq + Clone,
{
pub fn advance_clock(&self, by: Duration) {
self.clock.advance(by);
}
}
impl<K, C> RateLimiter<K, C>
where
K: Hash + Eq + Clone,
C: Clock,
{
pub fn add_quota_for_key(&self, key: K, value: Quota) {
self.gcra.insert(key, Gcra::new(value));
}
pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
match self.gcra.get(key) {
Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
gcra.test_and_update(self.start, key, &self.state, self.clock.now())
}),
}
}
pub async fn until_key_ready(&self, key: &K) {
loop {
match self.check_key(key) {
Ok(()) => {
break;
}
Err(e) => {
tokio::time::sleep(e.wait_time_from(self.clock.now())).await;
}
}
}
}
pub async fn await_keys_ready(&self, keys: Option<&[K]>) {
let Some(keys) = keys else {
return;
};
match keys.len() {
0 => {}
1 => self.until_key_ready(&keys[0]).await,
2 => {
tokio::join!(
self.until_key_ready(&keys[0]),
self.until_key_ready(&keys[1]),
);
}
_ => {
let tasks = keys.iter().map(|key| self.until_key_ready(key));
futures::stream::iter(tasks)
.for_each_concurrent(None, |key_future| async move {
key_future.await;
})
.await;
}
}
}
}
#[cfg(test)]
mod tests {
use std::{
num::NonZeroU32,
sync::atomic::{AtomicU32, Ordering},
time::Duration,
};
use dashmap::DashMap;
use rstest::rstest;
use super::{
DashMapStateStore, RateLimiter,
clock::{Clock, FakeRelativeClock},
gcra::{Gcra, StateSnapshot},
nanos::Nanos,
quota::Quota,
};
fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
let clock = FakeRelativeClock::default();
let start = clock.now();
let gcra = DashMap::new();
let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
RateLimiter {
default_gcra: Some(Gcra::new(base_quota)),
state: DashMapStateStore::new(),
gcra,
clock,
start,
}
}
#[rstest]
fn test_default_quota() {
let mock_limiter = initialize_mock_rate_limiter();
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_secs(1));
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
}
#[rstest]
fn test_custom_key_quota() {
let mock_limiter = initialize_mock_rate_limiter();
mock_limiter.add_quota_for_key(
"custom".to_string(),
Quota::per_second(NonZeroU32::new(1).unwrap()).unwrap(),
);
assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_err());
}
#[rstest]
fn test_multiple_keys() {
let mock_limiter = initialize_mock_rate_limiter();
mock_limiter.add_quota_for_key(
"key1".to_string(),
Quota::per_second(NonZeroU32::new(1).unwrap()).unwrap(),
);
mock_limiter.add_quota_for_key(
"key2".to_string(),
Quota::per_second(NonZeroU32::new(3).unwrap()).unwrap(),
);
assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
}
#[rstest]
fn test_quota_reset() {
let mock_limiter = initialize_mock_rate_limiter();
assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_millis(499));
assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_millis(501));
assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
}
#[rstest]
fn test_different_quotas() {
let mock_limiter = initialize_mock_rate_limiter();
mock_limiter.add_quota_for_key(
"per_second".to_string(),
Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap(),
);
mock_limiter.add_quota_for_key(
"per_minute".to_string(),
Quota::per_minute(NonZeroU32::new(3).unwrap()),
);
assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_secs(1));
assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
}
#[tokio::test]
async fn test_await_keys_ready() {
let mock_limiter = initialize_mock_rate_limiter();
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_secs(1));
let keys = ["default".to_string()];
mock_limiter.await_keys_ready(Some(keys.as_slice())).await;
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
}
#[rstest]
fn test_remaining_burst_capacity_zero_t() {
let snapshot = StateSnapshot::new(
Nanos::from(0u64),
Nanos::from(1_000_000u64),
Nanos::from(0u64),
Nanos::from(0u64),
);
assert_eq!(snapshot.remaining_burst_capacity(), 0);
}
#[rstest]
fn test_per_second_returns_none_on_zero_replenish_interval() {
assert!(Quota::per_second(NonZeroU32::new(u32::MAX).unwrap()).is_none());
}
#[rstest]
fn test_per_minute_accepts_max_burst() {
let quota = Quota::per_minute(NonZeroU32::new(u32::MAX).unwrap());
assert!(quota.replenish_interval().as_nanos() > 0);
}
#[rstest]
fn test_per_hour_accepts_max_burst() {
let quota = Quota::per_hour(NonZeroU32::new(u32::MAX).unwrap());
assert!(quota.replenish_interval().as_nanos() > 0);
}
mod property_tests {
use proptest::prelude::*;
use rstest::rstest;
use crate::ratelimiter::{gcra::StateSnapshot, nanos::Nanos};
const MAX_NANOS: u64 = 3_600_000_000_000;
proptest! {
#![proptest_config(ProptestConfig {
failure_persistence: Some(Box::new(
proptest::test_runner::FileFailurePersistence::WithSource("ratelimiter")
)),
..ProptestConfig::default()
})]
#[rstest]
fn remaining_burst_capacity_never_panics(
t in 0u64..=MAX_NANOS,
tau in 0u64..=MAX_NANOS,
time_of_measurement in 0u64..=MAX_NANOS,
tat in 0u64..=MAX_NANOS,
) {
let snapshot = StateSnapshot::new(
Nanos::from(t),
Nanos::from(tau),
Nanos::from(time_of_measurement),
Nanos::from(tat),
);
let _ = snapshot.remaining_burst_capacity();
}
}
}
#[rstest]
fn test_gcra_boundary_exact_replenishment() {
let mock_limiter = initialize_mock_rate_limiter();
let key = "boundary_test".to_string();
assert!(mock_limiter.check_key(&key).is_ok());
assert!(mock_limiter.check_key(&key).is_ok());
assert!(mock_limiter.check_key(&key).is_err());
let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
let replenish_interval = quota.replenish_interval();
mock_limiter.advance_clock(replenish_interval);
assert!(
mock_limiter.check_key(&key).is_ok(),
"Request at exact replenish boundary should be allowed"
);
assert!(
mock_limiter.check_key(&key).is_err(),
"Immediate follow-up should be rate-limited"
);
}
#[rstest]
fn test_per_second_boundary_exact_limit() {
let quota = Quota::per_second(NonZeroU32::new(1_000_000_000).unwrap()).unwrap();
assert_eq!(quota.replenish_interval().as_nanos(), 1);
}
#[rstest]
fn test_per_second_returns_none_above_one_billion() {
assert!(Quota::per_second(NonZeroU32::new(1_000_000_001).unwrap()).is_none());
}
#[rstest]
fn test_burst_size_replenished_in_truncation() {
let quota = Quota::with_period(Duration::from_secs(100))
.unwrap()
.allow_burst(NonZeroU32::new(u32::MAX).unwrap());
let replenished_in = quota.burst_size_replenished_in();
let full: u128 = 100_000_000_000u128 * u32::MAX as u128;
let truncated = full as u64;
assert_eq!(replenished_in, Duration::from_nanos(truncated));
assert_ne!(full, truncated as u128, "Truncation should have occurred");
}
#[rstest]
#[should_panic(expected = "t cannot be zero")]
fn test_from_gcra_parameters_panics_on_zero_t() {
let _ = Quota::from_gcra_parameters(Nanos::from(0u64), Nanos::from(100u64));
}
#[rstest]
#[should_panic(expected = "tau/t results in zero burst capacity")]
fn test_from_gcra_parameters_panics_on_zero_division() {
let _ = Quota::from_gcra_parameters(Nanos::from(2u64), Nanos::from(1u64));
}
#[rstest]
#[should_panic(expected = "tau/t exceeds u32::MAX")]
fn test_from_gcra_parameters_panics_on_overflow() {
let _ = Quota::from_gcra_parameters(Nanos::from(1u64), Nanos::from(u64::MAX));
}
#[rstest]
fn test_concurrent_check_key_respects_burst() {
let rate = 10u32;
let clock = FakeRelativeClock::default();
let start = clock.now();
let limiter = RateLimiter {
default_gcra: Some(Gcra::new(
Quota::per_second(NonZeroU32::new(rate).unwrap()).unwrap(),
)),
state: DashMapStateStore::new(),
gcra: DashMap::new(),
clock,
start,
};
let accepted = AtomicU32::new(0);
let num_threads = 50;
std::thread::scope(|s| {
for _ in 0..num_threads {
s.spawn(|| {
if limiter.check_key(&"hot_key".to_string()).is_ok() {
accepted.fetch_add(1, Ordering::Relaxed);
}
});
}
});
let total = accepted.load(Ordering::Relaxed);
assert!(total >= 1, "At least one request should be accepted");
assert!(
total <= rate,
"Accepted {total} but burst capacity is {rate}"
);
}
}