use dashmap::DashMap;
use rustc_hash::FxHasher;
use std::{
fmt::Display,
hash::{BuildHasher, BuildHasherDefault, Hash},
time::Instant,
};
use crate::{
clock::{Clock, InstantClock},
rate_limiter::entry::RateLimitEntry,
GcraError, RateLimit,
};
pub type FxBuildHasher = BuildHasherDefault<FxHasher>;
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct RateLimitRequest<T: Eq + Hash> {
key: T,
}
impl<T> Display for RateLimitRequest<T>
where
T: Display + Eq + Hash,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RateLimitRequest={}", self.key)
}
}
#[derive(Clone)]
pub struct RateLimiter<T: Eq + Hash, C = InstantClock, S = FxBuildHasher> {
clock: C,
map: DashMap<RateLimitRequest<T>, RateLimitEntry, S>,
}
impl<Key> RateLimiter<Key, InstantClock, FxBuildHasher>
where
Key: Send + Clone + Hash + Eq + Display + 'static,
{
pub fn new(max_data_capacity: usize) -> Self {
Self {
clock: InstantClock,
map: DashMap::with_capacity_and_hasher(max_data_capacity, FxBuildHasher::default()),
}
}
pub fn with_shards(max_data_capacity: usize, num_shards: usize) -> Self {
Self {
clock: InstantClock,
map: DashMap::with_capacity_and_hasher_and_shard_amount(
max_data_capacity,
FxBuildHasher::default(),
num_shards,
),
}
}
}
impl<Key, C, S> RateLimiter<Key, C, S>
where
Key: Send + Clone + Hash + Eq + Display + 'static,
C: Clock,
S: Default + BuildHasher + Clone,
{
pub fn with_clock(clock: C) -> Self {
Self {
clock,
map: DashMap::default(),
}
}
#[inline]
pub async fn check(
&self,
key: Key,
rate_limit: &RateLimit,
cost: u32,
) -> Result<Instant, GcraError> {
self.check_at(key, rate_limit, cost, self.clock.now()).await
}
pub async fn check_at(
&self,
key: Key,
rate_limit: &RateLimit,
cost: u32,
arrived_at: Instant,
) -> Result<Instant, GcraError> {
let request_key = RateLimitRequest { key };
let mut entry = self.map.entry(request_key.clone()).or_default();
match entry.check_and_modify_at(rate_limit, arrived_at, cost) {
Ok(_) => {
entry.update_expiration(rate_limit);
let expires_at = entry.expires_at.unwrap();
Ok(expires_at)
}
Err(e @ GcraError::DeniedUntil { .. }) => Err(e),
Err(e @ GcraError::DeniedIndefinitely { .. }) => {
drop(entry);
self.map.remove(&request_key);
Err(e)
}
}
}
pub fn prune_expired(&self) {
let now = self.clock.now();
self.map.retain(|_key, entry| match entry.expires_at {
Some(expires_at) => expires_at > now,
None => true,
})
}
}
#[cfg(test)]
mod tests {
use futures::stream::{self, StreamExt};
use crate::clock::tests::FakeClock;
use core::panic;
use std::{
sync::Arc,
time::{Duration, Instant},
};
use super::*;
#[tokio::test]
async fn rate_limiter_run_until_denied() {
let rate_limit = RateLimit::new(3, Duration::from_secs(3));
let rl = RateLimiter::with_shards(4, 2);
for _ in 0..rate_limit.resource_limit {
assert!(
rl.check("key", &rate_limit, 1).await.is_ok(),
"Shouldn't be rate limited yet"
);
}
match rl.check("key", &rate_limit, 1).await {
Ok(_) => panic!("We should be rate limited"),
Err(GcraError::DeniedUntil { next_allowed_at }) => {
assert!(next_allowed_at > Instant::now())
}
Err(_) => panic!("Unexpected error"),
}
}
#[tokio::test]
async fn rate_limiter_run_until_denied_concurrent_access() {
let rate_limit = RateLimit::new(3, Duration::from_secs(3));
let rate_limiter = Arc::new(RateLimiter::with_shards(4, 2));
let all_checked = stream::iter(0..rate_limit.resource_limit)
.then(|_| async {
let rate_limiter = rate_limiter.clone();
rate_limiter.check("key", &rate_limit, 1).await
})
.all(|result| async move { result.is_ok() })
.await;
assert!(
all_checked,
"All checks should have passed and not rate limited"
);
match rate_limiter.check("key", &rate_limit, 1).await {
Ok(_) => panic!("We should be rate limited"),
Err(GcraError::DeniedUntil { next_allowed_at }) => {
assert!(next_allowed_at > Instant::now())
}
Err(_) => panic!("Unexpected error"),
}
}
#[tokio::test]
async fn rate_limiter_indefinitly_denied() {
let rate_limit = RateLimit::new(3, Duration::from_secs(3));
let rl = RateLimiter::with_shards(4, 2);
match rl.check("key", &rate_limit, 9).await {
Ok(_) => panic!("We should be rate limited"),
Err(GcraError::DeniedIndefinitely {
cost,
rate_limit: err_rate_limit,
}) => {
assert_eq!(cost, 9);
assert_eq!(err_rate_limit, rate_limit);
}
Err(_) => panic!("Unexpected error"),
}
}
#[tokio::test]
async fn rate_limiter_leaks() {
let rate_limit = RateLimit::per_sec(2);
let rl = RateLimiter::with_shards(4, 2);
let now = Instant::now();
assert!(rl.check_at("key", &rate_limit, 1, now).await.is_ok());
assert!(
rl.check_at("key", &rate_limit, 1, now + Duration::from_millis(250))
.await
.is_ok(),
"delay the 2nd check"
);
assert!(
rl.check_at("key", &rate_limit, 1, now + Duration::from_millis(251))
.await
.is_err(),
"check we are denied start"
);
assert!(
rl.check_at("key", &rate_limit, 1, now + Duration::from_millis(499))
.await
.is_err(),
"check we are denied end"
);
assert!(
rl.check_at("key", &rate_limit, 1, now + Duration::from_millis(501))
.await
.is_ok(),
"1st use should be released"
)
}
#[tokio::test]
async fn rate_limiter_prune_expired() {
let clock = FakeClock::new();
let rate_limit = RateLimit::per_sec(3);
let rl: RateLimiter<_, _, FxBuildHasher> = RateLimiter::with_clock(clock.clone());
for index in 0..rate_limit.resource_limit {
assert!(
rl.check(index, &rate_limit, 1).await.is_ok(),
"Shouldn't be rate limited yet"
);
}
let before_len = rl.map.len();
rl.prune_expired();
let after_len = rl.map.len();
assert_eq!(
before_len, after_len,
"Nothing has expired, no elements should be removed"
);
clock.advance_by(Duration::from_secs(10));
rl.prune_expired();
let after_len = rl.map.len();
assert_eq!(
0, after_len,
"All entries have expired, no elements expected"
);
}
}