pub mod clock;
mod gcra;
mod nanos;
pub mod quota;
use std::{
fmt::Debug,
hash::Hash,
num::NonZeroU64,
sync::atomic::{AtomicU64, Ordering},
time::Duration,
};
use dashmap::DashMap;
use futures_util::StreamExt;
use self::{
clock::{Clock, FakeRelativeClock, MonotonicClock},
gcra::{Gcra, NotUntil},
nanos::Nanos,
quota::Quota,
};
pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
pub trait StateStore {
type Key;
fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
where
F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
}
#[derive(Debug, Default)]
pub struct InMemoryState(AtomicU64);
impl InMemoryState {
pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
where
F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>, {
let mut prev = self.0.load(Ordering::Acquire);
let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
while let Ok((result, new_data)) = decision {
match self.0.compare_exchange_weak(prev, new_data.into(), Ordering::Release, Ordering::Relaxed) {
Ok(_) => return Ok(result),
Err(e) => prev = e, }
decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
}
decision.map(|(result, _)| result)
}
}
impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
type Key = K;
fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
where
F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>, {
if let Some(v) = self.get(key) {
return v.measure_and_replace_one(f);
}
let entry = self.entry(key.clone()).or_default();
(*entry).measure_and_replace_one(f)
}
}
pub struct RateLimiter<K, C>
where
C: Clock, {
default_gcra: Option<Gcra>,
state: DashMapStateStore<K>,
gcra: DashMap<K, Gcra>,
clock: C,
start: C::Instant,
}
impl<K> RateLimiter<K, MonotonicClock>
where
K: Eq + Hash,
{
#[must_use]
pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
let clock = MonotonicClock {};
let start = MonotonicClock::now(&clock);
let gcra: DashMap<_, _> = keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))).collect();
Self {
default_gcra: base_quota.map(Gcra::new),
state: DashMapStateStore::default(),
gcra,
clock,
start,
}
}
}
impl<K> RateLimiter<K, FakeRelativeClock>
where
K: Hash + Eq + Clone,
{
pub fn advance_clock(&self, by: Duration) {
self.clock.advance(by);
}
}
impl<K, C> RateLimiter<K, C>
where
K: Hash + Eq + Clone,
C: Clock,
{
pub fn add_quota_for_key(&self, key: K, value: Quota) {
self.gcra.insert(key, Gcra::new(value));
}
pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
match self.gcra.get(key) {
Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
None => self
.default_gcra
.as_ref()
.map_or(Ok(()), |gcra| gcra.test_and_update(self.start, key, &self.state, self.clock.now())),
}
}
pub async fn until_key_ready(&self, key: &K) {
loop {
match self.check_key(key) {
Ok(()) => {
break;
}
Err(e) => {
tokio::time::sleep(e.wait_time_from(self.clock.now())).await;
}
}
}
}
pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
let keys = keys.unwrap_or_default();
let tasks = keys.iter().map(|key| self.until_key_ready(key));
futures_util::stream::iter(tasks)
.for_each_concurrent(None, |key_future| async move {
key_future.await;
})
.await;
}
pub fn check_key_n(&self, key: &K, n: u32) -> Result<(), NotUntil<C::Instant>> {
match self.gcra.get(key) {
Some(quota) => quota.test_and_update_n(self.start, key, &self.state, self.clock.now(), n),
None => self
.default_gcra
.as_ref()
.map_or(Ok(()), |gcra| gcra.test_and_update_n(self.start, key, &self.state, self.clock.now(), n)),
}
}
pub async fn until_key_ready_n(&self, key: &K, n: u32) {
loop {
match self.check_key_n(key, n) {
Ok(()) => break,
Err(e) => tokio::time::sleep(e.wait_time_from(self.clock.now())).await,
}
}
}
pub async fn await_keys_ready_n(&self, keys: Vec<(K, u32)>) {
futures_util::stream::iter(keys.iter().map(|(key, n)| self.until_key_ready_n(key, *n)))
.for_each_concurrent(None, |f| f)
.await;
}
}
impl<K, C> Debug for RateLimiter<K, C>
where
K: Debug,
C: Clock,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(RateLimiter)).finish()
}
}
#[cfg(test)]
mod tests {
use std::{num::NonZeroU32, time::Duration};
use dashmap::DashMap;
use rstest::rstest;
use super::{
DashMapStateStore, RateLimiter,
clock::{Clock, FakeRelativeClock},
gcra::Gcra,
quota::Quota,
};
fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
let clock = FakeRelativeClock::default();
let start = clock.now();
let gcra = DashMap::default();
let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
RateLimiter {
default_gcra: Some(Gcra::new(base_quota)),
state: DashMapStateStore::default(),
gcra,
clock,
start,
}
}
#[rstest]
fn test_default_quota() {
let mock_limiter = initialize_mock_rate_limiter();
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_secs(1));
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
}
#[rstest]
fn test_custom_key_quota() {
let mock_limiter = initialize_mock_rate_limiter();
mock_limiter.add_quota_for_key("custom".to_string(), Quota::per_second(NonZeroU32::new(1).unwrap()));
assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_err());
}
#[rstest]
fn test_multiple_keys() {
let mock_limiter = initialize_mock_rate_limiter();
mock_limiter.add_quota_for_key("key1".to_string(), Quota::per_second(NonZeroU32::new(1).unwrap()));
mock_limiter.add_quota_for_key("key2".to_string(), Quota::per_second(NonZeroU32::new(3).unwrap()));
assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
}
#[rstest]
fn test_quota_reset() {
let mock_limiter = initialize_mock_rate_limiter();
assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_millis(499));
assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_millis(501));
assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
}
#[rstest]
fn test_different_quotas() {
let mock_limiter = initialize_mock_rate_limiter();
mock_limiter.add_quota_for_key("per_second".to_string(), Quota::per_second(NonZeroU32::new(2).unwrap()));
mock_limiter.add_quota_for_key("per_minute".to_string(), Quota::per_minute(NonZeroU32::new(3).unwrap()));
assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_secs(1));
assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
}
#[tokio::test]
async fn test_await_keys_ready() {
let mock_limiter = initialize_mock_rate_limiter();
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
assert!(mock_limiter.check_key(&"default".to_string()).is_err());
mock_limiter.advance_clock(Duration::from_secs(1));
mock_limiter.await_keys_ready(Some(vec!["default".to_string()])).await;
assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
}
#[rstest]
fn test_gcra_boundary_exact_replenishment() {
let mock_limiter = initialize_mock_rate_limiter();
let key = "boundary_test".to_string();
assert!(mock_limiter.check_key(&key).is_ok());
assert!(mock_limiter.check_key(&key).is_ok());
assert!(mock_limiter.check_key(&key).is_err());
let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
let replenish_interval = quota.replenish_interval();
mock_limiter.advance_clock(replenish_interval);
assert!(mock_limiter.check_key(&key).is_ok(), "Request at exact replenish boundary should be allowed");
assert!(mock_limiter.check_key(&key).is_err(), "Immediate follow-up should be rate-limited");
}
}