use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use parking_lot::RwLock;
use tracing::{debug, error, trace, warn};
#[cfg(feature = "distributed-rate-limit")]
use redis::aio::ConnectionManager;
use zentinel_config::RedisBackendConfig;
use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
#[derive(Debug, Default)]
pub struct DistributedRateLimitStats {
pub total_checks: AtomicU64,
pub allowed: AtomicU64,
pub limited: AtomicU64,
pub redis_errors: AtomicU64,
pub local_fallbacks: AtomicU64,
}
impl DistributedRateLimitStats {
pub fn record_check(&self, outcome: RateLimitOutcome) {
self.total_checks.fetch_add(1, Ordering::Relaxed);
match outcome {
RateLimitOutcome::Allowed => {
self.allowed.fetch_add(1, Ordering::Relaxed);
}
RateLimitOutcome::Limited => {
self.limited.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn record_redis_error(&self) {
self.redis_errors.fetch_add(1, Ordering::Relaxed);
}
pub fn record_local_fallback(&self) {
self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
}
}
#[cfg(feature = "distributed-rate-limit")]
pub struct RedisRateLimiter {
connection: ConnectionManager,
config: RwLock<RedisConfig>,
healthy: AtomicBool,
pub stats: Arc<DistributedRateLimitStats>,
}
#[cfg(feature = "distributed-rate-limit")]
#[derive(Debug, Clone)]
struct RedisConfig {
key_prefix: String,
max_rps: u32,
window_secs: u64,
timeout: Duration,
fallback_local: bool,
}
#[cfg(feature = "distributed-rate-limit")]
impl RedisRateLimiter {
pub async fn new(
backend_config: &RedisBackendConfig,
rate_config: &RateLimitConfig,
) -> Result<Self, redis::RedisError> {
let client = redis::Client::open(backend_config.url.as_str())?;
let connection = ConnectionManager::new(client).await?;
debug!(
url = %backend_config.url,
prefix = %backend_config.key_prefix,
max_rps = rate_config.max_rps,
"Redis rate limiter initialized"
);
Ok(Self {
connection,
config: RwLock::new(RedisConfig {
key_prefix: backend_config.key_prefix.clone(),
max_rps: rate_config.max_rps,
window_secs: 1,
timeout: Duration::from_millis(backend_config.timeout_ms),
fallback_local: backend_config.fallback_local,
}),
healthy: AtomicBool::new(true),
stats: Arc::new(DistributedRateLimitStats::default()),
})
}
pub async fn check(&self, key: &str) -> Result<(RateLimitOutcome, i64), redis::RedisError> {
let config = self.config.read().clone();
let full_key = format!("{}{}", config.key_prefix, key);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as f64;
let window_start = now - (config.window_secs as f64 * 1000.0);
let mut conn = self.connection.clone();
let result: Result<(i64,), _> = tokio::time::timeout(config.timeout, async {
redis::pipe()
.atomic()
.zrembyscore(&full_key, 0.0, window_start)
.ignore()
.zadd(&full_key, now, now.to_string())
.ignore()
.expire(&full_key, (config.window_secs * 2) as i64)
.ignore()
.zcount(&full_key, window_start, now)
.query_async(&mut conn)
.await
})
.await
.map_err(|_| {
redis::RedisError::from((redis::ErrorKind::Io, "Redis operation timed out"))
})?;
let (count,) = result?;
self.healthy.store(true, Ordering::Relaxed);
let outcome = if count > config.max_rps as i64 {
RateLimitOutcome::Limited
} else {
RateLimitOutcome::Allowed
};
trace!(
key = key,
count = count,
max_rps = config.max_rps,
outcome = ?outcome,
"Redis rate limit check"
);
self.stats.record_check(outcome);
Ok((outcome, count))
}
pub fn update_config(
&self,
backend_config: &RedisBackendConfig,
rate_config: &RateLimitConfig,
) {
let mut config = self.config.write();
config.key_prefix = backend_config.key_prefix.clone();
config.max_rps = rate_config.max_rps;
config.timeout = Duration::from_millis(backend_config.timeout_ms);
config.fallback_local = backend_config.fallback_local;
}
pub fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
pub fn mark_unhealthy(&self) {
self.healthy.store(false, Ordering::Relaxed);
self.stats.record_redis_error();
}
pub fn fallback_enabled(&self) -> bool {
self.config.read().fallback_local
}
}
#[cfg(not(feature = "distributed-rate-limit"))]
pub struct RedisRateLimiter;
#[cfg(not(feature = "distributed-rate-limit"))]
impl RedisRateLimiter {
pub async fn new(
_backend_config: &RedisBackendConfig,
_rate_config: &RateLimitConfig,
) -> Result<Self, String> {
Err("Distributed rate limiting requires the 'distributed-rate-limit' feature".to_string())
}
}
#[cfg(feature = "distributed-rate-limit")]
pub async fn create_redis_rate_limiter(
backend_config: &RedisBackendConfig,
rate_config: &RateLimitConfig,
) -> Option<RedisRateLimiter> {
match RedisRateLimiter::new(backend_config, rate_config).await {
Ok(limiter) => {
debug!(
url = %backend_config.url,
"Redis rate limiter created successfully"
);
Some(limiter)
}
Err(e) => {
error!(
error = %e,
url = %backend_config.url,
"Failed to create Redis rate limiter"
);
if backend_config.fallback_local {
warn!("Falling back to local rate limiting");
}
None
}
}
}
#[cfg(not(feature = "distributed-rate-limit"))]
pub async fn create_redis_rate_limiter(
_backend_config: &RedisBackendConfig,
_rate_config: &RateLimitConfig,
) -> Option<RedisRateLimiter> {
warn!(
"Distributed rate limiting requested but feature is disabled. Using local rate limiting."
);
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stats_recording() {
let stats = DistributedRateLimitStats::default();
stats.record_check(RateLimitOutcome::Allowed);
stats.record_check(RateLimitOutcome::Allowed);
stats.record_check(RateLimitOutcome::Limited);
assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
}
#[test]
fn test_stats_redis_errors() {
let stats = DistributedRateLimitStats::default();
stats.record_redis_error();
stats.record_redis_error();
stats.record_local_fallback();
assert_eq!(stats.redis_errors.load(Ordering::Relaxed), 2);
assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
}
}