#![allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss,
clippy::cast_possible_wrap
)]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub rate: f64,
pub burst: Option<u32>,
pub sliding_window: bool,
pub window_size: u64,
}
impl RateLimitConfig {
#[must_use]
pub fn new(rate: f64) -> Self {
Self {
rate,
burst: None,
sliding_window: false,
window_size: 1,
}
}
#[must_use]
pub fn with_burst(mut self, burst: u32) -> Self {
self.burst = Some(burst);
self
}
#[must_use]
pub fn with_sliding_window(mut self, window_size: u64) -> Self {
self.sliding_window = true;
self.window_size = window_size;
self
}
#[must_use]
#[inline]
pub fn effective_burst(&self) -> u32 {
self.burst.unwrap_or(self.rate.ceil() as u32)
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
rate: 100.0, burst: None,
sliding_window: false,
window_size: 1,
}
}
}
pub trait RateLimiter: Send + Sync {
fn try_acquire(&mut self) -> bool;
fn acquire(&mut self) -> Duration;
fn time_until_available(&self) -> Duration;
fn available_permits(&self) -> u32;
fn reset(&mut self);
fn set_rate(&mut self, rate: f64);
fn config(&self) -> &RateLimitConfig;
}
#[derive(Debug)]
pub struct TokenBucket {
config: RateLimitConfig,
tokens: f64,
last_refill: Instant,
}
impl TokenBucket {
#[must_use]
pub fn new(config: RateLimitConfig) -> Self {
let tokens = f64::from(config.effective_burst());
Self {
config,
tokens,
last_refill: Instant::now(),
}
}
#[must_use]
pub fn with_rate(rate: f64) -> Self {
Self::new(RateLimitConfig::new(rate))
}
#[inline]
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let new_tokens = elapsed.as_secs_f64() * self.config.rate;
let max_tokens = f64::from(self.config.effective_burst());
self.tokens = (self.tokens + new_tokens).min(max_tokens);
self.last_refill = now;
}
}
impl RateLimiter for TokenBucket {
fn try_acquire(&mut self) -> bool {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
fn acquire(&mut self) -> Duration {
let start = Instant::now();
while !self.try_acquire() {
let wait_time = self.time_until_available();
if wait_time > Duration::ZERO {
std::thread::sleep(wait_time);
}
}
start.elapsed()
}
fn time_until_available(&self) -> Duration {
if self.tokens >= 1.0 {
Duration::ZERO
} else {
let tokens_needed = 1.0 - self.tokens;
let seconds = tokens_needed / self.config.rate;
Duration::from_secs_f64(seconds)
}
}
fn available_permits(&self) -> u32 {
self.tokens.floor() as u32
}
fn reset(&mut self) {
self.tokens = f64::from(self.config.effective_burst());
self.last_refill = Instant::now();
}
fn set_rate(&mut self, rate: f64) {
self.config.rate = rate;
}
fn config(&self) -> &RateLimitConfig {
&self.config
}
}
#[derive(Debug)]
pub struct SlidingWindow {
config: RateLimitConfig,
timestamps: Vec<Instant>,
}
impl SlidingWindow {
#[must_use]
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
timestamps: Vec::new(),
}
}
#[must_use]
pub fn with_rate(rate: f64, window_size: u64) -> Self {
let config = RateLimitConfig::new(rate).with_sliding_window(window_size);
Self::new(config)
}
#[inline]
fn cleanup(&mut self) {
let window = Duration::from_secs(self.config.window_size);
let cutoff = Instant::now()
.checked_sub(window)
.expect("window duration should be valid for subtraction");
self.timestamps.retain(|&t| t > cutoff);
}
#[inline]
fn max_executions(&self) -> usize {
(self.config.rate * self.config.window_size as f64).ceil() as usize
}
}
impl RateLimiter for SlidingWindow {
fn try_acquire(&mut self) -> bool {
self.cleanup();
if self.timestamps.len() < self.max_executions() {
self.timestamps.push(Instant::now());
true
} else {
false
}
}
fn acquire(&mut self) -> Duration {
let start = Instant::now();
while !self.try_acquire() {
let wait_time = self.time_until_available();
if wait_time > Duration::ZERO {
std::thread::sleep(wait_time);
}
}
start.elapsed()
}
fn time_until_available(&self) -> Duration {
if self.timestamps.len() < self.max_executions() {
Duration::ZERO
} else if let Some(&oldest) = self.timestamps.first() {
let window = Duration::from_secs(self.config.window_size);
let expires = oldest + window;
let now = Instant::now();
if expires > now {
expires - now
} else {
Duration::ZERO
}
} else {
Duration::ZERO
}
}
fn available_permits(&self) -> u32 {
let max = self.max_executions();
let current = self.timestamps.len();
(max.saturating_sub(current)) as u32
}
fn reset(&mut self) {
self.timestamps.clear();
}
fn set_rate(&mut self, rate: f64) {
self.config.rate = rate;
}
fn config(&self) -> &RateLimitConfig {
&self.config
}
}
#[derive(Debug)]
pub struct TaskRateLimiter {
limiters: HashMap<String, TokenBucket>,
default_config: Option<RateLimitConfig>,
}
impl TaskRateLimiter {
#[must_use]
pub fn new() -> Self {
Self {
limiters: HashMap::new(),
default_config: None,
}
}
#[must_use]
pub fn with_default(config: RateLimitConfig) -> Self {
Self {
limiters: HashMap::new(),
default_config: Some(config),
}
}
pub fn set_task_rate(&mut self, task_name: impl Into<String>, config: RateLimitConfig) {
let name = task_name.into();
self.limiters.insert(name, TokenBucket::new(config));
}
pub fn remove_task_rate(&mut self, task_name: &str) {
self.limiters.remove(task_name);
}
pub fn try_acquire(&mut self, task_name: &str) -> bool {
if let Some(limiter) = self.limiters.get_mut(task_name) {
limiter.try_acquire()
} else if let Some(ref config) = self.default_config {
let mut limiter = TokenBucket::new(config.clone());
let result = limiter.try_acquire();
self.limiters.insert(task_name.to_string(), limiter);
result
} else {
true
}
}
#[must_use]
pub fn time_until_available(&self, task_name: &str) -> Duration {
if let Some(limiter) = self.limiters.get(task_name) {
limiter.time_until_available()
} else {
Duration::ZERO
}
}
#[inline]
#[must_use]
pub fn has_rate_limit(&self, task_name: &str) -> bool {
self.limiters.contains_key(task_name) || self.default_config.is_some()
}
#[inline]
pub fn get_rate_limit(&self, task_name: &str) -> Option<&RateLimitConfig> {
self.limiters
.get(task_name)
.map(RateLimiter::config)
.or(self.default_config.as_ref())
}
pub fn reset_all(&mut self) {
for limiter in self.limiters.values_mut() {
limiter.reset();
}
}
}
impl Default for TaskRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct WorkerRateLimiter {
inner: Arc<RwLock<TaskRateLimiter>>,
}
impl WorkerRateLimiter {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(TaskRateLimiter::new())),
}
}
#[must_use]
pub fn with_default(config: RateLimitConfig) -> Self {
Self {
inner: Arc::new(RwLock::new(TaskRateLimiter::with_default(config))),
}
}
pub fn set_task_rate(&self, task_name: impl Into<String>, config: RateLimitConfig) {
if let Ok(mut guard) = self.inner.write() {
guard.set_task_rate(task_name, config);
}
}
pub fn remove_task_rate(&self, task_name: &str) {
if let Ok(mut guard) = self.inner.write() {
guard.remove_task_rate(task_name);
}
}
#[must_use]
pub fn try_acquire(&self, task_name: &str) -> bool {
if let Ok(mut guard) = self.inner.write() {
guard.try_acquire(task_name)
} else {
true
}
}
#[must_use]
pub fn time_until_available(&self, task_name: &str) -> Duration {
if let Ok(guard) = self.inner.read() {
guard.time_until_available(task_name)
} else {
Duration::ZERO
}
}
#[inline]
#[must_use]
pub fn has_rate_limit(&self, task_name: &str) -> bool {
if let Ok(guard) = self.inner.read() {
guard.has_rate_limit(task_name)
} else {
false
}
}
pub fn reset_all(&self) {
if let Ok(mut guard) = self.inner.write() {
guard.reset_all();
}
}
}
impl Default for WorkerRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn create_rate_limiter(config: RateLimitConfig) -> Box<dyn RateLimiter> {
if config.sliding_window {
Box::new(SlidingWindow::new(config))
} else {
Box::new(TokenBucket::new(config))
}
}
use async_trait::async_trait;
#[async_trait]
pub trait DistributedRateLimiter: Send + Sync {
async fn try_acquire(&self) -> crate::Result<bool>;
async fn time_until_available(&self) -> crate::Result<Duration>;
async fn available_permits(&self) -> crate::Result<u32>;
async fn reset(&self) -> crate::Result<()>;
async fn set_rate(&self, rate: f64) -> crate::Result<()>;
fn config(&self) -> &RateLimitConfig;
fn backend_name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct DistributedRateLimiterState {
pub key: String,
pub config: RateLimitConfig,
pub fallback: Arc<RwLock<TokenBucket>>,
}
impl DistributedRateLimiterState {
#[must_use]
pub fn new(key: String, config: RateLimitConfig) -> Self {
let fallback = Arc::new(RwLock::new(TokenBucket::new(config.clone())));
Self {
key,
config,
fallback,
}
}
#[inline]
#[must_use]
pub fn token_key(&self) -> String {
format!("{}:tokens", self.key)
}
#[inline]
#[must_use]
pub fn refill_key(&self) -> String {
format!("{}:refill", self.key)
}
#[inline]
#[must_use]
pub fn window_key(&self) -> String {
format!("{}:window", self.key)
}
fn try_acquire_fallback(&self) -> bool {
if let Ok(mut guard) = self.fallback.write() {
guard.try_acquire()
} else {
true
}
}
}
#[derive(Debug, Clone)]
pub struct DistributedTokenBucketSpec {
state: DistributedRateLimiterState,
}
impl DistributedTokenBucketSpec {
#[must_use]
pub fn new(key: String, config: RateLimitConfig) -> Self {
Self {
state: DistributedRateLimiterState::new(key, config),
}
}
#[must_use]
pub fn lua_acquire_script() -> &'static str {
r"
local tokens_key = KEYS[1]
local refill_key = KEYS[2]
local rate = tonumber(ARGV[1])
local burst = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local ttl = tonumber(ARGV[4])
local last_refill = redis.call('GET', refill_key)
local tokens = redis.call('GET', tokens_key)
if not tokens then
tokens = burst
else
tokens = tonumber(tokens)
end
if last_refill then
local elapsed = (now - tonumber(last_refill)) / 1000.0
tokens = math.min(tokens + elapsed * rate, burst)
end
if tokens >= 1.0 then
tokens = tokens - 1.0
redis.call('SET', tokens_key, tostring(tokens), 'EX', ttl)
redis.call('SET', refill_key, tostring(now), 'EX', ttl)
return {1, tokens}
else
redis.call('SET', tokens_key, tostring(tokens), 'EX', ttl)
redis.call('SET', refill_key, tostring(now), 'EX', ttl)
return {0, tokens}
end
"
}
#[must_use]
pub fn lua_available_script() -> &'static str {
r"
local tokens_key = KEYS[1]
local refill_key = KEYS[2]
local rate = tonumber(ARGV[1])
local burst = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local last_refill = redis.call('GET', refill_key)
local tokens = redis.call('GET', tokens_key)
if not tokens then
return burst
else
tokens = tonumber(tokens)
end
if last_refill then
local elapsed = (now - tonumber(last_refill)) / 1000.0
tokens = math.min(tokens + elapsed * rate, burst)
end
return math.floor(tokens)
"
}
#[inline]
#[must_use]
pub fn state(&self) -> &DistributedRateLimiterState {
&self.state
}
#[must_use]
pub fn try_acquire_fallback(&self) -> bool {
self.state.try_acquire_fallback()
}
}
#[derive(Debug, Clone)]
pub struct DistributedSlidingWindowSpec {
state: DistributedRateLimiterState,
}
impl DistributedSlidingWindowSpec {
#[must_use]
pub fn new(key: String, config: RateLimitConfig) -> Self {
Self {
state: DistributedRateLimiterState::new(key, config),
}
}
#[must_use]
pub fn lua_acquire_script() -> &'static str {
r"
local window_key = KEYS[1]
local now = tonumber(ARGV[1])
local window_size = tonumber(ARGV[2])
local max_count = tonumber(ARGV[3])
local uuid = ARGV[4]
local cutoff = now - window_size * 1000
redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
local count = redis.call('ZCARD', window_key)
if count < max_count then
redis.call('ZADD', window_key, now, uuid)
redis.call('EXPIRE', window_key, window_size * 2)
return {1, max_count - count - 1}
else
return {0, 0}
end
"
}
#[must_use]
pub fn lua_available_script() -> &'static str {
r"
local window_key = KEYS[1]
local now = tonumber(ARGV[1])
local window_size = tonumber(ARGV[2])
local max_count = tonumber(ARGV[3])
local cutoff = now - window_size * 1000
redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
local count = redis.call('ZCARD', window_key)
return math.max(0, max_count - count)
"
}
#[must_use]
pub fn lua_time_until_script() -> &'static str {
r"
local window_key = KEYS[1]
local now = tonumber(ARGV[1])
local window_size = tonumber(ARGV[2])
local max_count = tonumber(ARGV[3])
local cutoff = now - window_size * 1000
redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
local count = redis.call('ZCARD', window_key)
if count < max_count then
return 0
else
local oldest = redis.call('ZRANGE', window_key, 0, 0, 'WITHSCORES')
if #oldest >= 2 then
local oldest_timestamp = tonumber(oldest[2])
local expires = oldest_timestamp + window_size * 1000
return math.max(0, expires - now)
else
return 0
end
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_precision_loss)]
"
}
#[must_use]
#[inline]
pub fn max_executions(&self) -> usize {
(self.state.config.rate * self.state.config.window_size as f64).ceil() as usize
}
#[inline]
#[must_use]
pub fn state(&self) -> &DistributedRateLimiterState {
&self.state
}
#[must_use]
pub fn try_acquire_fallback(&self) -> bool {
self.state.try_acquire_fallback()
}
}
#[derive(Debug, Clone)]
pub struct DistributedRateLimiterCoordinator {
namespace: String,
token_buckets: Arc<RwLock<HashMap<String, DistributedTokenBucketSpec>>>,
sliding_windows: Arc<RwLock<HashMap<String, DistributedSlidingWindowSpec>>>,
default_config: Option<RateLimitConfig>,
}
impl DistributedRateLimiterCoordinator {
pub fn new(namespace: impl Into<String>) -> Self {
Self {
namespace: namespace.into(),
token_buckets: Arc::new(RwLock::new(HashMap::new())),
sliding_windows: Arc::new(RwLock::new(HashMap::new())),
default_config: None,
}
}
pub fn with_default(namespace: impl Into<String>, config: RateLimitConfig) -> Self {
Self {
namespace: namespace.into(),
token_buckets: Arc::new(RwLock::new(HashMap::new())),
sliding_windows: Arc::new(RwLock::new(HashMap::new())),
default_config: Some(config),
}
}
pub fn set_task_rate(&self, task_name: impl Into<String>, config: RateLimitConfig) {
let name = task_name.into();
let key = format!("{}:ratelimit:{}", self.namespace, name);
if config.sliding_window {
if let Ok(mut guard) = self.sliding_windows.write() {
guard.insert(name.clone(), DistributedSlidingWindowSpec::new(key, config));
}
} else if let Ok(mut guard) = self.token_buckets.write() {
guard.insert(name.clone(), DistributedTokenBucketSpec::new(key, config));
}
}
pub fn remove_task_rate(&self, task_name: &str) {
if let Ok(mut guard) = self.token_buckets.write() {
guard.remove(task_name);
}
if let Ok(mut guard) = self.sliding_windows.write() {
guard.remove(task_name);
}
}
#[inline]
#[must_use]
pub fn get_token_bucket_spec(&self, task_name: &str) -> Option<DistributedTokenBucketSpec> {
if let Ok(guard) = self.token_buckets.read() {
guard.get(task_name).cloned()
} else {
None
}
}
#[inline]
#[must_use]
pub fn get_sliding_window_spec(&self, task_name: &str) -> Option<DistributedSlidingWindowSpec> {
if let Ok(guard) = self.sliding_windows.read() {
guard.get(task_name).cloned()
} else {
None
}
}
#[inline]
#[must_use]
pub fn has_rate_limit(&self, task_name: &str) -> bool {
let has_bucket = if let Ok(guard) = self.token_buckets.read() {
guard.contains_key(task_name)
} else {
false
};
let has_window = if let Ok(guard) = self.sliding_windows.read() {
guard.contains_key(task_name)
} else {
false
};
has_bucket || has_window || self.default_config.is_some()
}
#[must_use]
pub fn try_acquire_fallback(&self, task_name: &str) -> bool {
if let Some(spec) = self.get_token_bucket_spec(task_name) {
return spec.try_acquire_fallback();
}
if let Some(spec) = self.get_sliding_window_spec(task_name) {
return spec.try_acquire_fallback();
}
if let Some(ref config) = self.default_config {
let key = format!("{}:ratelimit:{}", self.namespace, task_name);
let spec = DistributedTokenBucketSpec::new(key, config.clone());
return spec.try_acquire_fallback();
}
true
}
#[must_use]
pub fn redis_key(&self, task_name: &str) -> String {
format!("{}:ratelimit:{}", self.namespace, task_name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_token_bucket_basic() {
let config = RateLimitConfig::new(10.0).with_burst(5);
let mut limiter = TokenBucket::new(config);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
fn test_token_bucket_refill() {
let config = RateLimitConfig::new(100.0).with_burst(10);
let mut limiter = TokenBucket::new(config);
for _ in 0..10 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
thread::sleep(Duration::from_millis(15));
assert!(limiter.try_acquire());
}
#[test]
fn test_sliding_window_basic() {
let config = RateLimitConfig::new(5.0).with_sliding_window(1);
let mut limiter = SlidingWindow::new(config);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
fn test_task_rate_limiter() {
let mut manager = TaskRateLimiter::new();
manager.set_task_rate("task_a", RateLimitConfig::new(10.0).with_burst(2));
assert!(manager.try_acquire("task_a"));
assert!(manager.try_acquire("task_a"));
assert!(!manager.try_acquire("task_a"));
assert!(manager.try_acquire("task_b"));
assert!(manager.try_acquire("task_b"));
assert!(manager.try_acquire("task_b"));
}
#[test]
fn test_task_rate_limiter_default() {
let mut manager = TaskRateLimiter::with_default(RateLimitConfig::new(10.0).with_burst(2));
assert!(manager.try_acquire("task_a"));
assert!(manager.try_acquire("task_a"));
assert!(!manager.try_acquire("task_a"));
assert!(manager.try_acquire("task_b"));
assert!(manager.try_acquire("task_b"));
assert!(!manager.try_acquire("task_b"));
}
#[test]
fn test_worker_rate_limiter_thread_safe() {
let limiter = WorkerRateLimiter::new();
limiter.set_task_rate("task_a", RateLimitConfig::new(0.1).with_burst(10));
let limiter_clone = limiter.clone();
let handles: Vec<_> = (0..4)
.map(|_| {
let l = limiter_clone.clone();
thread::spawn(move || {
let mut count = 0;
for _ in 0..5 {
if l.try_acquire("task_a") {
count += 1;
}
}
count
})
})
.collect();
let total: usize = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(total <= 10);
}
#[test]
fn test_rate_limit_config_serialization() {
let config = RateLimitConfig::new(50.0)
.with_burst(100)
.with_sliding_window(10);
let json = serde_json::to_string(&config).unwrap();
let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
assert!((parsed.rate - 50.0).abs() < f64::EPSILON);
assert_eq!(parsed.burst, Some(100));
assert!(parsed.sliding_window);
assert_eq!(parsed.window_size, 10);
}
#[test]
fn test_time_until_available() {
let config = RateLimitConfig::new(10.0).with_burst(1);
let mut limiter = TokenBucket::new(config);
assert!(limiter.try_acquire());
let wait_time = limiter.time_until_available();
assert!(wait_time > Duration::ZERO);
assert!(wait_time <= Duration::from_millis(150));
}
#[test]
fn test_reset() {
let config = RateLimitConfig::new(10.0).with_burst(5);
let mut limiter = TokenBucket::new(config);
for _ in 0..5 {
limiter.try_acquire();
}
assert!(!limiter.try_acquire());
limiter.reset();
assert!(limiter.try_acquire());
}
#[test]
fn test_set_rate() {
let config = RateLimitConfig::new(10.0).with_burst(10);
let mut limiter = TokenBucket::new(config);
limiter.set_rate(100.0);
assert!((limiter.config().rate - 100.0).abs() < f64::EPSILON);
}
#[test]
fn test_create_rate_limiter() {
let config = RateLimitConfig::new(10.0);
let mut limiter = create_rate_limiter(config);
assert!(limiter.try_acquire());
let config = RateLimitConfig::new(10.0).with_sliding_window(1);
let mut limiter = create_rate_limiter(config);
assert!(limiter.try_acquire());
}
}