use std::{
collections::VecDeque,
sync::atomic::{AtomicU64, Ordering},
time::{Duration, Instant},
};
use ahash::RandomState;
use dashmap::DashMap;
use crate::{
LocalRateLimiterOptions,
common::{InstantRate, RateGroupSizeMs, RateLimit, RateLimitDecision, WindowSizeSeconds},
};
#[derive(Debug)]
pub(crate) struct RateLimitSeries {
pub limit: RateLimit,
pub series: VecDeque<InstantRate>,
pub total: AtomicU64,
}
impl RateLimitSeries {
pub fn new(limit: RateLimit) -> Self {
Self {
limit,
series: VecDeque::new(),
total: AtomicU64::new(0),
}
}
}
#[derive(Debug)]
pub struct AbsoluteLocalRateLimiter {
window_size_seconds: WindowSizeSeconds,
window_size_ms: u128,
window_duration: Duration,
rate_group_size_ms: RateGroupSizeMs,
series: DashMap<String, RateLimitSeries, RandomState>,
}
impl AbsoluteLocalRateLimiter {
pub(crate) fn new(options: LocalRateLimiterOptions) -> Self {
Self {
window_size_ms: (*options.window_size_seconds as u128).saturating_mul(1000),
window_duration: Duration::from_secs(*options.window_size_seconds),
window_size_seconds: options.window_size_seconds,
rate_group_size_ms: options.rate_group_size_ms,
series: DashMap::default(),
}
}
#[cfg(test)]
pub(crate) fn series(&self) -> &DashMap<String, RateLimitSeries, RandomState> {
&self.series
}
pub fn inc(&self, key: &str, rate_limit: &RateLimit, count: u64) -> RateLimitDecision {
let is_allowed = self.is_allowed(key);
if !matches!(is_allowed, RateLimitDecision::Allowed) {
return is_allowed;
}
let rate_limit_series = match self.series.get(key) {
Some(rate_limit_series) => rate_limit_series,
None => {
self.series
.entry(key.to_string())
.or_insert_with(|| RateLimitSeries::new(*rate_limit));
let Some(rate_limit_series) = self.series.get(key) else {
unreachable!("AbsoluteLocalRateLimiter::inc: key should be in map");
};
rate_limit_series
}
};
if let Some(last_entry) = rate_limit_series.series.back()
&& last_entry.timestamp.elapsed().as_millis() <= *self.rate_group_size_ms as u128
{
last_entry.count.fetch_add(count, Ordering::Relaxed);
rate_limit_series.total.fetch_add(count, Ordering::Relaxed);
} else {
drop(rate_limit_series);
let Some(mut rate_limit_series) = self.series.get_mut(key) else {
unreachable!("AbsoluteLocalRateLimiter::inc: key should be in map");
};
rate_limit_series.series.push_back(InstantRate {
count: count.into(),
timestamp: Instant::now(),
declined: AtomicU64::new(0),
});
rate_limit_series.total.fetch_add(count, Ordering::Relaxed);
}
RateLimitDecision::Allowed
}
pub fn is_allowed(&self, key: &str) -> RateLimitDecision {
let Some(rate_limit) = self.series.get(key) else {
return RateLimitDecision::Allowed;
};
let mut total_count = rate_limit.total.load(Ordering::Relaxed);
let window_limit = (*self.window_size_seconds as f64 * *rate_limit.limit) as u64;
if total_count < window_limit {
return RateLimitDecision::Allowed;
}
let rate_limit = match rate_limit.series.front() {
None => rate_limit,
Some(instant_rate)
if instant_rate.timestamp.elapsed().as_millis() <= self.window_size_ms =>
{
rate_limit
}
Some(_) => {
drop(rate_limit);
let Some(mut rate_limit) = self.series.get_mut(key) else {
unreachable!("AbsoluteLocalRateLimiter::is_allowed: key should be in map");
};
let now = Instant::now();
let split = rate_limit
.series
.partition_point(|r| now.duration_since(r.timestamp) > self.window_duration);
let total = rate_limit
.series
.drain(..split)
.map(|r| r.count.load(Ordering::Relaxed))
.sum::<u64>();
rate_limit.total.fetch_sub(total, Ordering::Relaxed);
total_count -= total;
drop(rate_limit);
let Some(rate_limit) = self.series.get(key) else {
unreachable!("AbsoluteLocalRateLimiter::is_allowed: key should be in map");
};
rate_limit
}
};
if total_count < window_limit {
return RateLimitDecision::Allowed;
}
let (retry_after_ms, remaining_after_waiting) = match rate_limit.series.front() {
None => (0, 0),
Some(instant_rate) => {
let elapsed_ms = instant_rate.timestamp.elapsed().as_millis();
let retry_after_ms = self.window_size_ms.saturating_sub(elapsed_ms);
let current_total = rate_limit.total.load(Ordering::Relaxed);
let oldest_count = instant_rate.count.load(Ordering::Relaxed);
let remaining_after_waiting = current_total.saturating_sub(oldest_count);
(retry_after_ms, remaining_after_waiting)
}
};
RateLimitDecision::Rejected {
window_size_seconds: *self.window_size_seconds,
retry_after_ms,
remaining_after_waiting,
}
}
pub(crate) fn cleanup(&self, stale_after_ms: u64) {
self.series.retain(
|_, rate_limit_series| match rate_limit_series.series.back() {
None => false,
Some(instant_rate)
if instant_rate.timestamp.elapsed().as_millis() > stale_after_ms as u128 =>
{
false
}
Some(_) => true,
},
);
} }