use dashmap::DashMap;
use parking_lot::RwLock;
use pingora_limits::rate::Rate;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::{debug, trace, warn};
use zentinel_config::{RateLimitAction, RateLimitBackend, RateLimitKey};
#[cfg(feature = "distributed-rate-limit")]
use crate::distributed_rate_limit::{create_redis_rate_limiter, RedisRateLimiter};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitOutcome {
Allowed,
Limited,
}
#[derive(Debug, Clone)]
pub struct RateLimitCheckInfo {
pub outcome: RateLimitOutcome,
pub current_count: i64,
pub limit: u32,
pub remaining: u32,
pub reset_at: u64,
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_rps: u32,
pub burst: u32,
pub key: RateLimitKey,
pub action: RateLimitAction,
pub status_code: u16,
pub message: Option<String>,
pub backend: RateLimitBackend,
pub max_delay_ms: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_rps: 100,
burst: 10,
key: RateLimitKey::ClientIp,
action: RateLimitAction::Reject,
status_code: 429,
message: None,
backend: RateLimitBackend::Local,
max_delay_ms: 5000,
}
}
}
struct KeyRateLimiter {
rate: Rate,
max_requests: isize,
}
impl KeyRateLimiter {
fn new(max_rps: u32) -> Self {
Self {
rate: Rate::new(Duration::from_secs(1)),
max_requests: max_rps as isize,
}
}
fn check(&self) -> RateLimitOutcome {
let curr_count = self.rate.observe(&(), 1);
if curr_count > self.max_requests {
RateLimitOutcome::Limited
} else {
RateLimitOutcome::Allowed
}
}
}
pub enum RateLimitBackendType {
Local {
limiters: DashMap<String, Arc<KeyRateLimiter>>,
},
#[cfg(feature = "distributed-rate-limit")]
Distributed {
redis: Arc<RedisRateLimiter>,
local_fallback: DashMap<String, Arc<KeyRateLimiter>>,
},
}
pub struct RateLimiterPool {
backend: RateLimitBackendType,
config: RwLock<RateLimitConfig>,
}
fn current_unix_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_secs()
}
fn calculate_reset_timestamp() -> u64 {
current_unix_timestamp() + 1
}
impl RateLimiterPool {
pub fn new(config: RateLimitConfig) -> Self {
Self {
backend: RateLimitBackendType::Local {
limiters: DashMap::new(),
},
config: RwLock::new(config),
}
}
#[cfg(feature = "distributed-rate-limit")]
pub fn with_redis(config: RateLimitConfig, redis: Arc<RedisRateLimiter>) -> Self {
Self {
backend: RateLimitBackendType::Distributed {
redis,
local_fallback: DashMap::new(),
},
config: RwLock::new(config),
}
}
pub fn check(&self, key: &str) -> RateLimitCheckInfo {
let config = self.config.read();
let max_rps = config.max_rps;
drop(config);
let limiters = match &self.backend {
RateLimitBackendType::Local { limiters } => limiters,
#[cfg(feature = "distributed-rate-limit")]
RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback,
};
let limiter = limiters
.entry(key.to_string())
.or_insert_with(|| Arc::new(KeyRateLimiter::new(max_rps)))
.clone();
let outcome = limiter.check();
let count = limiter.rate.observe(&(), 0); let remaining = if count >= max_rps as isize {
0
} else {
(max_rps as isize - count) as u32
};
RateLimitCheckInfo {
outcome,
current_count: count as i64,
limit: max_rps,
remaining,
reset_at: calculate_reset_timestamp(),
}
}
#[cfg(feature = "distributed-rate-limit")]
pub async fn check_async(&self, key: &str) -> RateLimitCheckInfo {
let max_rps = self.config.read().max_rps;
match &self.backend {
RateLimitBackendType::Local { .. } => self.check(key),
RateLimitBackendType::Distributed {
redis,
local_fallback,
} => {
match redis.check(key).await {
Ok((outcome, count)) => {
let remaining = if count >= max_rps as i64 {
0
} else {
(max_rps as i64 - count) as u32
};
RateLimitCheckInfo {
outcome,
current_count: count,
limit: max_rps,
remaining,
reset_at: calculate_reset_timestamp(),
}
}
Err(e) => {
warn!(
error = %e,
key = key,
"Redis rate limit check failed, falling back to local"
);
redis.mark_unhealthy();
if redis.fallback_enabled() {
let limiter = local_fallback
.entry(key.to_string())
.or_insert_with(|| Arc::new(KeyRateLimiter::new(max_rps)))
.clone();
let outcome = limiter.check();
let count = limiter.rate.observe(&(), 0);
let remaining = if count >= max_rps as isize {
0
} else {
(max_rps as isize - count) as u32
};
RateLimitCheckInfo {
outcome,
current_count: count as i64,
limit: max_rps,
remaining,
reset_at: calculate_reset_timestamp(),
}
} else {
RateLimitCheckInfo {
outcome: RateLimitOutcome::Allowed,
current_count: 0,
limit: max_rps,
remaining: max_rps,
reset_at: calculate_reset_timestamp(),
}
}
}
}
}
}
}
pub fn is_distributed(&self) -> bool {
match &self.backend {
RateLimitBackendType::Local { .. } => false,
#[cfg(feature = "distributed-rate-limit")]
RateLimitBackendType::Distributed { .. } => true,
}
}
pub fn extract_key(
&self,
client_ip: &str,
path: &str,
route_id: &str,
headers: Option<&impl HeaderAccessor>,
) -> String {
let config = self.config.read();
match &config.key {
RateLimitKey::ClientIp => client_ip.to_string(),
RateLimitKey::Path => path.to_string(),
RateLimitKey::Route => route_id.to_string(),
RateLimitKey::ClientIpAndPath => format!("{}:{}", client_ip, path),
RateLimitKey::Header(header_name) => headers
.and_then(|h| h.get_header(header_name))
.unwrap_or_else(|| "unknown".to_string()),
}
}
pub fn action(&self) -> RateLimitAction {
self.config.read().action.clone()
}
pub fn status_code(&self) -> u16 {
self.config.read().status_code
}
pub fn message(&self) -> Option<String> {
self.config.read().message.clone()
}
pub fn max_delay_ms(&self) -> u64 {
self.config.read().max_delay_ms
}
pub fn update_config(&self, config: RateLimitConfig) {
*self.config.write() = config;
self.clear_local_limiters();
}
fn clear_local_limiters(&self) {
match &self.backend {
RateLimitBackendType::Local { limiters } => limiters.clear(),
#[cfg(feature = "distributed-rate-limit")]
RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback.clear(),
}
}
fn local_limiter_count(&self) -> usize {
match &self.backend {
RateLimitBackendType::Local { limiters } => limiters.len(),
#[cfg(feature = "distributed-rate-limit")]
RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback.len(),
}
}
pub fn cleanup(&self) {
let max_entries = 100_000;
let limiters = match &self.backend {
RateLimitBackendType::Local { limiters } => limiters,
#[cfg(feature = "distributed-rate-limit")]
RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback,
};
if limiters.len() > max_entries {
let to_remove: Vec<_> = limiters
.iter()
.take(max_entries / 2)
.map(|e| e.key().clone())
.collect();
for key in to_remove {
limiters.remove(&key);
}
debug!(
entries_before = max_entries,
entries_after = limiters.len(),
"Rate limiter pool cleanup completed"
);
}
}
}
pub trait HeaderAccessor {
fn get_header(&self, name: &str) -> Option<String>;
}
pub struct RateLimitManager {
route_limiters: DashMap<String, Arc<RateLimiterPool>>,
global_limiter: Option<Arc<RateLimiterPool>>,
}
impl RateLimitManager {
pub fn new() -> Self {
Self {
route_limiters: DashMap::new(),
global_limiter: None,
}
}
pub fn with_global_limit(max_rps: u32, burst: u32) -> Self {
let config = RateLimitConfig {
max_rps,
burst,
key: RateLimitKey::ClientIp,
action: RateLimitAction::Reject,
status_code: 429,
message: None,
backend: RateLimitBackend::Local,
max_delay_ms: 5000,
};
Self {
route_limiters: DashMap::new(),
global_limiter: Some(Arc::new(RateLimiterPool::new(config))),
}
}
pub fn register_route(&self, route_id: &str, config: RateLimitConfig) {
trace!(
route_id = route_id,
max_rps = config.max_rps,
burst = config.burst,
key = ?config.key,
"Registering rate limiter for route"
);
self.route_limiters
.insert(route_id.to_string(), Arc::new(RateLimiterPool::new(config)));
}
pub fn check(
&self,
route_id: &str,
client_ip: &str,
path: &str,
headers: Option<&impl HeaderAccessor>,
) -> RateLimitResult {
let mut best_limit_info: Option<RateLimitCheckInfo> = None;
if let Some(ref global) = self.global_limiter {
let key = global.extract_key(client_ip, path, route_id, headers);
let check_info = global.check(&key);
if check_info.outcome == RateLimitOutcome::Limited {
warn!(
route_id = route_id,
client_ip = client_ip,
key = key,
count = check_info.current_count,
"Request rate limited by global limiter"
);
let suggested_delay_ms = if check_info.current_count > check_info.limit as i64 {
let excess = check_info.current_count - check_info.limit as i64;
Some((excess as u64 * 1000) / check_info.limit as u64)
} else {
None
};
return RateLimitResult {
allowed: false,
action: global.action(),
status_code: global.status_code(),
message: global.message(),
limiter: "global".to_string(),
limit: check_info.limit,
remaining: check_info.remaining,
reset_at: check_info.reset_at,
suggested_delay_ms,
max_delay_ms: global.max_delay_ms(),
};
}
best_limit_info = Some(check_info);
}
if let Some(pool) = self.route_limiters.get(route_id) {
let key = pool.extract_key(client_ip, path, route_id, headers);
let check_info = pool.check(&key);
if check_info.outcome == RateLimitOutcome::Limited {
warn!(
route_id = route_id,
client_ip = client_ip,
key = key,
count = check_info.current_count,
"Request rate limited by route limiter"
);
let suggested_delay_ms = if check_info.current_count > check_info.limit as i64 {
let excess = check_info.current_count - check_info.limit as i64;
Some((excess as u64 * 1000) / check_info.limit as u64)
} else {
None
};
return RateLimitResult {
allowed: false,
action: pool.action(),
status_code: pool.status_code(),
message: pool.message(),
limiter: route_id.to_string(),
limit: check_info.limit,
remaining: check_info.remaining,
reset_at: check_info.reset_at,
suggested_delay_ms,
max_delay_ms: pool.max_delay_ms(),
};
}
trace!(
route_id = route_id,
key = key,
count = check_info.current_count,
remaining = check_info.remaining,
"Request allowed by rate limiter"
);
if let Some(ref existing) = best_limit_info {
if check_info.remaining < existing.remaining {
best_limit_info = Some(check_info);
}
} else {
best_limit_info = Some(check_info);
}
}
let (limit, remaining, reset_at) = best_limit_info
.map(|info| (info.limit, info.remaining, info.reset_at))
.unwrap_or((0, 0, 0));
RateLimitResult {
allowed: true,
action: RateLimitAction::Reject,
status_code: 429,
message: None,
limiter: String::new(),
limit,
remaining,
reset_at,
suggested_delay_ms: None,
max_delay_ms: 5000, }
}
pub fn cleanup(&self) {
if let Some(ref global) = self.global_limiter {
global.cleanup();
}
for entry in self.route_limiters.iter() {
entry.value().cleanup();
}
}
pub fn route_count(&self) -> usize {
self.route_limiters.len()
}
#[inline]
pub fn is_enabled(&self) -> bool {
self.global_limiter.is_some() || !self.route_limiters.is_empty()
}
#[inline]
pub fn has_route_limiter(&self, route_id: &str) -> bool {
self.global_limiter.is_some() || self.route_limiters.contains_key(route_id)
}
}
impl Default for RateLimitManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RateLimitResult {
pub allowed: bool,
pub action: RateLimitAction,
pub status_code: u16,
pub message: Option<String>,
pub limiter: String,
pub limit: u32,
pub remaining: u32,
pub reset_at: u64,
pub suggested_delay_ms: Option<u64>,
pub max_delay_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_allows_under_limit() {
let config = RateLimitConfig {
max_rps: 10,
burst: 5,
key: RateLimitKey::ClientIp,
..Default::default()
};
let pool = RateLimiterPool::new(config);
for i in 0..10 {
let info = pool.check("127.0.0.1");
assert_eq!(info.outcome, RateLimitOutcome::Allowed);
assert_eq!(info.limit, 10);
assert_eq!(info.remaining, 10 - i - 1);
}
}
#[test]
fn test_rate_limiter_blocks_over_limit() {
let config = RateLimitConfig {
max_rps: 5,
burst: 2,
key: RateLimitKey::ClientIp,
..Default::default()
};
let pool = RateLimiterPool::new(config);
for _ in 0..5 {
let info = pool.check("127.0.0.1");
assert_eq!(info.outcome, RateLimitOutcome::Allowed);
}
let info = pool.check("127.0.0.1");
assert_eq!(info.outcome, RateLimitOutcome::Limited);
assert_eq!(info.remaining, 0);
}
#[test]
fn test_rate_limiter_separate_keys() {
let config = RateLimitConfig {
max_rps: 2,
burst: 1,
key: RateLimitKey::ClientIp,
..Default::default()
};
let pool = RateLimiterPool::new(config);
let info1 = pool.check("192.168.1.1");
let info2 = pool.check("192.168.1.2");
let info3 = pool.check("192.168.1.1");
let info4 = pool.check("192.168.1.2");
assert_eq!(info1.outcome, RateLimitOutcome::Allowed);
assert_eq!(info2.outcome, RateLimitOutcome::Allowed);
assert_eq!(info3.outcome, RateLimitOutcome::Allowed);
assert_eq!(info4.outcome, RateLimitOutcome::Allowed);
let info5 = pool.check("192.168.1.1");
let info6 = pool.check("192.168.1.2");
assert_eq!(info5.outcome, RateLimitOutcome::Limited);
assert_eq!(info6.outcome, RateLimitOutcome::Limited);
}
#[test]
fn test_rate_limit_info_fields() {
let config = RateLimitConfig {
max_rps: 5,
burst: 2,
key: RateLimitKey::ClientIp,
..Default::default()
};
let pool = RateLimiterPool::new(config);
let info = pool.check("10.0.0.1");
assert_eq!(info.limit, 5);
assert_eq!(info.remaining, 4); assert!(info.reset_at > 0);
assert_eq!(info.outcome, RateLimitOutcome::Allowed);
}
#[test]
fn test_rate_limit_manager() {
let manager = RateLimitManager::new();
manager.register_route(
"api",
RateLimitConfig {
max_rps: 5,
burst: 2,
key: RateLimitKey::ClientIp,
..Default::default()
},
);
let result = manager.check("web", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(result.allowed);
assert_eq!(result.limit, 0);
for i in 0..5 {
let result = manager.check("api", "127.0.0.1", "/api/test", Option::<&NoHeaders>::None);
assert!(result.allowed);
assert_eq!(result.limit, 5);
assert_eq!(result.remaining, 5 - i as u32 - 1);
}
let result = manager.check("api", "127.0.0.1", "/api/test", Option::<&NoHeaders>::None);
assert!(!result.allowed);
assert_eq!(result.status_code, 429);
assert_eq!(result.limit, 5);
assert_eq!(result.remaining, 0);
assert!(result.reset_at > 0);
}
#[test]
fn test_rate_limit_result_with_delay() {
let manager = RateLimitManager::new();
manager.register_route(
"api",
RateLimitConfig {
max_rps: 2,
burst: 1,
key: RateLimitKey::ClientIp,
..Default::default()
},
);
manager.check("api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
manager.check("api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
let result = manager.check("api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(!result.allowed);
assert!(result.suggested_delay_ms.is_some());
}
struct NoHeaders;
impl HeaderAccessor for NoHeaders {
fn get_header(&self, _name: &str) -> Option<String> {
None
}
}
#[test]
fn test_global_rate_limiter() {
let manager = RateLimitManager::with_global_limit(3, 1);
for i in 0..3 {
let result = manager.check("any-route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(result.allowed, "Request {} should be allowed", i);
assert_eq!(result.limit, 3);
assert_eq!(result.remaining, 3 - i as u32 - 1);
}
let result = manager.check(
"different-route",
"127.0.0.1",
"/",
Option::<&NoHeaders>::None,
);
assert!(!result.allowed);
assert_eq!(result.limiter, "global");
}
#[test]
fn test_global_and_route_limiters() {
let manager = RateLimitManager::with_global_limit(10, 5);
manager.register_route(
"strict-api",
RateLimitConfig {
max_rps: 2,
burst: 1,
key: RateLimitKey::ClientIp,
..Default::default()
},
);
let result1 = manager.check("strict-api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
let result2 = manager.check("strict-api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(result1.allowed);
assert!(result2.allowed);
let result3 = manager.check("strict-api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(!result3.allowed);
assert_eq!(result3.limiter, "strict-api");
let result4 = manager.check("other-route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(result4.allowed);
}
#[test]
fn test_suggested_delay_calculation() {
let manager = RateLimitManager::new();
manager.register_route(
"api",
RateLimitConfig {
max_rps: 10,
burst: 5,
key: RateLimitKey::ClientIp,
..Default::default()
},
);
for _ in 0..10 {
manager.check("api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
}
let result = manager.check("api", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(!result.allowed);
assert!(result.suggested_delay_ms.is_some());
let delay = result.suggested_delay_ms.unwrap();
assert!(delay > 0, "Delay should be positive");
assert!(delay <= 1000, "Delay should be reasonable");
}
#[test]
fn test_reset_timestamp_is_future() {
let config = RateLimitConfig {
max_rps: 5,
burst: 2,
key: RateLimitKey::ClientIp,
..Default::default()
};
let pool = RateLimiterPool::new(config);
let info = pool.check("10.0.0.1");
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
assert!(info.reset_at >= now, "Reset time should be >= now");
assert!(
info.reset_at <= now + 2,
"Reset time should be within 2 seconds"
);
}
#[test]
fn test_rate_limit_check_info_remaining_clamps_to_zero() {
let config = RateLimitConfig {
max_rps: 2,
burst: 1,
key: RateLimitKey::ClientIp,
..Default::default()
};
let pool = RateLimiterPool::new(config);
pool.check("10.0.0.1");
pool.check("10.0.0.1");
let info = pool.check("10.0.0.1");
assert_eq!(info.remaining, 0);
assert_eq!(info.outcome, RateLimitOutcome::Limited);
}
#[test]
fn test_rate_limit_result_fields() {
let manager = RateLimitManager::new();
manager.register_route(
"test",
RateLimitConfig {
max_rps: 1,
burst: 1,
key: RateLimitKey::ClientIp,
..Default::default()
},
);
let allowed_result = manager.check("test", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(allowed_result.allowed);
assert_eq!(allowed_result.limit, 1);
assert!(allowed_result.reset_at > 0);
let blocked_result = manager.check("test", "127.0.0.1", "/", Option::<&NoHeaders>::None);
assert!(!blocked_result.allowed);
assert_eq!(blocked_result.status_code, 429);
assert_eq!(blocked_result.remaining, 0);
}
#[test]
fn test_has_route_limiter() {
let manager = RateLimitManager::new();
assert!(!manager.has_route_limiter("test-route"));
manager.register_route(
"test-route",
RateLimitConfig {
max_rps: 10,
burst: 5,
key: RateLimitKey::ClientIp,
..Default::default()
},
);
assert!(manager.has_route_limiter("test-route"));
assert!(!manager.has_route_limiter("other-route"));
}
#[test]
fn test_global_limiter_is_enabled() {
let manager = RateLimitManager::with_global_limit(100, 50);
assert!(manager.is_enabled());
}
#[test]
fn test_is_enabled() {
let empty_manager = RateLimitManager::new();
assert!(!empty_manager.is_enabled());
let global_manager = RateLimitManager::with_global_limit(100, 50);
assert!(global_manager.is_enabled());
let route_manager = RateLimitManager::new();
route_manager.register_route(
"test",
RateLimitConfig {
max_rps: 10,
burst: 5,
key: RateLimitKey::ClientIp,
..Default::default()
},
);
assert!(route_manager.is_enabled());
}
}